diff --git a/src/finn/transformation/convert_qonnx_to_finn.py b/src/finn/transformation/convert_qonnx_to_finn.py index 47662d23b96be02242c7e557c7f9bc8f1d1186b6..a2fbd7f6f1901619c101e69f1b74bbb6893e07ca 100644 --- a/src/finn/transformation/convert_qonnx_to_finn.py +++ b/src/finn/transformation/convert_qonnx_to_finn.py @@ -6,12 +6,10 @@ from finn.core.modelwrapper import ModelWrapper from finn.custom_op.registry import getCustomOp from finn.transformation.base import Transformation -allowed_identity_successors = [ - "MatMul", - "Conv", - "MaxPool", - "Reshape", - None, +allowed_identity_predecessor = [ + "BatchNormalization", + "Sub", + # None, ] @@ -31,9 +29,9 @@ class ConvertQuantActToMultiThreshold(Transformation): out = model.get_initializer(n.output[0]) if not (inp is None and out is None): continue - successor = model.find_direct_successors(n) - if successor is not None: - successor = successor[0] + predecessor = model.find_direct_predecessors(n) + if predecessor is not None: + predecessor = predecessor[0] if model.is_fork_node(n): raise RuntimeError( "Forking Quant nodes are not currently supported by FINN." @@ -43,13 +41,14 @@ class ConvertQuantActToMultiThreshold(Transformation): # the Quant node, such as ReLu # Check that this is an idendity operation - if successor.op_type in allowed_identity_successors: + if predecessor.op_type in allowed_identity_predecessor: handler = QuantIdentityHandler(model, n, node_ind) else: raise RuntimeError( - f"Quant nodes with successor nodes of type {successor.op_type} " - f"are currently not supported by FINN and can not be converted " - f"to MultiThreshold nodes." + f"Quant nodes in the activation path and with predecessor " + f"nodes of type {predecessor.op_type} are currently not " + f"supported by FINN and can not be converted to " + f"MultiThreshold nodes." ) model = handler.replace_quant_node() graph_modified = True @@ -197,7 +196,7 @@ class QuantActBaseHandler(ABC): graph.node.insert(running_node_index, mul_node) running_node_index += 1 - # Now remove the Quant node + # Remove the Quant node graph.node.remove(n) # return the internal model representation