Skip to content
Snippets Groups Projects
Commit 658fe4ef authored by Hendrik Borras's avatar Hendrik Borras
Browse files

Automatically move scalar and bias node into MultiThreshold node where possible.

parent 58d28b11
No related branches found
No related tags found
No related merge requests found
...@@ -52,8 +52,6 @@ class ConvertQONNXtoFINN(Transformation): ...@@ -52,8 +52,6 @@ class ConvertQONNXtoFINN(Transformation):
model = model.transform(FoldQuantWeights()) model = model.transform(FoldQuantWeights())
# Convert activations # Convert activations
model = model.transform(ConvertQuantActToMultiThreshold()) model = model.transform(ConvertQuantActToMultiThreshold())
# Infer types again
model = model.transform(InferDataTypes())
# Unset FINN datatypes from MultiThreshold node output tensors to avoid warnings # Unset FINN datatypes from MultiThreshold node output tensors to avoid warnings
mt_nodes = model.get_nodes_by_op_type("MultiThreshold") mt_nodes = model.get_nodes_by_op_type("MultiThreshold")
......
...@@ -123,9 +123,6 @@ class QuantActBaseHandler(ABC): ...@@ -123,9 +123,6 @@ class QuantActBaseHandler(ABC):
graph = model.graph graph = model.graph
n = self._q_node n = self._q_node
running_node_index = self._q_index running_node_index = self._q_index
successor = model.find_direct_successors(n)
if successor is not None:
successor = successor[0]
# Calculate insertion parameters # Calculate insertion parameters
parameter_dict = self.calculate_node_parameters() parameter_dict = self.calculate_node_parameters()
...@@ -155,69 +152,105 @@ class QuantActBaseHandler(ABC): ...@@ -155,69 +152,105 @@ class QuantActBaseHandler(ABC):
graph.node.insert(running_node_index, outp_trans_node) graph.node.insert(running_node_index, outp_trans_node)
running_node_index += 1 running_node_index += 1
# Insert Add node # Get the MultiThreshold node instance to work with
if adder_bias.shape == (1,): mt_inst = getCustomOp(graph.node[running_node_index - 1])
adder_bias = adder_bias[0]
add_shape = tuple() # Set scale and bias
# If these values are scalar then they can be set as attributes
# of the MultiThreshold node, if not they get inserted as adder and mul nodes
# behind the MultiTrheshold nodes.
scale_compatible = adder_bias.shape == (1,) or len(adder_bias.shape) == 0
bias_compatible = mul_scale.shape == (1,) or len(mul_scale.shape) == 0
if scale_compatible and bias_compatible and True:
# Get Quant parameters
mul_scale = np.atleast_1d(mul_scale)
# ONNX only accepts 64bit floats as attributes
mul_scale = mul_scale.astype(dtype=np.float64)
adder_bias = np.atleast_1d(adder_bias)
adder_bias = adder_bias.astype(dtype=np.float64)
# Set Bias and scale
mt_inst.set_nodeattr("out_scale", mul_scale[0])
# FINN applies scale first then bias,
# which is the other way around in Brevitas,
# we thus need to adjust the bias in the MultiThreshold node
mt_inst.set_nodeattr("out_bias", adder_bias[0] * mul_scale[0])
else: else:
add_shape = adder_bias.shape if bias_compatible:
add_tensor = helper.make_tensor_value_info( adder_bias = np.atleast_1d(adder_bias)
model.make_new_valueinfo_name(), # ONNX only accepts 64bit floats as attributes
TensorProto.FLOAT, adder_bias = adder_bias.astype(dtype=np.float64)[0]
add_shape, add_shape = tuple()
) else:
graph.value_info.append(add_tensor) add_shape = adder_bias.shape
model.set_initializer(add_tensor.name, adder_bias)
in_tensor = n.output[0]
output_shape = model.get_tensor_shape(n.output[0]) successor_node = model.find_direct_successors(n)
act_add_tensor = helper.make_tensor_value_info( if successor_node is not None:
model.make_new_valueinfo_name(), successor_node = successor_node[0]
TensorProto.FLOAT, # Insert Add node
output_shape, add_tensor = helper.make_tensor_value_info(
) model.make_new_valueinfo_name(),
graph.value_info.append(act_add_tensor) TensorProto.FLOAT,
if successor is not None: add_shape,
successor.input[0] = act_add_tensor.name )
graph.value_info.append(add_tensor)
add_node = helper.make_node( model.set_initializer(add_tensor.name, adder_bias)
"Add",
[n.output[0], add_tensor.name], output_shape = model.get_tensor_shape(n.output[0])
[act_add_tensor.name], act_add_tensor = helper.make_tensor_value_info(
) model.make_new_valueinfo_name(),
graph.node.insert(running_node_index, add_node) TensorProto.FLOAT,
running_node_index += 1 output_shape,
)
# Insert Mul node graph.value_info.append(act_add_tensor)
if mul_scale.shape == (1,): if successor_node is not None:
mul_scale = mul_scale[0] successor_node.input[0] = act_add_tensor.name
mul_shape = tuple()
else: add_node = helper.make_node(
mul_shape = mul_scale.shape "Add",
mul_tensor = helper.make_tensor_value_info( [in_tensor, add_tensor.name],
model.make_new_valueinfo_name(), [act_add_tensor.name],
TensorProto.FLOAT, )
mul_shape, graph.node.insert(running_node_index, add_node)
) running_node_index += 1
graph.value_info.append(mul_tensor)
model.set_initializer(mul_tensor.name, mul_scale) # Re-point the input node for the next node to insert
in_tensor = act_add_tensor.name
output_shape = model.get_tensor_shape(n.output[0])
act_mul_tensor = helper.make_tensor_value_info( # Set scale
model.make_new_valueinfo_name(), # Insert Mul node
TensorProto.FLOAT, if mul_scale:
output_shape, mul_scale = np.atleast_1d(mul_scale)
) mul_scale = mul_scale.astype(dtype=np.float64)[0]
graph.value_info.append(act_mul_tensor) mul_shape = tuple()
if successor is not None: else:
successor.input[0] = act_mul_tensor.name mul_shape = mul_scale.shape
mul_tensor = helper.make_tensor_value_info(
mul_node = helper.make_node( model.make_new_valueinfo_name(),
"Mul", TensorProto.FLOAT,
[act_add_tensor.name, mul_tensor.name], mul_shape,
[act_mul_tensor.name], )
) graph.value_info.append(mul_tensor)
graph.node.insert(running_node_index, mul_node) model.set_initializer(mul_tensor.name, mul_scale)
running_node_index += 1
output_shape = model.get_tensor_shape(n.output[0])
act_mul_tensor = helper.make_tensor_value_info(
model.make_new_valueinfo_name(),
TensorProto.FLOAT,
output_shape,
)
graph.value_info.append(act_mul_tensor)
if successor_node is not None:
successor_node.input[0] = act_mul_tensor.name
mul_node = helper.make_node(
"Mul",
[in_tensor, mul_tensor.name],
[act_mul_tensor.name],
)
graph.node.insert(running_node_index, mul_node)
running_node_index += 1
# Remove activation node # Remove activation node
self._remove_activation_node() self._remove_activation_node()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment