diff --git a/src/finn/transformation/fpgadataflow/set_fifo_depths.py b/src/finn/transformation/fpgadataflow/set_fifo_depths.py index bc795aa922595a6c3fecb00844201a210cf1c89c..713148d7fcdfea4411554b6d3b817a14b33a53c6 100644 --- a/src/finn/transformation/fpgadataflow/set_fifo_depths.py +++ b/src/finn/transformation/fpgadataflow/set_fifo_depths.py @@ -77,6 +77,44 @@ def optimize_depth(depth): return int(2 ** math.ceil(math.log2(depth))) +class RemoveShallowFIFOs(Transformation): + """Remove small FIFOs as the streaming components have depth-2 FIFOs on the + input/outputs by default.""" + + # TODO add unit test + + def __init__(self, shallow_threshold=2): + self.shallow_threshold = shallow_threshold + + def apply(self, model): + shallow_fifos = [] + for node in model.graph.node: + if ( + node.op_type == "StreamingFIFO" + and getCustomOp(node).get_nodeattr("depth") <= self.shallow_threshold + ): + # bypass shallow fifos + shallow_fifos.append(node) + consumers = model.find_consumers(node.output[0]) + if consumers is None: + producer = model.find_producer(node.input[0]) + for idx, inp in enumerate(producer.output): + if inp == node.input[0]: + producer.output[idx] = node.output[0] + else: + assert len(consumers) == 1, "Fanout detected from FIFO output" + consumer = consumers[0] + # set fifo input tensor as new input tensor of second node + for idx, inp in enumerate(consumer.input): + if inp == node.output[0]: + consumer.input[idx] = node.input[0] + # now filter out + for node_to_remove in shallow_fifos: + model.graph.node.remove(node_to_remove) + + return (model, False) + + class CapConvolutionFIFODepths(Transformation): """Make the size of FIFOs for convolution layers smaller where possible. Will be automatically called from InsertAndSetFIFODepths if the appropriate @@ -101,6 +139,8 @@ class CapConvolutionFIFODepths(Transformation): than 1 row. """ + # TODO add unit test + def __init__(self, max_qsrl_depth=256): super().__init__() self.max_qsrl_depth = max_qsrl_depth @@ -346,32 +386,7 @@ class InsertAndSetFIFODepths(Transformation): model = model.transform( CapConvolutionFIFODepths(max_qsrl_depth=self.max_qsrl_depth) ) - - # Remove FIFOs which have depth <= 2 - # TODO move this to own transformation - shallow_fifos = [] - # First, bypass them - for node in model.graph.node: - if ( - node.op_type == "StreamingFIFO" - and getCustomOp(node).get_nodeattr("depth") <= 2 - ): - shallow_fifos.append(node) - consumers = model.find_consumers(node.output[0]) - if consumers is None: - producer = model.find_producer(node.input[0]) - for idx, inp in enumerate(producer.output): - if inp == node.input[0]: - producer.output[idx] = node.output[0] - else: - assert len(consumers) == 1, "Fanout detected from FIFO output" - consumer = consumers[0] - # set fifo input tensor as new input tensor of second node - for idx, inp in enumerate(consumer.input): - if inp == node.output[0]: - consumer.input[idx] = node.input[0] - # now filter out - for node_to_remove in shallow_fifos: - model.graph.node.remove(node_to_remove) + # remove shallow FIFOs + model = model.transform(RemoveShallowFIFOs()) return (model, False)