Skip to content
Snippets Groups Projects
Commit 3addec34 authored by mmrahorovic's avatar mmrahorovic
Browse files

Merge remote-tracking branch 'origin/feature/pe_maxpool' into feature/maxpool_ceil

parents 57b047a8 7b970fa8
No related branches found
No related tags found
No related merge requests found
......@@ -32,7 +32,6 @@ import warnings
from finn.core.datatype import DataType
from finn.custom_op.fpgadataflow.hlscustomop import HLSCustomOp
from finn.custom_op.general.im2col import compute_conv_output_dim
from finn.util.data_packing import npy_to_rtlsim_input, rtlsim_output_to_npy
......@@ -44,6 +43,7 @@ class StreamingMaxPool_Batch(HLSCustomOp):
"ImgDim": ("ints", True, []), # [H, W] = [Y, X]
"PoolDim": ("ints", True, []), # [H, W] = [Y, X]
"NumChannels": ("i", True, 0),
"PE": ("i", True, 0),
# FINN DataTypes for inputs/outputs
"dataType": ("s", True, ""),
}
......@@ -82,24 +82,29 @@ class StreamingMaxPool_Batch(HLSCustomOp):
return ishape
def get_folded_input_shape(self):
# even though there is no folding in the current hlslib op,
# insert a time multiplexing axis to remain compatible with the
# shapes produced by the rest of the dataflow pipeline
ret = list(self.get_normal_input_shape())
ret.insert(-1, 1)
return tuple(ret)
ifm_dim_h, ifm_dim_w = self.get_nodeattr("ImgDim")
ifm_ch = self.get_nodeattr("NumChannels")
pe = self.get_nodeattr("PE")
nf = int(ifm_ch / pe)
if self.is_1d():
folded_ishape = (1, ifm_dim_h, ifm_dim_w, nf, pe)
else:
folded_ishape = (1, ifm_dim_h, ifm_dim_w, 1, ifm_ch)
return folded_ishape
def get_normal_output_shape(self):
ifm_dim_h, ifm_dim_w = self.get_nodeattr("ImgDim")
k_h, k_w = tuple(self.get_nodeattr("PoolDim"))
ifm_ch = self.get_nodeattr("NumChannels")
stride_h = k_h
stride_w = k_w
pad = 0
assert ifm_dim_h % k_h == 0, "StreamingMaxPool needs ImgDim_h % PoolDim_h == 0"
assert ifm_dim_w % k_w == 0, "StreamingMaxPool needs ImgDim_w % PoolDim_w == 0"
ofm_dim_h = compute_conv_output_dim(ifm_dim_h, k_h, stride_h, pad)
ofm_dim_w = compute_conv_output_dim(ifm_dim_w, k_w, stride_w, pad)
if not self.is_1d():
assert (
ifm_dim_h % k_h == 0
), "StreamingMaxPool needs ImgDim_h % PoolDim_h == 0"
assert (
ifm_dim_w % k_w == 0
), "StreamingMaxPool needs ImgDim_w % PoolDim_w == 0"
ofm_dim_h = int(np.floor(ifm_dim_h / k_w))
ofm_dim_w = int(np.floor(ifm_dim_w / k_w))
oshape = (1, ofm_dim_h, ofm_dim_w, ifm_ch)
return oshape
......@@ -107,8 +112,15 @@ class StreamingMaxPool_Batch(HLSCustomOp):
# even though there is no folding in the current hlslib op,
# insert a time multiplexing axis to remain compatible with the
# shapes produced by the rest of the dataflow pipeline
ifm_ch = self.get_nodeattr("NumChannels")
pe = self.get_nodeattr("PE")
nf = int(ifm_ch / pe)
ret = list(self.get_normal_output_shape())
ret.insert(-1, 1)
if self.is_1d():
ret[-1] = nf
ret.append(pe)
else:
ret.insert(-1, 1)
return tuple(ret)
def get_number_output_values(self):
......@@ -118,20 +130,27 @@ class StreamingMaxPool_Batch(HLSCustomOp):
def get_exp_cycles(self):
# derived from StreamingMaxPool_Batch loop nest
ifm_dim, k, ifm_ch = self.get_1d_attrs_normalized()
_, _, ofm_dim_w, nf, _ = self.get_folded_output_shape()
if self.is_1d():
return int(ifm_dim[1] + k[1])
exp_cycles = ofm_dim_w * nf * (k[1] + 1)
return int(exp_cycles)
else:
# TODO: adjust inaccurate formula
return int(ifm_dim[1] * (ifm_dim[1] + (ifm_dim[1] / k[1])))
def get_instream_width(self):
dt_bits = self.get_input_datatype().bitwidth()
pe = self.get_nodeattr("PE")
ifm_ch = self.get_nodeattr("NumChannels")
in_width = int(dt_bits * ifm_ch)
if self.is_1d():
in_width = int(dt_bits * pe)
else:
in_width = int(dt_bits * ifm_ch)
return in_width
def get_outstream_width(self):
"""For streaming maxpool out stream with is the same as in stream width"""
"""For streaming maxpool out stream width is the same as in stream width"""
return self.get_instream_width()
def make_shape_compatible_op(self, model):
......@@ -179,15 +198,27 @@ class StreamingMaxPool_Batch(HLSCustomOp):
numReps = 1
ifm_dim, k, ifm_ch = self.get_1d_attrs_normalized()
self.code_gen_dict["$DEFINES$"] = [
"""#define ImgDim {}\n #define PoolDim {}\n
#define NumChannels {}\n #define numReps {}""".format(
ifm_dim[1],
k[1],
self.get_nodeattr("NumChannels"),
numReps,
)
]
if self.is_1d():
self.code_gen_dict["$DEFINES$"] = [
"""#define ImgDim {}\n #define PoolDim {}\n
#define NumChannels {}\n #define PE {}\n #define numReps {}""".format(
ifm_dim[1],
k[1],
self.get_nodeattr("NumChannels"),
self.get_nodeattr("PE"),
numReps,
)
]
else:
self.code_gen_dict["$DEFINES$"] = [
"""#define ImgDim {}\n #define PoolDim {}\n
#define NumChannels {}\n #define numReps {}""".format(
ifm_dim[1],
k[1],
self.get_nodeattr("NumChannels"),
numReps,
)
]
def read_npy_data(self):
code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim")
......@@ -227,17 +258,21 @@ class StreamingMaxPool_Batch(HLSCustomOp):
"%s<ImgDim, PoolDim, NumChannels>(in0, out);" % (op)
]
else:
dtype = self.get_input_datatype()
dtype_hls = dtype.get_hls_datatype_str()
minval_str = str(int(dtype.min()))
if self.is_1d():
op = "StreamingMaxPool_Precision_1d"
self.code_gen_dict["$DOCOMPUTE$"] = [
"%s<ImgDim, PoolDim, NumChannels, PE, %s, %s>(in0, out);"
% (op, dtype_hls, minval_str)
]
else:
op = "StreamingMaxPool_Precision"
dtype = self.get_input_datatype()
dtype_hls = dtype.get_hls_datatype_str()
minval_str = str(int(dtype.min()))
self.code_gen_dict["$DOCOMPUTE$"] = [
"%s<ImgDim, PoolDim, NumChannels, %s, %s>(in0, out);"
% (op, dtype_hls, minval_str)
]
self.code_gen_dict["$DOCOMPUTE$"] = [
"%s<ImgDim, PoolDim, NumChannels, %s, %s>(in0, out);"
% (op, dtype_hls, minval_str)
]
def dataoutstrm(self):
code_gen_dir = self.get_nodeattr("code_gen_dir_cppsim")
......@@ -293,6 +328,7 @@ class StreamingMaxPool_Batch(HLSCustomOp):
node = self.onnx_node
exp_ishape = self.get_normal_input_shape()
exp_oshape = self.get_normal_output_shape()
folded_ishape = self.get_folded_input_shape()
folded_oshape = self.get_folded_output_shape()
# TODO ensure codegen dir exists
......@@ -320,9 +356,8 @@ class StreamingMaxPool_Batch(HLSCustomOp):
export_idt = DataType["BINARY"]
else:
export_idt = self.get_input_datatype()
# no reshaping for input since assuming no folding on input
# make copy before saving array
reshaped_input = inp.copy()
reshaped_input = inp.reshape(folded_ishape)
np.save(os.path.join(code_gen_dir, "input_0.npy"), reshaped_input)
if mode == "cppsim":
......@@ -333,7 +368,7 @@ class StreamingMaxPool_Batch(HLSCustomOp):
assert (
context[node.output[0]].shape == folded_oshape
), "cppsim \
did not produce expected ofolded utput shape"
did not produce expected folded output shape"
context[node.output[0]] = context[node.output[0]].reshape(*exp_oshape)
elif mode == "rtlsim":
sim = self.get_rtlsim()
......@@ -371,4 +406,4 @@ class StreamingMaxPool_Batch(HLSCustomOp):
assert (
context[node.output[0]].shape == exp_oshape
), """Output
shape doesn't match expected shape (1, ofm_dim, ofm_dim, k*k*ifm_ch)."""
shape doesn't match expected shape (1, ofm_dim, ofm_dim, ifm_ch)."""
......@@ -362,7 +362,10 @@ class InferStreamingMaxPool(Transformation):
ifm_ch = mp_in_shape[-1]
ifm_dim_h = mp_in_shape[1]
ifm_dim_w = mp_in_shape[2]
if ifm_dim_h % k_h == 0 and ifm_dim_w % k_w == 0:
pe = 1
is_1d = (ifm_dim_h == 1 and k_h == 1) or (ifm_dim_w == 1 and k_w == 1)
is_divisable = ifm_dim_h % k_h == 0 or ifm_dim_w % k_w == 0
if is_1d or is_divisable:
# create equivalent StreamingMaxPool_Batch node
new_node = helper.make_node(
"StreamingMaxPool_Batch",
......@@ -374,6 +377,7 @@ class InferStreamingMaxPool(Transformation):
NumChannels=ifm_ch,
ImgDim=(ifm_dim_h, ifm_dim_w),
dataType=dt.name,
PE=pe,
name="StreamingMaxPool_Batch_" + n.name,
)
graph.node.insert(node_ind, new_node)
......
......@@ -81,7 +81,7 @@ def make_single_maxpoolnhwc_modelwrapper(k, ifm_ch, ifm_dim, ofm_dim, idt):
return model
def make_single_streamingmaxpool_modelwrapper(k, ifm_ch, ifm_dim, ofm_dim, idt):
def make_single_streamingmaxpool_modelwrapper(k, ifm_ch, pe, ifm_dim, ofm_dim, idt):
k_h, k_w = k
ifm_dim_h, ifm_dim_w = ifm_dim
ofm_dim_h, ofm_dim_w = ofm_dim
......@@ -101,6 +101,7 @@ def make_single_streamingmaxpool_modelwrapper(k, ifm_ch, ifm_dim, ofm_dim, idt):
backend="fpgadataflow",
PoolDim=[k_h, k_w],
NumChannels=ifm_ch,
PE=pe,
ImgDim=[ifm_dim_h, ifm_dim_w],
dataType=idt.name,
)
......@@ -131,11 +132,13 @@ def prepare_inputs(input_tensor):
@pytest.mark.parametrize("ifm_dim", [4, 8])
# input channels
@pytest.mark.parametrize("ifm_ch", [1, 3]) # 1,3
# pe
@pytest.mark.parametrize("pe", [1, 3])
# execution mode
@pytest.mark.parametrize("exec_mode", ["rtlsim", "cppsim"])
@pytest.mark.slow
@pytest.mark.vivado
def test_fpgadataflow_streamingmaxpool(idt, dim_1d, k, ifm_dim, ifm_ch, exec_mode):
def test_fpgadataflow_streamingmaxpool(idt, dim_1d, k, ifm_dim, ifm_ch, pe, exec_mode):
ifm_dim_h = ifm_dim
k_h = k
if dim_1d:
......@@ -156,6 +159,8 @@ def test_fpgadataflow_streamingmaxpool(idt, dim_1d, k, ifm_dim, ifm_ch, exec_mod
pytest.skip("Skipping binary StreamingMaxPool_1d (not implemented)")
if ifm_dim_h % k_h != 0 or ifm_dim_w % k_w != 0:
pytest.skip("Skipping StreamingMaxPool test w/ ImgDim % PoolDim != 0")
if pe > ifm_ch:
pytest.skip("SIMD cannot be larger than number of input channels")
x = gen_finn_dt_tensor(idt, (1, ifm_dim_h, ifm_dim_w, ifm_ch))
# prepare input data
......@@ -164,7 +169,9 @@ def test_fpgadataflow_streamingmaxpool(idt, dim_1d, k, ifm_dim, ifm_ch, exec_mod
golden = make_single_maxpoolnhwc_modelwrapper(k, ifm_ch, ifm_dim, ofm_dim, idt)
y_expected = oxe.execute_onnx(golden, input_dict)["outp"]
model = make_single_streamingmaxpool_modelwrapper(k, ifm_ch, ifm_dim, ofm_dim, idt)
model = make_single_streamingmaxpool_modelwrapper(
k, ifm_ch, pe, ifm_dim, ofm_dim, idt
)
if exec_mode == "cppsim":
model = model.transform(SetExecMode("cppsim"))
......@@ -173,7 +180,7 @@ def test_fpgadataflow_streamingmaxpool(idt, dim_1d, k, ifm_dim, ifm_ch, exec_mod
elif exec_mode == "rtlsim":
model = model.transform(SetExecMode("rtlsim"))
model = model.transform(GiveUniqueNodeNames())
model = model.transform(PrepareIP("xc7z020clg400-1", 5))
model = model.transform(PrepareIP("xczu3eg-sbva484-1-e", 5))
model = model.transform(HLSSynthIP())
model = model.transform(PrepareRTLSim())
else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment