Skip to content
Snippets Groups Projects
Commit 8a6fb6ff authored by Hendrik Borras's avatar Hendrik Borras
Browse files

Enabled non-scalar scales for the Trunc node and made bit width readout more robust.

parent 2a73e420
No related branches found
No related tags found
No related merge requests found
...@@ -127,32 +127,28 @@ class AvgPoolAndTruncToQuantAvgPool(Transformation): ...@@ -127,32 +127,28 @@ class AvgPoolAndTruncToQuantAvgPool(Transformation):
f"except the first, must be statically " f"except the first, must be statically "
f"initialized. However, {inp} is not." f"initialized. However, {inp} is not."
) )
scale = model.get_initializer(t_node.input[1])
if len(scale.shape) != 0:
raise ValueError(
f"Finn only supports scalar scales of zero dimension "
f"for the Trunc node, it currently is {scale}."
)
zero_pt = model.get_initializer(t_node.input[2]) zero_pt = model.get_initializer(t_node.input[2])
if len(zero_pt.shape) != 0 or zero_pt != 0: if len(zero_pt.shape) != 0 or zero_pt != 0:
raise ValueError( raise ValueError(
f"Finn only supports 0 as the zero point for " f"Finn only supports 0 as the zero point for "
f"the Trunc node, it currently is {zero_pt}." f"the Trunc node, it currently is {zero_pt}."
) )
trunc_in_bits = model.get_initializer(t_node.input[3]) trunc_in_bits = model.get_initializer(t_node.input[3]).flatten()
trunc_out_bits = model.get_initializer(t_node.input[4]) trunc_out_bits = model.get_initializer(
t_node.input[4]
).flatten()
if ( if (
len(trunc_in_bits.shape) != 0 len(trunc_in_bits.shape) != 1
or len(trunc_out_bits.shape) != 0 or len(trunc_out_bits.shape) != 1
): ):
raise ValueError( raise ValueError(
f"Finn only supports scalar bit widths of zero " f"Finn only supports scalar bit widths "
f"dimension for the Trunc node. The input bit width " f"for the Trunc node. The input bit width "
f"currently is: {trunc_in_bits}, " f"currently is: {trunc_in_bits}, "
f"while the output bit width is: {trunc_out_bits}." f"while the output bit width is: {trunc_out_bits}."
) )
trunc_in_bits = int(trunc_in_bits) trunc_in_bits = int(trunc_in_bits[0])
trunc_out_bits = int(trunc_out_bits) trunc_out_bits = int(trunc_out_bits[0])
# Calculate parameters for the QuantAvgPool2d node, # Calculate parameters for the QuantAvgPool2d node,
# Calculate input bit width. Basically this backwards: # Calculate input bit width. Basically this backwards:
...@@ -170,6 +166,7 @@ class AvgPoolAndTruncToQuantAvgPool(Transformation): ...@@ -170,6 +166,7 @@ class AvgPoolAndTruncToQuantAvgPool(Transformation):
data_layout = "NCHW" data_layout = "NCHW"
# Insert scale nodes, QuantAvgPool2d node and required tensors # Insert scale nodes, QuantAvgPool2d node and required tensors
scale = model.get_initializer(t_node.input[1])
scale_div_tensor = helper.make_tensor_value_info( scale_div_tensor = helper.make_tensor_value_info(
model.make_new_valueinfo_name(), model.make_new_valueinfo_name(),
TensorProto.FLOAT, TensorProto.FLOAT,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment