From 9a3a2eea4323f068624816e8c4e64c3dae756026 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Sun, 3 Nov 2019 17:58:19 +0000
Subject: [PATCH] [Test] refactor test_batchnorm_to_affine to use models.LFC

---
 tests/test_batchnorm_to_affine.py | 75 +++----------------------------
 1 file changed, 7 insertions(+), 68 deletions(-)

diff --git a/tests/test_batchnorm_to_affine.py b/tests/test_batchnorm_to_affine.py
index 5f0e64d3d..6c72aab2b 100644
--- a/tests/test_batchnorm_to_affine.py
+++ b/tests/test_batchnorm_to_affine.py
@@ -1,6 +1,4 @@
 import os
-from functools import reduce
-from operator import mul
 from pkgutil import get_data
 
 import brevitas.onnx as bo
@@ -8,20 +6,14 @@ import numpy as np
 import onnx
 import onnx.numpy_helper as nph
 import torch
-from models.common import get_act_quant, get_quant_linear, get_quant_type, get_stats_op
-from torch.nn import BatchNorm1d, Dropout, Module, ModuleList
+from models.LFC import LFC
 
 import finn.core.onnx_exec as oxe
 import finn.transformation.batchnorm_to_affine as tx
+import finn.transformation.fold_constants as fc
 import finn.transformation.infer_shapes as si
 from finn.core.modelwrapper import ModelWrapper
 
-FC_OUT_FEATURES = [1024, 1024, 1024]
-INTERMEDIATE_FC_PER_OUT_CH_SCALING = True
-LAST_FC_PER_OUT_CH_SCALING = False
-IN_DROPOUT = 0.2
-HIDDEN_DROPOUT = 0.2
-
 export_onnx_path = "test_output_lfc.onnx"
 transformed_onnx_path = "test_output_lfc_transformed.onnx"
 # TODO get from config instead, hardcoded to Docker path for now
@@ -30,62 +22,6 @@ trained_lfc_checkpoint = (
 )
 
 
-class LFC(Module):
-    def __init__(
-        self,
-        num_classes=10,
-        weight_bit_width=None,
-        act_bit_width=None,
-        in_bit_width=None,
-        in_ch=1,
-        in_features=(28, 28),
-    ):
-        super(LFC, self).__init__()
-
-        weight_quant_type = get_quant_type(weight_bit_width)
-        act_quant_type = get_quant_type(act_bit_width)
-        in_quant_type = get_quant_type(in_bit_width)
-        stats_op = get_stats_op(weight_quant_type)
-
-        self.features = ModuleList()
-        self.features.append(get_act_quant(in_bit_width, in_quant_type))
-        self.features.append(Dropout(p=IN_DROPOUT))
-        in_features = reduce(mul, in_features)
-        for out_features in FC_OUT_FEATURES:
-            self.features.append(
-                get_quant_linear(
-                    in_features=in_features,
-                    out_features=out_features,
-                    per_out_ch_scaling=INTERMEDIATE_FC_PER_OUT_CH_SCALING,
-                    bit_width=weight_bit_width,
-                    quant_type=weight_quant_type,
-                    stats_op=stats_op,
-                )
-            )
-            in_features = out_features
-            self.features.append(BatchNorm1d(num_features=in_features))
-            self.features.append(get_act_quant(act_bit_width, act_quant_type))
-            self.features.append(Dropout(p=HIDDEN_DROPOUT))
-        self.fc = get_quant_linear(
-            in_features=in_features,
-            out_features=num_classes,
-            per_out_ch_scaling=LAST_FC_PER_OUT_CH_SCALING,
-            bit_width=weight_bit_width,
-            quant_type=weight_quant_type,
-            stats_op=stats_op,
-        )
-
-    def forward(self, x):
-        x = x.view(1, 784)
-        # removing the torch.tensor here creates a float64 op for some reason..
-        # so explicitly wrapped with torch.tensor to make a float32 one instead
-        x = 2.0 * x - torch.tensor([1.0])
-        for mod in self.features:
-            x = mod(x)
-        out = self.fc(x)
-        return out
-
-
 def test_batchnorm_to_affine():
     lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1)
     checkpoint = torch.load(trained_lfc_checkpoint, map_location="cpu")
@@ -93,12 +29,15 @@ def test_batchnorm_to_affine():
     bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path)
     model = ModelWrapper(export_onnx_path)
     model = model.transform_single(si.infer_shapes)
+    model = model.transform_repeated(fc.fold_constants)
     new_model = model.transform_single(tx.batchnorm_to_affine)
     # load one of the test vectors
     raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb")
     input_tensor = onnx.load_tensor_from_string(raw_i)
+    out_old = model.graph.output[0].name
+    out_new = new_model.graph.output[0].name
     input_dict = {"0": nph.to_array(input_tensor)}
-    output_original = oxe.execute_onnx(model, input_dict)["53"]
-    output_transformed = oxe.execute_onnx(new_model, input_dict)["53"]
+    output_original = oxe.execute_onnx(model, input_dict)[out_old]
+    output_transformed = oxe.execute_onnx(new_model, input_dict)[out_new]
     assert np.isclose(output_transformed, output_original, atol=1e-3).all()
     os.remove(export_onnx_path)
-- 
GitLab