From ce13a866c1ea82029fa8cd2a85a99987e01a8168 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <yamanu@xilinx.com>
Date: Wed, 22 Sep 2021 10:23:04 +0200
Subject: [PATCH] [Transform] call ExternalizeParams in CreateDataflowPartition

---
 .../fpgadataflow/create_dataflow_partition.py  | 18 +++++++++---------
 1 file changed, 9 insertions(+), 9 deletions(-)

diff --git a/src/finn/transformation/fpgadataflow/create_dataflow_partition.py b/src/finn/transformation/fpgadataflow/create_dataflow_partition.py
index 7592c5f5f..9e4637976 100644
--- a/src/finn/transformation/fpgadataflow/create_dataflow_partition.py
+++ b/src/finn/transformation/fpgadataflow/create_dataflow_partition.py
@@ -30,6 +30,7 @@ from finn.core.modelwrapper import ModelWrapper
 from finn.custom_op.registry import getCustomOp
 from finn.transformation.base import Transformation
 from finn.transformation.create_generic_partitions import PartitionFromLambda
+from finn.transformation.fpgadataflow.externalize_params import ExternalizeParams
 from finn.util.basic import get_by_name
 
 
@@ -46,16 +47,15 @@ class CreateDataflowPartition(Transformation):
 
     def apply(self, model):
         def filter_fc_extw(x):
-            if x.op_type == "StreamingFCLayer_Batch":
-                mem_mode = get_by_name(x.attribute, "mem_mode")
-                if mem_mode is not None:
-                    mem_mode = mem_mode.s.decode("UTF-8")
-                    return mem_mode == "external"
+            if x.op_type == "IODMA":
+                burst_mode = get_by_name(x.attribute, "burstMode")
+                if burst_mode is not None:
+                    burst_mode = burst_mode.s.decode("UTF-8")
+                    return burst_mode == "wrap"
 
-        fc_extw_nodes = filter(filter_fc_extw, model.graph.node)
-        assert (
-            len(list(fc_extw_nodes)) == 0
-        ), "FIXME make external FC weight tensors into top-level graph inputs"
+        extw_dma_nodes = list(filter(filter_fc_extw, model.graph.node))
+        if len(extw_dma_nodes) > 0:
+            model = model.transform(ExternalizeParams())
 
         def assign_partition_id(node):
             if node.op_type in ["GenericPartition", "StreamingDataflowPartition"]:
-- 
GitLab