From cc0ccaf3164c9bf2335bfce1d9ac66e52b2f105e Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <yamanu@xilinx.com>
Date: Thu, 17 Oct 2019 18:09:57 +0100
Subject: [PATCH] [Transform] add a first unfinished version of batchnorm to
 affine

---
 src/finn/transformation/general.py | 67 ++++++++++++++++++++++++++++++
 1 file changed, 67 insertions(+)

diff --git a/src/finn/transformation/general.py b/src/finn/transformation/general.py
index 412fddaed..12fd14646 100644
--- a/src/finn/transformation/general.py
+++ b/src/finn/transformation/general.py
@@ -1,5 +1,8 @@
 import copy
 
+import numpy as np
+from onnx import TensorProto
+from onnx import helper as oh
 from onnx import numpy_helper as np_helper
 
 
@@ -77,3 +80,67 @@ def make_new_valueinfo_name(model):
     while candidate in names:
         candidate = str(int(candidate) + 1)
     return candidate
+
+
+def replace_batchnorm_with_affine(model):
+    """Replaces any test-time BatchNorm layers with Mul-Add layers."""
+    new_model = copy.deepcopy(model)
+    graph = new_model.graph
+    nodes_to_remove = []
+    for n in graph.node:
+        if n.op_type == "BatchNormalization":
+            bn_input = n.input[0]
+            bn_output = n.output[0]
+            # extract batchnorm parameters as numpy arrays
+            scale = get_initializer(n.input[1])
+            bias = get_initializer(n.input[2])
+            mean = get_initializer(n.input[3])
+            variance = get_initializer(n.input[4])
+            epsilon = 1e-5
+            # find A and B to compute batchnorm as affine transpose Ax+B
+            # TODO is a division by moving avg factor needed for variance?
+            A = scale / np.sqrt(epsilon + variance)
+            B = bias - (A * mean)
+            nodes_to_remove += [n]
+            # see if we have surrounding Unsqueeze/Squeeze nodes we can remove
+            producer = find_producer(n)
+            if producer is not None:
+                if producer.op_type == "Unsqueeze":
+                    bn_input = producer.input[0]
+                    nodes_to_remove += [producer]
+            consumer = find_consumer(n)
+            if consumer is not None:
+                if consumer.op_type == "Squeeze":
+                    bn_output = consumer.output[0]
+                    nodes_to_remove += [consumer]
+            # create value_info and initializers for Mul and Add constants
+            mul_const = oh.make_tensor_value_info(
+                make_new_valueinfo_name(new_model), TensorProto.FLOAT, A.shape
+            )
+            graph.value_info.append(mul_const)
+            set_initializer(new_model, mul_const.name, A)
+            mul_output = oh.make_tensor_value_info(
+                make_new_valueinfo_name(new_model), TensorProto.FLOAT, A.shape
+            )
+            graph.value_info.append(mul_output)
+            add_const = oh.make_tensor_value_info(
+                make_new_valueinfo_name(new_model), TensorProto.FLOAT, B.shape
+            )
+            graph.value_info.append(add_const)
+            set_initializer(new_model, add_const.name, B)
+            # create Mul and Add nodes to replace the batchnorm
+            mul_node = oh.make_node(
+                "Mul", [bn_input.name, mul_const.name], [mul_output.name]
+            )
+            add_node = oh.make_node(
+                "Add", [mul_output.name, add_const.name], [bn_output.name]
+            )
+            graph.node.append(mul_node)
+            graph.node.append(add_node)
+
+    # delete marked nodes
+    for n in nodes_to_remove:
+        graph.node.remove(n)
+    # TODO topologically sort nodes
+    # TODO give new names, maybe run shape inference?
+    return new_model
-- 
GitLab