diff --git a/tests/test_brevitas_cnv.py b/tests/test_brevitas_cnv.py
index 49a8db86f48384a730c879e9c308f46c7a9f8016..c777d0bd57d6d152e900dd6977979020c840926f 100644
--- a/tests/test_brevitas_cnv.py
+++ b/tests/test_brevitas_cnv.py
@@ -1,3 +1,4 @@
+import os
 import pkg_resources as pk
 
 import brevitas.onnx as bo
@@ -105,10 +106,10 @@ class CNV(Module):
         )
 
     def forward(self, x):
-        x = 2.0 * x - 1.0
+        x = 2.0 * x - torch.tensor([1.0])
         for mod in self.conv_features:
             x = mod(x)
-        x = x.view(x.shape[0], -1)
+        x = x.view(1, 256)
         for mod in self.linear_features:
             x = mod(x)
         out = self.fc(x)
@@ -149,7 +150,7 @@ def test_brevitas_cnv_export_exec():
     model = model.transform_single(si.infer_shapes)
     model.save(export_onnx_path)
     fn = pk.resource_filename("finn", "data/cifar10/cifar10-test-data-class3.npz")
-    input_tensor = np.load(fn)["arr_0"]
+    input_tensor = np.load(fn)["arr_0"].astype(np.float32)
     assert input_tensor.shape == (1, 3, 32, 32)
     # run using FINN-based execution
     input_dict = {"0": input_tensor}
@@ -159,3 +160,4 @@ def test_brevitas_cnv_export_exec():
     input_tensor = torch.from_numpy(input_tensor).float()
     expected = cnv.forward(input_tensor).detach().numpy()
     assert np.isclose(produced, expected, atol=1e-3).all()
+    os.remove(export_onnx_path)