diff --git a/src/finn/transformation/qonnx/qonnx_activation_handlers.py b/src/finn/transformation/qonnx/qonnx_activation_handlers.py index cbb94aa4846d8edb1456b559b2a4ca89deeaac47..a64ffb53845d73f0458a685fda239adc0b752703 100644 --- a/src/finn/transformation/qonnx/qonnx_activation_handlers.py +++ b/src/finn/transformation/qonnx/qonnx_activation_handlers.py @@ -152,7 +152,8 @@ class QuantActBaseHandler(ABC): running_node_index += 1 # Get the MultiThreshold node instance to work with - mt_inst = getCustomOp(graph.node[running_node_index - 1]) + mt_node = graph.node[running_node_index - 1] + mt_inst = getCustomOp(mt_node) # Set scale and bias # If these values are scalar then they can be set as attributes @@ -184,10 +185,7 @@ class QuantActBaseHandler(ABC): mt_inst.set_nodeattr("out_dtype", out_dtype) # Insertion parameters - in_tensor = n.output[0] - successor_node = model.find_direct_successors(n) - if successor_node is not None: - successor_node = successor_node[0] + up_stream_node = mt_node # Set bias zero_bias = False @@ -218,19 +216,19 @@ class QuantActBaseHandler(ABC): output_shape, ) graph.value_info.append(act_add_tensor) - if successor_node is not None: - successor_node.input[0] = act_add_tensor.name add_node = helper.make_node( "Add", - [in_tensor, add_tensor.name], - [act_add_tensor.name], + [act_add_tensor.name, add_tensor.name], + [n.output[0]], ) graph.node.insert(running_node_index, add_node) running_node_index += 1 + add_node = graph.node[running_node_index - 1] - # Re-point the input node for the next node to insert - in_tensor = act_add_tensor.name + # Re-point the upstream node + up_stream_node.output[0] = act_add_tensor.name + up_stream_node = add_node # Set scale # Insert Mul node @@ -260,16 +258,19 @@ class QuantActBaseHandler(ABC): output_shape, ) graph.value_info.append(act_mul_tensor) - if successor_node is not None: - successor_node.input[0] = act_mul_tensor.name mul_node = helper.make_node( "Mul", - [in_tensor, mul_tensor.name], - [act_mul_tensor.name], + [act_mul_tensor.name, mul_tensor.name], + [n.output[0]], ) graph.node.insert(running_node_index, mul_node) running_node_index += 1 + mul_node = graph.node[running_node_index - 1] + + # Re-point the upstream node + up_stream_node.output[0] = act_mul_tensor.name + up_stream_node = mul_node # Remove activation node self._remove_activation_node()