diff --git a/tests/fpgadataflow/test_fpgadataflow_upsampler.py b/tests/fpgadataflow/test_fpgadataflow_upsampler.py index cc398887f944e507ee804c6e859eace70041663b..64a0519c92c30cfa40f21db54b70f833dd7f2f1d 100644 --- a/tests/fpgadataflow/test_fpgadataflow_upsampler.py +++ b/tests/fpgadataflow/test_fpgadataflow_upsampler.py @@ -131,8 +131,11 @@ def test_fpgadataflow_upsampler(dt, IFMDim, OFMDim, NumChannels, exec_mode): torch_model = PyTorchTestModel(upscale_factor=OFMDim / IFMDim) input_shape = (1, NumChannels, IFMDim, IFMDim) test_in = torch.arange(0, np.prod(np.asarray(input_shape))) + # Limit the input to values valid for the given datatype test_in %= dt.max() - dt.min() + 1 test_in += dt.min() + # Additionally make sure we always start with 0, for convenience purposes. + test_in = torch.roll(test_in, dt.min()) test_in = test_in.view(*input_shape).type(torch.float32) # Get golden PyTorch and ONNX inputs