Skip to content
Snippets Groups Projects
Commit 2d7a31e7 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Test] add test_end2end_tfc_verify_ip_stitch

parent 6d0a79c5
No related branches found
No related tags found
No related merge requests found
import os
import numpy as np
# as of Feb'20 there is a bug that segfaults ONNX shape inference if we
# import pytorch before onnx, so we make sure to import onnx first
import onnx # NOQA
......@@ -7,11 +8,14 @@ import onnx # NOQA
import finn.transformation.fpgadataflow.convert_to_hls_layers as to_hls
import finn.transformation.streamline.absorb as absorb
from finn.core.modelwrapper import ModelWrapper
from finn.core.onnx_exec import execute_onnx
from finn.custom_op.registry import getCustomOp
from finn.transformation.bipolar_to_xnor import ConvertBipolarMatMulToXnorPopcount
from finn.transformation.fold_constants import FoldConstants
from finn.transformation.fpgadataflow.codegen_ipgen import CodeGen_ipgen
from finn.transformation.fpgadataflow.codegen_ipstitch import CodeGen_ipstitch
from finn.transformation.fpgadataflow.codegen_npysim import CodeGen_npysim
from finn.transformation.fpgadataflow.compile import Compile
from finn.transformation.fpgadataflow.create_dataflow_partition import (
CreateDataflowPartition,
)
......@@ -22,6 +26,7 @@ from finn.transformation.fpgadataflow.make_pynq_proj import MakePYNQProject
from finn.transformation.fpgadataflow.replace_verilog_relpaths import (
ReplaceVerilogRelPaths,
)
from finn.transformation.fpgadataflow.set_exec_mode import SetExecMode
from finn.transformation.fpgadataflow.synth_pynq_proj import SynthPYNQProject
from finn.transformation.general import GiveReadableTensorNames, GiveUniqueNodeNames
from finn.transformation.infer_datatypes import InferDataTypes
......@@ -124,6 +129,27 @@ def test_end2end_tfc_ip_stitch():
model.save(build_dir + "/end2end_tfc_w1_a1_ipstitch.onnx")
def test_end2end_tfc_verify_ip_stitch():
model = ModelWrapper(build_dir + "/end2end_tfc_w1_a1_ipstitch.onnx")
x = np.zeros((1, 784), dtype=np.float32)
inp_name = model.graph.input[0].name
out_name = model.graph.output[0].name
inp_dict = {inp_name: x}
# npysim
model = model.transform(CodeGen_npysim())
model = model.transform(Compile())
model = model.transform(SetExecMode("npysim"))
res_npysim = execute_onnx(model, inp_dict)[out_name]
# node-by-node rtlsim
model = model.transform(SetExecMode("rtlsim"))
res_rtlsim_nodebynode = execute_onnx(model, inp_dict)[out_name]
# whole-network (ip-stitched) rtlsim
model.set_metadata_prop("exec_mode", "rtlsim")
res_rtlsim_ipstitched = execute_onnx(model, inp_dict)[out_name]
assert np.isclose(res_npysim, res_rtlsim_nodebynode).all()
assert np.isclose(res_npysim, res_rtlsim_ipstitched).all()
def test_end2end_tfc_make_pynq_proj():
model = ModelWrapper(build_dir + "/end2end_tfc_w1_a1_ipstitch.onnx")
model = model.transform(MakePYNQProject(test_pynq_board))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment