From b2eec8207d899dd95ab759412007a7798dcd1d96 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <yamanu@xilinx.com>
Date: Tue, 22 Oct 2019 14:03:41 +0100
Subject: [PATCH] [Transform] use new interface for batchnorm_to_affine

---
 .../transformation/batchnorm_to_affine.py     | 38 +++++++++----------
 tests/test_batchnorm_to_affine.py             | 11 +++---
 2 files changed, 24 insertions(+), 25 deletions(-)

diff --git a/src/finn/transformation/batchnorm_to_affine.py b/src/finn/transformation/batchnorm_to_affine.py
index dfdad3ba2..4f3984cd7 100644
--- a/src/finn/transformation/batchnorm_to_affine.py
+++ b/src/finn/transformation/batchnorm_to_affine.py
@@ -1,29 +1,26 @@
-import copy
-
 import numpy as np
 import onnx.shape_inference as si
 from onnx import TensorProto
 from onnx import helper as oh
 
-import finn.transformation.general as tg
-
 
 def batchnorm_to_affine(model):
     """Replaces any test-time BatchNorm layers with Mul-Add layers."""
-    new_model = copy.deepcopy(model)
-    graph = new_model.graph
+    graph = model.graph
     nodes_to_remove = []
     node_ind = 0
+    graph_modified = False
     for n in graph.node:
         node_ind += 1
         if n.op_type == "BatchNormalization":
+            graph_modified = True
             bn_input = n.input[0]
             bn_output = n.output[0]
             # extract batchnorm parameters as numpy arrays
-            scale = tg.get_initializer(new_model, n.input[1])
-            bias = tg.get_initializer(new_model, n.input[2])
-            mean = tg.get_initializer(new_model, n.input[3])
-            variance = tg.get_initializer(new_model, n.input[4])
+            scale = model.get_initializer(n.input[1])
+            bias = model.get_initializer(n.input[2])
+            mean = model.get_initializer(n.input[3])
+            variance = model.get_initializer(n.input[4])
             epsilon = 1e-5
             # find A and B to compute batchnorm as affine transpose Ax+B
             # TODO is a division by moving avg factor needed for variance?
@@ -31,32 +28,32 @@ def batchnorm_to_affine(model):
             B = bias - (A * mean)
             nodes_to_remove += [n]
             # see if we have surrounding Unsqueeze/Squeeze nodes we can remove
-            producer = tg.find_producer(new_model, bn_input)
+            producer = model.find_producer(bn_input)
             if producer is not None:
                 if producer.op_type == "Unsqueeze":
                     bn_input = producer.input[0]
                     nodes_to_remove += [producer]
-            consumer = tg.find_consumer(new_model, bn_output)
+            consumer = model.find_consumer(bn_output)
             if consumer is not None:
                 if consumer.op_type == "Squeeze":
                     bn_output = consumer.output[0]
                     nodes_to_remove += [consumer]
-            data_shape = tg.get_tensor_shape(new_model, bn_input)
+            data_shape = model.get_tensor_shape(bn_input)
             # create value_info and initializers for Mul and Add constants
             mul_const = oh.make_tensor_value_info(
-                tg.make_new_valueinfo_name(new_model), TensorProto.FLOAT, A.shape
+                model.make_new_valueinfo_name(), TensorProto.FLOAT, A.shape
             )
             graph.value_info.append(mul_const)
-            tg.set_initializer(new_model, mul_const.name, A)
+            model.set_initializer(mul_const.name, A)
             mul_output = oh.make_tensor_value_info(
-                tg.make_new_valueinfo_name(new_model), TensorProto.FLOAT, data_shape
+                model.make_new_valueinfo_name(), TensorProto.FLOAT, data_shape
             )
             graph.value_info.append(mul_output)
             add_const = oh.make_tensor_value_info(
-                tg.make_new_valueinfo_name(new_model), TensorProto.FLOAT, B.shape
+                model.make_new_valueinfo_name(), TensorProto.FLOAT, B.shape
             )
             graph.value_info.append(add_const)
-            tg.set_initializer(new_model, add_const.name, B)
+            model.set_initializer(add_const.name, B)
             # create Mul and Add nodes to replace the batchnorm
             mul_node = oh.make_node(
                 "Mul", [bn_input, mul_const.name], [mul_output.name]
@@ -70,5 +67,6 @@ def batchnorm_to_affine(model):
     # delete marked nodes (batchnorm and (un)squeezing)
     for n in nodes_to_remove:
         graph.node.remove(n)
-    new_model = si.infer_shapes(new_model)
-    return new_model
+        graph_modified = True
+    model.model = si.infer_shapes(model.model)
+    return (model, graph_modified)
diff --git a/tests/test_batchnorm_to_affine.py b/tests/test_batchnorm_to_affine.py
index 34abb9a19..7c7efc494 100644
--- a/tests/test_batchnorm_to_affine.py
+++ b/tests/test_batchnorm_to_affine.py
@@ -15,6 +15,7 @@ from torch.nn import BatchNorm1d, Dropout, Module, ModuleList
 
 import finn.core.onnx_exec as oxe
 import finn.transformation.batchnorm_to_affine as tx
+from finn.core.modelwrapper import ModelWrapper
 
 FC_OUT_FEATURES = [1024, 1024, 1024]
 INTERMEDIATE_FC_PER_OUT_CH_SCALING = True
@@ -94,9 +95,9 @@ def test_batchnorm_to_affine():
     checkpoint = torch.load(trained_lfc_checkpoint, map_location="cpu")
     lfc.load_state_dict(checkpoint["state_dict"])
     bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
-    model = onnx.load(export_onnx_path)
-    model = si.infer_shapes(model)
-    new_model = tx.batchnorm_to_affine(model)
+    model = ModelWrapper(export_onnx_path)
+    model.model = si.infer_shapes(model.model)
+    new_model = model.transform_single(tx.batchnorm_to_affine)
     try:
         os.remove("/tmp/" + mnist_onnx_filename)
     except OSError:
@@ -108,8 +109,8 @@ def test_batchnorm_to_affine():
     with open(mnist_onnx_local_dir + "/mnist/test_data_set_0/input_0.pb", "rb") as f:
         input_tensor.ParseFromString(f.read())
     input_dict = {"0": nph.to_array(input_tensor)}
-    output_original = oxe.execute_onnx(model, input_dict)["53"]
-    output_transformed = oxe.execute_onnx(new_model, input_dict)["53"]
+    output_original = oxe.execute_onnx(model.model, input_dict)["53"]
+    output_transformed = oxe.execute_onnx(new_model.model, input_dict)["53"]
     assert np.isclose(output_transformed, output_original, atol=1e-3).all()
     # remove the downloaded model and extracted files
     os.remove(dl_ret)
-- 
GitLab