Skip to content
Snippets Groups Projects
Commit a46882de authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Test] simplify upsample tests, remove unsupported cases

parent 463c73c5
No related branches found
No related tags found
No related merge requests found
...@@ -29,12 +29,13 @@ ...@@ -29,12 +29,13 @@
import pytest import pytest
import numpy as np import numpy as np
import tempfile import os
import torch import torch
from brevitas.export import FINNManager from brevitas.export import FINNManager
from torch import nn from torch import nn
import finn.core.onnx_exec as oxe import finn.core.onnx_exec as oxe
import finn.transformation.streamline.absorb as absorb
from finn.core.datatype import DataType from finn.core.datatype import DataType
from finn.core.modelwrapper import ModelWrapper from finn.core.modelwrapper import ModelWrapper
from finn.transformation.base import Transformation from finn.transformation.base import Transformation
...@@ -46,8 +47,12 @@ from finn.transformation.fpgadataflow.prepare_ip import PrepareIP ...@@ -46,8 +47,12 @@ from finn.transformation.fpgadataflow.prepare_ip import PrepareIP
from finn.transformation.fpgadataflow.prepare_rtlsim import PrepareRTLSim from finn.transformation.fpgadataflow.prepare_rtlsim import PrepareRTLSim
from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode
from finn.transformation.general import GiveUniqueNodeNames from finn.transformation.general import GiveUniqueNodeNames
from finn.transformation.infer_data_layouts import InferDataLayouts
from finn.transformation.infer_datatypes import InferDataTypes from finn.transformation.infer_datatypes import InferDataTypes
from finn.transformation.infer_shapes import InferShapes from finn.transformation.infer_shapes import InferShapes
from finn.transformation.make_input_chanlast import MakeInputChannelsLast
tmpdir = os.environ["FINN_BUILD_DIR"]
class ForceDataTypeForTensors(Transformation): class ForceDataTypeForTensors(Transformation):
...@@ -114,19 +119,18 @@ class PyTorchTestModel(nn.Module): ...@@ -114,19 +119,18 @@ class PyTorchTestModel(nn.Module):
@pytest.mark.parametrize("dt", [DataType.INT8]) @pytest.mark.parametrize("dt", [DataType.INT8])
# Width/height of square input feature map # Width/height of square input feature map
@pytest.mark.parametrize("IFMDim", [3, 5]) @pytest.mark.parametrize("IFMDim", [3, 5])
# Width/height of square output feature map # upscaling factor
@pytest.mark.parametrize("OFMDim", [6, 7, 14]) @pytest.mark.parametrize("scale", [2, 3])
# Number of input/output channels # Number of input/output channels
@pytest.mark.parametrize("NumChannels", [1, 4]) @pytest.mark.parametrize("NumChannels", [4])
# execution mode # execution mode
@pytest.mark.parametrize("exec_mode", ["cppsim", "rtlsim"]) @pytest.mark.parametrize("exec_mode", ["cppsim", "rtlsim"])
@pytest.mark.vivado @pytest.mark.vivado
@pytest.mark.slow @pytest.mark.slow
def test_fpgadataflow_upsampler(dt, IFMDim, OFMDim, NumChannels, exec_mode): def test_fpgadataflow_upsampler(dt, IFMDim, scale, NumChannels, exec_mode):
tmpdir = tempfile.mkdtemp()
atol = 1e-3 atol = 1e-3
# Create the test model and inputs for it # Create the test model and inputs for it
torch_model = PyTorchTestModel(upscale_factor=OFMDim / IFMDim) torch_model = PyTorchTestModel(upscale_factor=scale)
input_shape = (1, NumChannels, IFMDim, IFMDim) input_shape = (1, NumChannels, IFMDim, IFMDim)
test_in = torch.arange(0, np.prod(np.asarray(input_shape))) test_in = torch.arange(0, np.prod(np.asarray(input_shape)))
# Limit the input to values valid for the given datatype # Limit the input to values valid for the given datatype
...@@ -154,18 +158,16 @@ def test_fpgadataflow_upsampler(dt, IFMDim, OFMDim, NumChannels, exec_mode): ...@@ -154,18 +158,16 @@ def test_fpgadataflow_upsampler(dt, IFMDim, OFMDim, NumChannels, exec_mode):
# Prep model for execution # Prep model for execution
model = ModelWrapper(export_path) model = ModelWrapper(export_path)
transpose_path = f"{tmpdir}/Upsample_transposed.onnx" # model = model.transform(TransposeUpsampleIO())
model = model.transform(TransposeUpsampleIO()) model = model.transform(MakeInputChannelsLast())
model.save(transpose_path) model = model.transform(InferDataLayouts())
model = model.transform(absorb.AbsorbTransposeIntoResize())
hls_upsample_path = f"{tmpdir}/Upsample_hls.onnx" model = model.transform(InferShapes())
model = ModelWrapper(transpose_path)
model = model.transform(ForceDataTypeForTensors(dType=dt)) model = model.transform(ForceDataTypeForTensors(dType=dt))
model = model.transform(GiveUniqueNodeNames()) model = model.transform(GiveUniqueNodeNames())
model = model.transform(InferUpsample()) model = model.transform(InferUpsample())
model = model.transform(InferShapes()) model = model.transform(InferShapes())
model = model.transform(InferDataTypes()) model = model.transform(InferDataTypes())
model.save(hls_upsample_path)
# Check that all nodes are UpsampleNearestNeighbour_Batch nodes # Check that all nodes are UpsampleNearestNeighbour_Batch nodes
for n in model.get_finn_nodes(): for n in model.get_finn_nodes():
...@@ -191,24 +193,9 @@ def test_fpgadataflow_upsampler(dt, IFMDim, OFMDim, NumChannels, exec_mode): ...@@ -191,24 +193,9 @@ def test_fpgadataflow_upsampler(dt, IFMDim, OFMDim, NumChannels, exec_mode):
input_dict = {model.graph.input[0].name: test_in_transposed} input_dict = {model.graph.input[0].name: test_in_transposed}
output_dict = oxe.execute_onnx(model, input_dict, True) output_dict = oxe.execute_onnx(model, input_dict, True)
test_result = output_dict[model.graph.output[0].name] test_result = output_dict[model.graph.output[0].name]
output_matches = np.isclose(golden_result, test_result, atol=atol).all()
test_result_transposed = test_result.transpose(_to_chan_first_args)
output_matches = np.isclose(golden_result, test_result_transposed, atol=atol).all()
if exec_mode == "cppsim": if exec_mode == "cppsim":
assert output_matches, "Cppsim output doesn't match ONNX/PyTorch." assert output_matches, "Cppsim output doesn't match ONNX/PyTorch."
elif exec_mode == "rtlsim": elif exec_mode == "rtlsim":
assert output_matches, "Rtlsim output doesn't match ONNX/PyTorch." assert output_matches, "Rtlsim output doesn't match ONNX/PyTorch."
# ToDo: Should this be done / implemented as well?
# if exec_mode == "rtlsim":
# hls_synt_res_est = model.analysis(hls_synth_res_estimation)
# assert "ChannelwiseOp_Batch_0" in hls_synt_res_est
#
# node = model.get_nodes_by_op_type("ChannelwiseOp_Batch")[0]
# inst = getCustomOp(node)
# cycles_rtlsim = inst.get_nodeattr("cycles_rtlsim")
# exp_cycles_dict = model.analysis(exp_cycles_per_layer)
# exp_cycles = exp_cycles_dict[node.name]
# assert np.isclose(exp_cycles, cycles_rtlsim, atol=10)
# assert exp_cycles != 0
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment