diff --git a/src/finn/transformation/bipolar_to_xnor.py b/src/finn/transformation/bipolar_to_xnor.py
index 2e43dbc8137fe2c0a12c097fe523a470ae6cafb5..4c7ebaf04e35f94e84e52e0b4520ee2369502120 100644
--- a/src/finn/transformation/bipolar_to_xnor.py
+++ b/src/finn/transformation/bipolar_to_xnor.py
@@ -33,6 +33,7 @@ from onnx import helper as oh
 from finn.core.datatype import DataType
 from finn.transformation import Transformation
 from finn.transformation.infer_shapes import InferShapes
+from finn.transformation.infer_datatypes import InferDataTypes
 from finn.util.basic import get_by_name
 
 
@@ -68,6 +69,7 @@ class ConvertBipolarMatMulToXnorPopcount(Transformation):
                             """Could not find upstream bipolar
                                             MultiThreshold"""
                         )
+                    graph_modified = True
                     mt = mt_chain[-1]
                     bin_dt_attr = "BINARY".encode("utf-8")
                     get_by_name(mt.attribute, "out_dtype").s = bin_dt_attr
@@ -83,7 +85,6 @@ class ConvertBipolarMatMulToXnorPopcount(Transformation):
                     K = Wbin.shape[0]
                     model.set_initializer(mm_weight, Wbin)
                     model.set_tensor_datatype(mm_weight, DataType.BINARY)
-                    graph_modified = True
                     # make new output node with correct shape
                     mm_out_shape = model.get_tensor_shape(mm_output)
                     xnorpcout = oh.make_tensor_value_info(
@@ -121,5 +122,7 @@ class ConvertBipolarMatMulToXnorPopcount(Transformation):
                     # insert where the batchnorm is to preserve topological ordering
                     graph.node.insert(node_ind, mul_node)
                     graph.node.insert(node_ind + 1, add_node)
-        model = model.transform(InferShapes())
+        if graph_modified:
+            model = model.transform(InferShapes())
+            model = model.transform(InferDataTypes())
         return (model, graph_modified)