Skip to content
Snippets Groups Projects
Commit efb7057d authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Refactor] move RemoveShallowFIFOs to own transformation

parent a1e7889b
No related merge requests found
......@@ -77,6 +77,44 @@ def optimize_depth(depth):
return int(2 ** math.ceil(math.log2(depth)))
class RemoveShallowFIFOs(Transformation):
"""Remove small FIFOs as the streaming components have depth-2 FIFOs on the
input/outputs by default."""
# TODO add unit test
def __init__(self, shallow_threshold=2):
self.shallow_threshold = shallow_threshold
def apply(self, model):
shallow_fifos = []
for node in model.graph.node:
if (
node.op_type == "StreamingFIFO"
and getCustomOp(node).get_nodeattr("depth") <= self.shallow_threshold
):
# bypass shallow fifos
shallow_fifos.append(node)
consumers = model.find_consumers(node.output[0])
if consumers is None:
producer = model.find_producer(node.input[0])
for idx, inp in enumerate(producer.output):
if inp == node.input[0]:
producer.output[idx] = node.output[0]
else:
assert len(consumers) == 1, "Fanout detected from FIFO output"
consumer = consumers[0]
# set fifo input tensor as new input tensor of second node
for idx, inp in enumerate(consumer.input):
if inp == node.output[0]:
consumer.input[idx] = node.input[0]
# now filter out
for node_to_remove in shallow_fifos:
model.graph.node.remove(node_to_remove)
return (model, False)
class CapConvolutionFIFODepths(Transformation):
"""Make the size of FIFOs for convolution layers smaller where possible.
Will be automatically called from InsertAndSetFIFODepths if the appropriate
......@@ -101,6 +139,8 @@ class CapConvolutionFIFODepths(Transformation):
than 1 row.
"""
# TODO add unit test
def __init__(self, max_qsrl_depth=256):
super().__init__()
self.max_qsrl_depth = max_qsrl_depth
......@@ -346,32 +386,7 @@ class InsertAndSetFIFODepths(Transformation):
model = model.transform(
CapConvolutionFIFODepths(max_qsrl_depth=self.max_qsrl_depth)
)
# Remove FIFOs which have depth <= 2
# TODO move this to own transformation
shallow_fifos = []
# First, bypass them
for node in model.graph.node:
if (
node.op_type == "StreamingFIFO"
and getCustomOp(node).get_nodeattr("depth") <= 2
):
shallow_fifos.append(node)
consumers = model.find_consumers(node.output[0])
if consumers is None:
producer = model.find_producer(node.input[0])
for idx, inp in enumerate(producer.output):
if inp == node.input[0]:
producer.output[idx] = node.output[0]
else:
assert len(consumers) == 1, "Fanout detected from FIFO output"
consumer = consumers[0]
# set fifo input tensor as new input tensor of second node
for idx, inp in enumerate(consumer.input):
if inp == node.output[0]:
consumer.input[idx] = node.input[0]
# now filter out
for node_to_remove in shallow_fifos:
model.graph.node.remove(node_to_remove)
# remove shallow FIFOs
model = model.transform(RemoveShallowFIFOs())
return (model, False)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment