Skip to content
Snippets Groups Projects
Commit ea3522ba authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[InsertDWC] handle StreamingConcat correctly

parent d966e8c8
No related branches found
No related tags found
No related merge requests found
import warnings
from onnx import TensorProto from onnx import TensorProto
from onnx import helper as oh from onnx import helper as oh
...@@ -48,23 +47,23 @@ class InsertDWC(Transformation): ...@@ -48,23 +47,23 @@ class InsertDWC(Transformation):
consumers = model.find_consumers(output_name) consumers = model.find_consumers(output_name)
if consumers is None: if consumers is None:
continue continue
if len(consumers) > 1: assert len(consumers) == 1, (
warnings.warn( n.name
n.name + ": HLS node with fan-out higher than 1 cannot be stitched"
+ ": HLS node with fan-out higher than 1 cannot be stitched" )
)
consumer = consumers[0] consumer = consumers[0]
if _suitable_node(consumer) is True: if _suitable_node(consumer) is True:
n0 = getCustomOp(n) n0 = getCustomOp(n)
n1 = getCustomOp(consumer) n1 = getCustomOp(consumer)
n0_out_shape = n0.get_folded_output_shape() n0_out_shape = n0.get_folded_output_shape()
# in some special cases, we need to get folded shapes of
# If FC and external mem, it could be connected to input 1 # non-default inputs for the consumer
# - if FC and external mem, it could be connected to input 1
# - if concat, could be connected to any input
if ( if (
consumer.op_type == "StreamingFCLayer_Batch" consumer.op_type == "StreamingFCLayer_Batch"
and n1.get_nodeattr("mem_mode") == "external" and n1.get_nodeattr("mem_mode") == "external"
): ) or (consumer.op_type == "StreamingConcat"):
# get input idx # get input idx
in_idx = None in_idx = None
for idx, n_input in enumerate(consumer.input): for idx, n_input in enumerate(consumer.input):
...@@ -73,6 +72,7 @@ class InsertDWC(Transformation): ...@@ -73,6 +72,7 @@ class InsertDWC(Transformation):
assert in_idx is not None, "Malformed model" assert in_idx is not None, "Malformed model"
n1_in_shape = n1.get_folded_input_shape(in_idx) n1_in_shape = n1.get_folded_input_shape(in_idx)
else: else:
# use default folded input shape
n1_in_shape = n1.get_folded_input_shape() n1_in_shape = n1.get_folded_input_shape()
if n0_out_shape[-1] != n1_in_shape[-1]: if n0_out_shape[-1] != n1_in_shape[-1]:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment