diff --git a/src/finn/transformation/fpgadataflow/set_fifo_depths.py b/src/finn/transformation/fpgadataflow/set_fifo_depths.py
index 067a98f7db469251b677ced38a3e370d9493ce85..0f8880b3bba42701b2e2ec3991cded4c1c89b958 100644
--- a/src/finn/transformation/fpgadataflow/set_fifo_depths.py
+++ b/src/finn/transformation/fpgadataflow/set_fifo_depths.py
@@ -43,7 +43,7 @@ from finn.core.rtlsim_exec import (
     _reset_rtlsim,
     _toggle_clk,
 )
-from finn.util.fpgadataflow import pyverilate_stitched_ip
+from finn.util.fpgadataflow import pyverilate_stitched_ip, is_fpgadataflow_node
 
 
 def reset_implementation(node):
@@ -76,12 +76,26 @@ def optimize_depth(depth):
     return int(math.ceil(depth / 1024) * 1024)
 
 
-class SetFIFODepths(Transformation):
-    """Determines minimum depths of StreamingFIFOs through RTLSim.
-    We assume we get a dataflow partition (all nodes are dataflow, no FIFOs)
-    We set initial depths very high (16k), run sim with multiple
-    images on input (random/constant data) and keep track of maximum
-    occupancy counts in each FIFO."""
+class InsertAndSetFIFODepths(Transformation):
+    """Insert appropriate-depth StreamingFIFOs through RTLSim that preserve
+    throughput in the created accelerator.
+
+    Assumed input graph properties:
+    - all nodes are fpgadataflow nodes
+    - no FIFOs inserted,
+    - (inFIFODepth/outFIFODepth attrs will be ignored)
+
+    Output:
+    - graph with appropriate-depth FIFOs inserted
+
+    How it works:
+    - insert very deep (default 16k deep) FIFOs between all fpgadataflow nodes
+    - create stitched design
+    - run through rtlsim with stream of multiple random input images (to fill pipeline)
+    - keep track of observed maximum occupancy for each FIFO during rtlsim
+    - when sim finished, update each FIFO depth to maximum observed occupancy
+      and set inFIFODepth/outFIFODepth attrs to 0 on relevant nodes
+    """
 
     def __init__(self, fpgapart, clk_ns=10.0, max_qsrl_depth=256, max_depth=2 ** 14):
         super().__init__()
@@ -91,11 +105,15 @@ class SetFIFODepths(Transformation):
         self.max_depth = max_depth
 
     def apply(self, model):
-
         # change external to decoupled and warn user
         # this way we are sure we have exactly one input/output
         modified_fc_nodes = []
         for node in model.graph.node:
+            # verify assumptions
+            assert is_fpgadataflow_node(node), "Found non-fpgadataflow node: " + str(
+                node
+            )
+            assert node.op_type != "StreamingFIFO", "Found existing StreamingFIFO node"
             node = getCustomOp(node)
             node.set_nodeattr("inFIFODepth", self.max_depth)
             node.set_nodeattr("outFIFODepth", self.max_depth)
diff --git a/tests/end2end/test_end2end_bnn_pynq.py b/tests/end2end/test_end2end_bnn_pynq.py
index 84149c91b4096163ba2a85fb7bde8b914b0d1632..30bcddb28d1bf105ad3c8338ff32aa28a91e41f6 100644
--- a/tests/end2end/test_end2end_bnn_pynq.py
+++ b/tests/end2end/test_end2end_bnn_pynq.py
@@ -84,7 +84,7 @@ from finn.transformation.fpgadataflow.prepare_rtlsim import PrepareRTLSim
 from finn.transformation.fpgadataflow.insert_dwc import InsertDWC
 from finn.transformation.fpgadataflow.insert_fifo import InsertFIFO
 from finn.transformation.fpgadataflow.annotate_cycles import AnnotateCycles
-from finn.transformation.fpgadataflow.set_fifo_depths import SetFIFODepths
+from finn.transformation.fpgadataflow.set_fifo_depths import InsertAndSetFIFODepths
 from finn.analysis.fpgadataflow.dataflow_performance import dataflow_performance
 from finn.core.modelwrapper import ModelWrapper
 from scipy.stats import linregress
@@ -454,7 +454,7 @@ class TestEnd2End:
         prev_chkpt_name = get_checkpoint_name(topology, wbits, abits, "ipgen_" + kind)
         model = load_test_checkpoint_or_skip(prev_chkpt_name)
         test_fpga_part = get_build_env(kind, target_clk_ns)["part"]
-        model = model.transform(SetFIFODepths(test_fpga_part, target_clk_ns))
+        model = model.transform(InsertAndSetFIFODepths(test_fpga_part, target_clk_ns))
         model.save(get_checkpoint_name(topology, wbits, abits, "fifodepth_" + kind))
 
     @pytest.mark.slow