diff --git a/src/finn/transformation/qonnx/infer_quant_avg_pool_2d.py b/src/finn/transformation/qonnx/infer_quant_avg_pool_2d.py
index 70d97946a38c3bab7e715a5896596584c5a67b7d..3de317b9e23bad70817b8626fc7085a631c2cdff 100644
--- a/src/finn/transformation/qonnx/infer_quant_avg_pool_2d.py
+++ b/src/finn/transformation/qonnx/infer_quant_avg_pool_2d.py
@@ -127,32 +127,28 @@ class AvgPoolAndTruncToQuantAvgPool(Transformation):
                                     f"except the first, must be statically "
                                     f"initialized. However, {inp} is not."
                                 )
-                        scale = model.get_initializer(t_node.input[1])
-                        if len(scale.shape) != 0:
-                            raise ValueError(
-                                f"Finn only supports scalar scales of zero dimension "
-                                f"for the Trunc node, it currently is {scale}."
-                            )
                         zero_pt = model.get_initializer(t_node.input[2])
                         if len(zero_pt.shape) != 0 or zero_pt != 0:
                             raise ValueError(
                                 f"Finn only supports 0 as the zero point for "
                                 f"the Trunc node, it currently is {zero_pt}."
                             )
-                        trunc_in_bits = model.get_initializer(t_node.input[3])
-                        trunc_out_bits = model.get_initializer(t_node.input[4])
+                        trunc_in_bits = model.get_initializer(t_node.input[3]).flatten()
+                        trunc_out_bits = model.get_initializer(
+                            t_node.input[4]
+                        ).flatten()
                         if (
-                            len(trunc_in_bits.shape) != 0
-                            or len(trunc_out_bits.shape) != 0
+                            len(trunc_in_bits.shape) != 1
+                            or len(trunc_out_bits.shape) != 1
                         ):
                             raise ValueError(
-                                f"Finn only supports scalar bit widths of zero "
-                                f"dimension for the Trunc node. The input bit width "
+                                f"Finn only supports scalar bit widths "
+                                f"for the Trunc node. The input bit width "
                                 f"currently is: {trunc_in_bits}, "
                                 f"while the output bit width is: {trunc_out_bits}."
                             )
-                        trunc_in_bits = int(trunc_in_bits)
-                        trunc_out_bits = int(trunc_out_bits)
+                        trunc_in_bits = int(trunc_in_bits[0])
+                        trunc_out_bits = int(trunc_out_bits[0])
 
                         # Calculate parameters for the QuantAvgPool2d node,
                         # Calculate input bit width. Basically this backwards:
@@ -170,6 +166,7 @@ class AvgPoolAndTruncToQuantAvgPool(Transformation):
                         data_layout = "NCHW"
 
                         # Insert scale nodes, QuantAvgPool2d node and required tensors
+                        scale = model.get_initializer(t_node.input[1])
                         scale_div_tensor = helper.make_tensor_value_info(
                             model.make_new_valueinfo_name(),
                             TensorProto.FLOAT,