diff --git a/src/finn/transformation/streamline.py b/src/finn/transformation/streamline.py index 52a922c766fd8ef71c019edd8d65f65260cbe373..309177e548be490a7de05aad6006b838b207f3bc 100644 --- a/src/finn/transformation/streamline.py +++ b/src/finn/transformation/streamline.py @@ -335,3 +335,31 @@ def factor_out_mul_sign_magnitude(model): graph.node.insert(node_ind - 1, new_mul) graph_modified = True return (model, graph_modified) + + +def absorb_1bit_mul_into_matmul(model): + """Absorb bipolar or binary multiplications into the preciding matrix + multiply.""" + graph = model.graph + node_ind = 0 + graph_modified = False + for n in graph.node: + node_ind += 1 + if n.op_type == "MatMul": + matmul_weight_name = n.input[1] + W = model.get_initializer(matmul_weight_name) + assert W is not None + consumer = model.find_consumer(n.output[0]) + if consumer is not None and consumer.op_type == "Mul": + mul_weight_name = consumer.input[1] + A = model.get_initializer(mul_weight_name) + assert A is not None + is_1bit = model.get_tensor_datatype(mul_weight_name).bitwidth() == 1 + if is_1bit: + Wnew = A * W + assert Wnew.shape == W.shape + model.set_initializer(matmul_weight_name, Wnew) + n.output[0] = consumer.output[0] + graph.node.remove(consumer) + graph_modified = True + return (model, graph_modified)