From f3f823e314db61edd63cb5ad9d8872bfbb82bccd Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <maltanar@gmail.com> Date: Thu, 31 Oct 2019 23:09:53 +0000 Subject: [PATCH] [Test] fix problems in brevitas_cnv_import tests * use float32 inputs * use static resizing * preproc float32 subtraction --- tests/test_brevitas_cnv.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/test_brevitas_cnv.py b/tests/test_brevitas_cnv.py index 49a8db86f..c777d0bd5 100644 --- a/tests/test_brevitas_cnv.py +++ b/tests/test_brevitas_cnv.py @@ -1,3 +1,4 @@ +import os import pkg_resources as pk import brevitas.onnx as bo @@ -105,10 +106,10 @@ class CNV(Module): ) def forward(self, x): - x = 2.0 * x - 1.0 + x = 2.0 * x - torch.tensor([1.0]) for mod in self.conv_features: x = mod(x) - x = x.view(x.shape[0], -1) + x = x.view(1, 256) for mod in self.linear_features: x = mod(x) out = self.fc(x) @@ -149,7 +150,7 @@ def test_brevitas_cnv_export_exec(): model = model.transform_single(si.infer_shapes) model.save(export_onnx_path) fn = pk.resource_filename("finn", "data/cifar10/cifar10-test-data-class3.npz") - input_tensor = np.load(fn)["arr_0"] + input_tensor = np.load(fn)["arr_0"].astype(np.float32) assert input_tensor.shape == (1, 3, 32, 32) # run using FINN-based execution input_dict = {"0": input_tensor} @@ -159,3 +160,4 @@ def test_brevitas_cnv_export_exec(): input_tensor = torch.from_numpy(input_tensor).float() expected = cnv.forward(input_tensor).detach().numpy() assert np.isclose(produced, expected, atol=1e-3).all() + os.remove(export_onnx_path) -- GitLab