Skip to content
Snippets Groups Projects
Commit 565b42cc authored by Hendrik Borras's avatar Hendrik Borras
Browse files

Modified activation handler selection to consider predecessor instead of successor.

parent 86acb28f
No related branches found
No related tags found
No related merge requests found
......@@ -6,12 +6,10 @@ from finn.core.modelwrapper import ModelWrapper
from finn.custom_op.registry import getCustomOp
from finn.transformation.base import Transformation
allowed_identity_successors = [
"MatMul",
"Conv",
"MaxPool",
"Reshape",
None,
allowed_identity_predecessor = [
"BatchNormalization",
"Sub",
# None,
]
......@@ -31,9 +29,9 @@ class ConvertQuantActToMultiThreshold(Transformation):
out = model.get_initializer(n.output[0])
if not (inp is None and out is None):
continue
successor = model.find_direct_successors(n)
if successor is not None:
successor = successor[0]
predecessor = model.find_direct_predecessors(n)
if predecessor is not None:
predecessor = predecessor[0]
if model.is_fork_node(n):
raise RuntimeError(
"Forking Quant nodes are not currently supported by FINN."
......@@ -43,13 +41,14 @@ class ConvertQuantActToMultiThreshold(Transformation):
# the Quant node, such as ReLu
# Check that this is an idendity operation
if successor.op_type in allowed_identity_successors:
if predecessor.op_type in allowed_identity_predecessor:
handler = QuantIdentityHandler(model, n, node_ind)
else:
raise RuntimeError(
f"Quant nodes with successor nodes of type {successor.op_type} "
f"are currently not supported by FINN and can not be converted "
f"to MultiThreshold nodes."
f"Quant nodes in the activation path and with predecessor "
f"nodes of type {predecessor.op_type} are currently not "
f"supported by FINN and can not be converted to "
f"MultiThreshold nodes."
)
model = handler.replace_quant_node()
graph_modified = True
......@@ -197,7 +196,7 @@ class QuantActBaseHandler(ABC):
graph.node.insert(running_node_index, mul_node)
running_node_index += 1
# Now remove the Quant node
# Remove the Quant node
graph.node.remove(n)
# return the internal model representation
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment