diff --git a/src/finn/transformation/bipolar_to_xnor.py b/src/finn/transformation/bipolar_to_xnor.py
index 411acad67039c17bc9730a744ae6ae6afd2a4c3e..2e43dbc8137fe2c0a12c097fe523a470ae6cafb5 100644
--- a/src/finn/transformation/bipolar_to_xnor.py
+++ b/src/finn/transformation/bipolar_to_xnor.py
@@ -53,7 +53,27 @@ class ConvertBipolarMatMulToXnorPopcount(Transformation):
                 i_bp = model.get_tensor_datatype(mm_input) == DataType.BIPOLAR
                 w_bp = model.get_tensor_datatype(mm_weight) == DataType.BIPOLAR
                 if i_bp and w_bp:
-                    graph_modified = True
+                    # find producing threshold node and adjust output to binary
+                    def find_prod_mt(x):
+                        is_mt = x.op_type == "MultiThreshold"
+                        is_bp = False
+                        if is_mt:
+                            dt = get_by_name(x.attribute, "out_dtype").s
+                            is_bp = dt.decode("utf-8") == "BIPOLAR"
+                        return is_mt and is_bp
+
+                    mt_chain = model.find_upstream(mm_input, find_prod_mt)
+                    if len(mt_chain) == 0:
+                        raise Exception(
+                            """Could not find upstream bipolar
+                                            MultiThreshold"""
+                        )
+                    mt = mt_chain[-1]
+                    bin_dt_attr = "BINARY".encode("utf-8")
+                    get_by_name(mt.attribute, "out_dtype").s = bin_dt_attr
+                    get_by_name(mt.attribute, "out_scale").f = 1.0
+                    get_by_name(mt.attribute, "out_bias").f = 0
+                    model.set_tensor_datatype(mm_input, DataType.BINARY)
                     # change node type and domain
                     n.op_type = "XnorPopcountMatMul"
                     n.domain = "finn"
@@ -63,37 +83,7 @@ class ConvertBipolarMatMulToXnorPopcount(Transformation):
                     K = Wbin.shape[0]
                     model.set_initializer(mm_weight, Wbin)
                     model.set_tensor_datatype(mm_weight, DataType.BINARY)
-                    # find producing threshold node and adjust output to binary
-                    mt = model.find_producer(mm_input)
-                    if mt is not None and mt.op_type == "MultiThreshold":
-                        bin_dt_attr = "BINARY".encode("utf-8")
-                        get_by_name(mt.attribute, "out_dtype").s = bin_dt_attr
-                        get_by_name(mt.attribute, "out_scale").f = 1.0
-                        get_by_name(mt.attribute, "out_bias").f = 0
-                        model.set_tensor_datatype(mm_input, DataType.BINARY)
-                    elif mt is not None and mt.op_type == "Im2Col":
-                        # Im2Col node for lowered convolution
-                        # go one more step back to see if we find threshold
-                        # node there
-                        i2c = mt
-                        mt = model.find_producer(i2c.input[0])
-                        if mt is not None and mt.op_type == "MultiThreshold":
-                            bin_dt_attr = "BINARY".encode("utf-8")
-                            get_by_name(mt.attribute, "out_dtype").s = bin_dt_attr
-                            get_by_name(mt.attribute, "out_scale").f = 1.0
-                            get_by_name(mt.attribute, "out_bias").f = 0
-                            model.set_tensor_datatype(mm_input, DataType.BINARY)
-                        else:
-                            raise Exception(
-                                """Found to MultiThreshold before
-                                            Im2Col for bipolar->binary conversion
-                                            """
-                            )
-                    else:
-                        raise Exception(
-                            """Requires Bipolar2Binary, not yet
-                                        implemented."""
-                        )
+                    graph_modified = True
                     # make new output node with correct shape
                     mm_out_shape = model.get_tensor_shape(mm_output)
                     xnorpcout = oh.make_tensor_value_info(