diff --git a/src/finn/transformation/bipolar_to_xnor.py b/src/finn/transformation/bipolar_to_xnor.py index 2e43dbc8137fe2c0a12c097fe523a470ae6cafb5..4c7ebaf04e35f94e84e52e0b4520ee2369502120 100644 --- a/src/finn/transformation/bipolar_to_xnor.py +++ b/src/finn/transformation/bipolar_to_xnor.py @@ -33,6 +33,7 @@ from onnx import helper as oh from finn.core.datatype import DataType from finn.transformation import Transformation from finn.transformation.infer_shapes import InferShapes +from finn.transformation.infer_datatypes import InferDataTypes from finn.util.basic import get_by_name @@ -68,6 +69,7 @@ class ConvertBipolarMatMulToXnorPopcount(Transformation): """Could not find upstream bipolar MultiThreshold""" ) + graph_modified = True mt = mt_chain[-1] bin_dt_attr = "BINARY".encode("utf-8") get_by_name(mt.attribute, "out_dtype").s = bin_dt_attr @@ -83,7 +85,6 @@ class ConvertBipolarMatMulToXnorPopcount(Transformation): K = Wbin.shape[0] model.set_initializer(mm_weight, Wbin) model.set_tensor_datatype(mm_weight, DataType.BINARY) - graph_modified = True # make new output node with correct shape mm_out_shape = model.get_tensor_shape(mm_output) xnorpcout = oh.make_tensor_value_info( @@ -121,5 +122,7 @@ class ConvertBipolarMatMulToXnorPopcount(Transformation): # insert where the batchnorm is to preserve topological ordering graph.node.insert(node_ind, mul_node) graph.node.insert(node_ind + 1, add_node) - model = model.transform(InferShapes()) + if graph_modified: + model = model.transform(InferShapes()) + model = model.transform(InferDataTypes()) return (model, graph_modified)