Skip to content
Snippets Groups Projects
Commit 341a8a3e authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[ToHLS] draft a InferStreamingEltwiseAbsDiff conversion

parent 4a0bff5d
No related branches found
No related tags found
No related merge requests found
......@@ -1671,3 +1671,95 @@ class InferConcatLayer(Transformation):
model = model.transform(InferShapes())
model = model.transform(InferDataTypes())
return (model, graph_modified)
class InferStreamingEltwiseAbsDiff(Transformation):
"""Convert eltwise Sub -> Abs to StreamingEltwise layer
with AbsDiffEltwise op."""
def apply(self, model):
graph = model.graph
node_ind = 0
graph_modified = False
for node in graph.node:
node_ind += 1
if node.op_type == "Sub":
in0 = node.input[0]
in1 = node.input[1]
result = node.output[0]
in0_shape = model.get_tensor_shape(in0)
in1_shape = model.get_tensor_shape(in1)
# skip if different shapes on inputs
if in0_shape != in1_shape:
continue
idt0 = model.get_tensor_datatype(in0)
idt1 = model.get_tensor_datatype(in1)
# skip conversion for layers with float input
if not (idt0.is_integer() and idt1.is_integer()):
continue
# look for a downstream Abs node
res_consumer = model.find_consumer(result)
if res_consumer is None:
continue
if res_consumer.op_type != "Abs":
continue
result = res_consumer.output[0]
# check layout and convert if necessary
in0_layout = model.get_tensor_layout(in0)
in1_layout = model.get_tensor_layout(in1)
result_layout = model.get_tensor_layout(result)
if in0_layout == DataLayout.NCHW:
in0 = nchw_to_nhwc(in0, model, node_ind)
node_ind += 1
in0_shape = model.get_tensor_shape(in0)
if in1_layout == DataLayout.NCHW:
in1 = nchw_to_nhwc(in1, model, node_ind)
node_ind += 1
in1_shape = model.get_tensor_shape(in1)
# keep track of where we need to insert the HLS Op
# it has to be ahead of the output transform
insert_point = node_ind
if result_layout == DataLayout.NCHW:
result = nchw_to_nhwc(result, model, node_ind, reverse=True)
node_ind += 1
# now safe to assume num_channels is size of last dimension
num_channels = int(in0_shape[-1])
# create node with no parallelization first
pe = 1
# create and insert new Eltwise node
new_node = helper.make_node(
"StreamingEltwise",
[in0, in1],
[result],
domain="finn.custom_op.fpgadataflow",
backend="fpgadataflow",
NumChannels=num_channels,
PE=pe,
inputDataType0=idt0.name,
inputDataType1=idt1.name,
eltwiseOp="AbsDiff",
numInputVectors=in0_shape[:-1],
name="StreamingEltwise_" + node.name,
)
graph.node.insert(insert_point, new_node)
# remove old nodes
graph.node.remove(node)
graph.node.remove(res_consumer)
graph_modified = True
# if graph_modified:
# model = model.transform(InferShapes())
# model = model.transform(InferDataTypes())
return (model, graph_modified)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment