From efe1a7a395c46a92e999cc4de99dc6daaffdf630 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <yamanu@xilinx.com> Date: Wed, 30 Oct 2019 11:42:16 +0000 Subject: [PATCH] [Transform] add move_scalar_add_past_matmul --- src/finn/transformation/streamline.py | 47 ++++++++++++++++++++++++--- 1 file changed, 43 insertions(+), 4 deletions(-) diff --git a/src/finn/transformation/streamline.py b/src/finn/transformation/streamline.py index 23f844cd9..7d90df603 100644 --- a/src/finn/transformation/streamline.py +++ b/src/finn/transformation/streamline.py @@ -1,3 +1,4 @@ +import numpy as np from onnx import helper as oh import finn.transformation.infer_shapes as si @@ -114,10 +115,6 @@ def move_scalar_mul_past_matmul(model): if n.op_type == "Mul": consumer = model.find_consumer(n.output[0]) if consumer is not None and consumer.op_type == "MatMul": - # have: (x) -> add(,B) -> (x+B) -> mul(,A) -> (xA+BA) - # want: (x) -> mul(,A) -> (xA) -> add(,BA) -> (xA+BA) - # assume input 0 is from the previous layer, input 1 is the - # trained (constant) parameter mul_weight_name = n.input[1] matmul_weight_name = consumer.input[1] A = model.get_initializer(mul_weight_name) @@ -147,3 +144,45 @@ def move_scalar_mul_past_matmul(model): graph_modified = True model = model.transform_single(si.infer_shapes) return (model, graph_modified) + + +def move_scalar_add_past_matmul(model): + """Move scalar add operations past matmul operations. We want to have adds + next to each other such that they can be collapsed into a single add.""" + graph = model.graph + node_ind = 0 + graph_modified = False + for n in graph.node: + node_ind += 1 + if n.op_type == "Add": + consumer = model.find_consumer(n.output[0]) + if consumer is not None and consumer.op_type == "MatMul": + add_weight_name = n.input[1] + matmul_weight_name = consumer.input[1] + A = model.get_initializer(add_weight_name) + W = model.get_initializer(matmul_weight_name) + assert A is not None + assert W is not None + start_name = n.input[0] + middle_name = n.output[0] + end_name = consumer.output[0] + if all(x == 1 for x in A.shape): + # if the add is scalar, we can move it past the matmul + # by taking it past the matmul with a dot product + Anew = np.dot(A * np.ones(W.shape[0], dtype=np.float32), W) + # update the add weight + model.set_initializer(add_weight_name, Anew) + new_matmul = oh.make_node( + "MatMul", [start_name, matmul_weight_name], [middle_name] + ) + new_add = oh.make_node( + "Add", [middle_name, add_weight_name], [end_name] + ) + graph.node.insert(node_ind, new_matmul) + graph.node.insert(node_ind + 1, new_add) + # remove old nodes + graph.node.remove(n) + graph.node.remove(consumer) + graph_modified = True + model = model.transform_single(si.infer_shapes) + return (model, graph_modified) -- GitLab