diff --git a/src/finn/transformation/fpgadataflow/insert_fifo.py b/src/finn/transformation/fpgadataflow/insert_fifo.py index f66d0dc087ecbdd112422484ee1e01cb5ceef1c0..9850ad7e2cf8065646d49f9cc272ab0c3aaae86d 100644 --- a/src/finn/transformation/fpgadataflow/insert_fifo.py +++ b/src/finn/transformation/fpgadataflow/insert_fifo.py @@ -4,6 +4,7 @@ from onnx import helper as oh from finn.custom_op.registry import getCustomOp from finn.transformation import Transformation from finn.util.fpgadataflow import is_fpgadataflow_node +import numpy as np def _is_fifo_node(node): @@ -26,6 +27,14 @@ def _suitable_node(node): return False +def _suitable_folded_shapes(ishape, oshape): + i_dummy = np.random.rand(*ishape) + o_dummy = np.random.rand(*oshape) + ishape_canonical = np.squeeze(i_dummy).shape + oshape_canonical = np.squeeze(o_dummy).shape + return ishape_canonical == oshape_canonical + + class InsertFIFO(Transformation): """Inserting FIFOs in the beginning and end of the graph as well as between fpgadataflow nodes. @@ -59,8 +68,9 @@ class InsertFIFO(Transformation): # check if folded_shape of output of first node and # input of the second node is equal n1 = getCustomOp(consumer) - assert ( - fld_shape == n1.get_folded_input_shape() + fld_shape_2 = n1.get_folded_input_shape() + assert _suitable_folded_shapes( + fld_shape, fld_shape_2 ), """The folded output shape of the first node is not the same as the folded output shape of the second node. A streaming fifo can't