diff --git a/custom_hls/checksum.hpp b/custom_hls/checksum.hpp
index bf580f31a6228ffd446221ff5c7cd5f29e439837..bb92027d9e3a466c9e82cfa57071ffe73b0ed1c3 100644
--- a/custom_hls/checksum.hpp
+++ b/custom_hls/checksum.hpp
@@ -74,6 +74,7 @@ void checksum(
 	hls::stream<T> &src,
 	hls::stream<T> &dst,
 	ap_uint<32>    &chk,
+	ap_uint<1>     drain,	// drain data after checksuming without forward to `dst`
 	F&& f = F()
 ) {
 	ap_uint<2>  coeff[3] = { 1, 2, 3 };
@@ -84,7 +85,7 @@ void checksum(
 		T const  x = src.read();
 
 		// Pass-thru copy
-		dst.write(x);
+		if(!drain)  dst.write(x);
 
 		// Actual checksum update
 		for(unsigned  j = 0; j < K; j++) {
@@ -118,14 +119,16 @@ void checksum(
 	void checksum_ ## WORDS_PER_FRAME ## _ ## WORD_SIZE ## _ ## ITEMS_PER_WORD ( \
 		hls::stream<T> &src, \
 		hls::stream<T> &dst, \
-		ap_uint<32>    &chk \
+		ap_uint<32>    &chk, \
+		ap_uint< 1>    drain \
 	) { \
 	_Pragma("HLS interface port=src axis") \
 	_Pragma("HLS interface port=dst axis") \
 	_Pragma("HLS interface port=chk s_axilite") \
+	_Pragma("HLS interface port=drain s_axilite") \
 	_Pragma("HLS interface port=return ap_ctrl_none") \
-	_Pragma("HLS dataflow") \
-		checksum<WORDS_PER_FRAME, ITEMS_PER_WORD>(src, dst, chk); \
+	_Pragma("HLS dataflow disable_start_propagation") \
+		checksum<WORDS_PER_FRAME, ITEMS_PER_WORD>(src, dst, chk, drain); \
 	}
 #define CHECKSUM_TOP(WORDS_PER_FRAME, WORD_SIZE, ITEMS_PER_WORD) \
 	CHECKSUM_TOP_(WORDS_PER_FRAME, WORD_SIZE, ITEMS_PER_WORD)
diff --git a/src/finn/custom_op/fpgadataflow/checksum.py b/src/finn/custom_op/fpgadataflow/checksum.py
index a17e2418ca80f4118733c7e0f2245a5bc6ef72a0..89cce74795fd26c214e789bcde45a1c4c5cc3148 100644
--- a/src/finn/custom_op/fpgadataflow/checksum.py
+++ b/src/finn/custom_op/fpgadataflow/checksum.py
@@ -254,10 +254,12 @@ class CheckSum(HLSCustomOp):
             'hls::stream<ap_uint<{}>> out ("out");'.format(self.get_outstream_width())
         )
         self.code_gen_dict["$STREAMDECLARATIONS$"].append("ap_uint<32> chk;")
+        # set drain = false for cppsim
+        self.code_gen_dict["$STREAMDECLARATIONS$"].append("ap_uint<1> drain = false;")
 
     def docompute(self):
         self.code_gen_dict["$DOCOMPUTE$"] = [
-            """checksum<WORDS_PER_FRAME, ITEMS_PER_WORD>(in0, out, chk);"""
+            """checksum<WORDS_PER_FRAME, ITEMS_PER_WORD>(in0, out, chk, drain);"""
         ]
 
     def dataoutstrm(self):
@@ -298,7 +300,7 @@ class CheckSum(HLSCustomOp):
     def blackboxfunction(self):
         self.code_gen_dict["$BLACKBOXFUNCTION$"] = [
             """using T = ap_uint<WORD_SIZE>;\n void {}(hls::stream<T> &in0,
-            hls::stream<T> &out, ap_uint<32> &chk)""".format(
+            hls::stream<T> &out, ap_uint<32> &chk, ap_uint<1> &drain)""".format(
                 self.onnx_node.name
             )
         ]
@@ -313,10 +315,16 @@ class CheckSum(HLSCustomOp):
         self.code_gen_dict["$PRAGMAS$"].append(
             "#pragma HLS interface s_axilite port=chk bundle=checksum"
         )
+        self.code_gen_dict["$PRAGMAS$"].append(
+            "#pragma HLS interface s_axilite port=drain bundle=checksum"
+        )
         self.code_gen_dict["$PRAGMAS$"].append(
             "#pragma HLS interface ap_ctrl_none port=return"
         )
         self.code_gen_dict["$PRAGMAS$"].append("#pragma HLS dataflow")
+        self.code_gen_dict["$PRAGMAS$"].append(
+            "#pragma HLS dataflow disable_start_propagation"
+        )
 
     def get_verilog_top_module_intf_names(self):
         intf_names = super().get_verilog_top_module_intf_names()
diff --git a/tests/fpgadataflow/test_fpgadataflow_checksum.py b/tests/fpgadataflow/test_fpgadataflow_checksum.py
index 62d73a9b0fa5771e6ed8b58ada2cfb9d399bb49d..871a315fe51d0a00b6e42db1f13831189c68e0bb 100644
--- a/tests/fpgadataflow/test_fpgadataflow_checksum.py
+++ b/tests/fpgadataflow/test_fpgadataflow_checksum.py
@@ -30,7 +30,7 @@ import pytest
 
 import numpy as np
 from onnx import TensorProto, helper
-from pyverilator.util.axi_utils import axilite_read
+from pyverilator.util.axi_utils import axilite_read, axilite_write
 from qonnx.core.datatype import DataType
 from qonnx.core.modelwrapper import ModelWrapper
 from qonnx.custom_op.registry import getCustomOp
@@ -181,16 +181,30 @@ def test_fpgadataflow_checksum():
 
     # define function to read out the checksums from axilite
     checksums = []
+    drain = []
 
-    def read_checksum(sim):
-        addr = 16
+    def read_checksum_and_drain(sim):
+        chk_addr = 16
+        drain_addr = 32
         for i in range(len(model.get_nodes_by_op_type("CheckSum"))):
             axi_name = "s_axi_checksum_{}_".format(i)
-            checksums.append(axilite_read(sim, addr, basename=axi_name))
+            checksums.append(axilite_read(sim, chk_addr, basename=axi_name))
+            drain.append(axilite_read(sim, drain_addr, basename=axi_name))
 
-    rtlsim_exec(model, inp, post_hook=read_checksum)
+    drain_value = False
+
+    def write_drain(sim):
+        addr = 32
+        for i in range(len(model.get_nodes_by_op_type("CheckSum"))):
+            axi_name = "s_axi_checksum_{}_".format(i)
+            axilite_write(sim, addr, drain_value, basename=axi_name)
+
+    rtlsim_exec(model, inp, pre_hook=write_drain, post_hook=read_checksum_and_drain)
     checksum0_rtlsim = int(checksums[0])
     checksum1_rtlsim = int(checksums[1])
+    checksum0_drain = int(drain[0])
+    checksum1_drain = int(drain[1])
+
     assert (
         checksum0_rtlsim == checksum0_cppsim
     ), """The first checksums do not
@@ -199,3 +213,12 @@ def test_fpgadataflow_checksum():
         checksum1_rtlsim == checksum1_cppsim
     ), """The second checksums do not
         match in cppsim vs. rtlsim"""
+
+    assert (
+        checksum0_drain == 0
+    ), "Drain read doesn't match drain write for first checksum"
+    assert (
+        checksum1_drain == 0
+    ), "Drain read doesn't match drain write for second checksum"
+
+    # TODO: test for drain set to true