From 529771f5a7426cd35c5816d33f2e7e168f6c9bd7 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Thu, 17 Oct 2019 22:56:06 +0100
Subject: [PATCH] [Transform] fix batchnorm to affine problems, add
 get_tensor_shape

---
 src/finn/transformation/general.py | 58 +++++++++++++++++++-----------
 1 file changed, 37 insertions(+), 21 deletions(-)

diff --git a/src/finn/transformation/general.py b/src/finn/transformation/general.py
index 12fd14646..db01b0091 100644
--- a/src/finn/transformation/general.py
+++ b/src/finn/transformation/general.py
@@ -1,6 +1,7 @@
 import copy
 
 import numpy as np
+import onnx.shape_inference as si
 from onnx import TensorProto
 from onnx import helper as oh
 from onnx import numpy_helper as np_helper
@@ -16,6 +17,21 @@ def give_unique_names(model):
     return new_model
 
 
+def get_tensor_shape(model, tensor_name):
+    """Returns the shape of tensor with given name, if it has ValueInfoProto."""
+    graph = model.graph
+    vi_names = [(x.name, x) for x in graph.input]
+    vi_names += [(x.name, x) for x in graph.output]
+    vi_names += [(x.name, x) for x in graph.value_info]
+    try:
+        vi_ind = [x[0] for x in vi_names].index(tensor_name)
+        vi = vi_names[vi_ind][1]
+        dims = [x.dim_value for x in vi.type.tensor_type.shape.dim]
+        return dims
+    except ValueError:
+        return None
+
+
 def set_initializer(model, tensor_name, tensor_value):
     """Set the initializer value for tensor with given name."""
     graph = model.graph
@@ -34,12 +50,9 @@ def set_initializer(model, tensor_name, tensor_value):
     graph.initializer.append(tensor_init_proto)
 
 
-def get_initializer(model, tensor_name, tensor_value):
+def get_initializer(model, tensor_name):
     """Get the initializer value for tensor with given name, if any."""
     graph = model.graph
-    # convert tensor_value (numpy array) into TensorProto w/ correct name
-    tensor_init_proto = np_helper.from_array(tensor_value)
-    tensor_init_proto.name = tensor_name
     init_names = [x.name for x in graph.initializer]
     try:
         init_ind = init_names.index(tensor_name)
@@ -51,7 +64,7 @@ def get_initializer(model, tensor_name, tensor_value):
 def find_producer(model, tensor_name):
     """Find and return the node that produces the tensor with given name.
     Currently only works for linear graphs."""
-    all_outputs = [x.output[0].name for x in model.graph.node]
+    all_outputs = [x.output[0] for x in model.graph.node]
     try:
         producer_ind = all_outputs.index(tensor_name)
         return model.graph.node[producer_ind]
@@ -62,7 +75,7 @@ def find_producer(model, tensor_name):
 def find_consumer(model, tensor_name):
     """Find and return the node that consumes the tensor with given name.
     Currently only works for linear graphs."""
-    all_inputs = [x.input[0].name for x in model.graph.node]
+    all_inputs = [x.input[0] for x in model.graph.node]
     try:
         consumer_ind = all_inputs.index(tensor_name)
         return model.graph.node[consumer_ind]
@@ -85,17 +98,20 @@ def make_new_valueinfo_name(model):
 def replace_batchnorm_with_affine(model):
     """Replaces any test-time BatchNorm layers with Mul-Add layers."""
     new_model = copy.deepcopy(model)
+    new_model = si.infer_shapes(new_model)
     graph = new_model.graph
     nodes_to_remove = []
+    node_ind = 0
     for n in graph.node:
+        node_ind += 1
         if n.op_type == "BatchNormalization":
             bn_input = n.input[0]
             bn_output = n.output[0]
             # extract batchnorm parameters as numpy arrays
-            scale = get_initializer(n.input[1])
-            bias = get_initializer(n.input[2])
-            mean = get_initializer(n.input[3])
-            variance = get_initializer(n.input[4])
+            scale = get_initializer(new_model, n.input[1])
+            bias = get_initializer(new_model, n.input[2])
+            mean = get_initializer(new_model, n.input[3])
+            variance = get_initializer(new_model, n.input[4])
             epsilon = 1e-5
             # find A and B to compute batchnorm as affine transpose Ax+B
             # TODO is a division by moving avg factor needed for variance?
@@ -103,16 +119,17 @@ def replace_batchnorm_with_affine(model):
             B = bias - (A * mean)
             nodes_to_remove += [n]
             # see if we have surrounding Unsqueeze/Squeeze nodes we can remove
-            producer = find_producer(n)
+            producer = find_producer(new_model, bn_input)
             if producer is not None:
                 if producer.op_type == "Unsqueeze":
                     bn_input = producer.input[0]
                     nodes_to_remove += [producer]
-            consumer = find_consumer(n)
+            consumer = find_consumer(new_model, bn_output)
             if consumer is not None:
                 if consumer.op_type == "Squeeze":
                     bn_output = consumer.output[0]
                     nodes_to_remove += [consumer]
+            data_shape = get_tensor_shape(new_model, bn_input)
             # create value_info and initializers for Mul and Add constants
             mul_const = oh.make_tensor_value_info(
                 make_new_valueinfo_name(new_model), TensorProto.FLOAT, A.shape
@@ -120,7 +137,7 @@ def replace_batchnorm_with_affine(model):
             graph.value_info.append(mul_const)
             set_initializer(new_model, mul_const.name, A)
             mul_output = oh.make_tensor_value_info(
-                make_new_valueinfo_name(new_model), TensorProto.FLOAT, A.shape
+                make_new_valueinfo_name(new_model), TensorProto.FLOAT, data_shape
             )
             graph.value_info.append(mul_output)
             add_const = oh.make_tensor_value_info(
@@ -130,17 +147,16 @@ def replace_batchnorm_with_affine(model):
             set_initializer(new_model, add_const.name, B)
             # create Mul and Add nodes to replace the batchnorm
             mul_node = oh.make_node(
-                "Mul", [bn_input.name, mul_const.name], [mul_output.name]
+                "Mul", [bn_input, mul_const.name], [mul_output.name]
             )
             add_node = oh.make_node(
-                "Add", [mul_output.name, add_const.name], [bn_output.name]
+                "Add", [mul_output.name, add_const.name], [bn_output]
             )
-            graph.node.append(mul_node)
-            graph.node.append(add_node)
-
-    # delete marked nodes
+            # insert where the batchnorm is to preserve topological ordering
+            graph.node.insert(node_ind, mul_node)
+            graph.node.insert(node_ind + 1, add_node)
+    # delete marked nodes (batchnorm and (un)squeezing)
     for n in nodes_to_remove:
         graph.node.remove(n)
-    # TODO topologically sort nodes
-    # TODO give new names, maybe run shape inference?
+    new_model = si.infer_shapes(new_model)
     return new_model
-- 
GitLab