Skip to content
Snippets Groups Projects
Commit ea3522ba authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[InsertDWC] handle StreamingConcat correctly

parent d966e8c8
No related branches found
No related tags found
No related merge requests found
import warnings
from onnx import TensorProto
from onnx import helper as oh
......@@ -48,23 +47,23 @@ class InsertDWC(Transformation):
consumers = model.find_consumers(output_name)
if consumers is None:
continue
if len(consumers) > 1:
warnings.warn(
n.name
+ ": HLS node with fan-out higher than 1 cannot be stitched"
)
assert len(consumers) == 1, (
n.name
+ ": HLS node with fan-out higher than 1 cannot be stitched"
)
consumer = consumers[0]
if _suitable_node(consumer) is True:
n0 = getCustomOp(n)
n1 = getCustomOp(consumer)
n0_out_shape = n0.get_folded_output_shape()
# If FC and external mem, it could be connected to input 1
# in some special cases, we need to get folded shapes of
# non-default inputs for the consumer
# - if FC and external mem, it could be connected to input 1
# - if concat, could be connected to any input
if (
consumer.op_type == "StreamingFCLayer_Batch"
and n1.get_nodeattr("mem_mode") == "external"
):
) or (consumer.op_type == "StreamingConcat"):
# get input idx
in_idx = None
for idx, n_input in enumerate(consumer.input):
......@@ -73,6 +72,7 @@ class InsertDWC(Transformation):
assert in_idx is not None, "Malformed model"
n1_in_shape = n1.get_folded_input_shape(in_idx)
else:
# use default folded input shape
n1_in_shape = n1.get_folded_input_shape()
if n0_out_shape[-1] != n1_in_shape[-1]:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment