Skip to content
Snippets Groups Projects
Commit 067d165b authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Test] add ipstitch+rtlsim test for concat

parent db37e945
No related branches found
No related tags found
No related merge requests found
...@@ -39,7 +39,9 @@ from finn.core.modelwrapper import ModelWrapper ...@@ -39,7 +39,9 @@ from finn.core.modelwrapper import ModelWrapper
from finn.core.onnx_exec import execute_onnx from finn.core.onnx_exec import execute_onnx
from finn.transformation.fpgadataflow.compile_cppsim import CompileCppSim from finn.transformation.fpgadataflow.compile_cppsim import CompileCppSim
from finn.transformation.fpgadataflow.convert_to_hls_layers import InferConcatLayer from finn.transformation.fpgadataflow.convert_to_hls_layers import InferConcatLayer
from finn.transformation.fpgadataflow.create_stitched_ip import CreateStitchedIP
from finn.transformation.fpgadataflow.hlssynth_ip import HLSSynthIP from finn.transformation.fpgadataflow.hlssynth_ip import HLSSynthIP
from finn.transformation.fpgadataflow.insert_fifo import InsertFIFO
from finn.transformation.fpgadataflow.prepare_cppsim import PrepareCppSim from finn.transformation.fpgadataflow.prepare_cppsim import PrepareCppSim
from finn.transformation.fpgadataflow.prepare_ip import PrepareIP from finn.transformation.fpgadataflow.prepare_ip import PrepareIP
from finn.transformation.fpgadataflow.prepare_rtlsim import PrepareRTLSim from finn.transformation.fpgadataflow.prepare_rtlsim import PrepareRTLSim
...@@ -89,9 +91,8 @@ def test_fpgadataflow_concat(exec_mode, idt): ...@@ -89,9 +91,8 @@ def test_fpgadataflow_concat(exec_mode, idt):
assert (ret[oname] == exp_out).all() assert (ret[oname] == exp_out).all()
# call transformation to convert to HLS and verify conversion # call transformation to convert to HLS and verify conversion
model = model.transform(InferConcatLayer()) model = model.transform(InferConcatLayer())
assert model.graph.node[0].op_type == "Concat" assert model.graph.node[0].op_type == "StreamingConcat"
assert model.graph.node[0].domain == "finn.custom_op.fpgadataflow" assert model.graph.node[0].domain == "finn.custom_op.fpgadataflow"
model.save("/tmp/finn_dev_maltanar/dbg.onnx")
if exec_mode == "cppsim": if exec_mode == "cppsim":
model = model.transform(PrepareCppSim()) model = model.transform(PrepareCppSim())
model = model.transform(CompileCppSim()) model = model.transform(CompileCppSim())
...@@ -104,3 +105,45 @@ def test_fpgadataflow_concat(exec_mode, idt): ...@@ -104,3 +105,45 @@ def test_fpgadataflow_concat(exec_mode, idt):
model = model.transform(PrepareRTLSim()) model = model.transform(PrepareRTLSim())
ret_sim = execute_onnx(model, inp_dict) ret_sim = execute_onnx(model, inp_dict)
assert (exp_out == ret_sim[oname]).all() assert (exp_out == ret_sim[oname]).all()
@pytest.mark.vivado
@pytest.mark.slow
def test_fpgadataflow_concat_stitchedip():
idt = DataType["INT4"]
fpga_part = "xc7z020clg400-1"
clk_ns = 10
i_shapes = [(1, 2, 4), (1, 2, 6), (1, 2, 1)]
i_data = [gen_finn_dt_tensor(idt, x) for x in i_shapes]
model = make_concat_model(i_shapes, idt)
assert len(i_shapes) == len(model.graph.input)
assert len(model.graph.output) == 1
exp_oshape = list(i_shapes[0][:-1]) + [sum(x[-1] for x in i_shapes)]
oname = model.graph.output[0].name
assert model.get_tensor_shape(oname) == exp_oshape
exp_out = np.concatenate(i_data, axis=-1)
inp_dict = {}
for i in range(len(i_shapes)):
inp_dict[model.graph.input[i].name] = i_data[i]
ret = execute_onnx(model, inp_dict)
assert (ret[oname] == exp_out).all()
# call transformation to convert to HLS and verify conversion
model = model.transform(InferConcatLayer())
assert model.graph.node[0].op_type == "StreamingConcat"
assert model.graph.node[0].domain == "finn.custom_op.fpgadataflow"
model = model.transform(InsertFIFO(create_shallow_fifos=True))
model = model.transform(GiveUniqueNodeNames())
model = model.transform(PrepareIP(fpga_part, clk_ns))
model = model.transform(HLSSynthIP())
model = model.transform(
CreateStitchedIP(
fpga_part,
clk_ns,
vitis=False,
)
)
model.set_metadata_prop("exec_mode", "rtlsim")
model.set_metadata_prop("rtlsim_trace", "trace.vcd")
model.save("dbg.onnx")
ret_sim = execute_onnx(model, inp_dict)
assert (exp_out == ret_sim[oname]).all()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment