diff --git a/src/finn/transformation/qonnx/convert_qonnx_to_finn.py b/src/finn/transformation/qonnx/convert_qonnx_to_finn.py index 5b218f2c38592afff3b790395154454e563028bb..ae49d3cd21f805339deab8658aaa6a324a72ea98 100644 --- a/src/finn/transformation/qonnx/convert_qonnx_to_finn.py +++ b/src/finn/transformation/qonnx/convert_qonnx_to_finn.py @@ -52,8 +52,6 @@ class ConvertQONNXtoFINN(Transformation): model = model.transform(FoldQuantWeights()) # Convert activations model = model.transform(ConvertQuantActToMultiThreshold()) - # Infer types again - model = model.transform(InferDataTypes()) # Unset FINN datatypes from MultiThreshold node output tensors to avoid warnings mt_nodes = model.get_nodes_by_op_type("MultiThreshold") diff --git a/src/finn/transformation/qonnx/qonnx_activation_handlers.py b/src/finn/transformation/qonnx/qonnx_activation_handlers.py index d5c00c73dab68a479a1f0f7cce7e7395d4fb6bd4..26c65a4cad600029d452326896b041e71f9423e7 100644 --- a/src/finn/transformation/qonnx/qonnx_activation_handlers.py +++ b/src/finn/transformation/qonnx/qonnx_activation_handlers.py @@ -123,9 +123,6 @@ class QuantActBaseHandler(ABC): graph = model.graph n = self._q_node running_node_index = self._q_index - successor = model.find_direct_successors(n) - if successor is not None: - successor = successor[0] # Calculate insertion parameters parameter_dict = self.calculate_node_parameters() @@ -155,69 +152,105 @@ class QuantActBaseHandler(ABC): graph.node.insert(running_node_index, outp_trans_node) running_node_index += 1 - # Insert Add node - if adder_bias.shape == (1,): - adder_bias = adder_bias[0] - add_shape = tuple() + # Get the MultiThreshold node instance to work with + mt_inst = getCustomOp(graph.node[running_node_index - 1]) + + # Set scale and bias + # If these values are scalar then they can be set as attributes + # of the MultiThreshold node, if not they get inserted as adder and mul nodes + # behind the MultiTrheshold nodes. + scale_compatible = adder_bias.shape == (1,) or len(adder_bias.shape) == 0 + bias_compatible = mul_scale.shape == (1,) or len(mul_scale.shape) == 0 + if scale_compatible and bias_compatible and True: + # Get Quant parameters + mul_scale = np.atleast_1d(mul_scale) + # ONNX only accepts 64bit floats as attributes + mul_scale = mul_scale.astype(dtype=np.float64) + adder_bias = np.atleast_1d(adder_bias) + adder_bias = adder_bias.astype(dtype=np.float64) + + # Set Bias and scale + mt_inst.set_nodeattr("out_scale", mul_scale[0]) + # FINN applies scale first then bias, + # which is the other way around in Brevitas, + # we thus need to adjust the bias in the MultiThreshold node + mt_inst.set_nodeattr("out_bias", adder_bias[0] * mul_scale[0]) else: - add_shape = adder_bias.shape - add_tensor = helper.make_tensor_value_info( - model.make_new_valueinfo_name(), - TensorProto.FLOAT, - add_shape, - ) - graph.value_info.append(add_tensor) - model.set_initializer(add_tensor.name, adder_bias) - - output_shape = model.get_tensor_shape(n.output[0]) - act_add_tensor = helper.make_tensor_value_info( - model.make_new_valueinfo_name(), - TensorProto.FLOAT, - output_shape, - ) - graph.value_info.append(act_add_tensor) - if successor is not None: - successor.input[0] = act_add_tensor.name - - add_node = helper.make_node( - "Add", - [n.output[0], add_tensor.name], - [act_add_tensor.name], - ) - graph.node.insert(running_node_index, add_node) - running_node_index += 1 - - # Insert Mul node - if mul_scale.shape == (1,): - mul_scale = mul_scale[0] - mul_shape = tuple() - else: - mul_shape = mul_scale.shape - mul_tensor = helper.make_tensor_value_info( - model.make_new_valueinfo_name(), - TensorProto.FLOAT, - mul_shape, - ) - graph.value_info.append(mul_tensor) - model.set_initializer(mul_tensor.name, mul_scale) - - output_shape = model.get_tensor_shape(n.output[0]) - act_mul_tensor = helper.make_tensor_value_info( - model.make_new_valueinfo_name(), - TensorProto.FLOAT, - output_shape, - ) - graph.value_info.append(act_mul_tensor) - if successor is not None: - successor.input[0] = act_mul_tensor.name - - mul_node = helper.make_node( - "Mul", - [act_add_tensor.name, mul_tensor.name], - [act_mul_tensor.name], - ) - graph.node.insert(running_node_index, mul_node) - running_node_index += 1 + if bias_compatible: + adder_bias = np.atleast_1d(adder_bias) + # ONNX only accepts 64bit floats as attributes + adder_bias = adder_bias.astype(dtype=np.float64)[0] + add_shape = tuple() + else: + add_shape = adder_bias.shape + + in_tensor = n.output[0] + successor_node = model.find_direct_successors(n) + if successor_node is not None: + successor_node = successor_node[0] + # Insert Add node + add_tensor = helper.make_tensor_value_info( + model.make_new_valueinfo_name(), + TensorProto.FLOAT, + add_shape, + ) + graph.value_info.append(add_tensor) + model.set_initializer(add_tensor.name, adder_bias) + + output_shape = model.get_tensor_shape(n.output[0]) + act_add_tensor = helper.make_tensor_value_info( + model.make_new_valueinfo_name(), + TensorProto.FLOAT, + output_shape, + ) + graph.value_info.append(act_add_tensor) + if successor_node is not None: + successor_node.input[0] = act_add_tensor.name + + add_node = helper.make_node( + "Add", + [in_tensor, add_tensor.name], + [act_add_tensor.name], + ) + graph.node.insert(running_node_index, add_node) + running_node_index += 1 + + # Re-point the input node for the next node to insert + in_tensor = act_add_tensor.name + + # Set scale + # Insert Mul node + if mul_scale: + mul_scale = np.atleast_1d(mul_scale) + mul_scale = mul_scale.astype(dtype=np.float64)[0] + mul_shape = tuple() + else: + mul_shape = mul_scale.shape + mul_tensor = helper.make_tensor_value_info( + model.make_new_valueinfo_name(), + TensorProto.FLOAT, + mul_shape, + ) + graph.value_info.append(mul_tensor) + model.set_initializer(mul_tensor.name, mul_scale) + + output_shape = model.get_tensor_shape(n.output[0]) + act_mul_tensor = helper.make_tensor_value_info( + model.make_new_valueinfo_name(), + TensorProto.FLOAT, + output_shape, + ) + graph.value_info.append(act_mul_tensor) + if successor_node is not None: + successor_node.input[0] = act_mul_tensor.name + + mul_node = helper.make_node( + "Mul", + [in_tensor, mul_tensor.name], + [act_mul_tensor.name], + ) + graph.node.insert(running_node_index, mul_node) + running_node_index += 1 # Remove activation node self._remove_activation_node()