diff --git a/src/finn/custom_op/im2col.py b/src/finn/custom_op/im2col.py index 8af1227a4f4b6b1a244fcbf51e1ed71ec2749802..0cd56e755cb1782c60ab10a820d492dc309649a7 100644 --- a/src/finn/custom_op/im2col.py +++ b/src/finn/custom_op/im2col.py @@ -2,62 +2,70 @@ import numpy as np from onnx import TensorProto, helper from finn.custom_op import CustomOp +import finn.util.basic as util +from finn.core.datatype import DataType +# adapted from A. Karpathy's CS231 im2col code +# utilities to generate a patch matrix from a multichannel image +# of shape (batches, channels, height, width) -def get_im2col_indices(x_shape, k, stride): + +def compute_conv_output_dim(ifm_dim, k, stride, pad=0): + "Return spatial output dimension size for convolution with given params." + return int(((ifm_dim + 2 * pad - k) / stride) + 1) + + +def get_im2col_indices_nchw( + x_shape, field_height, field_width, padding=0, stride_y=1, stride_x=1 +): # First figure out what the size of the output should be N, C, H, W = x_shape - assert H == W - assert (W - k) % stride == 0 - ofm_dim = int((W - k) / stride + 1) + assert (H + 2 * padding - field_height) % stride_y == 0 + assert (W + 2 * padding - field_width) % stride_x == 0 + out_height = compute_conv_output_dim(H, field_height, stride_y, padding) + out_width = compute_conv_output_dim(W, field_width, stride_x, padding) - i0 = np.repeat(np.arange(k), k) + i0 = np.repeat(np.arange(field_height), field_width) i0 = np.tile(i0, C) - i1 = stride * np.repeat(np.arange(ofm_dim), ofm_dim) - j0 = np.tile(np.arange(k), k * C) - j1 = stride * np.tile(np.arange(ofm_dim), ofm_dim) + i1 = stride_y * np.repeat(np.arange(out_height), out_width) + j0 = np.tile(np.arange(field_width), field_height * C) + j1 = stride_x * np.tile(np.arange(out_width), out_height) i = i0.reshape(-1, 1) + i1.reshape(1, -1) j = j0.reshape(-1, 1) + j1.reshape(1, -1) - k = np.repeat(np.arange(C), k * k).reshape(-1, 1) + k = np.repeat(np.arange(C), field_height * field_width).reshape(-1, 1) return (k, i, j) -def im2col_indices(x, k, stride): - """ An implementation of im2col based on indexing """ +def im2col_indices_nchw( + x, field_height, field_width, padding=0, stride_y=1, stride_x=1, pad_val=0 +): + """ An implementation of im2col based on some fancy indexing """ + # Zero-pad the input + p = padding + x_padded = np.pad( + x, ((0, 0), (0, 0), (p, p), (p, p)), mode="constant", constant_values=pad_val + ) - l, i, j = get_im2col_indices(x.shape, k, stride) + k, i, j = get_im2col_indices_nchw( + x.shape, field_height, field_width, padding, stride_y, stride_x + ) - cols = x[:, l, i, j] + cols = x_padded[:, k, i, j] C = x.shape[1] - cols = cols.transpose(1, 2, 0).reshape(k * k * C, -1) - cols = cols.transpose(1, 0) - - # rearranging the output so it matches with finn-hlslib function - # swapping the columns according to the input channel - # if C > 1 : - parts = {} - for ch in range(C): - parts[ch] = [] - - for i in range(cols.shape[1]): - if i % C == 0: - parts[0].append(i) - elif (i + (C - 1)) % C == 0: - parts[1].append(i) - elif (i + (C - 2)) % C == 0: - parts[2].append(i) - elif (i + (C - 3)) % C == 0: - parts[3].append(i) - permutation = [] - for i in parts: - for num in parts[i]: - permutation.append(num) - - i = np.argsort(permutation) - cols = cols[:, i] - return cols.reshape(1, -1, k * k * C) + cols = cols.transpose(1, 2, 0).reshape(field_height * field_width * C, -1) + return cols + + +# ONNX i/o tensor shape assumptions for Im2Col: +# input 0 is the input vector, shape (1, ih, iw, ifm) +# output 0 is the output vector, shape (1, oh, ow, k*k*ifm) +# where: +# * ih, iw are the height and width of the input image +# * oh, ow are the height and width of the output (lowered) image +# * ifm is the number of input channels +# * k is the convolutional kernel size class Im2Col(CustomOp): @@ -66,12 +74,15 @@ class Im2Col(CustomOp): "stride": ("i", True, 1), "kernel_size": ("i", True, 1), "input_shape": ("s", True, ""), + "pad_amount": ("i", False, 0), + "pad_value": ("i", False, 0), } def make_shape_compatible_op(self): k = self.get_nodeattr("kernel_size") stride = self.get_nodeattr("stride") ishape = self.get_nodeattr("input_shape") + pad = self.get_nodeattr("pad_amount") # convert string into list of integers ishape = ishape.strip("(") @@ -81,13 +92,14 @@ class Im2Col(CustomOp): ishape[i] = int(ishape[i]) # extract all necessary information and determine output dimensions - ifm_ch = ishape[1] - ifm_dim = ishape[2] - ofm_dim = int(((ifm_dim - k) / stride) + 1) - outpix = ofm_dim * ofm_dim + ifm_ch = ishape[-1] + assert len(ishape) == 4, "Unexpected input shape for Im2Col" + assert ishape[1] == ishape[2], "Im2Col for non-square images unsupported" + ifm_dim = ishape[1] + ofm_dim = compute_conv_output_dim(ifm_dim, k, stride, pad) # implement tensor with correct shape - values = np.random.randn(1, outpix, k * k * ifm_ch).astype(np.float32) + values = np.random.randn(1, ofm_dim, ofm_dim, k * k * ifm_ch).astype(np.float32) return helper.make_node( "Constant", inputs=[], @@ -110,9 +122,30 @@ class Im2Col(CustomOp): node = self.onnx_node k = self.get_nodeattr("kernel_size") stride = self.get_nodeattr("stride") - x = context[node.input[0]] - output = im2col_indices(x, k, stride) - context[node.output[0]] = output + pad = self.get_nodeattr("pad_amount") + pad_val = self.get_nodeattr("pad_value") + iname = node.input[0] + x = context[iname] + qnt_annotations = graph.quantization_annotation + ret = util.get_by_name(qnt_annotations, iname, "tensor_name") + ret = util.get_by_name(ret.quant_parameter_tensor_names, "finn_datatype", "key") + idt = DataType[ret.value] + if pad != 0: + assert idt.allowed(pad_val), "Im2Col dtype must allow pad_val" + # check that input is NHWC + assert x.ndim == 4, "Unexpected number of input dims for Im2Col" + N, H, W, C = x.shape + assert H == W, "Unexpected input shape for Im2Col" + out_dim = compute_conv_output_dim(H, k, stride, pad) + # internally convert input to NCHW + x = x.transpose(0, 3, 1, 2) + # call NCHW im2col implementation + ret = im2col_indices_nchw(x, k, k, pad, stride, stride, pad_val=pad_val) + # result shape is (k*k*N, out_dim*out_dim), convert to NCHW + ret = ret.reshape(N, k * k * N, out_dim, out_dim) + # convert output back to NHWC + ret = ret.transpose(0, 2, 3, 1) + context[node.output[0]] = ret def verify_node(self): node = self.onnx_node diff --git a/tests/custom_op/test_im2col.py b/tests/custom_op/test_im2col.py index 4db3e4250e59c476e0fad17cc717263425b79700..6ed67d67a497ee6c17e1611aa2faf10a48dbd154 100644 --- a/tests/custom_op/test_im2col.py +++ b/tests/custom_op/test_im2col.py @@ -24,14 +24,13 @@ def check_two_dict_for_equality(dict1, dict2): def execution_im2col(x, idt, k, stride, ifm_ch, ifm_dim): ofm_dim = int(((ifm_dim - k) / stride) + 1) - out_pix = ofm_dim * ofm_dim # set up onnx model inp = helper.make_tensor_value_info( - "inp", TensorProto.FLOAT, [1, ifm_ch, ifm_dim, ifm_dim] + "inp", TensorProto.FLOAT, [1, ifm_dim, ifm_dim, ifm_ch] ) outp = helper.make_tensor_value_info( - "outp", TensorProto.FLOAT, [1, out_pix, k * k * ifm_ch] + "outp", TensorProto.FLOAT, [1, ofm_dim, ofm_dim, k * k * ifm_ch] ) Im2Col_node = helper.make_node( @@ -41,7 +40,7 @@ def execution_im2col(x, idt, k, stride, ifm_ch, ifm_dim): domain="finn", stride=stride, kernel_size=k, - input_shape="(1,{},{},{})".format(ifm_ch, ifm_dim, ifm_dim), + input_shape="(1,{},{},{})".format(ifm_dim, ifm_dim, ifm_ch), ) graph = helper.make_graph( @@ -55,7 +54,7 @@ def execution_im2col(x, idt, k, stride, ifm_ch, ifm_dim): # test shape inference model.transform(InferShapes()) - assert model.get_tensor_shape("outp") == [1, out_pix, k * k * ifm_ch] + assert model.get_tensor_shape("outp") == [1, ofm_dim, ofm_dim, k * k * ifm_ch] # test datatype inference assert model.get_tensor_datatype("outp") is DataType.FLOAT32 @@ -94,7 +93,6 @@ def test_im2col(): ifm_ch = 1 ifm_dim = 4 ofm_dim = int(((ifm_dim - k) / stride) + 1) - out_pix = ofm_dim * ofm_dim x = np.asarray( [ @@ -116,7 +114,7 @@ def test_im2col(): 1.0, ], dtype=np.float32, - ).reshape(1, ifm_ch, ifm_dim, ifm_dim) + ).reshape(1, ifm_dim, ifm_dim, ifm_ch) expected = np.asarray( [ @@ -158,629 +156,7 @@ def test_im2col(): 1.0, ], dtype=np.float32, - ).reshape(1, out_pix, k * k * ifm_ch) - - produced = execution_im2col(x, idt, k, stride, ifm_ch, ifm_dim) - assert (produced == expected).all() - - # bipolar inputs with following im2col parameters - idt = DataType.BIPOLAR - k = 3 - stride = 1 - ifm_ch = 1 - ifm_dim = 4 - ofm_dim = int(((ifm_dim - k) / stride) + 1) - out_pix = ofm_dim * ofm_dim - - expected = np.asarray( - [ - -1.0, - -1.0, - 1.0, - 1.0, - -1.0, - 1.0, - -1.0, - 1.0, - -1.0, - -1.0, - 1.0, - 1.0, - -1.0, - 1.0, - -1.0, - 1.0, - -1.0, - -1.0, - 1.0, - -1.0, - 1.0, - -1.0, - 1.0, - -1.0, - 1.0, - 1.0, - 1.0, - -1.0, - 1.0, - -1.0, - 1.0, - -1.0, - -1.0, - 1.0, - 1.0, - 1.0, - ], - dtype=np.float32, - ).reshape(1, out_pix, k * k * ifm_ch) - - produced = execution_im2col(x, idt, k, stride, ifm_ch, ifm_dim) - assert (produced == expected).all() - - # bipolar inputs with following im2col parameters - idt = DataType.BIPOLAR - k = 2 - stride = 2 - ifm_ch = 1 - ifm_dim = 4 - ofm_dim = int(((ifm_dim - k) / stride) + 1) - out_pix = ofm_dim * ofm_dim - - expected = np.asarray( - [ - -1.0, - -1.0, - 1.0, - -1.0, - 1.0, - 1.0, - 1.0, - -1.0, - -1.0, - 1.0, - 1.0, - 1.0, - -1.0, - -1.0, - 1.0, - 1.0, - ], - dtype=np.float32, - ).reshape(1, out_pix, k * k * ifm_ch) - - produced = execution_im2col(x, idt, k, stride, ifm_ch, ifm_dim) - assert (produced == expected).all() - - # TO DO: check multiple channel result - # bipolar inputs with following im2col parameters - idt = DataType.BIPOLAR - k = 2 - stride = 2 - ifm_ch = 2 - ifm_dim = 4 - ofm_dim = int(((ifm_dim - k) / stride) + 1) - out_pix = ofm_dim * ofm_dim - - x = np.asarray( - [ - -1.0, - -1.0, - 1.0, - 1.0, - 1.0, - -1.0, - 1.0, - -1.0, - -1.0, - 1.0, - -1.0, - -1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - -1.0, - -1.0, - -1.0, - 1.0, - 1.0, - -1.0, - 1.0, - -1.0, - -1.0, - -1.0, - 1.0, - 1.0, - -1.0, - -1.0, - 1.0, - ], - dtype=np.float32, - ).reshape(1, ifm_ch, ifm_dim, ifm_dim) - - expected = np.asarray( - [ - -1.0, - 1.0, - -1.0, - -1.0, - 1.0, - 1.0, - -1.0, - 1.0, - 1.0, - -1.0, - 1.0, - -1.0, - 1.0, - -1.0, - -1.0, - 1.0, - -1.0, - -1.0, - 1.0, - -1.0, - 1.0, - 1.0, - 1.0, - -1.0, - -1.0, - -1.0, - -1.0, - 1.0, - 1.0, - -1.0, - 1.0, - 1.0, - ], - dtype=np.float32, - ).reshape(1, out_pix, k * k * ifm_ch) - - produced = execution_im2col(x, idt, k, stride, ifm_ch, ifm_dim) - assert (produced == expected).all() - - # bipolar inputs with following im2col parameters - idt = DataType.BIPOLAR - k = 2 - stride = 2 - ifm_ch = 1 - ifm_dim = 6 - ofm_dim = int(((ifm_dim - k) / stride) + 1) - out_pix = ofm_dim * ofm_dim - - x = np.asarray( - [ - 1.0, - 1.0, - 1.0, - -1.0, - -1.0, - -1.0, - -1.0, - 1.0, - -1.0, - 1.0, - -1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - -1.0, - -1.0, - -1.0, - 1.0, - 1.0, - -1.0, - 1.0, - -1.0, - -1.0, - 1.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - 1.0, - -1.0, - 1.0, - 1.0, - -1.0, - ], - dtype=np.float32, - ).reshape(1, ifm_ch, ifm_dim, ifm_dim) - - expected = np.asarray( - [ - 1.0, - 1.0, - -1.0, - 1.0, - 1.0, - -1.0, - -1.0, - 1.0, - -1.0, - -1.0, - -1.0, - 1.0, - 1.0, - 1.0, - -1.0, - 1.0, - 1.0, - 1.0, - 1.0, - -1.0, - -1.0, - -1.0, - 1.0, - -1.0, - -1.0, - 1.0, - -1.0, - 1.0, - -1.0, - -1.0, - -1.0, - 1.0, - -1.0, - -1.0, - 1.0, - -1.0, - ], - dtype=np.float32, - ).reshape(1, out_pix, k * k * ifm_ch) - - produced = execution_im2col(x, idt, k, stride, ifm_ch, ifm_dim) - assert (produced == expected).all() - - # int2 inputs with following im2col parameters - idt = DataType.INT2 - k = 2 - stride = 1 - ifm_ch = 1 - ifm_dim = 4 - ofm_dim = int(((ifm_dim - k) / stride) + 1) - out_pix = ofm_dim * ofm_dim - - x = np.asarray( - [ - 1.0, - -1.0, - -2.0, - 0.0, - 1.0, - 1.0, - -2.0, - 0.0, - 0.0, - -2.0, - 1.0, - -2.0, - -1.0, - -1.0, - 0.0, - -2.0, - ], - dtype=np.float32, - ).reshape(1, ifm_ch, ifm_dim, ifm_dim) - - expected = np.asarray( - [ - 1.0, - -1.0, - 1.0, - 1.0, - -1.0, - -2.0, - 1.0, - -2.0, - -2.0, - 0.0, - -2.0, - 0.0, - 1.0, - 1.0, - 0.0, - -2.0, - 1.0, - -2.0, - -2.0, - 1.0, - -2.0, - 0.0, - 1.0, - -2.0, - 0.0, - -2.0, - -1.0, - -1.0, - -2.0, - 1.0, - -1.0, - 0.0, - 1.0, - -2.0, - 0.0, - -2.0, - ], - dtype=np.float32, - ).reshape(1, out_pix, k * k * ifm_ch) - - produced = execution_im2col(x, idt, k, stride, ifm_ch, ifm_dim) - assert (produced == expected).all() - - # int2 inputs with following im2col parameters - idt = DataType.INT2 - k = 3 - stride = 1 - ifm_ch = 1 - ifm_dim = 4 - ofm_dim = int(((ifm_dim - k) / stride) + 1) - out_pix = ofm_dim * ofm_dim - - expected = np.asarray( - [ - 1.0, - -1.0, - -2.0, - 1.0, - 1.0, - -2.0, - 0.0, - -2.0, - 1.0, - -1.0, - -2.0, - 0.0, - 1.0, - -2.0, - 0.0, - -2.0, - 1.0, - -2.0, - 1.0, - 1.0, - -2.0, - 0.0, - -2.0, - 1.0, - -1.0, - -1.0, - 0.0, - 1.0, - -2.0, - 0.0, - -2.0, - 1.0, - -2.0, - -1.0, - 0.0, - -2.0, - ], - dtype=np.float32, - ).reshape(1, out_pix, k * k * ifm_ch) - - produced = execution_im2col(x, idt, k, stride, ifm_ch, ifm_dim) - assert (produced == expected).all() - - # int2 inputs with following im2col parameters - idt = DataType.INT2 - k = 2 - stride = 2 - ifm_ch = 1 - ifm_dim = 4 - ofm_dim = int(((ifm_dim - k) / stride) + 1) - out_pix = ofm_dim * ofm_dim - - expected = np.asarray( - [ - 1.0, - -1.0, - 1.0, - 1.0, - -2.0, - 0.0, - -2.0, - 0.0, - 0.0, - -2.0, - -1.0, - -1.0, - 1.0, - -2.0, - 0.0, - -2.0, - ], - dtype=np.float32, - ).reshape(1, out_pix, k * k * ifm_ch) - - produced = execution_im2col(x, idt, k, stride, ifm_ch, ifm_dim) - assert (produced == expected).all() - - # TO DO: check multiple channel result - # int2 inputs with following im2col parameters - idt = DataType.INT2 - k = 2 - stride = 2 - ifm_ch = 2 - ifm_dim = 4 - ofm_dim = int(((ifm_dim - k) / stride) + 1) - out_pix = ofm_dim * ofm_dim - - x = np.asarray( - [ - 1.0, - -1.0, - -2.0, - 0.0, - 1.0, - 1.0, - -2.0, - 0.0, - 0.0, - -2.0, - 1.0, - -2.0, - -1.0, - -1.0, - 0.0, - -2.0, - -2.0, - -1.0, - -1.0, - -2.0, - 1.0, - -2.0, - 0.0, - -1.0, - -1.0, - 0.0, - -2.0, - -2.0, - -2.0, - 1.0, - 0.0, - 1.0, - ], - dtype=np.float32, - ).reshape(1, ifm_ch, ifm_dim, ifm_dim) - - expected = np.asarray( - [ - 1.0, - -2.0, - -1.0, - -1.0, - 1.0, - 1.0, - 1.0, - -2.0, - -2.0, - -1.0, - 0.0, - -2.0, - -2.0, - 0.0, - 0.0, - -1.0, - 0.0, - -1.0, - -2.0, - 0.0, - -1.0, - -2.0, - -1.0, - 1.0, - 1.0, - -2.0, - -2.0, - -2.0, - 0.0, - 0.0, - -2.0, - 1.0, - ], - dtype=np.float32, - ).reshape(1, out_pix, k * k * ifm_ch) - - produced = execution_im2col(x, idt, k, stride, ifm_ch, ifm_dim) - assert (produced == expected).all() - - # int2 inputs with following im2col parameters - idt = DataType.INT2 - k = 2 - stride = 2 - ifm_ch = 1 - ifm_dim = 6 - ofm_dim = int(((ifm_dim - k) / stride) + 1) - out_pix = ofm_dim * ofm_dim - - x = np.asarray( - [ - 0.0, - -1.0, - -2.0, - -1.0, - 1.0, - 0.0, - 1.0, - 0.0, - 1.0, - 0.0, - -2.0, - -1.0, - 0.0, - -1.0, - 0.0, - -2.0, - -2.0, - 0.0, - 1.0, - -2.0, - -2.0, - -1.0, - -1.0, - -1.0, - -2.0, - -2.0, - -2.0, - -2.0, - -2.0, - -2.0, - -2.0, - -2.0, - 0.0, - -1.0, - 0.0, - 0.0, - ], - dtype=np.float32, - ).reshape(1, ifm_ch, ifm_dim, ifm_dim) - - expected = np.asarray( - [ - 0.0, - -1.0, - 1.0, - 0.0, - -2.0, - -1.0, - 1.0, - 0.0, - 1.0, - 0.0, - -2.0, - -1.0, - 0.0, - -1.0, - 1.0, - -2.0, - 0.0, - -2.0, - -2.0, - -1.0, - -2.0, - 0.0, - -1.0, - -1.0, - -2.0, - -2.0, - -2.0, - -2.0, - -2.0, - -2.0, - 0.0, - -1.0, - -2.0, - -2.0, - 0.0, - 0.0, - ], - dtype=np.float32, - ).reshape(1, out_pix, k * k * ifm_ch) + ).reshape(1, ofm_dim, ofm_dim, k * k * ifm_ch) produced = execution_im2col(x, idt, k, stride, ifm_ch, ifm_dim) assert (produced == expected).all() @@ -793,14 +169,13 @@ def test_im2col_infer_shapes(): ifm_ch = 1 ifm_dim = 4 ofm_dim = int(((ifm_dim - k) / stride) + 1) - out_pix = ofm_dim * ofm_dim # set up onnx model inp = helper.make_tensor_value_info( - "inp", TensorProto.FLOAT, [1, ifm_ch, ifm_dim, ifm_dim] + "inp", TensorProto.FLOAT, [1, ifm_dim, ifm_dim, ifm_ch] ) outp = helper.make_tensor_value_info( - "outp", TensorProto.FLOAT, [1, out_pix, k * k * ifm_ch] + "outp", TensorProto.FLOAT, [1, ofm_dim, ofm_dim, k * k * ifm_ch] ) abs_node = helper.make_node("Abs", inputs=["inp"], outputs=["abs"],) @@ -812,7 +187,7 @@ def test_im2col_infer_shapes(): domain="finn", stride=stride, kernel_size=k, - input_shape="(1,{},{},{})".format(ifm_ch, ifm_dim, ifm_dim), + input_shape="(1,{},{},{})".format(ifm_dim, ifm_dim, ifm_ch), ) abs1_node = helper.make_node("Abs", inputs=["im2col"], outputs=["outp"],) @@ -824,10 +199,10 @@ def test_im2col_infer_shapes(): outputs=[outp], value_info=[ helper.make_tensor_value_info( - "abs", TensorProto.FLOAT, [1, ifm_ch, ifm_dim, ifm_dim] + "abs", TensorProto.FLOAT, [1, ifm_dim, ifm_dim, ifm_ch] ), helper.make_tensor_value_info( - "im2col", TensorProto.FLOAT, [1, out_pix, k * k * ifm_ch] + "im2col", TensorProto.FLOAT, [1, ofm_dim, ofm_dim, k * k * ifm_ch] ), ], ) @@ -839,4 +214,4 @@ def test_im2col_infer_shapes(): # test shape inference model.transform(InferShapes()) - assert model.get_tensor_shape("im2col") == [1, out_pix, k * k * ifm_ch] + assert model.get_tensor_shape("im2col") == [1, ofm_dim, ofm_dim, k * k * ifm_ch]