Skip to content
Snippets Groups Projects
Commit fa5b11a7 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[CustomOp] change ConvInpGen layout, various shape fixes

parent 7aed59b6
No related branches found
No related tags found
No related merge requests found
......@@ -64,18 +64,44 @@ class ConvolutionInputGenerator(HLSCustomOp):
my_attrs.update(super().get_nodeattr_types())
return my_attrs
def make_shape_compatible_op(self, model):
def get_normal_input_shape(self):
ifm_dim = self.get_nodeattr("IFMDim")
ifm_ch = self.get_nodeattr("IFMChannels")
ishape = (1, ifm_dim, ifm_dim, ifm_ch)
return ishape
def get_normal_output_shape(self):
k = self.get_nodeattr("ConvKernelDim")
ifm_dim = self.get_nodeattr("IFMDim")
ifm_ch = self.get_nodeattr("IFMChannels")
stride = self.get_nodeattr("Stride")
pad = 0
exp_ishape = (1, ifm_dim, ifm_dim, ifm_ch)
ofm_dim = compute_conv_output_dim(ifm_dim, k, stride, pad)
oshape = (1, ofm_dim, ofm_dim, k * k * ifm_ch)
return oshape
def get_folded_output_shape(self):
k = self.get_nodeattr("ConvKernelDim")
ifm_dim = self.get_nodeattr("IFMDim")
ifm_ch = self.get_nodeattr("IFMChannels")
stride = self.get_nodeattr("Stride")
simd = self.get_nodeattr("SIMD")
pad = 0
ofm_dim = compute_conv_output_dim(ifm_dim, k, stride, pad)
assert k * k * ifm_ch % simd == 0, "SIMD must divide sliding window size"
wf = int(k * k * ifm_ch // simd)
folded_oshape = (1, ofm_dim, ofm_dim, wf, simd)
return folded_oshape
def make_shape_compatible_op(self, model):
exp_ishape = self.get_normal_input_shape()
oshape = self.get_normal_output_shape()
ishape = tuple(model.get_tensor_shape(self.onnx_node.input[0]))
assert ishape == exp_ishape, "Unexpect input shape for ConvInpGen."
ofm_dim = compute_conv_output_dim(ifm_dim, k, stride, pad)
# implement tensor with correct shape
values = np.random.randn(1, ofm_dim, ofm_dim, k * k * ifm_ch).astype(np.float32)
values = np.random.randn(*oshape).astype(np.float32)
return helper.make_node(
"Constant",
inputs=[],
......@@ -118,6 +144,7 @@ class ConvolutionInputGenerator(HLSCustomOp):
return self.get_nodeattr("SIMD") * ibits
def get_number_output_values(self):
# TODO this seems incorrect -- double check
k = self.get_nodeattr("ConvKernelDim")
ifm_ch = self.get_nodeattr("IFMChannels")
ofm_dim = self.get_nodeattr("OFMDim")
......@@ -128,11 +155,8 @@ class ConvolutionInputGenerator(HLSCustomOp):
def execute_node(self, context, graph):
mode = self.get_nodeattr("exec_mode")
node = self.onnx_node
k = self.get_nodeattr("ConvKernelDim")
ifm_dim = self.get_nodeattr("IFMDim")
ifm_ch = self.get_nodeattr("IFMChannels")
ofm_dim = self.get_nodeattr("OFMDim")
out_pix = ofm_dim * ofm_dim
exp_ishape = self.get_normal_input_shape()
exp_oshape = self.get_normal_output_shape()
if mode == "npysim":
idt = self.get_input_datatype()
......@@ -146,16 +170,12 @@ class ConvolutionInputGenerator(HLSCustomOp):
inp = context[node.input[0]]
assert str(inp.dtype) == "float32", "Input datatype is not float32"
assert inp.shape == (
1,
ifm_ch,
ifm_dim,
ifm_dim,
assert (
inp.shape == exp_ishape
), """Input shape doesn't
match expected shape (1, ifm_ch, ifm_dim, ifm_dim)."""
reshaped_inp = inp.transpose(0, 2, 3, 1)
match expected shape (1, ifm_dim, ifm_dim, ifm_ch)."""
# make copy before saving array
reshaped_inp = reshaped_inp.copy()
reshaped_inp = inp.copy()
np.save(os.path.join(code_gen_dir, "input_0.npy"), reshaped_inp)
# execute the precompiled model
super().exec_precompiled_singlenode_model()
......@@ -165,17 +185,8 @@ class ConvolutionInputGenerator(HLSCustomOp):
out = context[node.output[0]]
out = 2 * out - 1
context[node.output[0]] = out
assert context[node.output[0]].shape == (
1,
out_pix,
k * k,
ifm_ch,
), """Output
shape doesn't match expected shape (1, out_pix, k*k, ifm_ch)."""
# reshape output to have expected shape
context[node.output[0]] = context[node.output[0]].reshape(
1, out_pix, k * k * ifm_ch
)
context[node.output[0]] = context[node.output[0]].reshape(*exp_oshape)
elif mode == "rtlsim":
code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen")
prefixed_top_name = "%s_%s" % (node.name, node.name)
......@@ -185,7 +196,10 @@ class ConvolutionInputGenerator(HLSCustomOp):
)
if os.path.isfile(verilog_file):
inp = context[node.input[0]]
inp = inp.transpose(0, 2, 3, 1)
assert (
inp.shape == exp_ishape
), """Input shape doesn't
match expected shape (1, ifm_dim, ifm_dim, ifm_ch)."""
inp = inp.flatten()
# TODO: check how to sort inputs for multichannel inputs
......@@ -209,7 +223,7 @@ class ConvolutionInputGenerator(HLSCustomOp):
odt = self.get_output_datatype()
if odt == DataType.BIPOLAR:
output = [2 * x - 1 for x in output]
# TOOD use utils here to handle datatype intricacies...
# pyverilator interprets int2 as uint2, so output has to be corrected
elif odt == DataType.INT2:
mask = 2 ** (odt.bitwidth() - 1)
......@@ -220,9 +234,8 @@ class ConvolutionInputGenerator(HLSCustomOp):
# output_ch2 = [int(x[1:]) for x in output]
# reshape output
output = np.asarray([output], dtype=np.float32).reshape(
1, out_pix, k * k * ifm_ch
)
output = np.asarray([output], dtype=np.float32)
output = output.reshape(*exp_oshape)
context[node.output[0]] = output
else:
......@@ -237,6 +250,10 @@ class ConvolutionInputGenerator(HLSCustomOp):
mode
)
)
assert (
context[node.output[0]].shape == exp_oshape
), """Output
shape doesn't match expected shape (1, ofm_dim, ofm_dim, k*k*ifm_ch)."""
def global_includes(self):
self.code_gen_dict["$GLOBALS$"] = ['#include "slidingwindow.h"']
......@@ -306,12 +323,8 @@ class ConvolutionInputGenerator(HLSCustomOp):
elem_hls_type = dtype.get_hls_datatype_str()
npy_type = "float"
npy_out = "%s/output.npy" % code_gen_dir
ofm_dim = self.get_nodeattr("OFMDim")
out_pix = ofm_dim * ofm_dim
k = self.get_nodeattr("ConvKernelDim")
ifm_ch = self.get_nodeattr("IFMChannels")
shape = (1, out_pix, k * k, ifm_ch)
shape_cpp_str = str(shape).replace("(", "{").replace(")", "}")
oshape = self.get_folded_output_shape()
oshape_cpp_str = str(oshape).replace("(", "{").replace(")", "}")
self.code_gen_dict["$DATAOUTSTREAM$"] = [
'apintstream2npy<%s, %s, %d, %s>(out, %s, "%s");'
......@@ -320,7 +333,7 @@ class ConvolutionInputGenerator(HLSCustomOp):
elem_hls_type,
elem_bits,
npy_type,
shape_cpp_str,
oshape_cpp_str,
npy_out,
)
]
......
......@@ -28,7 +28,6 @@
import pytest
import numpy as np
from onnx import TensorProto, helper
import finn.core.onnx_exec as oxe
......@@ -43,74 +42,49 @@ from finn.transformation.general import GiveUniqueNodeNames
from finn.util.basic import gen_finn_dt_tensor
def get_im2col_indices(x_shape, k, stride):
# First figure out what the size of the output should be
N, C, H, W = x_shape
assert H == W
assert (W - k) % stride == 0
ofm_dim = int((W - k) / stride + 1)
i0 = np.repeat(np.arange(k), k)
i0 = np.tile(i0, C)
i1 = stride * np.repeat(np.arange(ofm_dim), ofm_dim)
j0 = np.tile(np.arange(k), k * C)
j1 = stride * np.tile(np.arange(ofm_dim), ofm_dim)
i = i0.reshape(-1, 1) + i1.reshape(1, -1)
j = j0.reshape(-1, 1) + j1.reshape(1, -1)
k = np.repeat(np.arange(C), k * k).reshape(-1, 1)
return (k, i, j)
def im2col_indices(x, k, stride):
""" An implementation of im2col based on some fancy indexing """
l, i, j = get_im2col_indices(x.shape, k, stride)
def make_single_im2col_modelwrapper(k, ifm_ch, ifm_dim, ofm_dim, simd, stride, idt):
odt = idt
inp = helper.make_tensor_value_info(
"inp", TensorProto.FLOAT, [1, ifm_dim, ifm_dim, ifm_ch]
)
outp = helper.make_tensor_value_info(
"outp", TensorProto.FLOAT, [1, ofm_dim, ofm_dim, k * k * ifm_ch]
)
cols = x[:, l, i, j]
C = x.shape[1]
cols = cols.transpose(1, 2, 0).reshape(k * k * C, -1)
cols = cols.transpose(1, 0)
im2col_node = helper.make_node(
"Im2Col",
["inp"],
["outp"],
domain="finn",
backend="fpgadataflow",
stride=stride,
kernel_size=k,
input_shape=str((1, ifm_dim, ifm_dim, ifm_ch)),
pad_amount=0,
pad_value=0,
)
graph = helper.make_graph(
nodes=[im2col_node], name="im2col_graph", inputs=[inp], outputs=[outp]
)
# rearranging the output so it matches with finn-hlslib function
# swapping the columns according to the input channel
# if C > 1 :
parts = {}
for ch in range(C):
parts[ch] = []
model = helper.make_model(graph, producer_name="im2col-model")
model = ModelWrapper(model)
for i in range(cols.shape[1]):
if i % C == 0:
parts[0].append(i)
elif (i + (C - 1)) % C == 0:
parts[1].append(i)
elif (i + (C - 2)) % C == 0:
parts[2].append(i)
elif (i + (C - 3)) % C == 0:
parts[3].append(i)
permutation = []
for i in parts:
for num in parts[i]:
permutation.append(num)
model.set_tensor_datatype("inp", idt)
model.set_tensor_datatype("outp", odt)
i = np.argsort(permutation)
cols = cols[:, i]
return cols
return model
def make_single_slidingwindow_modelwrapper(
k, ifm_ch, ifm_dim, ofm_dim, simd, stride, idt
):
odt = idt
out_pix = ofm_dim * ofm_dim
inp = helper.make_tensor_value_info(
"inp", TensorProto.FLOAT, [1, ifm_ch, ifm_dim, ifm_dim]
"inp", TensorProto.FLOAT, [1, ifm_dim, ifm_dim, ifm_ch]
)
outp = helper.make_tensor_value_info(
"outp", TensorProto.FLOAT, [1, out_pix, k * k * ifm_ch]
"outp", TensorProto.FLOAT, [1, ofm_dim, ofm_dim, k * k * ifm_ch]
)
SlidingWindow_node = helper.make_node(
......@@ -168,7 +142,7 @@ def test_fpgadataflow_slidingwindow(idt, k, ifm_dim, ifm_ch, stride, exec_mode):
simd = ifm_ch
ofm_dim = int(((ifm_dim - k) / stride) + 1)
x = gen_finn_dt_tensor(idt, (1, ifm_ch, ifm_dim, ifm_dim))
x = gen_finn_dt_tensor(idt, (1, ifm_dim, ifm_dim, ifm_ch))
model = make_single_slidingwindow_modelwrapper(
k, ifm_ch, ifm_dim, ofm_dim, simd, stride, idt
)
......@@ -189,8 +163,10 @@ def test_fpgadataflow_slidingwindow(idt, k, ifm_dim, ifm_ch, stride, exec_mode):
input_dict = prepare_inputs(x, idt)
# execute model
y_produced = oxe.execute_onnx(model, input_dict)["outp"]
y_expected = im2col_indices(x, k, stride)
# reshape expected output to match node output
oshape = y_produced.shape
y_expected = y_expected.reshape(oshape)
golden = make_single_im2col_modelwrapper(
k, ifm_ch, ifm_dim, ofm_dim, simd, stride, idt
)
y_expected = oxe.execute_onnx(golden, input_dict)["outp"]
if idt == DataType.BIPOLAR:
y_expected = 2 * y_expected - 1
assert (y_produced == y_expected).all()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment