Skip to content
Snippets Groups Projects
Commit 4633b63d authored by Hendrik Borras's avatar Hendrik Borras
Browse files

Store test results in temporary folder.

parent d6461301
No related branches found
No related tags found
No related merge requests found
......@@ -29,6 +29,7 @@
import pytest
import numpy as np
import tempfile
import torch
from brevitas.export import FINNManager
from torch import nn
......@@ -124,6 +125,8 @@ class PyTorchTestModel(nn.Module):
@pytest.mark.vivado
@pytest.mark.slow
def test_fpgadataflow_upsampler(dt, IFMDim, OFMDim, NumChannels, exec_mode):
tmpdir = tempfile.mkdtemp()
atol = 1e-3
# Create the test model and inputs for it
torch_model = PyTorchTestModel(upscale_factor=OFMDim / IFMDim)
input_shape = (1, NumChannels, IFMDim, IFMDim)
......@@ -132,7 +135,7 @@ def test_fpgadataflow_upsampler(dt, IFMDim, OFMDim, NumChannels, exec_mode):
# Get golden PyTorch and ONNX inputs
golden_torch_float = torch_model(test_in)
export_path = "Upsample_exported.onnx"
export_path = f"{tmpdir}/Upsample_exported.onnx"
FINNManager.export(
torch_model, input_shape=input_shape, export_path=export_path, opset_version=11
)
......@@ -140,19 +143,19 @@ def test_fpgadataflow_upsampler(dt, IFMDim, OFMDim, NumChannels, exec_mode):
input_dict = {model.graph.input[0].name: test_in.numpy().astype(np.int8)}
input_dict = {model.graph.input[0].name: test_in.numpy()}
golden_output_dict = oxe.execute_onnx(model, input_dict, True)
golden_result_float = golden_output_dict[model.graph.output[0].name]
golden_result = golden_output_dict[model.graph.output[0].name]
# Make sure PyTorch and ONNX match
pyTorch_onnx_match = np.isclose(golden_result_float, golden_torch_float).all()
pyTorch_onnx_match = np.isclose(golden_result, golden_torch_float).all()
assert pyTorch_onnx_match, "ONNX and PyTorch upsampling output don't match."
# Prep model for execution
model = ModelWrapper(export_path)
transpose_path = "Upsample_transposed.onnx"
transpose_path = f"{tmpdir}/Upsample_transposed.onnx"
model = model.transform(TransposeUpsampleIO())
model.save(transpose_path)
hls_upsample_path = "Upsample_hls.onnx"
hls_upsample_path = f"{tmpdir}/Upsample_hls.onnx"
model = ModelWrapper(transpose_path)
model = model.transform(ForceDataTypeForTensors(dType=dt))
model = model.transform(GiveUniqueNodeNames())
......@@ -188,8 +191,8 @@ def test_fpgadataflow_upsampler(dt, IFMDim, OFMDim, NumChannels, exec_mode):
output_dict = oxe.execute_onnx(model, input_dict, True)
test_result = output_dict[model.graph.output[0].name]
test_restuls_transposed = test_result.transpose(_to_chan_first_args)
output_matches = np.isclose(golden_result_float, test_restuls_transposed).all()
test_result_transposed = test_result.transpose(_to_chan_first_args)
output_matches = np.isclose(golden_result, test_result_transposed, atol=atol).all()
if exec_mode == "cppsim":
assert output_matches, "Cppsim output doesn't match ONNX/PyTorch."
......@@ -197,8 +200,7 @@ def test_fpgadataflow_upsampler(dt, IFMDim, OFMDim, NumChannels, exec_mode):
# os.environ["LIVENESS_THRESHOLD"] = str(liveness)
assert output_matches, "Rtlsim output doesn't match ONNX/PyTorch."
# Should this be done as well?
# ToDo: Should this be done as well?
# if exec_mode == "rtlsim":
# hls_synt_res_est = model.analysis(hls_synth_res_estimation)
# assert "ChannelwiseOp_Batch_0" in hls_synt_res_est
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment