From 5b0b186d19cc52802f98f473461a7290ffe54a63 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <maltanar@gmail.com> Date: Wed, 18 Mar 2020 00:59:45 +0000 Subject: [PATCH] [Test] simplify and generalize Brevitas CNV export test --- tests/brevitas/test_brevitas_cnv.py | 58 ++++++----------------------- 1 file changed, 11 insertions(+), 47 deletions(-) diff --git a/tests/brevitas/test_brevitas_cnv.py b/tests/brevitas/test_brevitas_cnv.py index 168906c76..8d21a33f7 100644 --- a/tests/brevitas/test_brevitas_cnv.py +++ b/tests/brevitas/test_brevitas_cnv.py @@ -28,6 +28,7 @@ import os import pkg_resources as pk +import pytest import brevitas.onnx as bo import numpy as np @@ -39,25 +40,17 @@ from finn.transformation.fold_constants import FoldConstants from finn.transformation.infer_shapes import InferShapes from finn.transformation.general import GiveUniqueNodeNames from finn.transformation.double_to_single_float import DoubleToSingleFloat -from finn.util.test import get_test_model_trained, get_test_model_untrained +from finn.util.test import get_test_model_trained export_onnx_path = "test_output_cnv.onnx" -def test_brevitas_cnv_w1a1_export(): - cnv = get_test_model_untrained("CNV", 1, 1) - bo.export_finn_onnx(cnv, (1, 3, 32, 32), export_onnx_path) - model = ModelWrapper(export_onnx_path) - assert model.graph.node[2].op_type == "Sign" - assert model.graph.node[3].op_type == "Conv" - conv0_wname = model.graph.node[3].input[1] - assert list(model.get_initializer(conv0_wname).shape) == [64, 3, 3, 3] - assert model.graph.node[4].op_type == "Mul" - os.remove(export_onnx_path) - - -def test_brevitas_cnv_w1a1_export_exec(): - cnv = get_test_model_trained("CNV", 1, 1) +@pytest.mark.parametrize("abits", [1, 2]) +@pytest.mark.parametrize("wbits", [1, 2]) +def test_brevitas_cnv_export_exec(wbits, abits): + if wbits > abits: + pytest.skip("No wbits > abits cases at the moment") + cnv = get_test_model_trained("CNV", wbits, abits) bo.export_finn_onnx(cnv, (1, 3, 32, 32), export_onnx_path) model = ModelWrapper(export_onnx_path) model = model.transform(GiveUniqueNodeNames()) @@ -68,40 +61,11 @@ def test_brevitas_cnv_w1a1_export_exec(): input_tensor = np.load(fn)["arr_0"].astype(np.float32) assert input_tensor.shape == (1, 3, 32, 32) # run using FINN-based execution - input_dict = {"0": input_tensor} - output_dict = oxe.execute_onnx(model, input_dict) - produced = output_dict[list(output_dict.keys())[0]] + input_dict = {model.graph.input[0].name: input_tensor} + output_dict = oxe.execute_onnx(model, input_dict, True) + produced = output_dict[model.graph.output[0].name] # do forward pass in PyTorch/Brevitas input_tensor = torch.from_numpy(input_tensor).float() expected = cnv.forward(input_tensor).detach().numpy() assert np.isclose(produced, expected, atol=1e-3).all() os.remove(export_onnx_path) - - -def test_brevitas_cnv_w1a1_pytorch(): - # load pretrained weights into CNV-w1a1 - cnv = get_test_model_trained("CNV", 1, 1) - fn = pk.resource_filename("finn", "data/cifar10/cifar10-test-data-class3.npz") - input_tensor = np.load(fn)["arr_0"] - input_tensor = torch.from_numpy(input_tensor).float() - assert input_tensor.shape == (1, 3, 32, 32) - # do forward pass in PyTorch/Brevitas - produced = cnv.forward(input_tensor).detach().numpy() - expected = np.asarray( - [ - [ - 3.7939777, - -2.3108773, - 0.06898145, - 0.55185133, - 0.37939775, - -1.9659703, - -0.3104164, - -2.828238, - 2.6902752, - 0.48286998, - ] - ], - dtype=np.float32, - ) - assert np.isclose(produced, expected, atol=1e-3).all() -- GitLab