diff --git a/src/finn/custom_op/im2col.py b/src/finn/custom_op/im2col.py
index 8af1227a4f4b6b1a244fcbf51e1ed71ec2749802..0cd56e755cb1782c60ab10a820d492dc309649a7 100644
--- a/src/finn/custom_op/im2col.py
+++ b/src/finn/custom_op/im2col.py
@@ -2,62 +2,70 @@ import numpy as np
 from onnx import TensorProto, helper
 
 from finn.custom_op import CustomOp
+import finn.util.basic as util
+from finn.core.datatype import DataType
 
+# adapted from A. Karpathy's CS231 im2col code
+# utilities to generate a patch matrix from a multichannel image
+# of shape (batches, channels, height, width)
 
-def get_im2col_indices(x_shape, k, stride):
+
+def compute_conv_output_dim(ifm_dim, k, stride, pad=0):
+    "Return spatial output dimension size for convolution with given params."
+    return int(((ifm_dim + 2 * pad - k) / stride) + 1)
+
+
+def get_im2col_indices_nchw(
+    x_shape, field_height, field_width, padding=0, stride_y=1, stride_x=1
+):
     # First figure out what the size of the output should be
     N, C, H, W = x_shape
-    assert H == W
-    assert (W - k) % stride == 0
-    ofm_dim = int((W - k) / stride + 1)
+    assert (H + 2 * padding - field_height) % stride_y == 0
+    assert (W + 2 * padding - field_width) % stride_x == 0
+    out_height = compute_conv_output_dim(H, field_height, stride_y, padding)
+    out_width = compute_conv_output_dim(W, field_width, stride_x, padding)
 
-    i0 = np.repeat(np.arange(k), k)
+    i0 = np.repeat(np.arange(field_height), field_width)
     i0 = np.tile(i0, C)
-    i1 = stride * np.repeat(np.arange(ofm_dim), ofm_dim)
-    j0 = np.tile(np.arange(k), k * C)
-    j1 = stride * np.tile(np.arange(ofm_dim), ofm_dim)
+    i1 = stride_y * np.repeat(np.arange(out_height), out_width)
+    j0 = np.tile(np.arange(field_width), field_height * C)
+    j1 = stride_x * np.tile(np.arange(out_width), out_height)
     i = i0.reshape(-1, 1) + i1.reshape(1, -1)
     j = j0.reshape(-1, 1) + j1.reshape(1, -1)
 
-    k = np.repeat(np.arange(C), k * k).reshape(-1, 1)
+    k = np.repeat(np.arange(C), field_height * field_width).reshape(-1, 1)
 
     return (k, i, j)
 
 
-def im2col_indices(x, k, stride):
-    """ An implementation of im2col based on indexing """
+def im2col_indices_nchw(
+    x, field_height, field_width, padding=0, stride_y=1, stride_x=1, pad_val=0
+):
+    """ An implementation of im2col based on some fancy indexing """
+    # Zero-pad the input
+    p = padding
+    x_padded = np.pad(
+        x, ((0, 0), (0, 0), (p, p), (p, p)), mode="constant", constant_values=pad_val
+    )
 
-    l, i, j = get_im2col_indices(x.shape, k, stride)
+    k, i, j = get_im2col_indices_nchw(
+        x.shape, field_height, field_width, padding, stride_y, stride_x
+    )
 
-    cols = x[:, l, i, j]
+    cols = x_padded[:, k, i, j]
     C = x.shape[1]
-    cols = cols.transpose(1, 2, 0).reshape(k * k * C, -1)
-    cols = cols.transpose(1, 0)
-
-    # rearranging the output so it matches with finn-hlslib function
-    # swapping the columns according to the input channel
-    # if C > 1 :
-    parts = {}
-    for ch in range(C):
-        parts[ch] = []
-
-    for i in range(cols.shape[1]):
-        if i % C == 0:
-            parts[0].append(i)
-        elif (i + (C - 1)) % C == 0:
-            parts[1].append(i)
-        elif (i + (C - 2)) % C == 0:
-            parts[2].append(i)
-        elif (i + (C - 3)) % C == 0:
-            parts[3].append(i)
-    permutation = []
-    for i in parts:
-        for num in parts[i]:
-            permutation.append(num)
-
-    i = np.argsort(permutation)
-    cols = cols[:, i]
-    return cols.reshape(1, -1, k * k * C)
+    cols = cols.transpose(1, 2, 0).reshape(field_height * field_width * C, -1)
+    return cols
+
+
+# ONNX i/o tensor shape assumptions for Im2Col:
+# input 0 is the input vector, shape (1, ih, iw, ifm)
+# output 0 is the output vector, shape (1, oh, ow, k*k*ifm)
+# where:
+# * ih, iw are the height and width of the input image
+# * oh, ow are the height and width of the output (lowered) image
+# * ifm is the number of input channels
+# * k is the convolutional kernel size
 
 
 class Im2Col(CustomOp):
@@ -66,12 +74,15 @@ class Im2Col(CustomOp):
             "stride": ("i", True, 1),
             "kernel_size": ("i", True, 1),
             "input_shape": ("s", True, ""),
+            "pad_amount": ("i", False, 0),
+            "pad_value": ("i", False, 0),
         }
 
     def make_shape_compatible_op(self):
         k = self.get_nodeattr("kernel_size")
         stride = self.get_nodeattr("stride")
         ishape = self.get_nodeattr("input_shape")
+        pad = self.get_nodeattr("pad_amount")
 
         # convert string into list of integers
         ishape = ishape.strip("(")
@@ -81,13 +92,14 @@ class Im2Col(CustomOp):
             ishape[i] = int(ishape[i])
 
         # extract all necessary information and determine output dimensions
-        ifm_ch = ishape[1]
-        ifm_dim = ishape[2]
-        ofm_dim = int(((ifm_dim - k) / stride) + 1)
-        outpix = ofm_dim * ofm_dim
+        ifm_ch = ishape[-1]
+        assert len(ishape) == 4, "Unexpected input shape for Im2Col"
+        assert ishape[1] == ishape[2], "Im2Col for non-square images unsupported"
+        ifm_dim = ishape[1]
+        ofm_dim = compute_conv_output_dim(ifm_dim, k, stride, pad)
 
         # implement tensor with correct shape
-        values = np.random.randn(1, outpix, k * k * ifm_ch).astype(np.float32)
+        values = np.random.randn(1, ofm_dim, ofm_dim, k * k * ifm_ch).astype(np.float32)
         return helper.make_node(
             "Constant",
             inputs=[],
@@ -110,9 +122,30 @@ class Im2Col(CustomOp):
         node = self.onnx_node
         k = self.get_nodeattr("kernel_size")
         stride = self.get_nodeattr("stride")
-        x = context[node.input[0]]
-        output = im2col_indices(x, k, stride)
-        context[node.output[0]] = output
+        pad = self.get_nodeattr("pad_amount")
+        pad_val = self.get_nodeattr("pad_value")
+        iname = node.input[0]
+        x = context[iname]
+        qnt_annotations = graph.quantization_annotation
+        ret = util.get_by_name(qnt_annotations, iname, "tensor_name")
+        ret = util.get_by_name(ret.quant_parameter_tensor_names, "finn_datatype", "key")
+        idt = DataType[ret.value]
+        if pad != 0:
+            assert idt.allowed(pad_val), "Im2Col dtype must allow pad_val"
+        # check that input is NHWC
+        assert x.ndim == 4, "Unexpected number of input dims for Im2Col"
+        N, H, W, C = x.shape
+        assert H == W, "Unexpected input shape for Im2Col"
+        out_dim = compute_conv_output_dim(H, k, stride, pad)
+        # internally convert input to NCHW
+        x = x.transpose(0, 3, 1, 2)
+        # call NCHW im2col implementation
+        ret = im2col_indices_nchw(x, k, k, pad, stride, stride, pad_val=pad_val)
+        # result shape is (k*k*N, out_dim*out_dim), convert to NCHW
+        ret = ret.reshape(N, k * k * N, out_dim, out_dim)
+        # convert output back to NHWC
+        ret = ret.transpose(0, 2, 3, 1)
+        context[node.output[0]] = ret
 
     def verify_node(self):
         node = self.onnx_node
diff --git a/tests/custom_op/test_im2col.py b/tests/custom_op/test_im2col.py
index 4db3e4250e59c476e0fad17cc717263425b79700..6ed67d67a497ee6c17e1611aa2faf10a48dbd154 100644
--- a/tests/custom_op/test_im2col.py
+++ b/tests/custom_op/test_im2col.py
@@ -24,14 +24,13 @@ def check_two_dict_for_equality(dict1, dict2):
 
 def execution_im2col(x, idt, k, stride, ifm_ch, ifm_dim):
     ofm_dim = int(((ifm_dim - k) / stride) + 1)
-    out_pix = ofm_dim * ofm_dim
 
     # set up onnx model
     inp = helper.make_tensor_value_info(
-        "inp", TensorProto.FLOAT, [1, ifm_ch, ifm_dim, ifm_dim]
+        "inp", TensorProto.FLOAT, [1, ifm_dim, ifm_dim, ifm_ch]
     )
     outp = helper.make_tensor_value_info(
-        "outp", TensorProto.FLOAT, [1, out_pix, k * k * ifm_ch]
+        "outp", TensorProto.FLOAT, [1, ofm_dim, ofm_dim, k * k * ifm_ch]
     )
 
     Im2Col_node = helper.make_node(
@@ -41,7 +40,7 @@ def execution_im2col(x, idt, k, stride, ifm_ch, ifm_dim):
         domain="finn",
         stride=stride,
         kernel_size=k,
-        input_shape="(1,{},{},{})".format(ifm_ch, ifm_dim, ifm_dim),
+        input_shape="(1,{},{},{})".format(ifm_dim, ifm_dim, ifm_ch),
     )
 
     graph = helper.make_graph(
@@ -55,7 +54,7 @@ def execution_im2col(x, idt, k, stride, ifm_ch, ifm_dim):
 
     # test shape inference
     model.transform(InferShapes())
-    assert model.get_tensor_shape("outp") == [1, out_pix, k * k * ifm_ch]
+    assert model.get_tensor_shape("outp") == [1, ofm_dim, ofm_dim, k * k * ifm_ch]
 
     # test datatype inference
     assert model.get_tensor_datatype("outp") is DataType.FLOAT32
@@ -94,7 +93,6 @@ def test_im2col():
     ifm_ch = 1
     ifm_dim = 4
     ofm_dim = int(((ifm_dim - k) / stride) + 1)
-    out_pix = ofm_dim * ofm_dim
 
     x = np.asarray(
         [
@@ -116,7 +114,7 @@ def test_im2col():
             1.0,
         ],
         dtype=np.float32,
-    ).reshape(1, ifm_ch, ifm_dim, ifm_dim)
+    ).reshape(1, ifm_dim, ifm_dim, ifm_ch)
 
     expected = np.asarray(
         [
@@ -158,629 +156,7 @@ def test_im2col():
             1.0,
         ],
         dtype=np.float32,
-    ).reshape(1, out_pix, k * k * ifm_ch)
-
-    produced = execution_im2col(x, idt, k, stride, ifm_ch, ifm_dim)
-    assert (produced == expected).all()
-
-    # bipolar inputs with following im2col parameters
-    idt = DataType.BIPOLAR
-    k = 3
-    stride = 1
-    ifm_ch = 1
-    ifm_dim = 4
-    ofm_dim = int(((ifm_dim - k) / stride) + 1)
-    out_pix = ofm_dim * ofm_dim
-
-    expected = np.asarray(
-        [
-            -1.0,
-            -1.0,
-            1.0,
-            1.0,
-            -1.0,
-            1.0,
-            -1.0,
-            1.0,
-            -1.0,
-            -1.0,
-            1.0,
-            1.0,
-            -1.0,
-            1.0,
-            -1.0,
-            1.0,
-            -1.0,
-            -1.0,
-            1.0,
-            -1.0,
-            1.0,
-            -1.0,
-            1.0,
-            -1.0,
-            1.0,
-            1.0,
-            1.0,
-            -1.0,
-            1.0,
-            -1.0,
-            1.0,
-            -1.0,
-            -1.0,
-            1.0,
-            1.0,
-            1.0,
-        ],
-        dtype=np.float32,
-    ).reshape(1, out_pix, k * k * ifm_ch)
-
-    produced = execution_im2col(x, idt, k, stride, ifm_ch, ifm_dim)
-    assert (produced == expected).all()
-
-    # bipolar inputs with following im2col parameters
-    idt = DataType.BIPOLAR
-    k = 2
-    stride = 2
-    ifm_ch = 1
-    ifm_dim = 4
-    ofm_dim = int(((ifm_dim - k) / stride) + 1)
-    out_pix = ofm_dim * ofm_dim
-
-    expected = np.asarray(
-        [
-            -1.0,
-            -1.0,
-            1.0,
-            -1.0,
-            1.0,
-            1.0,
-            1.0,
-            -1.0,
-            -1.0,
-            1.0,
-            1.0,
-            1.0,
-            -1.0,
-            -1.0,
-            1.0,
-            1.0,
-        ],
-        dtype=np.float32,
-    ).reshape(1, out_pix, k * k * ifm_ch)
-
-    produced = execution_im2col(x, idt, k, stride, ifm_ch, ifm_dim)
-    assert (produced == expected).all()
-
-    # TO DO: check multiple channel result
-    # bipolar inputs with following im2col parameters
-    idt = DataType.BIPOLAR
-    k = 2
-    stride = 2
-    ifm_ch = 2
-    ifm_dim = 4
-    ofm_dim = int(((ifm_dim - k) / stride) + 1)
-    out_pix = ofm_dim * ofm_dim
-
-    x = np.asarray(
-        [
-            -1.0,
-            -1.0,
-            1.0,
-            1.0,
-            1.0,
-            -1.0,
-            1.0,
-            -1.0,
-            -1.0,
-            1.0,
-            -1.0,
-            -1.0,
-            1.0,
-            1.0,
-            1.0,
-            1.0,
-            1.0,
-            -1.0,
-            -1.0,
-            -1.0,
-            1.0,
-            1.0,
-            -1.0,
-            1.0,
-            -1.0,
-            -1.0,
-            -1.0,
-            1.0,
-            1.0,
-            -1.0,
-            -1.0,
-            1.0,
-        ],
-        dtype=np.float32,
-    ).reshape(1, ifm_ch, ifm_dim, ifm_dim)
-
-    expected = np.asarray(
-        [
-            -1.0,
-            1.0,
-            -1.0,
-            -1.0,
-            1.0,
-            1.0,
-            -1.0,
-            1.0,
-            1.0,
-            -1.0,
-            1.0,
-            -1.0,
-            1.0,
-            -1.0,
-            -1.0,
-            1.0,
-            -1.0,
-            -1.0,
-            1.0,
-            -1.0,
-            1.0,
-            1.0,
-            1.0,
-            -1.0,
-            -1.0,
-            -1.0,
-            -1.0,
-            1.0,
-            1.0,
-            -1.0,
-            1.0,
-            1.0,
-        ],
-        dtype=np.float32,
-    ).reshape(1, out_pix, k * k * ifm_ch)
-
-    produced = execution_im2col(x, idt, k, stride, ifm_ch, ifm_dim)
-    assert (produced == expected).all()
-
-    # bipolar inputs with following im2col parameters
-    idt = DataType.BIPOLAR
-    k = 2
-    stride = 2
-    ifm_ch = 1
-    ifm_dim = 6
-    ofm_dim = int(((ifm_dim - k) / stride) + 1)
-    out_pix = ofm_dim * ofm_dim
-
-    x = np.asarray(
-        [
-            1.0,
-            1.0,
-            1.0,
-            -1.0,
-            -1.0,
-            -1.0,
-            -1.0,
-            1.0,
-            -1.0,
-            1.0,
-            -1.0,
-            1.0,
-            1.0,
-            1.0,
-            1.0,
-            1.0,
-            -1.0,
-            -1.0,
-            -1.0,
-            1.0,
-            1.0,
-            -1.0,
-            1.0,
-            -1.0,
-            -1.0,
-            1.0,
-            -1.0,
-            -1.0,
-            -1.0,
-            -1.0,
-            -1.0,
-            1.0,
-            -1.0,
-            1.0,
-            1.0,
-            -1.0,
-        ],
-        dtype=np.float32,
-    ).reshape(1, ifm_ch, ifm_dim, ifm_dim)
-
-    expected = np.asarray(
-        [
-            1.0,
-            1.0,
-            -1.0,
-            1.0,
-            1.0,
-            -1.0,
-            -1.0,
-            1.0,
-            -1.0,
-            -1.0,
-            -1.0,
-            1.0,
-            1.0,
-            1.0,
-            -1.0,
-            1.0,
-            1.0,
-            1.0,
-            1.0,
-            -1.0,
-            -1.0,
-            -1.0,
-            1.0,
-            -1.0,
-            -1.0,
-            1.0,
-            -1.0,
-            1.0,
-            -1.0,
-            -1.0,
-            -1.0,
-            1.0,
-            -1.0,
-            -1.0,
-            1.0,
-            -1.0,
-        ],
-        dtype=np.float32,
-    ).reshape(1, out_pix, k * k * ifm_ch)
-
-    produced = execution_im2col(x, idt, k, stride, ifm_ch, ifm_dim)
-    assert (produced == expected).all()
-
-    # int2 inputs with following im2col parameters
-    idt = DataType.INT2
-    k = 2
-    stride = 1
-    ifm_ch = 1
-    ifm_dim = 4
-    ofm_dim = int(((ifm_dim - k) / stride) + 1)
-    out_pix = ofm_dim * ofm_dim
-
-    x = np.asarray(
-        [
-            1.0,
-            -1.0,
-            -2.0,
-            0.0,
-            1.0,
-            1.0,
-            -2.0,
-            0.0,
-            0.0,
-            -2.0,
-            1.0,
-            -2.0,
-            -1.0,
-            -1.0,
-            0.0,
-            -2.0,
-        ],
-        dtype=np.float32,
-    ).reshape(1, ifm_ch, ifm_dim, ifm_dim)
-
-    expected = np.asarray(
-        [
-            1.0,
-            -1.0,
-            1.0,
-            1.0,
-            -1.0,
-            -2.0,
-            1.0,
-            -2.0,
-            -2.0,
-            0.0,
-            -2.0,
-            0.0,
-            1.0,
-            1.0,
-            0.0,
-            -2.0,
-            1.0,
-            -2.0,
-            -2.0,
-            1.0,
-            -2.0,
-            0.0,
-            1.0,
-            -2.0,
-            0.0,
-            -2.0,
-            -1.0,
-            -1.0,
-            -2.0,
-            1.0,
-            -1.0,
-            0.0,
-            1.0,
-            -2.0,
-            0.0,
-            -2.0,
-        ],
-        dtype=np.float32,
-    ).reshape(1, out_pix, k * k * ifm_ch)
-
-    produced = execution_im2col(x, idt, k, stride, ifm_ch, ifm_dim)
-    assert (produced == expected).all()
-
-    # int2 inputs with following im2col parameters
-    idt = DataType.INT2
-    k = 3
-    stride = 1
-    ifm_ch = 1
-    ifm_dim = 4
-    ofm_dim = int(((ifm_dim - k) / stride) + 1)
-    out_pix = ofm_dim * ofm_dim
-
-    expected = np.asarray(
-        [
-            1.0,
-            -1.0,
-            -2.0,
-            1.0,
-            1.0,
-            -2.0,
-            0.0,
-            -2.0,
-            1.0,
-            -1.0,
-            -2.0,
-            0.0,
-            1.0,
-            -2.0,
-            0.0,
-            -2.0,
-            1.0,
-            -2.0,
-            1.0,
-            1.0,
-            -2.0,
-            0.0,
-            -2.0,
-            1.0,
-            -1.0,
-            -1.0,
-            0.0,
-            1.0,
-            -2.0,
-            0.0,
-            -2.0,
-            1.0,
-            -2.0,
-            -1.0,
-            0.0,
-            -2.0,
-        ],
-        dtype=np.float32,
-    ).reshape(1, out_pix, k * k * ifm_ch)
-
-    produced = execution_im2col(x, idt, k, stride, ifm_ch, ifm_dim)
-    assert (produced == expected).all()
-
-    # int2 inputs with following im2col parameters
-    idt = DataType.INT2
-    k = 2
-    stride = 2
-    ifm_ch = 1
-    ifm_dim = 4
-    ofm_dim = int(((ifm_dim - k) / stride) + 1)
-    out_pix = ofm_dim * ofm_dim
-
-    expected = np.asarray(
-        [
-            1.0,
-            -1.0,
-            1.0,
-            1.0,
-            -2.0,
-            0.0,
-            -2.0,
-            0.0,
-            0.0,
-            -2.0,
-            -1.0,
-            -1.0,
-            1.0,
-            -2.0,
-            0.0,
-            -2.0,
-        ],
-        dtype=np.float32,
-    ).reshape(1, out_pix, k * k * ifm_ch)
-
-    produced = execution_im2col(x, idt, k, stride, ifm_ch, ifm_dim)
-    assert (produced == expected).all()
-
-    # TO DO: check multiple channel result
-    # int2 inputs with following im2col parameters
-    idt = DataType.INT2
-    k = 2
-    stride = 2
-    ifm_ch = 2
-    ifm_dim = 4
-    ofm_dim = int(((ifm_dim - k) / stride) + 1)
-    out_pix = ofm_dim * ofm_dim
-
-    x = np.asarray(
-        [
-            1.0,
-            -1.0,
-            -2.0,
-            0.0,
-            1.0,
-            1.0,
-            -2.0,
-            0.0,
-            0.0,
-            -2.0,
-            1.0,
-            -2.0,
-            -1.0,
-            -1.0,
-            0.0,
-            -2.0,
-            -2.0,
-            -1.0,
-            -1.0,
-            -2.0,
-            1.0,
-            -2.0,
-            0.0,
-            -1.0,
-            -1.0,
-            0.0,
-            -2.0,
-            -2.0,
-            -2.0,
-            1.0,
-            0.0,
-            1.0,
-        ],
-        dtype=np.float32,
-    ).reshape(1, ifm_ch, ifm_dim, ifm_dim)
-
-    expected = np.asarray(
-        [
-            1.0,
-            -2.0,
-            -1.0,
-            -1.0,
-            1.0,
-            1.0,
-            1.0,
-            -2.0,
-            -2.0,
-            -1.0,
-            0.0,
-            -2.0,
-            -2.0,
-            0.0,
-            0.0,
-            -1.0,
-            0.0,
-            -1.0,
-            -2.0,
-            0.0,
-            -1.0,
-            -2.0,
-            -1.0,
-            1.0,
-            1.0,
-            -2.0,
-            -2.0,
-            -2.0,
-            0.0,
-            0.0,
-            -2.0,
-            1.0,
-        ],
-        dtype=np.float32,
-    ).reshape(1, out_pix, k * k * ifm_ch)
-
-    produced = execution_im2col(x, idt, k, stride, ifm_ch, ifm_dim)
-    assert (produced == expected).all()
-
-    # int2 inputs with following im2col parameters
-    idt = DataType.INT2
-    k = 2
-    stride = 2
-    ifm_ch = 1
-    ifm_dim = 6
-    ofm_dim = int(((ifm_dim - k) / stride) + 1)
-    out_pix = ofm_dim * ofm_dim
-
-    x = np.asarray(
-        [
-            0.0,
-            -1.0,
-            -2.0,
-            -1.0,
-            1.0,
-            0.0,
-            1.0,
-            0.0,
-            1.0,
-            0.0,
-            -2.0,
-            -1.0,
-            0.0,
-            -1.0,
-            0.0,
-            -2.0,
-            -2.0,
-            0.0,
-            1.0,
-            -2.0,
-            -2.0,
-            -1.0,
-            -1.0,
-            -1.0,
-            -2.0,
-            -2.0,
-            -2.0,
-            -2.0,
-            -2.0,
-            -2.0,
-            -2.0,
-            -2.0,
-            0.0,
-            -1.0,
-            0.0,
-            0.0,
-        ],
-        dtype=np.float32,
-    ).reshape(1, ifm_ch, ifm_dim, ifm_dim)
-
-    expected = np.asarray(
-        [
-            0.0,
-            -1.0,
-            1.0,
-            0.0,
-            -2.0,
-            -1.0,
-            1.0,
-            0.0,
-            1.0,
-            0.0,
-            -2.0,
-            -1.0,
-            0.0,
-            -1.0,
-            1.0,
-            -2.0,
-            0.0,
-            -2.0,
-            -2.0,
-            -1.0,
-            -2.0,
-            0.0,
-            -1.0,
-            -1.0,
-            -2.0,
-            -2.0,
-            -2.0,
-            -2.0,
-            -2.0,
-            -2.0,
-            0.0,
-            -1.0,
-            -2.0,
-            -2.0,
-            0.0,
-            0.0,
-        ],
-        dtype=np.float32,
-    ).reshape(1, out_pix, k * k * ifm_ch)
+    ).reshape(1, ofm_dim, ofm_dim, k * k * ifm_ch)
 
     produced = execution_im2col(x, idt, k, stride, ifm_ch, ifm_dim)
     assert (produced == expected).all()
@@ -793,14 +169,13 @@ def test_im2col_infer_shapes():
     ifm_ch = 1
     ifm_dim = 4
     ofm_dim = int(((ifm_dim - k) / stride) + 1)
-    out_pix = ofm_dim * ofm_dim
 
     # set up onnx model
     inp = helper.make_tensor_value_info(
-        "inp", TensorProto.FLOAT, [1, ifm_ch, ifm_dim, ifm_dim]
+        "inp", TensorProto.FLOAT, [1, ifm_dim, ifm_dim, ifm_ch]
     )
     outp = helper.make_tensor_value_info(
-        "outp", TensorProto.FLOAT, [1, out_pix, k * k * ifm_ch]
+        "outp", TensorProto.FLOAT, [1, ofm_dim, ofm_dim, k * k * ifm_ch]
     )
 
     abs_node = helper.make_node("Abs", inputs=["inp"], outputs=["abs"],)
@@ -812,7 +187,7 @@ def test_im2col_infer_shapes():
         domain="finn",
         stride=stride,
         kernel_size=k,
-        input_shape="(1,{},{},{})".format(ifm_ch, ifm_dim, ifm_dim),
+        input_shape="(1,{},{},{})".format(ifm_dim, ifm_dim, ifm_ch),
     )
 
     abs1_node = helper.make_node("Abs", inputs=["im2col"], outputs=["outp"],)
@@ -824,10 +199,10 @@ def test_im2col_infer_shapes():
         outputs=[outp],
         value_info=[
             helper.make_tensor_value_info(
-                "abs", TensorProto.FLOAT, [1, ifm_ch, ifm_dim, ifm_dim]
+                "abs", TensorProto.FLOAT, [1, ifm_dim, ifm_dim, ifm_ch]
             ),
             helper.make_tensor_value_info(
-                "im2col", TensorProto.FLOAT, [1, out_pix, k * k * ifm_ch]
+                "im2col", TensorProto.FLOAT, [1, ofm_dim, ofm_dim, k * k * ifm_ch]
             ),
         ],
     )
@@ -839,4 +214,4 @@ def test_im2col_infer_shapes():
 
     # test shape inference
     model.transform(InferShapes())
-    assert model.get_tensor_shape("im2col") == [1, out_pix, k * k * ifm_ch]
+    assert model.get_tensor_shape("im2col") == [1, ofm_dim, ofm_dim, k * k * ifm_ch]