Skip to content
Snippets Groups Projects
Commit a46882de authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Test] simplify upsample tests, remove unsupported cases

parent 463c73c5
No related branches found
No related tags found
No related merge requests found
......@@ -29,12 +29,13 @@
import pytest
import numpy as np
import tempfile
import os
import torch
from brevitas.export import FINNManager
from torch import nn
import finn.core.onnx_exec as oxe
import finn.transformation.streamline.absorb as absorb
from finn.core.datatype import DataType
from finn.core.modelwrapper import ModelWrapper
from finn.transformation.base import Transformation
......@@ -46,8 +47,12 @@ from finn.transformation.fpgadataflow.prepare_ip import PrepareIP
from finn.transformation.fpgadataflow.prepare_rtlsim import PrepareRTLSim
from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode
from finn.transformation.general import GiveUniqueNodeNames
from finn.transformation.infer_data_layouts import InferDataLayouts
from finn.transformation.infer_datatypes import InferDataTypes
from finn.transformation.infer_shapes import InferShapes
from finn.transformation.make_input_chanlast import MakeInputChannelsLast
tmpdir = os.environ["FINN_BUILD_DIR"]
class ForceDataTypeForTensors(Transformation):
......@@ -114,19 +119,18 @@ class PyTorchTestModel(nn.Module):
@pytest.mark.parametrize("dt", [DataType.INT8])
# Width/height of square input feature map
@pytest.mark.parametrize("IFMDim", [3, 5])
# Width/height of square output feature map
@pytest.mark.parametrize("OFMDim", [6, 7, 14])
# upscaling factor
@pytest.mark.parametrize("scale", [2, 3])
# Number of input/output channels
@pytest.mark.parametrize("NumChannels", [1, 4])
@pytest.mark.parametrize("NumChannels", [4])
# execution mode
@pytest.mark.parametrize("exec_mode", ["cppsim", "rtlsim"])
@pytest.mark.vivado
@pytest.mark.slow
def test_fpgadataflow_upsampler(dt, IFMDim, OFMDim, NumChannels, exec_mode):
tmpdir = tempfile.mkdtemp()
def test_fpgadataflow_upsampler(dt, IFMDim, scale, NumChannels, exec_mode):
atol = 1e-3
# Create the test model and inputs for it
torch_model = PyTorchTestModel(upscale_factor=OFMDim / IFMDim)
torch_model = PyTorchTestModel(upscale_factor=scale)
input_shape = (1, NumChannels, IFMDim, IFMDim)
test_in = torch.arange(0, np.prod(np.asarray(input_shape)))
# Limit the input to values valid for the given datatype
......@@ -154,18 +158,16 @@ def test_fpgadataflow_upsampler(dt, IFMDim, OFMDim, NumChannels, exec_mode):
# Prep model for execution
model = ModelWrapper(export_path)
transpose_path = f"{tmpdir}/Upsample_transposed.onnx"
model = model.transform(TransposeUpsampleIO())
model.save(transpose_path)
hls_upsample_path = f"{tmpdir}/Upsample_hls.onnx"
model = ModelWrapper(transpose_path)
# model = model.transform(TransposeUpsampleIO())
model = model.transform(MakeInputChannelsLast())
model = model.transform(InferDataLayouts())
model = model.transform(absorb.AbsorbTransposeIntoResize())
model = model.transform(InferShapes())
model = model.transform(ForceDataTypeForTensors(dType=dt))
model = model.transform(GiveUniqueNodeNames())
model = model.transform(InferUpsample())
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
model.save(hls_upsample_path)
# Check that all nodes are UpsampleNearestNeighbour_Batch nodes
for n in model.get_finn_nodes():
......@@ -191,24 +193,9 @@ def test_fpgadataflow_upsampler(dt, IFMDim, OFMDim, NumChannels, exec_mode):
input_dict = {model.graph.input[0].name: test_in_transposed}
output_dict = oxe.execute_onnx(model, input_dict, True)
test_result = output_dict[model.graph.output[0].name]
test_result_transposed = test_result.transpose(_to_chan_first_args)
output_matches = np.isclose(golden_result, test_result_transposed, atol=atol).all()
output_matches = np.isclose(golden_result, test_result, atol=atol).all()
if exec_mode == "cppsim":
assert output_matches, "Cppsim output doesn't match ONNX/PyTorch."
elif exec_mode == "rtlsim":
assert output_matches, "Rtlsim output doesn't match ONNX/PyTorch."
# ToDo: Should this be done / implemented as well?
# if exec_mode == "rtlsim":
# hls_synt_res_est = model.analysis(hls_synth_res_estimation)
# assert "ChannelwiseOp_Batch_0" in hls_synt_res_est
#
# node = model.get_nodes_by_op_type("ChannelwiseOp_Batch")[0]
# inst = getCustomOp(node)
# cycles_rtlsim = inst.get_nodeattr("cycles_rtlsim")
# exp_cycles_dict = model.analysis(exp_cycles_per_layer)
# exp_cycles = exp_cycles_dict[node.name]
# assert np.isclose(exp_cycles, cycles_rtlsim, atol=10)
# assert exp_cycles != 0
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment