From 9a3a2eea4323f068624816e8c4e64c3dae756026 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <maltanar@gmail.com> Date: Sun, 3 Nov 2019 17:58:19 +0000 Subject: [PATCH] [Test] refactor test_batchnorm_to_affine to use models.LFC --- tests/test_batchnorm_to_affine.py | 75 +++---------------------------- 1 file changed, 7 insertions(+), 68 deletions(-) diff --git a/tests/test_batchnorm_to_affine.py b/tests/test_batchnorm_to_affine.py index 5f0e64d3d..6c72aab2b 100644 --- a/tests/test_batchnorm_to_affine.py +++ b/tests/test_batchnorm_to_affine.py @@ -1,6 +1,4 @@ import os -from functools import reduce -from operator import mul from pkgutil import get_data import brevitas.onnx as bo @@ -8,20 +6,14 @@ import numpy as np import onnx import onnx.numpy_helper as nph import torch -from models.common import get_act_quant, get_quant_linear, get_quant_type, get_stats_op -from torch.nn import BatchNorm1d, Dropout, Module, ModuleList +from models.LFC import LFC import finn.core.onnx_exec as oxe import finn.transformation.batchnorm_to_affine as tx +import finn.transformation.fold_constants as fc import finn.transformation.infer_shapes as si from finn.core.modelwrapper import ModelWrapper -FC_OUT_FEATURES = [1024, 1024, 1024] -INTERMEDIATE_FC_PER_OUT_CH_SCALING = True -LAST_FC_PER_OUT_CH_SCALING = False -IN_DROPOUT = 0.2 -HIDDEN_DROPOUT = 0.2 - export_onnx_path = "test_output_lfc.onnx" transformed_onnx_path = "test_output_lfc_transformed.onnx" # TODO get from config instead, hardcoded to Docker path for now @@ -30,62 +22,6 @@ trained_lfc_checkpoint = ( ) -class LFC(Module): - def __init__( - self, - num_classes=10, - weight_bit_width=None, - act_bit_width=None, - in_bit_width=None, - in_ch=1, - in_features=(28, 28), - ): - super(LFC, self).__init__() - - weight_quant_type = get_quant_type(weight_bit_width) - act_quant_type = get_quant_type(act_bit_width) - in_quant_type = get_quant_type(in_bit_width) - stats_op = get_stats_op(weight_quant_type) - - self.features = ModuleList() - self.features.append(get_act_quant(in_bit_width, in_quant_type)) - self.features.append(Dropout(p=IN_DROPOUT)) - in_features = reduce(mul, in_features) - for out_features in FC_OUT_FEATURES: - self.features.append( - get_quant_linear( - in_features=in_features, - out_features=out_features, - per_out_ch_scaling=INTERMEDIATE_FC_PER_OUT_CH_SCALING, - bit_width=weight_bit_width, - quant_type=weight_quant_type, - stats_op=stats_op, - ) - ) - in_features = out_features - self.features.append(BatchNorm1d(num_features=in_features)) - self.features.append(get_act_quant(act_bit_width, act_quant_type)) - self.features.append(Dropout(p=HIDDEN_DROPOUT)) - self.fc = get_quant_linear( - in_features=in_features, - out_features=num_classes, - per_out_ch_scaling=LAST_FC_PER_OUT_CH_SCALING, - bit_width=weight_bit_width, - quant_type=weight_quant_type, - stats_op=stats_op, - ) - - def forward(self, x): - x = x.view(1, 784) - # removing the torch.tensor here creates a float64 op for some reason.. - # so explicitly wrapped with torch.tensor to make a float32 one instead - x = 2.0 * x - torch.tensor([1.0]) - for mod in self.features: - x = mod(x) - out = self.fc(x) - return out - - def test_batchnorm_to_affine(): lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1) checkpoint = torch.load(trained_lfc_checkpoint, map_location="cpu") @@ -93,12 +29,15 @@ def test_batchnorm_to_affine(): bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path) model = ModelWrapper(export_onnx_path) model = model.transform_single(si.infer_shapes) + model = model.transform_repeated(fc.fold_constants) new_model = model.transform_single(tx.batchnorm_to_affine) # load one of the test vectors raw_i = get_data("finn", "data/onnx/mnist-conv/test_data_set_0/input_0.pb") input_tensor = onnx.load_tensor_from_string(raw_i) + out_old = model.graph.output[0].name + out_new = new_model.graph.output[0].name input_dict = {"0": nph.to_array(input_tensor)} - output_original = oxe.execute_onnx(model, input_dict)["53"] - output_transformed = oxe.execute_onnx(new_model, input_dict)["53"] + output_original = oxe.execute_onnx(model, input_dict)[out_old] + output_transformed = oxe.execute_onnx(new_model, input_dict)[out_new] assert np.isclose(output_transformed, output_original, atol=1e-3).all() os.remove(export_onnx_path) -- GitLab