diff --git a/src/finn/custom_op/fpgadataflow/convolutioninputgenerator.py b/src/finn/custom_op/fpgadataflow/convolutioninputgenerator.py
index 72dd65385ad0bdc81b77d69226995a366ff6b421..3f8dc30ac302993ba0565cc1f4245ee10065e458 100644
--- a/src/finn/custom_op/fpgadataflow/convolutioninputgenerator.py
+++ b/src/finn/custom_op/fpgadataflow/convolutioninputgenerator.py
@@ -64,18 +64,44 @@ class ConvolutionInputGenerator(HLSCustomOp):
         my_attrs.update(super().get_nodeattr_types())
         return my_attrs
 
-    def make_shape_compatible_op(self, model):
+    def get_normal_input_shape(self):
+
+        ifm_dim = self.get_nodeattr("IFMDim")
+        ifm_ch = self.get_nodeattr("IFMChannels")
+
+        ishape = (1, ifm_dim, ifm_dim, ifm_ch)
+        return ishape
+
+    def get_normal_output_shape(self):
         k = self.get_nodeattr("ConvKernelDim")
         ifm_dim = self.get_nodeattr("IFMDim")
         ifm_ch = self.get_nodeattr("IFMChannels")
         stride = self.get_nodeattr("Stride")
         pad = 0
-        exp_ishape = (1, ifm_dim, ifm_dim, ifm_ch)
+        ofm_dim = compute_conv_output_dim(ifm_dim, k, stride, pad)
+        oshape = (1, ofm_dim, ofm_dim, k * k * ifm_ch)
+        return oshape
+
+    def get_folded_output_shape(self):
+        k = self.get_nodeattr("ConvKernelDim")
+        ifm_dim = self.get_nodeattr("IFMDim")
+        ifm_ch = self.get_nodeattr("IFMChannels")
+        stride = self.get_nodeattr("Stride")
+        simd = self.get_nodeattr("SIMD")
+        pad = 0
+        ofm_dim = compute_conv_output_dim(ifm_dim, k, stride, pad)
+        assert k * k * ifm_ch % simd == 0, "SIMD must divide sliding window size"
+        wf = int(k * k * ifm_ch // simd)
+        folded_oshape = (1, ofm_dim, ofm_dim, wf, simd)
+        return folded_oshape
+
+    def make_shape_compatible_op(self, model):
+        exp_ishape = self.get_normal_input_shape()
+        oshape = self.get_normal_output_shape()
         ishape = tuple(model.get_tensor_shape(self.onnx_node.input[0]))
         assert ishape == exp_ishape, "Unexpect input shape for ConvInpGen."
-        ofm_dim = compute_conv_output_dim(ifm_dim, k, stride, pad)
         # implement tensor with correct shape
-        values = np.random.randn(1, ofm_dim, ofm_dim, k * k * ifm_ch).astype(np.float32)
+        values = np.random.randn(*oshape).astype(np.float32)
         return helper.make_node(
             "Constant",
             inputs=[],
@@ -118,6 +144,7 @@ class ConvolutionInputGenerator(HLSCustomOp):
         return self.get_nodeattr("SIMD") * ibits
 
     def get_number_output_values(self):
+        # TODO this seems incorrect -- double check
         k = self.get_nodeattr("ConvKernelDim")
         ifm_ch = self.get_nodeattr("IFMChannels")
         ofm_dim = self.get_nodeattr("OFMDim")
@@ -128,11 +155,8 @@ class ConvolutionInputGenerator(HLSCustomOp):
     def execute_node(self, context, graph):
         mode = self.get_nodeattr("exec_mode")
         node = self.onnx_node
-        k = self.get_nodeattr("ConvKernelDim")
-        ifm_dim = self.get_nodeattr("IFMDim")
-        ifm_ch = self.get_nodeattr("IFMChannels")
-        ofm_dim = self.get_nodeattr("OFMDim")
-        out_pix = ofm_dim * ofm_dim
+        exp_ishape = self.get_normal_input_shape()
+        exp_oshape = self.get_normal_output_shape()
 
         if mode == "npysim":
             idt = self.get_input_datatype()
@@ -146,16 +170,12 @@ class ConvolutionInputGenerator(HLSCustomOp):
 
             inp = context[node.input[0]]
             assert str(inp.dtype) == "float32", "Input datatype is not float32"
-            assert inp.shape == (
-                1,
-                ifm_ch,
-                ifm_dim,
-                ifm_dim,
+            assert (
+                inp.shape == exp_ishape
             ), """Input shape doesn't
-            match expected shape (1, ifm_ch, ifm_dim, ifm_dim)."""
-            reshaped_inp = inp.transpose(0, 2, 3, 1)
+            match expected shape (1, ifm_dim, ifm_dim, ifm_ch)."""
             # make copy before saving array
-            reshaped_inp = reshaped_inp.copy()
+            reshaped_inp = inp.copy()
             np.save(os.path.join(code_gen_dir, "input_0.npy"), reshaped_inp)
             # execute the precompiled model
             super().exec_precompiled_singlenode_model()
@@ -165,17 +185,8 @@ class ConvolutionInputGenerator(HLSCustomOp):
                 out = context[node.output[0]]
                 out = 2 * out - 1
                 context[node.output[0]] = out
-            assert context[node.output[0]].shape == (
-                1,
-                out_pix,
-                k * k,
-                ifm_ch,
-            ), """Output
-            shape doesn't match expected shape (1, out_pix, k*k, ifm_ch)."""
-            # reshape output to have expected shape
-            context[node.output[0]] = context[node.output[0]].reshape(
-                1, out_pix, k * k * ifm_ch
-            )
+            context[node.output[0]] = context[node.output[0]].reshape(*exp_oshape)
+
         elif mode == "rtlsim":
             code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen")
             prefixed_top_name = "%s_%s" % (node.name, node.name)
@@ -185,7 +196,10 @@ class ConvolutionInputGenerator(HLSCustomOp):
             )
             if os.path.isfile(verilog_file):
                 inp = context[node.input[0]]
-                inp = inp.transpose(0, 2, 3, 1)
+                assert (
+                    inp.shape == exp_ishape
+                ), """Input shape doesn't
+                match expected shape (1, ifm_dim, ifm_dim, ifm_ch)."""
                 inp = inp.flatten()
 
                 # TODO: check how to sort inputs for multichannel inputs
@@ -209,7 +223,7 @@ class ConvolutionInputGenerator(HLSCustomOp):
                 odt = self.get_output_datatype()
                 if odt == DataType.BIPOLAR:
                     output = [2 * x - 1 for x in output]
-
+                # TOOD use utils here to handle datatype intricacies...
                 # pyverilator interprets int2 as uint2, so output has to be corrected
                 elif odt == DataType.INT2:
                     mask = 2 ** (odt.bitwidth() - 1)
@@ -220,9 +234,8 @@ class ConvolutionInputGenerator(HLSCustomOp):
                 # output_ch2 = [int(x[1:]) for x in output]
 
                 # reshape output
-                output = np.asarray([output], dtype=np.float32).reshape(
-                    1, out_pix, k * k * ifm_ch
-                )
+                output = np.asarray([output], dtype=np.float32)
+                output = output.reshape(*exp_oshape)
                 context[node.output[0]] = output
 
             else:
@@ -237,6 +250,10 @@ class ConvolutionInputGenerator(HLSCustomOp):
                     mode
                 )
             )
+        assert (
+            context[node.output[0]].shape == exp_oshape
+        ), """Output
+        shape doesn't match expected shape (1, ofm_dim, ofm_dim, k*k*ifm_ch)."""
 
     def global_includes(self):
         self.code_gen_dict["$GLOBALS$"] = ['#include "slidingwindow.h"']
@@ -306,12 +323,8 @@ class ConvolutionInputGenerator(HLSCustomOp):
         elem_hls_type = dtype.get_hls_datatype_str()
         npy_type = "float"
         npy_out = "%s/output.npy" % code_gen_dir
-        ofm_dim = self.get_nodeattr("OFMDim")
-        out_pix = ofm_dim * ofm_dim
-        k = self.get_nodeattr("ConvKernelDim")
-        ifm_ch = self.get_nodeattr("IFMChannels")
-        shape = (1, out_pix, k * k, ifm_ch)
-        shape_cpp_str = str(shape).replace("(", "{").replace(")", "}")
+        oshape = self.get_folded_output_shape()
+        oshape_cpp_str = str(oshape).replace("(", "{").replace(")", "}")
 
         self.code_gen_dict["$DATAOUTSTREAM$"] = [
             'apintstream2npy<%s, %s, %d, %s>(out, %s, "%s");'
@@ -320,7 +333,7 @@ class ConvolutionInputGenerator(HLSCustomOp):
                 elem_hls_type,
                 elem_bits,
                 npy_type,
-                shape_cpp_str,
+                oshape_cpp_str,
                 npy_out,
             )
         ]
diff --git a/tests/fpgadataflow/test_fpgadataflow_convinputgenerator.py b/tests/fpgadataflow/test_fpgadataflow_convinputgenerator.py
index c898393608f5f8f31b6044a5505f5903743eb346..7e2d0767108e7a6ff7297f93b639b86f930ee5f0 100644
--- a/tests/fpgadataflow/test_fpgadataflow_convinputgenerator.py
+++ b/tests/fpgadataflow/test_fpgadataflow_convinputgenerator.py
@@ -28,7 +28,6 @@
 
 import pytest
 
-import numpy as np
 from onnx import TensorProto, helper
 
 import finn.core.onnx_exec as oxe
@@ -43,74 +42,49 @@ from finn.transformation.general import GiveUniqueNodeNames
 from finn.util.basic import gen_finn_dt_tensor
 
 
-def get_im2col_indices(x_shape, k, stride):
-    # First figure out what the size of the output should be
-    N, C, H, W = x_shape
-    assert H == W
-    assert (W - k) % stride == 0
-    ofm_dim = int((W - k) / stride + 1)
-
-    i0 = np.repeat(np.arange(k), k)
-    i0 = np.tile(i0, C)
-    i1 = stride * np.repeat(np.arange(ofm_dim), ofm_dim)
-    j0 = np.tile(np.arange(k), k * C)
-    j1 = stride * np.tile(np.arange(ofm_dim), ofm_dim)
-    i = i0.reshape(-1, 1) + i1.reshape(1, -1)
-    j = j0.reshape(-1, 1) + j1.reshape(1, -1)
-
-    k = np.repeat(np.arange(C), k * k).reshape(-1, 1)
-
-    return (k, i, j)
-
-
-def im2col_indices(x, k, stride):
-    """ An implementation of im2col based on some fancy indexing """
-
-    l, i, j = get_im2col_indices(x.shape, k, stride)
+def make_single_im2col_modelwrapper(k, ifm_ch, ifm_dim, ofm_dim, simd, stride, idt):
+    odt = idt
+    inp = helper.make_tensor_value_info(
+        "inp", TensorProto.FLOAT, [1, ifm_dim, ifm_dim, ifm_ch]
+    )
+    outp = helper.make_tensor_value_info(
+        "outp", TensorProto.FLOAT, [1, ofm_dim, ofm_dim, k * k * ifm_ch]
+    )
 
-    cols = x[:, l, i, j]
-    C = x.shape[1]
-    cols = cols.transpose(1, 2, 0).reshape(k * k * C, -1)
-    cols = cols.transpose(1, 0)
+    im2col_node = helper.make_node(
+        "Im2Col",
+        ["inp"],
+        ["outp"],
+        domain="finn",
+        backend="fpgadataflow",
+        stride=stride,
+        kernel_size=k,
+        input_shape=str((1, ifm_dim, ifm_dim, ifm_ch)),
+        pad_amount=0,
+        pad_value=0,
+    )
+    graph = helper.make_graph(
+        nodes=[im2col_node], name="im2col_graph", inputs=[inp], outputs=[outp]
+    )
 
-    # rearranging the output so it matches with finn-hlslib function
-    # swapping the columns according to the input channel
-    # if C > 1 :
-    parts = {}
-    for ch in range(C):
-        parts[ch] = []
+    model = helper.make_model(graph, producer_name="im2col-model")
+    model = ModelWrapper(model)
 
-    for i in range(cols.shape[1]):
-        if i % C == 0:
-            parts[0].append(i)
-        elif (i + (C - 1)) % C == 0:
-            parts[1].append(i)
-        elif (i + (C - 2)) % C == 0:
-            parts[2].append(i)
-        elif (i + (C - 3)) % C == 0:
-            parts[3].append(i)
-    permutation = []
-    for i in parts:
-        for num in parts[i]:
-            permutation.append(num)
+    model.set_tensor_datatype("inp", idt)
+    model.set_tensor_datatype("outp", odt)
 
-    i = np.argsort(permutation)
-    cols = cols[:, i]
-    return cols
+    return model
 
 
 def make_single_slidingwindow_modelwrapper(
     k, ifm_ch, ifm_dim, ofm_dim, simd, stride, idt
 ):
-
     odt = idt
-    out_pix = ofm_dim * ofm_dim
-
     inp = helper.make_tensor_value_info(
-        "inp", TensorProto.FLOAT, [1, ifm_ch, ifm_dim, ifm_dim]
+        "inp", TensorProto.FLOAT, [1, ifm_dim, ifm_dim, ifm_ch]
     )
     outp = helper.make_tensor_value_info(
-        "outp", TensorProto.FLOAT, [1, out_pix, k * k * ifm_ch]
+        "outp", TensorProto.FLOAT, [1, ofm_dim, ofm_dim, k * k * ifm_ch]
     )
 
     SlidingWindow_node = helper.make_node(
@@ -168,7 +142,7 @@ def test_fpgadataflow_slidingwindow(idt, k, ifm_dim, ifm_ch, stride, exec_mode):
     simd = ifm_ch
     ofm_dim = int(((ifm_dim - k) / stride) + 1)
 
-    x = gen_finn_dt_tensor(idt, (1, ifm_ch, ifm_dim, ifm_dim))
+    x = gen_finn_dt_tensor(idt, (1, ifm_dim, ifm_dim, ifm_ch))
     model = make_single_slidingwindow_modelwrapper(
         k, ifm_ch, ifm_dim, ofm_dim, simd, stride, idt
     )
@@ -189,8 +163,10 @@ def test_fpgadataflow_slidingwindow(idt, k, ifm_dim, ifm_ch, stride, exec_mode):
     input_dict = prepare_inputs(x, idt)
     # execute model
     y_produced = oxe.execute_onnx(model, input_dict)["outp"]
-    y_expected = im2col_indices(x, k, stride)
-    # reshape expected output to match node output
-    oshape = y_produced.shape
-    y_expected = y_expected.reshape(oshape)
+    golden = make_single_im2col_modelwrapper(
+        k, ifm_ch, ifm_dim, ofm_dim, simd, stride, idt
+    )
+    y_expected = oxe.execute_onnx(golden, input_dict)["outp"]
+    if idt == DataType.BIPOLAR:
+        y_expected = 2 * y_expected - 1
     assert (y_produced == y_expected).all()