diff --git a/src/finn/transformation/bipolar_to_xnor.py b/src/finn/transformation/bipolar_to_xnor.py index 411acad67039c17bc9730a744ae6ae6afd2a4c3e..2e43dbc8137fe2c0a12c097fe523a470ae6cafb5 100644 --- a/src/finn/transformation/bipolar_to_xnor.py +++ b/src/finn/transformation/bipolar_to_xnor.py @@ -53,7 +53,27 @@ class ConvertBipolarMatMulToXnorPopcount(Transformation): i_bp = model.get_tensor_datatype(mm_input) == DataType.BIPOLAR w_bp = model.get_tensor_datatype(mm_weight) == DataType.BIPOLAR if i_bp and w_bp: - graph_modified = True + # find producing threshold node and adjust output to binary + def find_prod_mt(x): + is_mt = x.op_type == "MultiThreshold" + is_bp = False + if is_mt: + dt = get_by_name(x.attribute, "out_dtype").s + is_bp = dt.decode("utf-8") == "BIPOLAR" + return is_mt and is_bp + + mt_chain = model.find_upstream(mm_input, find_prod_mt) + if len(mt_chain) == 0: + raise Exception( + """Could not find upstream bipolar + MultiThreshold""" + ) + mt = mt_chain[-1] + bin_dt_attr = "BINARY".encode("utf-8") + get_by_name(mt.attribute, "out_dtype").s = bin_dt_attr + get_by_name(mt.attribute, "out_scale").f = 1.0 + get_by_name(mt.attribute, "out_bias").f = 0 + model.set_tensor_datatype(mm_input, DataType.BINARY) # change node type and domain n.op_type = "XnorPopcountMatMul" n.domain = "finn" @@ -63,37 +83,7 @@ class ConvertBipolarMatMulToXnorPopcount(Transformation): K = Wbin.shape[0] model.set_initializer(mm_weight, Wbin) model.set_tensor_datatype(mm_weight, DataType.BINARY) - # find producing threshold node and adjust output to binary - mt = model.find_producer(mm_input) - if mt is not None and mt.op_type == "MultiThreshold": - bin_dt_attr = "BINARY".encode("utf-8") - get_by_name(mt.attribute, "out_dtype").s = bin_dt_attr - get_by_name(mt.attribute, "out_scale").f = 1.0 - get_by_name(mt.attribute, "out_bias").f = 0 - model.set_tensor_datatype(mm_input, DataType.BINARY) - elif mt is not None and mt.op_type == "Im2Col": - # Im2Col node for lowered convolution - # go one more step back to see if we find threshold - # node there - i2c = mt - mt = model.find_producer(i2c.input[0]) - if mt is not None and mt.op_type == "MultiThreshold": - bin_dt_attr = "BINARY".encode("utf-8") - get_by_name(mt.attribute, "out_dtype").s = bin_dt_attr - get_by_name(mt.attribute, "out_scale").f = 1.0 - get_by_name(mt.attribute, "out_bias").f = 0 - model.set_tensor_datatype(mm_input, DataType.BINARY) - else: - raise Exception( - """Found to MultiThreshold before - Im2Col for bipolar->binary conversion - """ - ) - else: - raise Exception( - """Requires Bipolar2Binary, not yet - implemented.""" - ) + graph_modified = True # make new output node with correct shape mm_out_shape = model.get_tensor_shape(mm_output) xnorpcout = oh.make_tensor_value_info(