Skip to content
Snippets Groups Projects
Commit cd72c1cc authored by Tobi-Alonso's avatar Tobi-Alonso
Browse files

Fix merge conflict in docker/finn_entrypoint.sh while merging upstream/dev

parents 1eef6625 ef3d4c9d
No related branches found
No related tags found
No related merge requests found
...@@ -15,7 +15,7 @@ gecho () { ...@@ -15,7 +15,7 @@ gecho () {
# the repos themselves are cloned in the Dockerfile # the repos themselves are cloned in the Dockerfile
BREVITAS_COMMIT=026a509186b7e7b0b65d46a2f905043d41069306 BREVITAS_COMMIT=026a509186b7e7b0b65d46a2f905043d41069306
CNPY_COMMIT=4e8810b1a8637695171ed346ce68f6984e585ef4 CNPY_COMMIT=4e8810b1a8637695171ed346ce68f6984e585ef4
HLSLIB_COMMIT=afcfe75f3404249bddeeb3f15df65bd1fcb1072e HLSLIB_COMMIT=8aed899c278c36c977a249558d71795086cf852c
PYVERILATOR_COMMIT=c97a5ba41bbc7c419d6f25c74cdf3bdc3393174f PYVERILATOR_COMMIT=c97a5ba41bbc7c419d6f25c74cdf3bdc3393174f
PYNQSHELL_COMMIT=0c82a61b0ec1a07fa275a14146233824ded7a13d PYNQSHELL_COMMIT=0c82a61b0ec1a07fa275a14146233824ded7a13d
OMX_COMMIT=1bae737669901e762f581af73348332b5c4b2ada OMX_COMMIT=1bae737669901e762f581af73348332b5c4b2ada
......
...@@ -18,6 +18,7 @@ Requirements ...@@ -18,6 +18,7 @@ Requirements
* A working Vivado 2019.1 installation * A working Vivado 2019.1 installation
* A `VIVADO_PATH` environment variable pointing to the Vivado installation directory (e.g. the directory where settings64.sh is located) * A `VIVADO_PATH` environment variable pointing to the Vivado installation directory (e.g. the directory where settings64.sh is located)
* (optional) A PYNQ board with a network connection * (optional) A PYNQ board with a network connection
* the ``bitstring`` package must be installed on the PYNQ: ``sudo pip install bitstring``
Running FINN in Docker Running FINN in Docker
====================== ======================
......
...@@ -21,6 +21,8 @@ class FMPadding_Batch(HLSCustomOp): ...@@ -21,6 +21,8 @@ class FMPadding_Batch(HLSCustomOp):
"Padding": ("i", True, 2), "Padding": ("i", True, 2),
# number of channels in input image # number of channels in input image
"NumChannels": ("i", True, 0), "NumChannels": ("i", True, 0),
# SIMD Input parallelism
"SIMD": ("i", False, 1),
# FINN input datatype # FINN input datatype
"inputDataType": ("s", True, ""), "inputDataType": ("s", True, ""),
# controls distribution of padded pixels # controls distribution of padded pixels
...@@ -55,20 +57,22 @@ class FMPadding_Batch(HLSCustomOp): ...@@ -55,20 +57,22 @@ class FMPadding_Batch(HLSCustomOp):
return oshape return oshape
def get_folded_input_shape(self): def get_folded_input_shape(self):
# even though there is no folding in the current hlslib op, normal_ishape = list(self.get_normal_input_shape())
# insert a time multiplexing axis to remain compatible with the ifm_ch = self.get_nodeattr("NumChannels")
# shapes produced by the rest of the dataflow pipeline simd = self.get_nodeattr("SIMD")
ret = list(self.get_normal_input_shape()) assert ifm_ch % simd == 0, "SIMD must divide input channels"
ret.insert(-1, 1) fold = int(normal_ishape[-1] / simd)
return tuple(ret) folded_ishape = normal_ishape[:-1] + [fold, simd]
return tuple(folded_ishape)
def get_folded_output_shape(self): def get_folded_output_shape(self):
# even though there is no folding in the current hlslib op, normal_oshape = list(self.get_normal_output_shape())
# insert a time multiplexing axis to remain compatible with the ifm_ch = self.get_nodeattr("NumChannels")
# shapes produced by the rest of the dataflow pipeline simd = self.get_nodeattr("SIMD")
ret = list(self.get_normal_output_shape()) assert ifm_ch % simd == 0, "SIMD must divide input channels"
ret.insert(-1, 1) fold = int(normal_oshape[-1] / simd)
return tuple(ret) folded_oshape = normal_oshape[:-1] + [fold, simd]
return tuple(folded_oshape)
def make_shape_compatible_op(self, model): def make_shape_compatible_op(self, model):
exp_ishape = self.get_normal_input_shape() exp_ishape = self.get_normal_input_shape()
...@@ -114,15 +118,13 @@ class FMPadding_Batch(HLSCustomOp): ...@@ -114,15 +118,13 @@ class FMPadding_Batch(HLSCustomOp):
def get_instream_width(self): def get_instream_width(self):
ibits = self.get_input_datatype().bitwidth() ibits = self.get_input_datatype().bitwidth()
num_ch = self.get_nodeattr("NumChannels") simd = self.get_nodeattr("SIMD")
return ibits * simd
return ibits * num_ch
def get_outstream_width(self): def get_outstream_width(self):
obits = self.get_output_datatype().bitwidth() obits = self.get_output_datatype().bitwidth()
num_ch = self.get_nodeattr("NumChannels") simd = self.get_nodeattr("SIMD")
return obits * simd
return obits * num_ch
def get_number_output_values(self): def get_number_output_values(self):
folded_oshape = self.get_folded_output_shape() folded_oshape = self.get_folded_output_shape()
...@@ -135,13 +137,15 @@ class FMPadding_Batch(HLSCustomOp): ...@@ -135,13 +137,15 @@ class FMPadding_Batch(HLSCustomOp):
self.code_gen_dict["$DEFINES$"] = [ self.code_gen_dict["$DEFINES$"] = [
"""#define ImgDim1 {}\n#define OutputDim1 {}\n """#define ImgDim1 {}\n#define OutputDim1 {}\n
#define Padding1 {}\n#define NumChannels1 {}\n #define Padding1 {}\n#define NumChannels1 {}\n
#define PaddingStyle1 {}\n#define numReps {}\n""".format( #define PaddingStyle1 {}\n#define numReps {}
#define SIMD1 {}\n""".format(
self.get_nodeattr("ImgDim"), self.get_nodeattr("ImgDim"),
self.get_padded_odim(), self.get_padded_odim(),
self.get_nodeattr("Padding"), self.get_nodeattr("Padding"),
self.get_nodeattr("NumChannels"), self.get_nodeattr("NumChannels"),
self.get_nodeattr("PaddingStyle"), self.get_nodeattr("PaddingStyle"),
self.get_nodeattr("numInputVectors"), self.get_nodeattr("numInputVectors"),
self.get_nodeattr("SIMD"),
) )
] ]
...@@ -176,7 +180,7 @@ class FMPadding_Batch(HLSCustomOp): ...@@ -176,7 +180,7 @@ class FMPadding_Batch(HLSCustomOp):
in_t = self.get_input_datatype().get_hls_datatype_str() in_t = self.get_input_datatype().get_hls_datatype_str()
node = self.onnx_node node = self.onnx_node
self.code_gen_dict["$DOCOMPUTE$"] = [ self.code_gen_dict["$DOCOMPUTE$"] = [
"""{}<ImgDim1, OutputDim1, Padding1, NumChannels1, """{}<ImgDim1, OutputDim1, Padding1, NumChannels1,SIMD1,
{}, PaddingStyle1> (in0, out, numReps);""".format( {}, PaddingStyle1> (in0, out, numReps);""".format(
node.op_type, in_t node.op_type, in_t
) )
...@@ -232,6 +236,7 @@ class FMPadding_Batch(HLSCustomOp): ...@@ -232,6 +236,7 @@ class FMPadding_Batch(HLSCustomOp):
node = self.onnx_node node = self.onnx_node
exp_ishape = self.get_normal_input_shape() exp_ishape = self.get_normal_input_shape()
exp_oshape = self.get_normal_output_shape() exp_oshape = self.get_normal_output_shape()
folded_ishape = self.get_folded_input_shape()
folded_oshape = self.get_folded_output_shape() folded_oshape = self.get_folded_output_shape()
if mode == "cppsim": if mode == "cppsim":
...@@ -254,10 +259,8 @@ class FMPadding_Batch(HLSCustomOp): ...@@ -254,10 +259,8 @@ class FMPadding_Batch(HLSCustomOp):
match expected shape (1, ImgDim, ImgDim, NumChannels).""" match expected shape (1, ImgDim, ImgDim, NumChannels)."""
export_idt = self.get_input_datatype() export_idt = self.get_input_datatype()
# no reshaping for input since assuming no folding on input reshaped_input = inp.reshape(folded_ishape)
# make copy before saving array np.save(os.path.join(code_gen_dir, "input_0.npy"), reshaped_input)
inp = inp.copy()
np.save(os.path.join(code_gen_dir, "input_0.npy"), inp)
if mode == "cppsim": if mode == "cppsim":
# execute the precompiled model # execute the precompiled model
......
...@@ -30,20 +30,30 @@ from finn.custom_op.fpgadataflow import HLSCustomOp ...@@ -30,20 +30,30 @@ from finn.custom_op.fpgadataflow import HLSCustomOp
class TLastMarker(HLSCustomOp): class TLastMarker(HLSCustomOp):
"""Class that corresponds to the TLastMarker node that needs to be """Node that adds/removes AXI stream TLAST signals where needed. Its behavior
inserted at the end of the model for rtlsim with stitched IP. is transparent in node-by-node execution, only visible in IP-stitched rtlsim or
It marks the end of the current image/input sample.""" actual hardware.
This node may be needed at the end of the network to signal a DMA write (needed by the
FINN PYNQ shell) or at the beginning to remove the end-of-burst from DMA read."""
def __init__(self, onnx_node): def __init__(self, onnx_node):
super().__init__(onnx_node) super().__init__(onnx_node)
def get_nodeattr_types(self): def get_nodeattr_types(self):
my_attrs = { my_attrs = {
# number of (static) iterations until TLAST=1 is generated for Direction=out
"NumIters": ("i", True, 0), "NumIters": ("i", True, 0),
# whether static or dynamic (from AXI lite) number of iterations are used
"DynIters": ("i", False, 1),
# direction: whether to insert or remove TLAST
"Direction": ("s", False, "out"),
# width of input-output data streams, in bits # width of input-output data streams, in bits
"StreamWidth": ("i", True, 0), "StreamWidth": ("i", True, 0),
# width of individual element in stream, in bits # width of individual element in stream, in bits
"ElemWidth": ("i", True, 0), "ElemWidth": ("i", True, 0),
# Protocol: external or internal
# Vitis docs recommend using qdma_axis for external, ap_axiu for internal
"Protocol": ("s", False, "external"),
} }
my_attrs.update(super().get_nodeattr_types()) my_attrs.update(super().get_nodeattr_types())
return my_attrs return my_attrs
...@@ -76,12 +86,33 @@ class TLastMarker(HLSCustomOp): ...@@ -76,12 +86,33 @@ class TLastMarker(HLSCustomOp):
def defines(self, var): def defines(self, var):
stream_width = self.get_nodeattr("StreamWidth") stream_width = self.get_nodeattr("StreamWidth")
direction = self.get_nodeattr("Direction")
protocol = self.get_nodeattr("Protocol")
# output stream must have TLAST, so we use this stream data type: # output stream must have TLAST, so we use this stream data type:
# qdma_axis<stream_data_width,0,0,0 > # qdma_axis<stream_data_width,0,0,0 >
out_stream_dtype = "qdma_axis<%d,0,0,0>" % stream_width if direction == "out":
if protocol == "external":
out_stream_dtype = "qdma_axis<%d,0,0,0>" % stream_width
elif protocol == "internal":
out_stream_dtype = "ap_axiu<%d,0,0,0>" % stream_width
else:
raise Exception("Unrecognized Protocol in TLastMarker")
in_stream_dtype = "ap_uint<%d>" % stream_width
elif direction == "in":
out_stream_dtype = "ap_uint<%d>" % stream_width
if protocol == "external":
in_stream_dtype = "qdma_axis<%d,0,0,0>" % stream_width
elif protocol == "internal":
in_stream_dtype = "ap_axiu<%d,0,0,0>" % stream_width
else:
raise Exception("Unrecognized Protocol in TLastMarker")
else:
raise Exception("Unrecognized Direction in TLastMarker")
self.code_gen_dict["$DEFINES$"] = [ self.code_gen_dict["$DEFINES$"] = [
"#define StreamWidth %d" % stream_width, "#define StreamWidth %d" % stream_width,
"#define OutDType %s" % out_stream_dtype, "#define OutDType %s" % out_stream_dtype,
"#define InDType %s" % in_stream_dtype,
"#define NumItersPerImg %d" % self.get_nodeattr("NumIters"), "#define NumItersPerImg %d" % self.get_nodeattr("NumIters"),
] ]
...@@ -89,27 +120,60 @@ class TLastMarker(HLSCustomOp): ...@@ -89,27 +120,60 @@ class TLastMarker(HLSCustomOp):
self.code_gen_dict["$READNPYDATA$"] = [] self.code_gen_dict["$READNPYDATA$"] = []
def docompute(self): def docompute(self):
self.code_gen_dict["$DOCOMPUTE$"] = [ dyn_iters = self.get_nodeattr("DynIters")
"unsigned int n = 1;", direction = self.get_nodeattr("Direction")
"OutDType t;", use_qdma_axis = self.get_nodeattr("Protocol") == "external"
"t.set_keep(-1);", if direction == "in":
"io_section: { // start of cycle accurate region", # read from input and just pass data along; ignore tlast
"#pragma HLS protocol fixed", # no dyn iters on input, it doesnt make sense
"// do a first read from stream before we decide on numIters", self.code_gen_dict["$DOCOMPUTE$"] = [
"// giving software a chance to set up the numIters prior to startup", "for(unsigned int i=0; i<NumItersPerImg; i++) {",
"t.set_data(in0.read());", "#pragma HLS PIPELINE II=1",
"n = (numIters == 0 ? NumItersPerImg : numIters);", "out.write(in0.read().get_data());"
"t.set_last(n==1);", if use_qdma_axis
"out.write(t);", else "out.write(in0.read().data);",
"} // end of cycle accurate region", "}",
"// do one less iteration than spec since we already did one", ]
"for(unsigned int i=1; i<n; i++) {",
"#pragma HLS PIPELINE II=1", elif dyn_iters == 1:
"t.set_data(in0.read());", # output, with dynamic iteration counts
"t.set_last(i==(n-1));", self.code_gen_dict["$DOCOMPUTE$"] = [
"out.write(t);", "unsigned int n = 1;",
"}", "OutDType t;",
] "t.set_keep(-1);" if use_qdma_axis else "t.keep = -1;",
"io_section: { // start of cycle accurate region",
"#pragma HLS protocol fixed",
"// do a first read from stream before we decide on numIters",
"// giving software a chance to set up the numIters prior to startup",
"t.set_data(in0.read());" if use_qdma_axis else "t.data = in0.read();",
"n = (numIters == 0 ? NumItersPerImg : numIters);",
"t.set_last(n==1);" if use_qdma_axis else "t.last = (n==1);",
"out.write(t);",
"} // end of cycle accurate region",
"// do one less iteration than spec since we already did one",
"for(unsigned int i=1; i<n; i++) {",
"#pragma HLS PIPELINE II=1",
"t.set_data(in0.read());" if use_qdma_axis else "t.data = in0.read();",
"t.set_last(i==(n-1));" if use_qdma_axis else "t.last = (i==(n-1));",
"out.write(t);",
"}",
]
else:
# output, with static iteration counts
self.code_gen_dict["$DOCOMPUTE$"] = [
"unsigned int n = 1;",
"OutDType t;",
"t.set_keep(-1);" if use_qdma_axis else "t.keep = -1;",
"for(unsigned int i=0; i<NumItersPerImg; i++) {",
"#pragma HLS PIPELINE II=1",
"t.set_data(in0.read());" if use_qdma_axis else "t.data = in0.read();",
"t.set_last(i==(NumItersPerImg-1));"
if use_qdma_axis
else "t.last = (i==(NumItersPerImg-1));",
"out.write(t);",
"}",
]
def dataoutstrm(self): def dataoutstrm(self):
self.code_gen_dict["$DATAOUTSTREAM$"] = [] self.code_gen_dict["$DATAOUTSTREAM$"] = []
...@@ -118,18 +182,30 @@ class TLastMarker(HLSCustomOp): ...@@ -118,18 +182,30 @@ class TLastMarker(HLSCustomOp):
self.code_gen_dict["$SAVEASCNPY$"] = [] self.code_gen_dict["$SAVEASCNPY$"] = []
def blackboxfunction(self): def blackboxfunction(self):
self.code_gen_dict["$BLACKBOXFUNCTION$"] = [ dyn_iters = self.get_nodeattr("DynIters")
"""void %s(hls::stream<ap_uint<StreamWidth> > &in0,
hls::stream<OutDType> &out, unsigned int numIters)""" if dyn_iters == 1:
% self.onnx_node.name self.code_gen_dict["$BLACKBOXFUNCTION$"] = [
] """void %s(hls::stream<InDType> &in0,
hls::stream<OutDType> &out, unsigned int numIters)"""
% self.onnx_node.name
]
else:
self.code_gen_dict["$BLACKBOXFUNCTION$"] = [
"""void %s(hls::stream<InDType> &in0, hls::stream<OutDType> &out)"""
% self.onnx_node.name
]
def pragmas(self): def pragmas(self):
self.code_gen_dict["$PRAGMAS$"] = ["#pragma HLS INTERFACE axis port=in0"] self.code_gen_dict["$PRAGMAS$"] = ["#pragma HLS INTERFACE axis port=in0"]
self.code_gen_dict["$PRAGMAS$"].append("#pragma HLS INTERFACE axis port=out") self.code_gen_dict["$PRAGMAS$"].append("#pragma HLS INTERFACE axis port=out")
self.code_gen_dict["$PRAGMAS$"].append(
"#pragma HLS INTERFACE s_axilite port=numIters bundle=control" dyn_iters = self.get_nodeattr("DynIters")
) if dyn_iters == 1:
self.code_gen_dict["$PRAGMAS$"].append(
"#pragma HLS INTERFACE s_axilite port=numIters bundle=control"
)
self.code_gen_dict["$PRAGMAS$"].append( self.code_gen_dict["$PRAGMAS$"].append(
"#pragma HLS INTERFACE ap_ctrl_none port=return" "#pragma HLS INTERFACE ap_ctrl_none port=return"
) )
...@@ -158,7 +234,7 @@ class TLastMarker(HLSCustomOp): ...@@ -158,7 +234,7 @@ class TLastMarker(HLSCustomOp):
def strm_decl(self): def strm_decl(self):
self.code_gen_dict["$STREAMDECLARATIONS$"] = [] self.code_gen_dict["$STREAMDECLARATIONS$"] = []
self.code_gen_dict["$STREAMDECLARATIONS$"].append( self.code_gen_dict["$STREAMDECLARATIONS$"].append(
'hls::stream<ap_uint<{}>> in0 ("in0");'.format(self.get_instream_width()) 'hls::stream<InDType> in0 ("in0");'
) )
self.code_gen_dict["$STREAMDECLARATIONS$"].append( self.code_gen_dict["$STREAMDECLARATIONS$"].append(
'hls::stream<OutDType> out ("out");' 'hls::stream<OutDType> out ("out");'
......
...@@ -44,7 +44,7 @@ from finn.custom_op.fpgadataflow.streamingdatawidthconverter_batch import ( ...@@ -44,7 +44,7 @@ from finn.custom_op.fpgadataflow.streamingdatawidthconverter_batch import (
StreamingDataWidthConverter_Batch, StreamingDataWidthConverter_Batch,
) )
from finn.custom_op.fpgadataflow.globalaccpool_batch import GlobalAccPool_Batch from finn.custom_op.fpgadataflow.globalaccpool_batch import GlobalAccPool_Batch
from finn.custom_op.fpgadataflow.fmpadding import FMPadding_Batch from finn.custom_op.fpgadataflow.fmpadding_batch import FMPadding_Batch
from finn.custom_op.fpgadataflow.thresholding_batch import Thresholding_Batch from finn.custom_op.fpgadataflow.thresholding_batch import Thresholding_Batch
from finn.custom_op.fpgadataflow.addstreams_batch import AddStreams_Batch from finn.custom_op.fpgadataflow.addstreams_batch import AddStreams_Batch
from finn.custom_op.fpgadataflow.labelselect_batch import LabelSelect_Batch from finn.custom_op.fpgadataflow.labelselect_batch import LabelSelect_Batch
......
...@@ -31,23 +31,34 @@ from onnx import helper as oh ...@@ -31,23 +31,34 @@ from onnx import helper as oh
from finn.custom_op.registry import getCustomOp from finn.custom_op.registry import getCustomOp
from finn.transformation import Transformation from finn.transformation import Transformation
from finn.util.basic import get_by_name
import numpy as np
class InsertTLastMarker(Transformation): class InsertTLastMarker(Transformation):
"""Ensure that the graph is terminated with a TLastMarker node, inserting """Ensure that the graph is started/terminated with a TLastMarker node, inserting
one if necessary.""" one if necessary. Use constructor args to determine type of TLastMarker to be inserted.
More information available on the TLastMarker documentation.
"""
def __init__(self): def __init__(self, both=False, external=True, dynamic=True):
super().__init__() super().__init__()
self.dyniters = dynamic
self.external = external
self.both = both
def apply(self, model): def apply(self, model):
# TODO only makes sense for a pure fpgadataflow graph -- check! # TODO only makes sense for a pure fpgadataflow graph -- check!
graph_out_name = model.graph.output[0].name graph_out_name = model.graph.output[0].name
final_node = model.find_producer(graph_out_name) final_node = model.find_producer(graph_out_name)
if final_node.op_type == "TLastMarker": graph_modified = False
# TODO maybe check the correctness of properties if final_node.op_type != "TLastMarker" and not (
return (model, False) final_node.op_type == "IODMA"
else: and get_by_name(final_node.attribute, "direction").s.decode("UTF-8")
== "out"
):
custom_op = getCustomOp(final_node) custom_op = getCustomOp(final_node)
num_iters = int(custom_op.get_number_output_values()) num_iters = int(custom_op.get_number_output_values())
stream_width = int(custom_op.get_outstream_width()) stream_width = int(custom_op.get_outstream_width())
...@@ -69,8 +80,51 @@ class InsertTLastMarker(Transformation): ...@@ -69,8 +80,51 @@ class InsertTLastMarker(Transformation):
NumIters=num_iters, NumIters=num_iters,
StreamWidth=stream_width, StreamWidth=stream_width,
ElemWidth=elem_width, ElemWidth=elem_width,
DynIters=(1 if self.dyniters else 0),
Direction="out",
Protocol=("external" if self.external else "internal"),
domain="finn", domain="finn",
backend="fpgadataflow", backend="fpgadataflow",
) )
model.graph.node.append(tlast_node) model.graph.node.append(tlast_node)
return (model, True) graph_modified = True
# if both is True, also insert marker on input
if self.both:
graph_in_name = model.graph.input[0].name
first_node = model.find_consumer(graph_in_name)
if first_node.op_type != "TLastMarker" and not (
first_node.op_type == "IODMA"
and get_by_name(first_node.attribute, "direction").s.decode("UTF-8")
== "in"
):
custom_op = getCustomOp(first_node)
num_iters = np.prod(custom_op.get_folded_input_shape()[1:-1])
stream_width = int(custom_op.get_instream_width())
in_shape = model.get_tensor_shape(graph_in_name)
in_dtype = model.get_tensor_datatype(graph_in_name)
elem_width = in_dtype.bitwidth()
# make new buffer
first_node_in = oh.make_tensor_value_info(
model.make_new_valueinfo_name(), TensorProto.FLOAT, in_shape
)
model.graph.value_info.append(first_node_in)
model.set_tensor_datatype(first_node_in.name, in_dtype)
# reroute final node output to first_node_in_name
first_node.input[0] = first_node_in.name
tlast_node = oh.make_node(
"TLastMarker",
[graph_in_name],
[first_node_in.name],
NumIters=num_iters,
StreamWidth=stream_width,
ElemWidth=elem_width,
DynIters=(1 if self.dyniters else 0),
Direction="in",
Protocol=("external" if self.external else "internal"),
domain="finn",
backend="fpgadataflow",
)
model.graph.node.insert(0, tlast_node)
graph_modified = True
return (model, graph_modified)
...@@ -23,7 +23,7 @@ test_fpga_part = pynq_part_map[test_pynq_board] ...@@ -23,7 +23,7 @@ test_fpga_part = pynq_part_map[test_pynq_board]
target_clk_ns = 10 target_clk_ns = 10
def make_single_fmpadding_modelwrapper(idim, padding, num_ch, idt, pad_style): def make_single_fmpadding_modelwrapper(idim, padding, num_ch, simd, idt, pad_style):
assert pad_style == 2, "only pad_style == 2 supported in hlslib" assert pad_style == 2, "only pad_style == 2 supported in hlslib"
assert padding > 0, "Output dim should be greater than input dim" assert padding > 0, "Output dim should be greater than input dim"
odim = idim + padding odim = idim + padding
...@@ -47,6 +47,7 @@ def make_single_fmpadding_modelwrapper(idim, padding, num_ch, idt, pad_style): ...@@ -47,6 +47,7 @@ def make_single_fmpadding_modelwrapper(idim, padding, num_ch, idt, pad_style):
inputDataType=str(idt.name), inputDataType=str(idt.name),
PaddingStyle=pad_style, PaddingStyle=pad_style,
numInputVectors=1, numInputVectors=1,
SIMD=simd,
) )
graph = helper.make_graph( graph = helper.make_graph(
...@@ -63,11 +64,13 @@ def make_single_fmpadding_modelwrapper(idim, padding, num_ch, idt, pad_style): ...@@ -63,11 +64,13 @@ def make_single_fmpadding_modelwrapper(idim, padding, num_ch, idt, pad_style):
# input image dimension # input image dimension
@pytest.mark.parametrize("idim", [8, 16]) @pytest.mark.parametrize("idim", [8])
# number of rows and number of cols to add # number of rows and number of cols to add
@pytest.mark.parametrize("pad", [2, 3]) @pytest.mark.parametrize("pad", [2, 3])
# number of channels # number of channels
@pytest.mark.parametrize("num_ch", [1, 2]) @pytest.mark.parametrize("num_ch", [2, 4])
# Input parallelism
@pytest.mark.parametrize("simd", [1, 2])
# PaddingStyle: selects behavior when (odim-idim)%2 != 0 # PaddingStyle: selects behavior when (odim-idim)%2 != 0
@pytest.mark.parametrize("pad_style", [2]) @pytest.mark.parametrize("pad_style", [2])
# FINN input datatype # FINN input datatype
...@@ -76,14 +79,15 @@ def make_single_fmpadding_modelwrapper(idim, padding, num_ch, idt, pad_style): ...@@ -76,14 +79,15 @@ def make_single_fmpadding_modelwrapper(idim, padding, num_ch, idt, pad_style):
@pytest.mark.parametrize("mode", ["cppsim", "rtlsim"]) @pytest.mark.parametrize("mode", ["cppsim", "rtlsim"])
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.vivado @pytest.mark.vivado
def test_fpgadataflow_fmpadding(idim, pad, num_ch, pad_style, idt, mode): def test_fpgadataflow_fmpadding(idim, pad, num_ch, simd, pad_style, idt, mode):
if num_ch % simd != 0:
pytest.skip(" num_ch % simd != 0, skipping")
# generate input data # generate input data
x = gen_finn_dt_tensor(idt, [1, idim, idim, num_ch]) x = gen_finn_dt_tensor(idt, [1, idim, idim, num_ch])
input_dict = {"inp": x} input_dict = {"inp": x}
odim = idim + pad odim = idim + pad
model = make_single_fmpadding_modelwrapper(idim, pad, num_ch, idt, pad_style) model = make_single_fmpadding_modelwrapper(idim, pad, num_ch, simd, idt, pad_style)
model = model.transform(InferShapes()) model = model.transform(InferShapes())
model = model.transform(SetExecMode(mode)) model = model.transform(SetExecMode(mode))
model = model.transform(GiveUniqueNodeNames()) model = model.transform(GiveUniqueNodeNames())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment