Skip to content
Snippets Groups Projects
Commit f3f823e3 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Test] fix problems in brevitas_cnv_import tests

* use float32 inputs
* use static resizing
* preproc float32 subtraction
parent d1184c4b
No related branches found
No related tags found
No related merge requests found
import os
import pkg_resources as pk
import brevitas.onnx as bo
......@@ -105,10 +106,10 @@ class CNV(Module):
)
def forward(self, x):
x = 2.0 * x - 1.0
x = 2.0 * x - torch.tensor([1.0])
for mod in self.conv_features:
x = mod(x)
x = x.view(x.shape[0], -1)
x = x.view(1, 256)
for mod in self.linear_features:
x = mod(x)
out = self.fc(x)
......@@ -149,7 +150,7 @@ def test_brevitas_cnv_export_exec():
model = model.transform_single(si.infer_shapes)
model.save(export_onnx_path)
fn = pk.resource_filename("finn", "data/cifar10/cifar10-test-data-class3.npz")
input_tensor = np.load(fn)["arr_0"]
input_tensor = np.load(fn)["arr_0"].astype(np.float32)
assert input_tensor.shape == (1, 3, 32, 32)
# run using FINN-based execution
input_dict = {"0": input_tensor}
......@@ -159,3 +160,4 @@ def test_brevitas_cnv_export_exec():
input_tensor = torch.from_numpy(input_tensor).float()
expected = cnv.forward(input_tensor).detach().numpy()
assert np.isclose(produced, expected, atol=1e-3).all()
os.remove(export_onnx_path)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment