Skip to content
Snippets Groups Projects
Commit 78d481c0 authored by Hendrik Borras's avatar Hendrik Borras
Browse files

Fixed a bug where node insertion of the QuantActBaseHandler would fail at the end of a graph.

parent 6d7026ee
No related branches found
No related tags found
No related merge requests found
......@@ -152,7 +152,8 @@ class QuantActBaseHandler(ABC):
running_node_index += 1
# Get the MultiThreshold node instance to work with
mt_inst = getCustomOp(graph.node[running_node_index - 1])
mt_node = graph.node[running_node_index - 1]
mt_inst = getCustomOp(mt_node)
# Set scale and bias
# If these values are scalar then they can be set as attributes
......@@ -184,10 +185,7 @@ class QuantActBaseHandler(ABC):
mt_inst.set_nodeattr("out_dtype", out_dtype)
# Insertion parameters
in_tensor = n.output[0]
successor_node = model.find_direct_successors(n)
if successor_node is not None:
successor_node = successor_node[0]
up_stream_node = mt_node
# Set bias
zero_bias = False
......@@ -218,19 +216,19 @@ class QuantActBaseHandler(ABC):
output_shape,
)
graph.value_info.append(act_add_tensor)
if successor_node is not None:
successor_node.input[0] = act_add_tensor.name
add_node = helper.make_node(
"Add",
[in_tensor, add_tensor.name],
[act_add_tensor.name],
[act_add_tensor.name, add_tensor.name],
[n.output[0]],
)
graph.node.insert(running_node_index, add_node)
running_node_index += 1
add_node = graph.node[running_node_index - 1]
# Re-point the input node for the next node to insert
in_tensor = act_add_tensor.name
# Re-point the upstream node
up_stream_node.output[0] = act_add_tensor.name
up_stream_node = add_node
# Set scale
# Insert Mul node
......@@ -260,16 +258,19 @@ class QuantActBaseHandler(ABC):
output_shape,
)
graph.value_info.append(act_mul_tensor)
if successor_node is not None:
successor_node.input[0] = act_mul_tensor.name
mul_node = helper.make_node(
"Mul",
[in_tensor, mul_tensor.name],
[act_mul_tensor.name],
[act_mul_tensor.name, mul_tensor.name],
[n.output[0]],
)
graph.node.insert(running_node_index, mul_node)
running_node_index += 1
mul_node = graph.node[running_node_index - 1]
# Re-point the upstream node
up_stream_node.output[0] = act_mul_tensor.name
up_stream_node = mul_node
# Remove activation node
self._remove_activation_node()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment