From 8a6fb6ff00b975f5af0f4dbc493d974d337b381c Mon Sep 17 00:00:00 2001 From: Hendrik Borras <hendrikborras@web.de> Date: Tue, 19 Oct 2021 19:17:15 +0100 Subject: [PATCH] Enabled non-scalar scales for the Trunc node and made bit width readout more robust. --- .../qonnx/infer_quant_avg_pool_2d.py | 25 ++++++++----------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/src/finn/transformation/qonnx/infer_quant_avg_pool_2d.py b/src/finn/transformation/qonnx/infer_quant_avg_pool_2d.py index 70d97946a..3de317b9e 100644 --- a/src/finn/transformation/qonnx/infer_quant_avg_pool_2d.py +++ b/src/finn/transformation/qonnx/infer_quant_avg_pool_2d.py @@ -127,32 +127,28 @@ class AvgPoolAndTruncToQuantAvgPool(Transformation): f"except the first, must be statically " f"initialized. However, {inp} is not." ) - scale = model.get_initializer(t_node.input[1]) - if len(scale.shape) != 0: - raise ValueError( - f"Finn only supports scalar scales of zero dimension " - f"for the Trunc node, it currently is {scale}." - ) zero_pt = model.get_initializer(t_node.input[2]) if len(zero_pt.shape) != 0 or zero_pt != 0: raise ValueError( f"Finn only supports 0 as the zero point for " f"the Trunc node, it currently is {zero_pt}." ) - trunc_in_bits = model.get_initializer(t_node.input[3]) - trunc_out_bits = model.get_initializer(t_node.input[4]) + trunc_in_bits = model.get_initializer(t_node.input[3]).flatten() + trunc_out_bits = model.get_initializer( + t_node.input[4] + ).flatten() if ( - len(trunc_in_bits.shape) != 0 - or len(trunc_out_bits.shape) != 0 + len(trunc_in_bits.shape) != 1 + or len(trunc_out_bits.shape) != 1 ): raise ValueError( - f"Finn only supports scalar bit widths of zero " - f"dimension for the Trunc node. The input bit width " + f"Finn only supports scalar bit widths " + f"for the Trunc node. The input bit width " f"currently is: {trunc_in_bits}, " f"while the output bit width is: {trunc_out_bits}." ) - trunc_in_bits = int(trunc_in_bits) - trunc_out_bits = int(trunc_out_bits) + trunc_in_bits = int(trunc_in_bits[0]) + trunc_out_bits = int(trunc_out_bits[0]) # Calculate parameters for the QuantAvgPool2d node, # Calculate input bit width. Basically this backwards: @@ -170,6 +166,7 @@ class AvgPoolAndTruncToQuantAvgPool(Transformation): data_layout = "NCHW" # Insert scale nodes, QuantAvgPool2d node and required tensors + scale = model.get_initializer(t_node.input[1]) scale_div_tensor = helper.make_tensor_value_info( model.make_new_valueinfo_name(), TensorProto.FLOAT, -- GitLab