diff --git a/tests/test_brevitas_cnv.py b/tests/test_brevitas_cnv.py index 49a8db86f48384a730c879e9c308f46c7a9f8016..c777d0bd57d6d152e900dd6977979020c840926f 100644 --- a/tests/test_brevitas_cnv.py +++ b/tests/test_brevitas_cnv.py @@ -1,3 +1,4 @@ +import os import pkg_resources as pk import brevitas.onnx as bo @@ -105,10 +106,10 @@ class CNV(Module): ) def forward(self, x): - x = 2.0 * x - 1.0 + x = 2.0 * x - torch.tensor([1.0]) for mod in self.conv_features: x = mod(x) - x = x.view(x.shape[0], -1) + x = x.view(1, 256) for mod in self.linear_features: x = mod(x) out = self.fc(x) @@ -149,7 +150,7 @@ def test_brevitas_cnv_export_exec(): model = model.transform_single(si.infer_shapes) model.save(export_onnx_path) fn = pk.resource_filename("finn", "data/cifar10/cifar10-test-data-class3.npz") - input_tensor = np.load(fn)["arr_0"] + input_tensor = np.load(fn)["arr_0"].astype(np.float32) assert input_tensor.shape == (1, 3, 32, 32) # run using FINN-based execution input_dict = {"0": input_tensor} @@ -159,3 +160,4 @@ def test_brevitas_cnv_export_exec(): input_tensor = torch.from_numpy(input_tensor).float() expected = cnv.forward(input_tensor).detach().numpy() assert np.isclose(produced, expected, atol=1e-3).all() + os.remove(export_onnx_path)