diff --git a/tests/test_brevitas_cnv.py b/tests/test_brevitas_cnv.py index a0904e37dd9e9dc00d8e0bc52193ae5faff00c76..631e8073f437052958d6c8aa22126bda88468c49 100644 --- a/tests/test_brevitas_cnv.py +++ b/tests/test_brevitas_cnv.py @@ -4,40 +4,13 @@ import pkg_resources as pk import brevitas.onnx as bo import numpy as np import torch -from models.common import ( - get_act_quant, - get_quant_conv2d, - get_quant_linear, - get_quant_type, - get_stats_op -) -from torch.nn import BatchNorm1d, BatchNorm2d, MaxPool2d, Module, ModuleList, Sequential +from models.CNV import CNV import finn.core.onnx_exec as oxe from finn.core.modelwrapper import ModelWrapper +from finn.transformation.fold_constants import FoldConstants from finn.transformation.infer_shapes import InferShapes -# QuantConv2d configuration -CNV_OUT_CH_POOL = [ - (0, 64, False), - (1, 64, True), - (2, 128, False), - (3, 128, True), - (4, 256, False), - (5, 256, False), -] - -# Intermediate QuantLinear configuration -INTERMEDIATE_FC_PER_OUT_CH_SCALING = True -INTERMEDIATE_FC_FEATURES = [(256, 512), (512, 512)] - -# Last QuantLinear configuration -LAST_FC_IN_FEATURES = 512 -LAST_FC_PER_OUT_CH_SCALING = False - -# MaxPool2d configuration -POOL_SIZE = 2 - export_onnx_path = "test_output_cnv.onnx" # TODO get from config instead, hardcoded to Docker path for now trained_cnv_checkpoint = ( @@ -45,92 +18,7 @@ trained_cnv_checkpoint = ( ) -class CNV(Module): - def __init__( - self, - num_classes=10, - weight_bit_width=None, - act_bit_width=None, - in_bit_width=None, - in_ch=3, - ): - super(CNV, self).__init__() - - weight_quant_type = get_quant_type(weight_bit_width) - act_quant_type = get_quant_type(act_bit_width) - in_quant_type = get_quant_type(in_bit_width) - stats_op = get_stats_op(weight_quant_type) - - self.conv_features = ModuleList() - self.linear_features = ModuleList() - self.conv_features.append(get_act_quant(in_bit_width, in_quant_type)) - - for i, out_ch, is_pool_enabled in CNV_OUT_CH_POOL: - self.conv_features.append( - get_quant_conv2d( - in_ch=in_ch, - out_ch=out_ch, - bit_width=weight_bit_width, - quant_type=weight_quant_type, - stats_op=stats_op, - ) - ) - in_ch = out_ch - if is_pool_enabled: - self.conv_features.append(MaxPool2d(kernel_size=2)) - if i == 5: - self.conv_features.append(Sequential()) - self.conv_features.append(BatchNorm2d(in_ch)) - self.conv_features.append(get_act_quant(act_bit_width, act_quant_type)) - - for in_features, out_features in INTERMEDIATE_FC_FEATURES: - self.linear_features.append( - get_quant_linear( - in_features=in_features, - out_features=out_features, - per_out_ch_scaling=INTERMEDIATE_FC_PER_OUT_CH_SCALING, - bit_width=weight_bit_width, - quant_type=weight_quant_type, - stats_op=stats_op, - ) - ) - self.linear_features.append(BatchNorm1d(out_features)) - self.linear_features.append(get_act_quant(act_bit_width, act_quant_type)) - self.fc = get_quant_linear( - in_features=LAST_FC_IN_FEATURES, - out_features=num_classes, - per_out_ch_scaling=LAST_FC_PER_OUT_CH_SCALING, - bit_width=weight_bit_width, - quant_type=weight_quant_type, - stats_op=stats_op, - ) - - def forward(self, x): - x = 2.0 * x - torch.tensor([1.0]) - for mod in self.conv_features: - x = mod(x) - x = x.view(1, 256) - for mod in self.linear_features: - x = mod(x) - out = self.fc(x) - return out - - -def test_brevitas_trained_cnv_pytorch(): - # load pretrained weights into CNV-w1a1 - cnv = CNV(weight_bit_width=1, act_bit_width=1, in_bit_width=1, in_ch=3).eval() - checkpoint = torch.load(trained_cnv_checkpoint, map_location="cpu") - cnv.load_state_dict(checkpoint["state_dict"]) - fn = pk.resource_filename("finn", "data/cifar10/cifar10-test-data-class3.npz") - input_tensor = np.load(fn)["arr_0"] - input_tensor = torch.from_numpy(input_tensor).float() - assert input_tensor.shape == (1, 3, 32, 32) - # do forward pass in PyTorch/Brevitas - cnv.forward(input_tensor).detach().numpy() - # TODO verify produced answer - - -def test_brevitas_cnv_export(): +def test_brevitas_cnv_w1a1_export(): cnv = CNV(weight_bit_width=1, act_bit_width=1, in_bit_width=1, in_ch=3).eval() bo.export_finn_onnx(cnv, (1, 3, 32, 32), export_onnx_path) model = ModelWrapper(export_onnx_path) @@ -139,15 +27,17 @@ def test_brevitas_cnv_export(): conv0_wname = model.graph.node[3].input[1] assert list(model.get_initializer(conv0_wname).shape) == [64, 3, 3, 3] assert model.graph.node[4].op_type == "Mul" + os.remove(export_onnx_path) -def test_brevitas_cnv_export_exec(): +def test_brevitas_cnv_w1a1_export_exec(): cnv = CNV(weight_bit_width=1, act_bit_width=1, in_bit_width=1, in_ch=3).eval() checkpoint = torch.load(trained_cnv_checkpoint, map_location="cpu") cnv.load_state_dict(checkpoint["state_dict"]) bo.export_finn_onnx(cnv, (1, 3, 32, 32), export_onnx_path) model = ModelWrapper(export_onnx_path) model = model.transform(InferShapes()) + model = model.transform(FoldConstants()) model.save(export_onnx_path) fn = pk.resource_filename("finn", "data/cifar10/cifar10-test-data-class3.npz") input_tensor = np.load(fn)["arr_0"].astype(np.float32) @@ -161,3 +51,17 @@ def test_brevitas_cnv_export_exec(): expected = cnv.forward(input_tensor).detach().numpy() assert np.isclose(produced, expected, atol=1e-3).all() os.remove(export_onnx_path) + + +def test_brevitas_trained_cnv_w1a1_pytorch(): + # load pretrained weights into CNV-w1a1 + cnv = CNV(weight_bit_width=1, act_bit_width=1, in_bit_width=1, in_ch=3).eval() + checkpoint = torch.load(trained_cnv_checkpoint, map_location="cpu") + cnv.load_state_dict(checkpoint["state_dict"]) + fn = pk.resource_filename("finn", "data/cifar10/cifar10-test-data-class3.npz") + input_tensor = np.load(fn)["arr_0"] + input_tensor = torch.from_numpy(input_tensor).float() + assert input_tensor.shape == (1, 3, 32, 32) + # do forward pass in PyTorch/Brevitas + cnv.forward(input_tensor).detach().numpy() + # TODO verify produced answer