Skip to content
Snippets Groups Projects
Commit 836e6e32 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Transform] add absorb_1bit_mul_into_matmul

parent 5f757163
No related branches found
No related tags found
No related merge requests found
......@@ -335,3 +335,31 @@ def factor_out_mul_sign_magnitude(model):
graph.node.insert(node_ind - 1, new_mul)
graph_modified = True
return (model, graph_modified)
def absorb_1bit_mul_into_matmul(model):
"""Absorb bipolar or binary multiplications into the preciding matrix
multiply."""
graph = model.graph
node_ind = 0
graph_modified = False
for n in graph.node:
node_ind += 1
if n.op_type == "MatMul":
matmul_weight_name = n.input[1]
W = model.get_initializer(matmul_weight_name)
assert W is not None
consumer = model.find_consumer(n.output[0])
if consumer is not None and consumer.op_type == "Mul":
mul_weight_name = consumer.input[1]
A = model.get_initializer(mul_weight_name)
assert A is not None
is_1bit = model.get_tensor_datatype(mul_weight_name).bitwidth() == 1
if is_1bit:
Wnew = A * W
assert Wnew.shape == W.shape
model.set_initializer(matmul_weight_name, Wnew)
n.output[0] = consumer.output[0]
graph.node.remove(consumer)
graph_modified = True
return (model, graph_modified)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment