From 213611d6abb417afc52bd8ac90b592d65d568be1 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <yamanu@xilinx.com> Date: Mon, 21 Oct 2019 14:35:58 +0100 Subject: [PATCH] [Test] add test_modelwrapper --- tests/test_modelwrapper.py | 102 +++++++++++++++++++++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 tests/test_modelwrapper.py diff --git a/tests/test_modelwrapper.py b/tests/test_modelwrapper.py new file mode 100644 index 000000000..3dfeee89e --- /dev/null +++ b/tests/test_modelwrapper.py @@ -0,0 +1,102 @@ +import os +from collections import Counter +from functools import reduce +from operator import mul + +import brevitas.onnx as bo +import numpy as np +import torch +from models.common import get_act_quant, get_quant_linear, get_quant_type, get_stats_op +from torch.nn import BatchNorm1d, Dropout, Module, ModuleList + +from finn.core.modelwrapper import ModelWrapper + +FC_OUT_FEATURES = [1024, 1024, 1024] +INTERMEDIATE_FC_PER_OUT_CH_SCALING = True +LAST_FC_PER_OUT_CH_SCALING = False +IN_DROPOUT = 0.2 +HIDDEN_DROPOUT = 0.2 + +export_onnx_path = "test_output_lfc.onnx" +# TODO get from config instead, hardcoded to Docker path for now +trained_lfc_checkpoint = ( + "/workspace/brevitas_cnv_lfc/pretrained_models/LFC_1W1A/checkpoints/best.tar" +) + + +class LFC(Module): + def __init__( + self, + num_classes=10, + weight_bit_width=None, + act_bit_width=None, + in_bit_width=None, + in_ch=1, + in_features=(28, 28), + ): + super(LFC, self).__init__() + + weight_quant_type = get_quant_type(weight_bit_width) + act_quant_type = get_quant_type(act_bit_width) + in_quant_type = get_quant_type(in_bit_width) + stats_op = get_stats_op(weight_quant_type) + + self.features = ModuleList() + self.features.append(get_act_quant(in_bit_width, in_quant_type)) + self.features.append(Dropout(p=IN_DROPOUT)) + in_features = reduce(mul, in_features) + for out_features in FC_OUT_FEATURES: + self.features.append( + get_quant_linear( + in_features=in_features, + out_features=out_features, + per_out_ch_scaling=INTERMEDIATE_FC_PER_OUT_CH_SCALING, + bit_width=weight_bit_width, + quant_type=weight_quant_type, + stats_op=stats_op, + ) + ) + in_features = out_features + self.features.append(BatchNorm1d(num_features=in_features)) + self.features.append(get_act_quant(act_bit_width, act_quant_type)) + self.features.append(Dropout(p=HIDDEN_DROPOUT)) + self.fc = get_quant_linear( + in_features=in_features, + out_features=num_classes, + per_out_ch_scaling=LAST_FC_PER_OUT_CH_SCALING, + bit_width=weight_bit_width, + quant_type=weight_quant_type, + stats_op=stats_op, + ) + + def forward(self, x): + x = x.view(1, 784) + # removing the torch.tensor here creates a float64 op for some reason.. + # so explicitly wrapped with torch.tensor to make a float32 one instead + x = 2.0 * x - torch.tensor([1.0]) + for mod in self.features: + x = mod(x) + out = self.fc(x) + return out + + +def test_modelwrapper(): + lfc = LFC(weight_bit_width=1, act_bit_width=1, in_bit_width=1) + checkpoint = torch.load(trained_lfc_checkpoint, map_location="cpu") + lfc.load_state_dict(checkpoint["state_dict"]) + bo.export_finn_onnx(lfc, (1, 1, 28, 28), export_onnx_path) + model = ModelWrapper(export_onnx_path) + inp_shape = model.get_tensor_shape("0") + assert inp_shape == [1, 1, 28, 28] + l0_weights = model.get_initializer("26") + assert l0_weights.shape == (784, 1024) + l0_weights_hist = Counter(l0_weights.flatten()) + assert l0_weights_hist[1.0] == 401311 and l0_weights_hist[-1.0] == 401505 + l0_weights_rand = np.random.randn(784, 1024) + model.set_initializer("26", l0_weights_rand) + assert (model.get_initializer("26") == l0_weights_rand).all() + inp_cons = model.find_consumer("0") + assert inp_cons.op_type == "Flatten" + out_prod = model.find_producer("53") + assert out_prod.op_type == "Mul" + os.remove(export_onnx_path) -- GitLab