diff --git a/src/finn/transformation/fpgadataflow/set_fifo_depths.py b/src/finn/transformation/fpgadataflow/set_fifo_depths.py
index 6d1202d8618ca43242d4d1a8b920f9c7e9d0321c..f3e6a277364745425105c2ef9741a1e921105877 100644
--- a/src/finn/transformation/fpgadataflow/set_fifo_depths.py
+++ b/src/finn/transformation/fpgadataflow/set_fifo_depths.py
@@ -422,6 +422,54 @@ class InsertAndSetFIFODepths(Transformation):
         return (model, False)
 
 
+def get_fifo_split_configs(depth, max_qsrl_depth, max_vivado_depth):
+    def floor_pow2(x):
+        if (x & (x - 1) == 0) and x != 0:
+            return x
+        else:
+            return 1 << ((x - 1).bit_length() - 1)
+
+    def decompose_pow2(x):
+        if x <= max_qsrl_depth:
+            return [x]
+        else:
+            r = floor_pow2(x)
+            if x == r:
+                return [x]
+            else:
+                return [r, *decompose_pow2(x - r)]
+
+    ret = []
+    # trivial case: for small FIFOs, return as-is with rtl style
+    if depth <= max_qsrl_depth:
+        return [(depth, "rtl")]
+    # first pass: ensure max depth is respected
+    # (restricted by Vivado AXIS infra IP)
+    remainder = depth
+    while remainder != 0:
+        if remainder > max_vivado_depth:
+            ret.append(max_vivado_depth)
+            remainder -= max_vivado_depth
+        else:
+            ret.append(remainder)
+            remainder = 0
+    # second pass: break non-power-of-2 sized FIFOs
+    # into several ones
+
+    ret_pass2 = list(map(decompose_pow2, ret))
+    ret_pass2 = [x for dec_list in ret_pass2 for x in dec_list]
+
+    # finally, add impl_style to each split FIFO
+    ret_final = []
+    for cand_depth in ret_pass2:
+        if cand_depth <= max_qsrl_depth:
+            ret_final.append((cand_depth, "rtl"))
+        else:
+            ret_final.append((cand_depth, "vivado"))
+
+    return ret_final
+
+
 class SplitLargeFIFOs(Transformation):
     """Split large FIFOs before implementation, for two reasons:
     - impl_style="vivado" supports a max depth of 32k. Any larger
@@ -438,46 +486,6 @@ class SplitLargeFIFOs(Transformation):
         self.max_qsrl_depth = max_qsrl_depth
         self.max_vivado_depth = max_vivado_depth
 
-    def get_split_configs(self, depth):
-        def floor_pow2(x):
-            if (x & (x - 1) == 0) and x != 0:
-                return x
-            else:
-                return 1 << ((x - 1).bit_length() - 1)
-
-        ret = []
-        # trivial case: for small FIFOs, return as-is with rtl style
-        if depth <= self.max_qsrl_depth:
-            return [(depth, "rtl")]
-        # first pass: ensure max depth is respected
-        # (restricted by Vivado AXIS infra IP)
-        remainder = depth
-        while remainder != 0:
-            if remainder > self.max_vivado_depth:
-                ret.append(self.max_vivado_depth)
-                remainder -= self.max_vivado_depth
-            else:
-                ret.append(remainder)
-                remainder = 0
-        # second pass: break non-power-of-2 sized FIFOs
-        # into several ones
-        ret_pass2 = []
-
-        for cand_depth in ret:
-            cand_floor_pow2 = floor_pow2(cand_depth)
-            ret_pass2.append(cand_floor_pow2)
-            if cand_floor_pow2 < cand_depth:
-                ret_pass2.append(cand_depth - cand_floor_pow2)
-        # finally, add impl_style to each split FIFO
-        ret_final = []
-        for cand_depth in ret_pass2:
-            if cand_depth <= self.max_qsrl_depth:
-                ret_final.append((cand_depth, "rtl"))
-            else:
-                ret_final.append((cand_depth, "vivado"))
-
-        return ret_final
-
     def apply(self, model):
         graph = model.graph
         node_ind = 0
@@ -487,7 +495,9 @@ class SplitLargeFIFOs(Transformation):
             if node.op_type == "StreamingFIFO":
                 n_inst = getCustomOp(node)
                 depth = n_inst.get_nodeattr("depth")
-                cfgs = self.get_split_configs(depth)
+                cfgs = get_fifo_split_configs(
+                    depth, self.max_qsrl_depth, self.max_vivado_depth
+                )
                 if len(cfgs) > 1:
                     fld_shape = n_inst.get_folded_output_shape()
                     dtype = n_inst.get_nodeattr("dataType")