From ce13a866c1ea82029fa8cd2a85a99987e01a8168 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <yamanu@xilinx.com> Date: Wed, 22 Sep 2021 10:23:04 +0200 Subject: [PATCH] [Transform] call ExternalizeParams in CreateDataflowPartition --- .../fpgadataflow/create_dataflow_partition.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/finn/transformation/fpgadataflow/create_dataflow_partition.py b/src/finn/transformation/fpgadataflow/create_dataflow_partition.py index 7592c5f5f..9e4637976 100644 --- a/src/finn/transformation/fpgadataflow/create_dataflow_partition.py +++ b/src/finn/transformation/fpgadataflow/create_dataflow_partition.py @@ -30,6 +30,7 @@ from finn.core.modelwrapper import ModelWrapper from finn.custom_op.registry import getCustomOp from finn.transformation.base import Transformation from finn.transformation.create_generic_partitions import PartitionFromLambda +from finn.transformation.fpgadataflow.externalize_params import ExternalizeParams from finn.util.basic import get_by_name @@ -46,16 +47,15 @@ class CreateDataflowPartition(Transformation): def apply(self, model): def filter_fc_extw(x): - if x.op_type == "StreamingFCLayer_Batch": - mem_mode = get_by_name(x.attribute, "mem_mode") - if mem_mode is not None: - mem_mode = mem_mode.s.decode("UTF-8") - return mem_mode == "external" + if x.op_type == "IODMA": + burst_mode = get_by_name(x.attribute, "burstMode") + if burst_mode is not None: + burst_mode = burst_mode.s.decode("UTF-8") + return burst_mode == "wrap" - fc_extw_nodes = filter(filter_fc_extw, model.graph.node) - assert ( - len(list(fc_extw_nodes)) == 0 - ), "FIXME make external FC weight tensors into top-level graph inputs" + extw_dma_nodes = list(filter(filter_fc_extw, model.graph.node)) + if len(extw_dma_nodes) > 0: + model = model.transform(ExternalizeParams()) def assign_partition_id(node): if node.op_type in ["GenericPartition", "StreamingDataflowPartition"]: -- GitLab