From 8a6fb6ff00b975f5af0f4dbc493d974d337b381c Mon Sep 17 00:00:00 2001
From: Hendrik Borras <hendrikborras@web.de>
Date: Tue, 19 Oct 2021 19:17:15 +0100
Subject: [PATCH] Enabled non-scalar scales for the Trunc node and made bit
 width readout more robust.

---
 .../qonnx/infer_quant_avg_pool_2d.py          | 25 ++++++++-----------
 1 file changed, 11 insertions(+), 14 deletions(-)

diff --git a/src/finn/transformation/qonnx/infer_quant_avg_pool_2d.py b/src/finn/transformation/qonnx/infer_quant_avg_pool_2d.py
index 70d97946a..3de317b9e 100644
--- a/src/finn/transformation/qonnx/infer_quant_avg_pool_2d.py
+++ b/src/finn/transformation/qonnx/infer_quant_avg_pool_2d.py
@@ -127,32 +127,28 @@ class AvgPoolAndTruncToQuantAvgPool(Transformation):
                                     f"except the first, must be statically "
                                     f"initialized. However, {inp} is not."
                                 )
-                        scale = model.get_initializer(t_node.input[1])
-                        if len(scale.shape) != 0:
-                            raise ValueError(
-                                f"Finn only supports scalar scales of zero dimension "
-                                f"for the Trunc node, it currently is {scale}."
-                            )
                         zero_pt = model.get_initializer(t_node.input[2])
                         if len(zero_pt.shape) != 0 or zero_pt != 0:
                             raise ValueError(
                                 f"Finn only supports 0 as the zero point for "
                                 f"the Trunc node, it currently is {zero_pt}."
                             )
-                        trunc_in_bits = model.get_initializer(t_node.input[3])
-                        trunc_out_bits = model.get_initializer(t_node.input[4])
+                        trunc_in_bits = model.get_initializer(t_node.input[3]).flatten()
+                        trunc_out_bits = model.get_initializer(
+                            t_node.input[4]
+                        ).flatten()
                         if (
-                            len(trunc_in_bits.shape) != 0
-                            or len(trunc_out_bits.shape) != 0
+                            len(trunc_in_bits.shape) != 1
+                            or len(trunc_out_bits.shape) != 1
                         ):
                             raise ValueError(
-                                f"Finn only supports scalar bit widths of zero "
-                                f"dimension for the Trunc node. The input bit width "
+                                f"Finn only supports scalar bit widths "
+                                f"for the Trunc node. The input bit width "
                                 f"currently is: {trunc_in_bits}, "
                                 f"while the output bit width is: {trunc_out_bits}."
                             )
-                        trunc_in_bits = int(trunc_in_bits)
-                        trunc_out_bits = int(trunc_out_bits)
+                        trunc_in_bits = int(trunc_in_bits[0])
+                        trunc_out_bits = int(trunc_out_bits[0])
 
                         # Calculate parameters for the QuantAvgPool2d node,
                         # Calculate input bit width. Basically this backwards:
@@ -170,6 +166,7 @@ class AvgPoolAndTruncToQuantAvgPool(Transformation):
                         data_layout = "NCHW"
 
                         # Insert scale nodes, QuantAvgPool2d node and required tensors
+                        scale = model.get_initializer(t_node.input[1])
                         scale_div_tensor = helper.make_tensor_value_info(
                             model.make_new_valueinfo_name(),
                             TensorProto.FLOAT,
-- 
GitLab