diff --git a/src/finn/transformation/fpgadataflow/set_fifo_depths.py b/src/finn/transformation/fpgadataflow/set_fifo_depths.py
index dccd97020b49cf54d3fb35edc1f7ba3e65301dbe..9b882eeac00d600e24f9099245af7c3da99fd1a5 100644
--- a/src/finn/transformation/fpgadataflow/set_fifo_depths.py
+++ b/src/finn/transformation/fpgadataflow/set_fifo_depths.py
@@ -39,7 +39,6 @@ from qonnx.transformation.general import (
     GiveUniqueNodeNames,
     SortGraph,
 )
-from qonnx.util.basic import get_by_name
 
 from finn.analysis.fpgadataflow.dataflow_performance import dataflow_performance
 from finn.transformation.fpgadataflow.annotate_cycles import AnnotateCycles
@@ -427,32 +426,80 @@ class SplitLargeFifos(Transformation):
     """Split FIFOs with a depth larger than 32768 into smaller ones
     to ensure that they can be correctly generated."""
 
+    def __init__(
+        self,
+        max_qsrl_depth=256,
+    ):
+        super().__init__()
+        self.max_qsrl_depth = max_qsrl_depth
+
+    def get_split_configs(self, depth):
+        max_size = 32768
+
+        def floor_pow2(x):
+            if (x & (x - 1) == 0) and x != 0:
+                return x
+            else:
+                return 1 << ((x - 1).bit_length() - 1)
+
+        ret = []
+        # trivial case: for small FIFOs, return as-is with rtl style
+        if depth <= self.max_qsrl_depth:
+            return [(depth, "rtl")]
+        # first pass: ensure max depth of 32k is respected
+        # (restricted by Vivado AXIS infra IP)
+
+        remainder = depth
+        while remainder != 0:
+            if remainder > max_size:
+                ret.append(max_size)
+                remainder -= max_size
+            else:
+                ret.append(remainder)
+                remainder = 0
+        # second pass: break non-power-of-2 sized FIFOs
+        # into several ones
+        ret_pass2 = []
+
+        for cand_depth in ret:
+            cand_floor_pow2 = floor_pow2(cand_depth)
+            ret_pass2.append(cand_floor_pow2)
+            if cand_floor_pow2 < cand_depth:
+                ret_pass2.append(cand_depth - cand_floor_pow2)
+        # finally, add impl_style to each split FIFO
+        ret_final = []
+        for cand_depth in ret_pass2:
+            if cand_depth <= self.max_qsrl_depth:
+                ret_final.append((cand_depth, "rtl"))
+            else:
+                ret_final.append((cand_depth, "vivado"))
+
+        return ret_final
+
     def apply(self, model):
         graph = model.graph
         node_ind = 0
         graph_modified = False
-        for n in graph.node:
+        for node in graph.node:
             node_ind += 1
-            if n.op_type == "StreamingFIFO":
-                depth = get_by_name(n.attribute, "depth")
-                if depth.i > 32768:
-                    n0 = getCustomOp(n)
-                    fld_shape = n0.get_folded_output_shape()
-                    dtype = n0.get_nodeattr("dataType")
-                    impl_style = n0.get_nodeattr("impl_style")
-                    ram_style = n0.get_nodeattr("ram_style")
-                    shape = model.get_tensor_shape(n.input[0])
-                    split_n = math.ceil(depth.i / 32768)
-                    fifo_depth = math.ceil(depth.i / split_n)
-                    for i in range(split_n):
+            if node.op_type == "StreamingFIFO":
+                n_inst = getCustomOp(node)
+                depth = n_inst.get_nodeattr("depth")
+                cfgs = self.get_split_configs(depth)
+                if len(cfgs) > 1:
+                    fld_shape = n_inst.get_folded_output_shape()
+                    dtype = n_inst.get_nodeattr("dataType")
+                    ram_style = n_inst.get_nodeattr("ram_style")
+                    shape = model.get_tensor_shape(node.input[0])
+                    for i, (fifo_depth, impl_style) in enumerate(cfgs):
                         if i == 0:
-                            inp = n.input[0]
+                            inp = node.input[0]
                         else:
-                            inp = n.name + "_" + str(i - 1) + "_out"
-                        if i == split_n - 1:
-                            outp = n.output[0]
+                            inp = node.name + "_" + str(i - 1) + "_out"
+                        if i == len(cfgs) - 1:
+                            outp = node.output[0]
                         else:
-                            outp = n.name + "_" + str(i) + "_out"
+                            outp = node.name + "_" + str(i) + "_out"
                             out_tensor = helper.make_tensor_value_info(
                                 outp, TensorProto.FLOAT, shape
                             )
@@ -469,21 +516,13 @@ class SplitLargeFifos(Transformation):
                             dataType=dtype,
                             impl_style=impl_style,
                             ram_style=ram_style,
+                            name=node.name + "_" + str(i),
                         )
                         graph.node.insert(node_ind + i, fifo_node)
 
-                    graph.node.remove(n)
-                    if n.output[0] != "global_out":
-                        consumer = model.find_consumer(n.output[0])
-                        n1 = getCustomOp(consumer)
-                        n1.set_nodeattr("outFIFODepth", fifo_depth)
-                    if n.input[0] != "global_in":
-                        producer = model.find_producer(n.input[0])
-                        n2 = getCustomOp(producer)
-                        n2.set_nodeattr("inFIFODepth", fifo_depth)
+                    graph.node.remove(node)
                     graph_modified = True
         if graph_modified:
             model = model.transform(SortGraph())
-            model = model.transform(GiveUniqueNodeNames())
             model = model.transform(GiveReadableTensorNames())
-        return (model, graph_modified)
+        return (model, False)