Skip to content
Snippets Groups Projects
Commit 0b458215 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Refactor] make varname more explanatory in InsertDWC

parent 4f6be619
No related branches found
No related tags found
No related merge requests found
......@@ -44,8 +44,8 @@ class InsertDWC(Transformation):
for n in graph.node:
node_ind += 1
if _suitable_node(n):
for n_output in n.output:
consumers = model.find_consumers(n_output)
for output_name in n.output:
consumers = model.find_consumers(output_name)
if consumers is None:
continue
if len(consumers) > 1:
......@@ -61,14 +61,16 @@ class InsertDWC(Transformation):
n0_out_shape = n0.get_folded_output_shape()
# If FC and external mem, it could be connected to input 1
if (consumer.op_type == "StreamingFCLayer_Batch" and
n1.get_nodeattr("mem_mode") == "external"):
if (
consumer.op_type == "StreamingFCLayer_Batch"
and n1.get_nodeattr("mem_mode") == "external"
):
# get input idx
in_idx = None
for idx, n_input in enumerate(consumer.input):
if n_output == n_input:
if output_name == n_input:
in_idx = idx
assert in_idx is not None,"Malformed model"
assert in_idx is not None, "Malformed model"
n1_in_shape = n1.get_folded_input_shape(in_idx)
else:
n1_in_shape = n1.get_folded_input_shape()
......@@ -95,7 +97,7 @@ class InsertDWC(Transformation):
dwc_node = oh.make_node(
"StreamingDataWidthConverter_Batch",
[n_output],
[output_name],
[dwc_output_tensor.name],
domain="finn.custom_op.fpgadataflow",
backend="fpgadataflow",
......@@ -109,7 +111,7 @@ class InsertDWC(Transformation):
# set dwc output tensor as new input tensor of second node
for idx, inp in enumerate(consumer.input):
if inp == n_output:
if inp == output_name:
consumer.input[idx] = dwc_output_tensor.name
return (model, graph_modified)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment