Skip to content
Snippets Groups Projects
Commit 565b42cc authored by Hendrik Borras's avatar Hendrik Borras
Browse files

Modified activation handler selection to consider predecessor instead of successor.

parent 86acb28f
No related branches found
No related tags found
No related merge requests found
...@@ -6,12 +6,10 @@ from finn.core.modelwrapper import ModelWrapper ...@@ -6,12 +6,10 @@ from finn.core.modelwrapper import ModelWrapper
from finn.custom_op.registry import getCustomOp from finn.custom_op.registry import getCustomOp
from finn.transformation.base import Transformation from finn.transformation.base import Transformation
allowed_identity_successors = [ allowed_identity_predecessor = [
"MatMul", "BatchNormalization",
"Conv", "Sub",
"MaxPool", # None,
"Reshape",
None,
] ]
...@@ -31,9 +29,9 @@ class ConvertQuantActToMultiThreshold(Transformation): ...@@ -31,9 +29,9 @@ class ConvertQuantActToMultiThreshold(Transformation):
out = model.get_initializer(n.output[0]) out = model.get_initializer(n.output[0])
if not (inp is None and out is None): if not (inp is None and out is None):
continue continue
successor = model.find_direct_successors(n) predecessor = model.find_direct_predecessors(n)
if successor is not None: if predecessor is not None:
successor = successor[0] predecessor = predecessor[0]
if model.is_fork_node(n): if model.is_fork_node(n):
raise RuntimeError( raise RuntimeError(
"Forking Quant nodes are not currently supported by FINN." "Forking Quant nodes are not currently supported by FINN."
...@@ -43,13 +41,14 @@ class ConvertQuantActToMultiThreshold(Transformation): ...@@ -43,13 +41,14 @@ class ConvertQuantActToMultiThreshold(Transformation):
# the Quant node, such as ReLu # the Quant node, such as ReLu
# Check that this is an idendity operation # Check that this is an idendity operation
if successor.op_type in allowed_identity_successors: if predecessor.op_type in allowed_identity_predecessor:
handler = QuantIdentityHandler(model, n, node_ind) handler = QuantIdentityHandler(model, n, node_ind)
else: else:
raise RuntimeError( raise RuntimeError(
f"Quant nodes with successor nodes of type {successor.op_type} " f"Quant nodes in the activation path and with predecessor "
f"are currently not supported by FINN and can not be converted " f"nodes of type {predecessor.op_type} are currently not "
f"to MultiThreshold nodes." f"supported by FINN and can not be converted to "
f"MultiThreshold nodes."
) )
model = handler.replace_quant_node() model = handler.replace_quant_node()
graph_modified = True graph_modified = True
...@@ -197,7 +196,7 @@ class QuantActBaseHandler(ABC): ...@@ -197,7 +196,7 @@ class QuantActBaseHandler(ABC):
graph.node.insert(running_node_index, mul_node) graph.node.insert(running_node_index, mul_node)
running_node_index += 1 running_node_index += 1
# Now remove the Quant node # Remove the Quant node
graph.node.remove(n) graph.node.remove(n)
# return the internal model representation # return the internal model representation
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment