Skip to content
Snippets Groups Projects
Commit d36baa78 authored by auphelia's avatar auphelia
Browse files

[Transformation] Add condition that reshape node is inbetween two fpgadataflow nodes to trafo

parent 059427ff
No related branches found
No related tags found
No related merge requests found
from finn.transformation import Transformation
from finn.transformation.infer_shapes import InferShapes
from finn.util.basic import get_by_name
def _get_number_of_nodes(model):
node_count = 0
for n in model.graph.node:
node_count += 1
return node_count
def _is_fpgadataflow_node(node):
if node is not None:
if node.domain == "finn":
n_backend = get_by_name(node.attribute, "backend")
if n_backend is None:
return False
backend_value = n_backend.s.decode("UTF-8")
if backend_value == "fpgadataflow":
return True
else:
return False
else:
return False
class MoveReshape(Transformation):
"""Removes a node that implements a (1, -1) reshape and runs
InferShapes on the model"""
"""Removes a node that implements a (1, -1) reshape if it is
between two fpgadataflow nodes"""
def apply(self, model):
......@@ -22,10 +30,11 @@ class MoveReshape(Transformation):
graph_modified = True
shape = model.get_initializer(n.input[1])
if (shape == [1, -1]).all():
consumer = model.find_consumer(n.output[0])
if consumer is not None:
consumer.input[0] = n.input[0]
graph.node.remove(n)
producer = model.find_producer(n.input[0])
if _is_fpgadataflow_node(producer) is True:
consumer = model.find_consumer(n.output[0])
if _is_fpgadataflow_node(consumer) is True:
consumer.input[0] = n.input[0]
graph.node.remove(n)
model = model.transform(InferShapes())
return (model, graph_modified)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment