diff --git a/src/finn/builder/build_dataflow_steps.py b/src/finn/builder/build_dataflow_steps.py
index b19f43c9f5810db11090630bc6d865747d582fdf..c7eaf95bd0725eac087029d24f69a6aa20fac3bc 100644
--- a/src/finn/builder/build_dataflow_steps.py
+++ b/src/finn/builder/build_dataflow_steps.py
@@ -213,7 +213,12 @@ def step_create_dataflow_partition(model: ModelWrapper, cfg: DataflowBuildConfig
     nodes, which point to a separate ONNX file. Dataflow accelerator synthesis
     can only be performed on those HLSCustomOp sub-graphs."""
 
-    parent_model = model.transform(CreateDataflowPartition())
+    parent_model = model.transform(
+        CreateDataflowPartition(
+            partition_model_dir=cfg.output_dir
+            + "/intermediate_models/dataflow_partitions"
+        )
+    )
     sdp_nodes = parent_model.get_nodes_by_op_type("StreamingDataflowPartition")
     assert len(sdp_nodes) == 1, "Only a single StreamingDataflowPartition supported."
     sdp_node = sdp_nodes[0]
diff --git a/src/finn/transformation/fpgadataflow/create_dataflow_partition.py b/src/finn/transformation/fpgadataflow/create_dataflow_partition.py
index 0aba60f9b6f08210c40f305694495b77f517f323..7592c5f5f8acd85f8d406585dab23f7ad76eec9b 100644
--- a/src/finn/transformation/fpgadataflow/create_dataflow_partition.py
+++ b/src/finn/transformation/fpgadataflow/create_dataflow_partition.py
@@ -26,12 +26,11 @@
 # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
-import copy
-from onnx import helper
-
+from finn.core.modelwrapper import ModelWrapper
 from finn.custom_op.registry import getCustomOp
 from finn.transformation.base import Transformation
-from finn.util.basic import get_by_name, make_build_dir
+from finn.transformation.create_generic_partitions import PartitionFromLambda
+from finn.util.basic import get_by_name
 
 
 class CreateDataflowPartition(Transformation):
@@ -41,120 +40,74 @@ class CreateDataflowPartition(Transformation):
     that indicates the filename for the second graph that only contains
     dataflow nodes. No action is taken if there are no dataflow nodes."""
 
-    def __init__(self):
+    def __init__(self, partition_model_dir="dataflow_partition"):
         super().__init__()
+        self.partition_model_dir = partition_model_dir
 
     def apply(self, model):
-        target_partition_id = 0
-        # we currently assume that all dataflow nodes belonging to the same partition
-        # are connected to each other and there is a single input/output to/from each.
-        # NOTE: all dataflow nodes with no partition_id set are moved to partition 0
-        # TODO: check the assumption and/or improve this.
-        while True:
-            all_nodes = list(model.graph.node)
-            df_nodes = filter(
-                lambda x: get_by_name(x.attribute, "backend") is not None, all_nodes
-            )
-            df_nodes = filter(
-                lambda x: get_by_name(x.attribute, "backend").s.decode("UTF-8")
-                == "fpgadataflow"
-                and (
-                    get_by_name(x.attribute, "partition_id") is None
-                    or get_by_name(x.attribute, "partition_id").i == target_partition_id
-                )
-                and x.op_type != "StreamingDataflowPartition",
-                df_nodes,
-            )
-            df_nodes = list(df_nodes)
-            non_df_nodes = filter(lambda x: x not in df_nodes, all_nodes)
-            non_df_nodes = list(non_df_nodes)
-
-            if len(df_nodes) == 0:
-                # no changes if no dataflow nodes are present
-                break
-            else:
-                # partition the model into two models
-                df_model = copy.deepcopy(model)
-                non_df_model = model
-                # remove all non-dataflow nodes from the dataflow model
-                for node_to_remove in non_df_nodes:
-                    df_model.graph.node.remove(node_to_remove)
-                # identify the entry and exit points for the dataflow part
-                df_in = df_model.graph.node[0].input[0]
-                df_out = df_model.graph.node[-1].output[0]
-                df_in_vi = df_model.get_tensor_valueinfo(df_in)
-                df_out_vi = df_model.get_tensor_valueinfo(df_out)
-                # set df graph in/out to be df_in/df_out
-                df_model.graph.input.remove(df_model.graph.input[0])
-                df_model.graph.input.insert(0, df_in_vi)
-                df_model.graph.output.remove(df_model.graph.output[0])
-                df_model.graph.output.insert(0, df_out_vi)
-                # parse StreamingFCLayers looking for external weight memories
-                fc_extw_nodes = filter(
-                    lambda x: x.op_type == "StreamingFCLayer_Batch"
-                    and get_by_name(x.attribute, "mem_mode") is not None
-                    and get_by_name(x.attribute, "mem_mode").s.decode("UTF-8")
-                    == "external",
-                    df_nodes,
-                )
-                fc_extw_nodes = list(fc_extw_nodes)
-                extra_df_inputs = []
+        def filter_fc_extw(x):
+            if x.op_type == "StreamingFCLayer_Batch":
+                mem_mode = get_by_name(x.attribute, "mem_mode")
+                if mem_mode is not None:
+                    mem_mode = mem_mode.s.decode("UTF-8")
+                    return mem_mode == "external"
 
-                for i in range(len(fc_extw_nodes)):
-                    fc_weight_vi = df_model.get_tensor_valueinfo(
-                        fc_extw_nodes[i].input[1]
-                    )
-                    df_model.graph.input.insert(i + 1, fc_weight_vi)
-                    extra_df_inputs.append(fc_extw_nodes[i].input[1])
+        fc_extw_nodes = filter(filter_fc_extw, model.graph.node)
+        assert (
+            len(list(fc_extw_nodes)) == 0
+        ), "FIXME make external FC weight tensors into top-level graph inputs"
 
-                # save model
-                df_model_dir = make_build_dir(
-                    "dataflow_partition" + str(target_partition_id) + "_"
-                )
-                df_model_filename = df_model_dir + "/df_model.onnx"
-                df_model.cleanup()
-                df_model.save(df_model_filename)
-                # remove all dataflow nodes from the non-dataflow model
-                # keep track of where the dataflow part starts
-                df_start_ind = all_nodes.index(df_nodes[0])
-
-                # get and check floorplan
-                inst = getCustomOp(df_nodes[0])
-                slr = inst.get_nodeattr("slr")
-                for node in df_nodes[1:]:
-                    inst = getCustomOp(node)
-                    assert slr == inst.get_nodeattr(
-                        "slr"
-                    ), """all nodes with
-                same partition_id must have the same slr id"""
-
-                # check that there is only one non-null mem_port per partition
-                nmemports = 0
-                mem_port = ""
-                for node in df_nodes:
-                    inst = getCustomOp(node)
-                    port = inst.get_nodeattr("mem_port")
-                    if port is not None and port != "":
-                        nmemports += 1
-                        mem_port = port
-                assert nmemports <= 1, """too many memory ports per partition"""
+        def assign_partition_id(node):
+            if node.op_type in ["GenericPartition", "StreamingDataflowPartition"]:
+                return -1
+            else:
+                backend = get_by_name(node.attribute, "backend")
+                if backend is not None and backend.s.decode("UTF-8") == "fpgadataflow":
+                    assigned_partition = get_by_name(node.attribute, "partition_id")
+                    if assigned_partition is not None:
+                        return assigned_partition.i
+                    else:
+                        return 0
+                else:
+                    return -1
 
-                for node_to_remove in df_nodes:
-                    non_df_model.graph.node.remove(node_to_remove)
-                # create StreamingDataflow node with df_in/df_out io
-                df_node = helper.make_node(
-                    "StreamingDataflowPartition",
-                    [df_in] + extra_df_inputs,
-                    [df_out],
-                    # use the model attribute to mark the df model
-                    model=df_model_filename,
-                    domain="finn.custom_op.fpgadataflow",
-                    partition_id=target_partition_id,
-                    slr=slr,
-                    mem_port=mem_port,
-                )
-                non_df_model.graph.node.insert(df_start_ind, df_node)
-                model = non_df_model
-                target_partition_id += 1
+        # first, use the generic partitioning functionality to split up the graph
+        parent_model = model.transform(
+            PartitionFromLambda(
+                partitioning=assign_partition_id, partition_dir=self.partition_model_dir
+            )
+        )
+        # change node types to StreamingDataflowPartition
+        p_nodes = parent_model.get_nodes_by_op_type("GenericPartition")
+        for partition_ind, p_node in enumerate(p_nodes):
+            # go into partition to extract some info
+            p_node_inst = getCustomOp(p_node)
+            node_model_filename = p_node_inst.get_nodeattr("model")
+            p_model = ModelWrapper(node_model_filename)
+            # check floorplan (SLR assignment per node)
+            inst = getCustomOp(p_model.graph.node[0])
+            slr = inst.get_nodeattr("slr")
+            for node in p_model.graph.node:
+                inst = getCustomOp(node)
+                assert slr == inst.get_nodeattr(
+                    "slr"
+                ), """all nodes with same partition_id must have the same slr id"""
+            # check that there is only one non-null mem_port per partition
+            nmemports = 0
+            mem_port = ""
+            for node in p_model.graph.node:
+                inst = getCustomOp(node)
+                port = inst.get_nodeattr("mem_port")
+                if port is not None and port != "":
+                    nmemports += 1
+                    mem_port = port
+            assert nmemports <= 1, """Too many memory ports per partition"""
+            # done, change node type and add info in parent graph
+            p_node.op_type = "StreamingDataflowPartition"
+            p_node.domain = "finn.custom_op.fpgadataflow"
+            new_p_node_inst = getCustomOp(p_node)
+            new_p_node_inst.set_nodeattr("partition_id", partition_ind)
+            new_p_node_inst.set_nodeattr("slr", slr)
+            new_p_node_inst.set_nodeattr("mem_port", mem_port)
 
-        return (model, False)
+        return (parent_model, False)