From 801c92da96754d5871d91539641ee0fe908e3f73 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <maltanar@gmail.com> Date: Thu, 14 Jul 2022 16:38:34 +0200 Subject: [PATCH] [Test] make upsampler test suitable for parallel exec --- tests/fpgadataflow/test_fpgadataflow_upsampler.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/fpgadataflow/test_fpgadataflow_upsampler.py b/tests/fpgadataflow/test_fpgadataflow_upsampler.py index 534e1ce50..a08d31f7b 100644 --- a/tests/fpgadataflow/test_fpgadataflow_upsampler.py +++ b/tests/fpgadataflow/test_fpgadataflow_upsampler.py @@ -30,6 +30,7 @@ import pytest import numpy as np import os +import shutil import torch from brevitas.export import FINNManager from qonnx.core.datatype import DataType @@ -51,6 +52,7 @@ from finn.transformation.fpgadataflow.prepare_cppsim import PrepareCppSim from finn.transformation.fpgadataflow.prepare_ip import PrepareIP from finn.transformation.fpgadataflow.prepare_rtlsim import PrepareRTLSim from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode +from finn.util.basic import make_build_dir tmpdir = os.environ["FINN_BUILD_DIR"] @@ -131,13 +133,16 @@ class PyTorchTestModel(nn.Module): @pytest.mark.vivado @pytest.mark.slow def test_fpgadataflow_upsampler(dt, IFMDim, scale, NumChannels, exec_mode, is_1d): + tmpdir = make_build_dir("upsample_export_") atol = 1e-3 - # Create the test model and inputs for it - torch_model = PyTorchTestModel(upscale_factor=scale) if is_1d: input_shape = (1, NumChannels, IFMDim, 1) + upscale_factor = (scale, 1) else: input_shape = (1, NumChannels, IFMDim, IFMDim) + upscale_factor = (scale, scale) + # Create the test model and inputs for it + torch_model = PyTorchTestModel(upscale_factor=upscale_factor) test_in = torch.arange(0, np.prod(np.asarray(input_shape))) # Limit the input to values valid for the given datatype test_in %= dt.max() - dt.min() + 1 @@ -205,3 +210,4 @@ def test_fpgadataflow_upsampler(dt, IFMDim, scale, NumChannels, exec_mode, is_1d assert output_matches, "Cppsim output doesn't match ONNX/PyTorch." elif exec_mode == "rtlsim": assert output_matches, "Rtlsim output doesn't match ONNX/PyTorch." + shutil.rmtree(tmpdir, ignore_errors=True) -- GitLab