Skip to content
Snippets Groups Projects
Commit 0e7dca9a authored by Lucian Petrica's avatar Lucian Petrica
Browse files

Added to_hls transform and test for it

parent d37aa6c9
No related branches found
No related tags found
No related merge requests found
......@@ -398,3 +398,60 @@ class InferQuantizedStreamingFCLayer(Transformation):
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
return (model, graph_modified)
class InferThresholdingLayer(Transformation):
"""Convert any MultiThreshold into a standalone thresholding HLS layer."""
def apply(self, model):
graph = model.graph
node_ind = 0
graph_modified = False
for node in graph.node:
node_ind += 1
if node.op_type == "MultiThreshold":
thl_input = node.input[0]
thl_threshold = node.input[1]
thl_output = node.output[0]
thl_in_shape = model.get_tensor_shape(thl_input)
thl_out_shape = model.get_tensor_shape(thl_output)
idt = model.get_tensor_datatype(thl_input)
# skip conversion for layers with float input
if not idt.is_integer():
continue
# extract weight shape, note that ONNX and finn-hlslib
# make different assumptions about dim order here
# ONNX assumes W has (in, out) shape
# finn-hlslib assumes W has (out, in) shape
ifc = int(thl_in_shape[-1])
# create node with no parallelization first
pe = 1
assert ifc % pe == 0, "Requirement IFC divisable by PE is violated."
odt = model.get_tensor_datatype(thl_output)
model.set_tensor_shape(thl_input, thl_in_shape)
model.set_tensor_shape(thl_output, thl_out_shape)
# create and insert new StreamingFCLayer node
new_node = helper.make_node(
"Thresholding_Batch",
[thl_input, thl_threshold],
[thl_output],
domain="finn",
backend="fpgadataflow",
NumChannels=ifc,
PE=pe,
inputDataType=idt.name,
outputDataType=odt.name,
numInputVectors=list(thl_in_shape[:-1]),
)
graph.node.insert(node_ind, new_node)
# remove old node
graph.node.remove(node)
graph_modified = True
if graph_modified:
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
return (model, graph_modified)
......@@ -54,7 +54,9 @@ export_onnx_path_cnv = "test_output_cnv.onnx"
@pytest.mark.vivado
def test_convert_to_hls_layers_cnv_w1a1():
# Standalone or fused thresholding-based activation
@pytest.mark.parametrize("fused_activation", [True, False])
def test_convert_to_hls_layers_cnv_w1a1(fused_activation):
cnv = get_test_model_trained("CNV", 1, 1)
bo.export_finn_onnx(cnv, (1, 3, 32, 32), export_onnx_path_cnv)
model = ModelWrapper(export_onnx_path_cnv)
......@@ -80,6 +82,10 @@ def test_convert_to_hls_layers_cnv_w1a1():
expected_ctx = oxe.execute_onnx(model, input_dict, True)
expected = expected_ctx[model.graph.output[0].name]
# if we infer thresholding first, all MultiThresholds get converted to HLS
# subsequently, the FC inference will generate passthrough MVAUs
if not fused_activation:
model = model.transform(to_hls.InferThresholdingLayer())
model = model.transform(to_hls.InferBinaryStreamingFCLayer())
model = model.transform(to_hls.InferQuantizedStreamingFCLayer())
for node in model.graph.node:
......@@ -102,7 +108,12 @@ def test_convert_to_hls_layers_cnv_w1a1():
model = model.transform(to_hls.InferStreamingMaxPool())
# check topology status
finn_nodes = model.get_finn_nodes()
assert len(finn_nodes) == 18
if fused_activation:
assert len(finn_nodes) == 18
else:
assert len(finn_nodes) == 26
thr_nodes = model.get_nodes_by_op_type("Thresholding_Batch")
assert len(thr_nodes) == 8
non_finn_nodes = model.get_non_finn_nodes()
assert len(non_finn_nodes) == 4
exp_non_finn_nodes = ["Transpose", "Reshape", "Mul", "Add"]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment