diff --git a/src/finn/transformation/fpgadataflow/create_dataflow_partition.py b/src/finn/transformation/fpgadataflow/create_dataflow_partition.py index 7592c5f5f8acd85f8d406585dab23f7ad76eec9b..9e4637976e0ed0a1badfea766120d580cf86488d 100644 --- a/src/finn/transformation/fpgadataflow/create_dataflow_partition.py +++ b/src/finn/transformation/fpgadataflow/create_dataflow_partition.py @@ -30,6 +30,7 @@ from finn.core.modelwrapper import ModelWrapper from finn.custom_op.registry import getCustomOp from finn.transformation.base import Transformation from finn.transformation.create_generic_partitions import PartitionFromLambda +from finn.transformation.fpgadataflow.externalize_params import ExternalizeParams from finn.util.basic import get_by_name @@ -46,16 +47,15 @@ class CreateDataflowPartition(Transformation): def apply(self, model): def filter_fc_extw(x): - if x.op_type == "StreamingFCLayer_Batch": - mem_mode = get_by_name(x.attribute, "mem_mode") - if mem_mode is not None: - mem_mode = mem_mode.s.decode("UTF-8") - return mem_mode == "external" + if x.op_type == "IODMA": + burst_mode = get_by_name(x.attribute, "burstMode") + if burst_mode is not None: + burst_mode = burst_mode.s.decode("UTF-8") + return burst_mode == "wrap" - fc_extw_nodes = filter(filter_fc_extw, model.graph.node) - assert ( - len(list(fc_extw_nodes)) == 0 - ), "FIXME make external FC weight tensors into top-level graph inputs" + extw_dma_nodes = list(filter(filter_fc_extw, model.graph.node)) + if len(extw_dma_nodes) > 0: + model = model.transform(ExternalizeParams()) def assign_partition_id(node): if node.op_type in ["GenericPartition", "StreamingDataflowPartition"]: