From 801c92da96754d5871d91539641ee0fe908e3f73 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Thu, 14 Jul 2022 16:38:34 +0200
Subject: [PATCH] [Test] make upsampler test suitable for parallel exec

---
 tests/fpgadataflow/test_fpgadataflow_upsampler.py | 10 ++++++++--
 1 file changed, 8 insertions(+), 2 deletions(-)

diff --git a/tests/fpgadataflow/test_fpgadataflow_upsampler.py b/tests/fpgadataflow/test_fpgadataflow_upsampler.py
index 534e1ce50..a08d31f7b 100644
--- a/tests/fpgadataflow/test_fpgadataflow_upsampler.py
+++ b/tests/fpgadataflow/test_fpgadataflow_upsampler.py
@@ -30,6 +30,7 @@ import pytest
 
 import numpy as np
 import os
+import shutil
 import torch
 from brevitas.export import FINNManager
 from qonnx.core.datatype import DataType
@@ -51,6 +52,7 @@ from finn.transformation.fpgadataflow.prepare_cppsim import PrepareCppSim
 from finn.transformation.fpgadataflow.prepare_ip import PrepareIP
 from finn.transformation.fpgadataflow.prepare_rtlsim import PrepareRTLSim
 from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode
+from finn.util.basic import make_build_dir
 
 tmpdir = os.environ["FINN_BUILD_DIR"]
 
@@ -131,13 +133,16 @@ class PyTorchTestModel(nn.Module):
 @pytest.mark.vivado
 @pytest.mark.slow
 def test_fpgadataflow_upsampler(dt, IFMDim, scale, NumChannels, exec_mode, is_1d):
+    tmpdir = make_build_dir("upsample_export_")
     atol = 1e-3
-    # Create the test model and inputs for it
-    torch_model = PyTorchTestModel(upscale_factor=scale)
     if is_1d:
         input_shape = (1, NumChannels, IFMDim, 1)
+        upscale_factor = (scale, 1)
     else:
         input_shape = (1, NumChannels, IFMDim, IFMDim)
+        upscale_factor = (scale, scale)
+    # Create the test model and inputs for it
+    torch_model = PyTorchTestModel(upscale_factor=upscale_factor)
     test_in = torch.arange(0, np.prod(np.asarray(input_shape)))
     # Limit the input to values valid for the given datatype
     test_in %= dt.max() - dt.min() + 1
@@ -205,3 +210,4 @@ def test_fpgadataflow_upsampler(dt, IFMDim, scale, NumChannels, exec_mode, is_1d
         assert output_matches, "Cppsim output doesn't match ONNX/PyTorch."
     elif exec_mode == "rtlsim":
         assert output_matches, "Rtlsim output doesn't match ONNX/PyTorch."
+    shutil.rmtree(tmpdir, ignore_errors=True)
-- 
GitLab