From 3319c81edeade8aa8fafc5bb3246d63633b4a384 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <yamanu@xilinx.com> Date: Wed, 23 Oct 2019 15:00:16 +0100 Subject: [PATCH] [Transform] complete move_add_past_mul --- src/finn/transformation/streamline.py | 38 +++++++++++++++++++++------ 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/src/finn/transformation/streamline.py b/src/finn/transformation/streamline.py index 3d025c97c..39684af25 100644 --- a/src/finn/transformation/streamline.py +++ b/src/finn/transformation/streamline.py @@ -1,3 +1,5 @@ +from onnx import helper as oh + import finn.transformation.infer_shapes as si @@ -12,16 +14,36 @@ def move_add_past_mul(model): node_ind += 1 if n.op_type == "Add": consumer = model.find_consumer(n.output[0]) - if consumer.op_type == "Mul": + if consumer is not None and consumer.op_type == "Mul": + # have: (x) -> add(,B) -> (x+B) -> mul(,A) -> (xA+BA) + # want: (x) -> mul(,A) -> (xA) -> add(,BA) -> (xA+BA) # assume input 0 is from the previous layer, input 1 is the # trained (constant) parameter - mul_param = model.get_initializer(consumer.input[1]) - add_param = model.get_initializer(n.input[1]) - assert mul_param is not None - assert add_param is not None - # TODO compute new param values - # TODO make new nodes - # TODO mark nodes for removal + mul_weight_name = consumer.input[1] + add_weight_name = n.input[1] + A = model.get_initializer(mul_weight_name) + B = model.get_initializer(add_weight_name) + assert A is not None + assert B is not None + start_name = n.input[0] + middle_name = n.output[0] + end_name = consumer.output[0] + # compute new param value for add + BA = B * A + # make and insert new nodes + new_mul = oh.make_node( + "Mul", [start_name, mul_weight_name], [middle_name] + ) + new_add = oh.make_node( + "Add", [middle_name, add_weight_name], [end_name] + ) + graph.node.insert(node_ind, new_mul) + graph.node.insert(node_ind + 1, new_add) + # replace add value + model.set_initializer(add_weight_name, BA) + # mark old nodes for removal + nodes_to_remove += [n, consumer] + graph_modified = True # delete marked nodes (batchnorm and (un)squeezing) for n in nodes_to_remove: graph.node.remove(n) -- GitLab