diff --git a/custom_hls/checksum.hpp b/custom_hls/checksum.hpp index bf580f31a6228ffd446221ff5c7cd5f29e439837..bb92027d9e3a466c9e82cfa57071ffe73b0ed1c3 100644 --- a/custom_hls/checksum.hpp +++ b/custom_hls/checksum.hpp @@ -74,6 +74,7 @@ void checksum( hls::stream<T> &src, hls::stream<T> &dst, ap_uint<32> &chk, + ap_uint<1> drain, // drain data after checksuming without forward to `dst` F&& f = F() ) { ap_uint<2> coeff[3] = { 1, 2, 3 }; @@ -84,7 +85,7 @@ void checksum( T const x = src.read(); // Pass-thru copy - dst.write(x); + if(!drain) dst.write(x); // Actual checksum update for(unsigned j = 0; j < K; j++) { @@ -118,14 +119,16 @@ void checksum( void checksum_ ## WORDS_PER_FRAME ## _ ## WORD_SIZE ## _ ## ITEMS_PER_WORD ( \ hls::stream<T> &src, \ hls::stream<T> &dst, \ - ap_uint<32> &chk \ + ap_uint<32> &chk, \ + ap_uint< 1> drain \ ) { \ _Pragma("HLS interface port=src axis") \ _Pragma("HLS interface port=dst axis") \ _Pragma("HLS interface port=chk s_axilite") \ + _Pragma("HLS interface port=drain s_axilite") \ _Pragma("HLS interface port=return ap_ctrl_none") \ - _Pragma("HLS dataflow") \ - checksum<WORDS_PER_FRAME, ITEMS_PER_WORD>(src, dst, chk); \ + _Pragma("HLS dataflow disable_start_propagation") \ + checksum<WORDS_PER_FRAME, ITEMS_PER_WORD>(src, dst, chk, drain); \ } #define CHECKSUM_TOP(WORDS_PER_FRAME, WORD_SIZE, ITEMS_PER_WORD) \ CHECKSUM_TOP_(WORDS_PER_FRAME, WORD_SIZE, ITEMS_PER_WORD) diff --git a/src/finn/custom_op/fpgadataflow/checksum.py b/src/finn/custom_op/fpgadataflow/checksum.py index a17e2418ca80f4118733c7e0f2245a5bc6ef72a0..89cce74795fd26c214e789bcde45a1c4c5cc3148 100644 --- a/src/finn/custom_op/fpgadataflow/checksum.py +++ b/src/finn/custom_op/fpgadataflow/checksum.py @@ -254,10 +254,12 @@ class CheckSum(HLSCustomOp): 'hls::stream<ap_uint<{}>> out ("out");'.format(self.get_outstream_width()) ) self.code_gen_dict["$STREAMDECLARATIONS$"].append("ap_uint<32> chk;") + # set drain = false for cppsim + self.code_gen_dict["$STREAMDECLARATIONS$"].append("ap_uint<1> drain = false;") def docompute(self): self.code_gen_dict["$DOCOMPUTE$"] = [ - """checksum<WORDS_PER_FRAME, ITEMS_PER_WORD>(in0, out, chk);""" + """checksum<WORDS_PER_FRAME, ITEMS_PER_WORD>(in0, out, chk, drain);""" ] def dataoutstrm(self): @@ -298,7 +300,7 @@ class CheckSum(HLSCustomOp): def blackboxfunction(self): self.code_gen_dict["$BLACKBOXFUNCTION$"] = [ """using T = ap_uint<WORD_SIZE>;\n void {}(hls::stream<T> &in0, - hls::stream<T> &out, ap_uint<32> &chk)""".format( + hls::stream<T> &out, ap_uint<32> &chk, ap_uint<1> &drain)""".format( self.onnx_node.name ) ] @@ -313,10 +315,16 @@ class CheckSum(HLSCustomOp): self.code_gen_dict["$PRAGMAS$"].append( "#pragma HLS interface s_axilite port=chk bundle=checksum" ) + self.code_gen_dict["$PRAGMAS$"].append( + "#pragma HLS interface s_axilite port=drain bundle=checksum" + ) self.code_gen_dict["$PRAGMAS$"].append( "#pragma HLS interface ap_ctrl_none port=return" ) self.code_gen_dict["$PRAGMAS$"].append("#pragma HLS dataflow") + self.code_gen_dict["$PRAGMAS$"].append( + "#pragma HLS dataflow disable_start_propagation" + ) def get_verilog_top_module_intf_names(self): intf_names = super().get_verilog_top_module_intf_names() diff --git a/tests/fpgadataflow/test_fpgadataflow_checksum.py b/tests/fpgadataflow/test_fpgadataflow_checksum.py index 62d73a9b0fa5771e6ed8b58ada2cfb9d399bb49d..871a315fe51d0a00b6e42db1f13831189c68e0bb 100644 --- a/tests/fpgadataflow/test_fpgadataflow_checksum.py +++ b/tests/fpgadataflow/test_fpgadataflow_checksum.py @@ -30,7 +30,7 @@ import pytest import numpy as np from onnx import TensorProto, helper -from pyverilator.util.axi_utils import axilite_read +from pyverilator.util.axi_utils import axilite_read, axilite_write from qonnx.core.datatype import DataType from qonnx.core.modelwrapper import ModelWrapper from qonnx.custom_op.registry import getCustomOp @@ -181,16 +181,30 @@ def test_fpgadataflow_checksum(): # define function to read out the checksums from axilite checksums = [] + drain = [] - def read_checksum(sim): - addr = 16 + def read_checksum_and_drain(sim): + chk_addr = 16 + drain_addr = 32 for i in range(len(model.get_nodes_by_op_type("CheckSum"))): axi_name = "s_axi_checksum_{}_".format(i) - checksums.append(axilite_read(sim, addr, basename=axi_name)) + checksums.append(axilite_read(sim, chk_addr, basename=axi_name)) + drain.append(axilite_read(sim, drain_addr, basename=axi_name)) - rtlsim_exec(model, inp, post_hook=read_checksum) + drain_value = False + + def write_drain(sim): + addr = 32 + for i in range(len(model.get_nodes_by_op_type("CheckSum"))): + axi_name = "s_axi_checksum_{}_".format(i) + axilite_write(sim, addr, drain_value, basename=axi_name) + + rtlsim_exec(model, inp, pre_hook=write_drain, post_hook=read_checksum_and_drain) checksum0_rtlsim = int(checksums[0]) checksum1_rtlsim = int(checksums[1]) + checksum0_drain = int(drain[0]) + checksum1_drain = int(drain[1]) + assert ( checksum0_rtlsim == checksum0_cppsim ), """The first checksums do not @@ -199,3 +213,12 @@ def test_fpgadataflow_checksum(): checksum1_rtlsim == checksum1_cppsim ), """The second checksums do not match in cppsim vs. rtlsim""" + + assert ( + checksum0_drain == 0 + ), "Drain read doesn't match drain write for first checksum" + assert ( + checksum1_drain == 0 + ), "Drain read doesn't match drain write for second checksum" + + # TODO: test for drain set to true