diff --git a/src/finn/custom_op/fpgadataflow/__init__.py b/src/finn/custom_op/fpgadataflow/__init__.py
index ce4a72d355a54c6096b159236b727959f9b2e3fd..47bced5d7dc9a30f13b21820e79f3642521c02a0 100644
--- a/src/finn/custom_op/fpgadataflow/__init__.py
+++ b/src/finn/custom_op/fpgadataflow/__init__.py
@@ -91,11 +91,11 @@ class HLSCustomOp(CustomOp):
         f.close()
 
     @abstractmethod
-    def global_includes(self, node, code_gen_dict):
+    def global_includes(self, node):
         pass
 
     @abstractmethod
-    def defines(self, node, code_gen_dict):
+    def defines(self, node):
         pass
 
     @abstractmethod
@@ -103,17 +103,17 @@ class HLSCustomOp(CustomOp):
         pass
 
     @abstractmethod
-    def strm_decl(self, node, code_gen_dict):
+    def strm_decl(self, node):
         pass
 
     @abstractmethod
-    def docompute(self, node, code_gen_dict):
+    def docompute(self, node):
         pass
 
     @abstractmethod
-    def dataoutstrm(self, node, code_gen_dict):
+    def dataoutstrm(self, node):
         pass
 
     @abstractmethod
-    def save_as_npy(self, node, code_gen_dict):
+    def save_as_npy(self, node):
         pass
diff --git a/src/finn/custom_op/fpgadataflow/streamingmaxpool_batch.py b/src/finn/custom_op/fpgadataflow/streamingmaxpool_batch.py
new file mode 100644
index 0000000000000000000000000000000000000000..90edf06268942ac59ac9129217b3face1401dd69
--- /dev/null
+++ b/src/finn/custom_op/fpgadataflow/streamingmaxpool_batch.py
@@ -0,0 +1,125 @@
+from finn.core.utils import get_by_name
+from finn.custom_op.fpgadataflow import HLSCustomOp
+
+
+class StreamingMaxPool_Batch(HLSCustomOp):
+    def make_shape_compatible_op(self, node):
+        pass
+
+    def infer_node_datatype(self, node, model):
+        pass
+
+    def get_attributes(self, node):
+        self.ImgDim = get_by_name(node.attribute, "ImgDim").i
+        self.PoolDim = get_by_name(node.attribute, "PoolDim").i
+        self.NumChannels = get_by_name(node.attribute, "NumChannels").i
+
+    def global_includes(self, node):
+        self.code_gen_dict["$GLOBALS$"] = ['#include "maxpool.h"']
+
+    def defines(self, node):
+        numReps = 2
+        self.code_gen_dict["$DEFINES$"] = [
+            """#define ImgDim {}\n #define PoolDim {}\n
+            #define NumChannels {}\n #define numReps {}""".format(
+                self.ImgDim, self.PoolDim, self.NumChannels, numReps
+            )
+        ]
+
+    def read_npy_data(self, node):
+        self.code_gen_dict["$READNPYDATA$"] = []
+        input_ind = 0
+        input_file_names = []
+        for inputs in node.input:
+            input_file_names.append("input_{}.npy".format(input_ind))
+            input_ind += 1
+
+        input_ind = 0
+        for input_file in input_file_names:
+            self.code_gen_dict["$READNPYDATA$"].append(
+                """cnpy::NpyArray arr = cnpy::npy_load("{}");\n
+                float* loaded_data{} = arr.data<float>();""".format(
+                    input_file, input_ind
+                )
+            )
+            self.code_gen_dict["$READNPYDATA$"].append(
+                """int num_values = 1; \n
+                for(int i = 0; i < arr.shape.size(); i++){\n
+                num_values *= arr.shape[i]; \n }"""
+            )
+            self.code_gen_dict["$READNPYDATA$"].append(
+                "ap_uint<{}> dat;".format(self.NumChannels)
+            )
+            self.code_gen_dict["$READNPYDATA$"].append(
+                "for(int i=0; i < num_values/{}; i++){{".format(self.NumChannels)
+            )
+            for channel in range(self.NumChannels):
+                self.code_gen_dict["$READNPYDATA$"].append(
+                    "dat.range({},{}) = loaded_data{}[i+((num_values/{})*{})];".format(
+                        channel, channel, input_ind, self.NumChannels, channel
+                    )
+                )
+            self.code_gen_dict["$READNPYDATA$"].append("in{} << dat;".format(input_ind))
+            self.code_gen_dict["$READNPYDATA$"].append("}")
+            input_ind += 1
+
+    def strm_decl(self, node):
+        self.code_gen_dict["$STREAMDECLARATIONS$"] = []
+        input_ind = 0
+        for inputs in node.input:
+            self.code_gen_dict["$STREAMDECLARATIONS$"].append(
+                'hls::stream<ap_uint<{}>> in{} ("in{}");'.format(
+                    self.NumChannels, input_ind, input_ind
+                )
+            )
+            input_ind += 1
+        self.code_gen_dict["$STREAMDECLARATIONS$"].append(
+            'hls::stream<ap_uint<{}>> out ("out");'.format(self.NumChannels)
+        )
+
+    def docompute(self, node):
+        self.code_gen_dict["$DOCOMPUTE$"] = [
+            "{}<ImgDim, PoolDim, NumChannels>(in0, out, numReps);".format(node.op_type)
+        ]
+
+    def dataoutstrm(self, node):
+        self.code_gen_dict["$DATAOUTSTREAM$"] = [
+            "ap_uint<{}> out_data;\n std::vector<ap_uint<{}>> out_data_vector;".format(
+                self.NumChannels, self.NumChannels
+            )
+        ]
+        self.code_gen_dict["$DATAOUTSTREAM$"].append("while(out.read_nb(out_data)){")
+        self.code_gen_dict["$DATAOUTSTREAM$"].append(
+            "out_data_vector.push_back(out_data);\n}"
+        )
+        self.code_gen_dict["$DATAOUTSTREAM$"].append(
+            "std::vector<float> output_data_vector;"
+        )
+        self.code_gen_dict["$DATAOUTSTREAM$"].append(
+            """for(std::vector<ap_uint<{}>>::iterator it = out_data_vector.begin();
+            it != out_data_vector.end(); ++it){{""".format(
+                self.NumChannels
+            )
+        )
+        self.code_gen_dict["$DATAOUTSTREAM$"].append(
+            "ap_uint<{}> output_data = *it;".format(self.NumChannels)
+        )
+        for channel in range(self.NumChannels):
+            self.code_gen_dict["$DATAOUTSTREAM$"].append(
+                "output_data_vector.push_back(output_data.range({},{}));".format(
+                    channel, channel
+                )
+            )
+        self.code_gen_dict["$DATAOUTSTREAM$"].append("}")
+
+    def save_as_npy(self, node):
+        numReps = 2
+        self.code_gen_dict["$SAVEASCNPY$"] = [
+            """cnpy::npy_save("output.npy",&output_data_vector[0],
+            {{{},{},{},{}}},"w");""".format(
+                numReps,
+                self.NumChannels,
+                int(self.ImgDim / self.PoolDim),
+                int(self.ImgDim / self.PoolDim),
+            )
+        ]
diff --git a/src/finn/custom_op/registry.py b/src/finn/custom_op/registry.py
index 978e7ce1434bfa53977950776b8264de160c2f2f..7b1d2cf4f405552f13ca5d1c9b21a858f350130a 100644
--- a/src/finn/custom_op/registry.py
+++ b/src/finn/custom_op/registry.py
@@ -1,8 +1,9 @@
 # make sure new CustomOp subclasses are imported here so that they get
 # registered and plug in correctly into the infrastructure
+from finn.custom_op.fpgadataflow.streamingmaxpool import StreamingMaxPool
+from finn.custom_op.fpgadataflow.streamingmaxpool_batch import StreamingMaxPool_Batch
 from finn.custom_op.multithreshold import MultiThreshold
 from finn.custom_op.xnorpopcount import XnorPopcountMatMul
-from finn.custom_op.fpgadataflow.streamingmaxpool import StreamingMaxPool
 
 # create a mapping of all known CustomOp names and classes
 custom_op = {}
@@ -10,3 +11,4 @@ custom_op = {}
 custom_op["MultiThreshold"] = MultiThreshold
 custom_op["XnorPopcountMatMul"] = XnorPopcountMatMul
 custom_op["StreamingMaxPool"] = StreamingMaxPool
+custom_op["StreamingMaxPool_Batch"] = StreamingMaxPool_Batch
diff --git a/tests/test_layer_streaming_maxpool_batch.py b/tests/test_layer_streaming_maxpool_batch.py
new file mode 100644
index 0000000000000000000000000000000000000000..75cbd64376055996de58455c1d28530b770500a5
--- /dev/null
+++ b/tests/test_layer_streaming_maxpool_batch.py
@@ -0,0 +1,114 @@
+import numpy as np
+from onnx import TensorProto, helper
+
+import finn.core.onnx_exec as oxe
+from finn.core.datatype import DataType
+from finn.core.modelwrapper import ModelWrapper
+
+
+def test_layer_streaming_maxpool_batch():
+    inp = helper.make_tensor_value_info("in", TensorProto.FLOAT, [2, 2, 4, 4])
+    outp = helper.make_tensor_value_info("out", TensorProto.FLOAT, [2, 2, 2, 2])
+
+    MaxPool_batch_node = helper.make_node(
+        "StreamingMaxPool_Batch",
+        ["in"],
+        ["out"],
+        domain="finn",
+        backend="fpgadataflow",
+        ImgDim=4,
+        PoolDim=2,
+        NumChannels=2,
+    )
+
+    graph = helper.make_graph(
+        nodes=[MaxPool_batch_node],
+        name="max_pool_batch_graph",
+        inputs=[inp],
+        outputs=[outp],
+    )
+    model = helper.make_model(graph, producer_name="finn-hls-onnx-model")
+    model = ModelWrapper(model)
+
+    # set the tensor datatypes (in this case: all to bipolar)
+    for tensor in graph.input:
+        model.set_tensor_datatype(tensor.name, DataType["BIPOLAR"])
+    for tensor in graph.output:
+        model.set_tensor_datatype(tensor.name, DataType["BIPOLAR"])
+
+    # onnx.save(model.model, "max-pool-model.onnx")
+
+    input_tensor = np.asarray(
+        [
+            1,
+            1,
+            1,
+            1,
+            1,
+            1,
+            1,
+            1,
+            0,
+            0,
+            0,
+            0,
+            0,
+            0,
+            0,
+            0,
+            0,
+            0,
+            0,
+            0,
+            0,
+            0,
+            0,
+            0,
+            1,
+            1,
+            1,
+            1,
+            1,
+            1,
+            1,
+            1,
+            1,
+            1,
+            1,
+            1,
+            1,
+            1,
+            1,
+            1,
+            0,
+            0,
+            0,
+            0,
+            0,
+            0,
+            0,
+            0,
+            0,
+            0,
+            0,
+            0,
+            0,
+            0,
+            0,
+            0,
+            1,
+            1,
+            1,
+            1,
+            1,
+            1,
+            1,
+            1,
+        ],
+        dtype=np.float32,
+    ).reshape(2, 2, 4, 4)
+    print(input_tensor)
+
+    input_dict = {"in": input_tensor}
+    output_dict = oxe.execute_onnx(model, input_dict)
+    print(output_dict)