From 607551b8cb0621d28347e2ed351087c6c5eefba9 Mon Sep 17 00:00:00 2001
From: auphelia <jakobapk@web.de>
Date: Mon, 27 Jun 2022 13:48:01 +0100
Subject: [PATCH] [Test] Add write/read test for checksum layer drain

---
 src/finn/custom_op/fpgadataflow/checksum.py   |  4 ++-
 .../test_fpgadataflow_checksum.py             | 33 ++++++++++++++++---
 2 files changed, 31 insertions(+), 6 deletions(-)

diff --git a/src/finn/custom_op/fpgadataflow/checksum.py b/src/finn/custom_op/fpgadataflow/checksum.py
index 9a492aa7d..89cce7479 100644
--- a/src/finn/custom_op/fpgadataflow/checksum.py
+++ b/src/finn/custom_op/fpgadataflow/checksum.py
@@ -322,7 +322,9 @@ class CheckSum(HLSCustomOp):
             "#pragma HLS interface ap_ctrl_none port=return"
         )
         self.code_gen_dict["$PRAGMAS$"].append("#pragma HLS dataflow")
-        self.code_gen_dict["$PRAGMAS$"].append("#pragma HLS dataflow disable_start_propagation")
+        self.code_gen_dict["$PRAGMAS$"].append(
+            "#pragma HLS dataflow disable_start_propagation"
+        )
 
     def get_verilog_top_module_intf_names(self):
         intf_names = super().get_verilog_top_module_intf_names()
diff --git a/tests/fpgadataflow/test_fpgadataflow_checksum.py b/tests/fpgadataflow/test_fpgadataflow_checksum.py
index 62d73a9b0..cbbb8fdcc 100644
--- a/tests/fpgadataflow/test_fpgadataflow_checksum.py
+++ b/tests/fpgadataflow/test_fpgadataflow_checksum.py
@@ -30,7 +30,7 @@ import pytest
 
 import numpy as np
 from onnx import TensorProto, helper
-from pyverilator.util.axi_utils import axilite_read
+from pyverilator.util.axi_utils import axilite_read, axilite_write
 from qonnx.core.datatype import DataType
 from qonnx.core.modelwrapper import ModelWrapper
 from qonnx.custom_op.registry import getCustomOp
@@ -181,16 +181,30 @@ def test_fpgadataflow_checksum():
 
     # define function to read out the checksums from axilite
     checksums = []
+    drain = []
 
-    def read_checksum(sim):
-        addr = 16
+    def read_checksum_and_drain(sim):
+        chk_addr = 16
+        drain_addr = 32
         for i in range(len(model.get_nodes_by_op_type("CheckSum"))):
             axi_name = "s_axi_checksum_{}_".format(i)
-            checksums.append(axilite_read(sim, addr, basename=axi_name))
+            checksums.append(axilite_read(sim, chk_addr, basename=axi_name))
+            drain.append(axilite_read(sim, drain_addr, basename=axi_name))
 
-    rtlsim_exec(model, inp, post_hook=read_checksum)
+    drain_value = False
+
+    def write_drain(sim):
+        addr = 32
+        for i in range(len(model.get_nodes_by_op_type("CheckSum"))):
+            axi_name = "s_axi_checksum_{}_".format(i)
+            axilite_write(sim, addr, drain_value, basename=axi_name)
+
+    rtlsim_exec(model, inp, pre_hook=write_drain, post_hook=read_checksum_and_drain)
     checksum0_rtlsim = int(checksums[0])
     checksum1_rtlsim = int(checksums[1])
+    checksum0_drain = int(drain[0])
+    checksum1_drain = int(drain[0])
+
     assert (
         checksum0_rtlsim == checksum0_cppsim
     ), """The first checksums do not
@@ -199,3 +213,12 @@ def test_fpgadataflow_checksum():
         checksum1_rtlsim == checksum1_cppsim
     ), """The second checksums do not
         match in cppsim vs. rtlsim"""
+
+    assert (
+        checksum0_drain == 0
+    ), "Drain read doesn't match drain write for first checksum"
+    assert (
+        checksum1_drain == 0
+    ), "Drain read doesn't match drain write for second checksum"
+
+    # TODO: test for drain set to true
-- 
GitLab