Skip to content
Snippets Groups Projects
Commit 2b13d42a authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Transform] add InsertTLastMarker and test

parent d5785d44
No related branches found
No related tags found
No related merge requests found
from onnx import TensorProto
from onnx import helper as oh
from finn.custom_op.registry import getCustomOp
from finn.transformation import Transformation
class InsertTLastMarker(Transformation):
"""Ensure that the graph is terminated with a TLastMarker node, inserting
one if necessary."""
def __init__(self):
super().__init__()
def apply(self, model):
# TODO only makes sense for a pure fpgadataflow graph -- check!
graph_out_name = model.graph.output[0].name
final_node = model.find_producer(graph_out_name)
if final_node.op_type == "TLastMarker":
# TODO maybe check the correctness of properties
return (model, False)
else:
custom_op = getCustomOp(final_node)
num_iters = int(custom_op.get_number_output_values())
stream_width = int(custom_op.get_outstream_width())
out_shape = model.get_tensor_shape(graph_out_name)
out_dtype = model.get_tensor_datatype(graph_out_name)
# make new buffer
final_node_out = oh.make_tensor_value_info(
model.make_new_valueinfo_name(), TensorProto.FLOAT, out_shape
)
model.graph.value_info.append(final_node_out)
model.set_tensor_datatype(final_node_out.name, out_dtype)
# reroute final node output to final_node_out_name
final_node.output[0] = final_node_out.name
tlast_node = oh.make_node(
"TLastMarker",
[final_node_out.name],
[graph_out_name],
NumIters=num_iters,
StreamWidth=stream_width,
)
model.graph.node.append(tlast_node)
return (model, True)
import os.path
from pkgutil import get_data
import pytest
from finn.core.modelwrapper import ModelWrapper
from finn.custom_op.registry import getCustomOp
from finn.transformation.fpgadataflow.create_dataflow_partition import (
CreateDataflowPartition,
)
from finn.transformation.fpgadataflow.insert_tlastmarker import InsertTLastMarker
from finn.util.basic import make_build_dir
build_dir = make_build_dir("test_dataflow_partition_")
def test_create_dataflow_partition():
@pytest.mark.dependency()
def test_dataflow_partition_create():
# load the onnx model
raw_m = get_data(
"finn", "data/onnx/finn-hls-model/tfc_w1_a1_after_conv_to_hls.onnx"
......@@ -19,3 +26,19 @@ def test_create_dataflow_partition():
sdp_node = getCustomOp(model.graph.node[2])
assert sdp_node.__class__.__name__ == "StreamingDataflowPartition"
assert os.path.isfile(sdp_node.get_nodeattr("model"))
model.save(build_dir + "/test_dataflow_partition_create.onnx")
@pytest.mark.dependency(depends=["test_dataflow_partition_create"])
def test_dataflow_partition_tlastmarker():
model = ModelWrapper(build_dir + "/test_dataflow_partition_create.onnx")
model_path = getCustomOp(model.graph.node[2]).get_nodeattr("model")
model = ModelWrapper(model_path)
model = model.transform(InsertTLastMarker())
assert model.graph.node[-1].op_type == "TLastMarker"
tl_node = getCustomOp(model.graph.node[-1])
assert tl_node.get_nodeattr("NumIters") == 1
assert tl_node.get_nodeattr("StreamWidth") == 320
model.save(build_dir + "/test_dataflow_partition_tlastmarker.onnx")
model = model.transform(InsertTLastMarker())
model.save(build_dir + "/test_dataflow_partition_tlastmarker2.onnx")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment