From f3f823e314db61edd63cb5ad9d8872bfbb82bccd Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Thu, 31 Oct 2019 23:09:53 +0000
Subject: [PATCH] [Test] fix problems in brevitas_cnv_import tests

* use float32 inputs
* use static resizing
* preproc float32 subtraction
---
 tests/test_brevitas_cnv.py | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)

diff --git a/tests/test_brevitas_cnv.py b/tests/test_brevitas_cnv.py
index 49a8db86f..c777d0bd5 100644
--- a/tests/test_brevitas_cnv.py
+++ b/tests/test_brevitas_cnv.py
@@ -1,3 +1,4 @@
+import os
 import pkg_resources as pk
 
 import brevitas.onnx as bo
@@ -105,10 +106,10 @@ class CNV(Module):
         )
 
     def forward(self, x):
-        x = 2.0 * x - 1.0
+        x = 2.0 * x - torch.tensor([1.0])
         for mod in self.conv_features:
             x = mod(x)
-        x = x.view(x.shape[0], -1)
+        x = x.view(1, 256)
         for mod in self.linear_features:
             x = mod(x)
         out = self.fc(x)
@@ -149,7 +150,7 @@ def test_brevitas_cnv_export_exec():
     model = model.transform_single(si.infer_shapes)
     model.save(export_onnx_path)
     fn = pk.resource_filename("finn", "data/cifar10/cifar10-test-data-class3.npz")
-    input_tensor = np.load(fn)["arr_0"]
+    input_tensor = np.load(fn)["arr_0"].astype(np.float32)
     assert input_tensor.shape == (1, 3, 32, 32)
     # run using FINN-based execution
     input_dict = {"0": input_tensor}
@@ -159,3 +160,4 @@ def test_brevitas_cnv_export_exec():
     input_tensor = torch.from_numpy(input_tensor).float()
     expected = cnv.forward(input_tensor).detach().numpy()
     assert np.isclose(produced, expected, atol=1e-3).all()
+    os.remove(export_onnx_path)
-- 
GitLab