diff --git a/src/finn/custom_op/fpgadataflow/tlastmarker.py b/src/finn/custom_op/fpgadataflow/tlastmarker.py index b1d332cdd85373839c1912c6fdd0fe5da652df31..a2855fb00b7f2dba8c2f01d0bb362bbd31b0b614 100644 --- a/src/finn/custom_op/fpgadataflow/tlastmarker.py +++ b/src/finn/custom_op/fpgadataflow/tlastmarker.py @@ -30,24 +30,29 @@ from finn.custom_op.fpgadataflow import HLSCustomOp class TLastMarker(HLSCustomOp): - """Class that corresponds to the TLastMarker node that needs to be - inserted at the end of the model for rtlsim with stitched IP. - It marks the end of the current image/input sample.""" + """Node that adds/removes AXI stream TLAST signals where needed. Its behavior + is transparent in node-by-node execution, only visible in IP-stitched rtlsim or + actual hardware. + This node may be needed at the end of the network to signal a DMA write (needed by the + FINN PYNQ shell) or at the beginning to remove the end-of-burst from DMA read.""" def __init__(self, onnx_node): super().__init__(onnx_node) def get_nodeattr_types(self): my_attrs = { + # number of (static) iterations until TLAST=1 is generated for Direction=out "NumIters": ("i", True, 0), - "DynIters": ("s", False, "true"), - # direction + # whether static or dynamic (from AXI lite) number of iterations are used + "DynIters": ("i", False, 1), + # direction: whether to insert or remove TLAST "Direction": ("s", False, "out"), # width of input-output data streams, in bits "StreamWidth": ("i", True, 0), # width of individual element in stream, in bits "ElemWidth": ("i", True, 0), - # Protocol + # Protocol: external or kernel2kernel + # Vitis docs recommend using qdma_axis for external, ap_axiu for kernel2kernel "Protocol": ("s", False, "external"), } my_attrs.update(super().get_nodeattr_types()) @@ -88,15 +93,21 @@ class TLastMarker(HLSCustomOp): if direction == "out": if protocol == "external": out_stream_dtype = "qdma_axis<%d,0,0,0>" % stream_width - else: + elif protocol == "kernel2kernel: out_stream_dtype = "ap_axiu<%d,0,0,0>" % stream_width + else: + raise Exception("Unrecognized Protocol in TLastMarker") in_stream_dtype = "ap_uint<%d>" % stream_width - else: + elif direction == "in": out_stream_dtype = "ap_uint<%d>" % stream_width if protocol == "external": in_stream_dtype = "qdma_axis<%d,0,0,0>" % stream_width - else: + elif protocol == "kernel2kernel: in_stream_dtype = "ap_axiu<%d,0,0,0>" % stream_width + else: + raise Exception("Unrecognized Protocol in TLastMarker") + else: + raise Exception("Unrecognized Direction in TLastMarker") self.code_gen_dict["$DEFINES$"] = [ "#define StreamWidth %d" % stream_width, @@ -124,7 +135,7 @@ class TLastMarker(HLSCustomOp): "}", ] - elif dyn_iters == "true": + elif dyn_iters == 1: # output, with dynamic iteration counts self.code_gen_dict["$DOCOMPUTE$"] = [ "unsigned int n = 1;", @@ -173,7 +184,7 @@ class TLastMarker(HLSCustomOp): def blackboxfunction(self): dyn_iters = self.get_nodeattr("DynIters") - if dyn_iters == "true": + if dyn_iters == 1: self.code_gen_dict["$BLACKBOXFUNCTION$"] = [ """void %s(hls::stream<InDType> &in0, hls::stream<OutDType> &out, unsigned int numIters)""" @@ -190,7 +201,7 @@ class TLastMarker(HLSCustomOp): self.code_gen_dict["$PRAGMAS$"].append("#pragma HLS INTERFACE axis port=out") dyn_iters = self.get_nodeattr("DynIters") - if dyn_iters == "true": + if dyn_iters == 1: self.code_gen_dict["$PRAGMAS$"].append( "#pragma HLS INTERFACE s_axilite port=numIters bundle=control" )