diff --git a/src/finn/custom_op/fpgadataflow/__init__.py b/src/finn/custom_op/fpgadataflow/__init__.py
index e5938bfaad56cffca895df0eb7c2eefec3730212..80f8649fc2f203ce3bf6bc45d728827b9b413ee6 100644
--- a/src/finn/custom_op/fpgadataflow/__init__.py
+++ b/src/finn/custom_op/fpgadataflow/__init__.py
@@ -38,11 +38,15 @@ class HLSCustomOp(CustomOp):
         self.code_gen_dict = {}
 
         self.tmp_dir = ""
-        self.code_gen_dir = (util.get_by_name(onnx_node.attribute, "code_gen_dir")).s
+        self.code_gen_dir = util.get_by_name(onnx_node.attribute, "code_gen_dir")
         self.executable_path = ""
 
-    def code_generation(self):
+    def code_generation(self, context):
         node = self.onnx_node
+        if "weights" in context:
+            self.generate_weights(context)
+        if "thresh" in context:
+            self.generate_thresholds(context)
         self.global_includes()
         self.defines()
         self.read_npy_data()
@@ -62,6 +66,14 @@ class HLSCustomOp(CustomOp):
         f.write(template)
         f.close()
 
+    @abstractmethod
+    def generate_weights(self, context):
+        pass
+
+    @abstractmethod
+    def generate_thresholds(self, context):
+        pass
+
     @abstractmethod
     def global_includes(self):
         pass
diff --git a/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py b/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py
index 4857620a5d9b2f7c035f2da40c1d71a884ae1178..d8d67713700505a0078f08c7542beb13d98930da 100644
--- a/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py
+++ b/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py
@@ -148,6 +148,71 @@ class StreamingFCLayer_Batch(HLSCustomOp):
         assert ret.shape[2] == n_thres_steps
         return ret
 
+    def generate_weights(self, context):
+
+        weights = context["weights"]
+        # convert weights into hlslib-compatible format
+        weight_tensor = self.get_hls_compatible_weight_tensor(weights)
+        export_wdt = self.get_weight_datatype()
+        # we have converted bipolar weights to binary for export,
+        # so use it as such for weight generation
+        if self.get_weight_datatype() == DataType.BIPOLAR:
+            export_wdt = DataType.BINARY
+        weight_hls_code = numpy_to_hls_code(
+            weight_tensor, export_wdt, "weights", True, True
+        )
+        # write weights into params.h
+        f_weights = open("{}/params.h".format(self.tmp_dir), "w")
+
+        if export_wdt.bitwidth() != 1:
+            f_weights.write(
+                "static FixedPointWeights<{},{},{},{}> weights = ".format(
+                    self.get_nodeattr("SIMD"),
+                    export_wdt.get_hls_datatype_str(),
+                    self.get_nodeattr("PE"),
+                    self.get_nodeattr("WMEM"),
+                )
+            )
+        else:
+            f_weights.write(
+                "static BinaryWeights<{},{},{}> weights = ".format(
+                    self.get_nodeattr("SIMD"),
+                    self.get_nodeattr("PE"),
+                    self.get_nodeattr("WMEM"),
+                )
+            )
+        f_weights.write(weight_hls_code)
+        f_weights.close()
+
+    def generate_thresholds(self, context):
+        thresholds = context["thresh"]
+        threshold_tensor = self.get_hls_compatible_threshold_tensor(thresholds)
+        tdt = DataType.INT32
+        # use UINT32 threshold export for bipolar times bipolar
+        inp_is_bipolar = self.get_input_datatype() == DataType.BIPOLAR
+        wt_is_bipolar = self.get_weight_datatype() == DataType.BIPOLAR
+        if inp_is_bipolar and wt_is_bipolar:
+            tdt = DataType.UINT32
+        thresholds_hls_code = numpy_to_hls_code(
+            threshold_tensor, tdt, "thresholds", False, True
+        )
+        # write thresholds into thresh.h
+        f_thresh = open("{}/thresh.h".format(self.tmp_dir), "w")
+        tdt_hls = tdt.get_hls_datatype_str()
+        odt_hls = self.get_output_datatype().get_hls_datatype_str()
+        f_thresh.write(
+            "static ThresholdsActivation<{},{},{},{},{},{}> threshs = ".format(
+                self.get_nodeattr("TMEM"),
+                self.get_nodeattr("PE"),
+                threshold_tensor.shape[-1],
+                tdt_hls,
+                odt_hls,
+                self.get_nodeattr("ActVal"),
+            )
+        )
+        f_thresh.write(thresholds_hls_code)
+        f_thresh.close()
+
     def execute_node(self, context, graph):
         node = self.onnx_node
         # make temporary directory for generated files
@@ -180,78 +245,16 @@ class StreamingFCLayer_Batch(HLSCustomOp):
                         context[inputs],
                     )
                 temp_files.append("{}/input_{}.npy".format(self.tmp_dir, in_ind))
-            elif in_ind == 1:
-                weights = context[inputs]
-                # convert weights into hlslib-compatible format
-                weight_tensor = self.get_hls_compatible_weight_tensor(weights)
-                export_wdt = self.get_weight_datatype()
-                # we have converted bipolar weights to binary for export,
-                # so use it as such for weight generation
-                if self.get_weight_datatype() == DataType.BIPOLAR:
-                    export_wdt = DataType.BINARY
-                weight_hls_code = numpy_to_hls_code(
-                    weight_tensor, export_wdt, "weights", True, True
-                )
-                # write weights into params.h
-                f_weights = open("{}/params.h".format(self.tmp_dir), "w")
-
-                if export_wdt.bitwidth() != 1:
-                    f_weights.write(
-                        "static FixedPointWeights<{},{},{},{}> weights = ".format(
-                            self.get_nodeattr("SIMD"),
-                            export_wdt.get_hls_datatype_str(),
-                            self.get_nodeattr("PE"),
-                            self.get_nodeattr("WMEM"),
-                        )
-                    )
-                else:
-                    f_weights.write(
-                        "static BinaryWeights<{},{},{}> weights = ".format(
-                            self.get_nodeattr("SIMD"),
-                            self.get_nodeattr("PE"),
-                            self.get_nodeattr("WMEM"),
-                        )
-                    )
-                f_weights.write(weight_hls_code)
-                f_weights.close()
-                temp_files.append("{}/params.h".format(self.tmp_dir))
-
-            elif in_ind == 2:
-                thresholds = context[inputs]
-                threshold_tensor = self.get_hls_compatible_threshold_tensor(thresholds)
-                tdt = DataType.INT32
-                # use UINT32 threshold export for bipolar times bipolar
-                inp_is_bipolar = self.get_input_datatype() == DataType.BIPOLAR
-                wt_is_bipolar = self.get_weight_datatype() == DataType.BIPOLAR
-                if inp_is_bipolar and wt_is_bipolar:
-                    tdt = DataType.UINT32
-                thresholds_hls_code = numpy_to_hls_code(
-                    threshold_tensor, tdt, "thresholds", False, True
-                )
-                # write weights into thresh.h
-                f_thresh = open("{}/thresh.h".format(self.tmp_dir), "w")
-                tdt_hls = tdt.get_hls_datatype_str()
-                odt_hls = self.get_output_datatype().get_hls_datatype_str()
-                f_thresh.write(
-                    "static ThresholdsActivation<{},{},{},{},{},{}> threshs = ".format(
-                        self.get_nodeattr("TMEM"),
-                        self.get_nodeattr("PE"),
-                        threshold_tensor.shape[-1],
-                        tdt_hls,
-                        odt_hls,
-                        self.get_nodeattr("ActVal"),
-                    )
-                )
-                f_thresh.write(thresholds_hls_code)
-                f_thresh.close()
-                temp_files.append("{}/thresh.h".format(self.tmp_dir))
-            else:
+            elif in_ind > 2:
                 raise Exception("Unexpected input found for StreamingFCLayer")
 
             in_ind += 1
 
+        temp_files.append("{}/params.h".format(self.tmp_dir))
+        temp_files.append("{}/thresh.h".format(self.tmp_dir))
+
         # code generation
-        self.code_generation()
+        self.code_generation(context)
 
         # c++ compilation and execution flow
         temp_files.append("{}/execute_{}.cpp".format(self.tmp_dir, node.op_type))
diff --git a/src/finn/transformation/code_gen_transformation.py b/src/finn/transformation/code_gen_transformation.py
index 35ce2a841ce6cb53d6a81b1a9a5c85396f981557..7e0ec8429dcfeb1478f90184e24c6b8936781607 100644
--- a/src/finn/transformation/code_gen_transformation.py
+++ b/src/finn/transformation/code_gen_transformation.py
@@ -12,7 +12,7 @@ def code_gen_transformation(node):
         # get the path of the code generation directory if already set
         # check instance and check node attributes for value
         code_gen_dir = inst.code_gen_dir
-
+        print(code_gen_dir)
         # parameter is empty
         if not code_gen_dir:
             print("parameter is empty")
diff --git a/tests/fpgadataflow/test_code_gen_trafo.py b/tests/fpgadataflow/test_code_gen_trafo.py
index f19dca73cd41d644757851056ee71e62719346a0..39e41f3d2f75e6f420005a7ea6579bcaa9aa6752 100644
--- a/tests/fpgadataflow/test_code_gen_trafo.py
+++ b/tests/fpgadataflow/test_code_gen_trafo.py
@@ -1,113 +1,62 @@
 import numpy as np
 from onnx import TensorProto, helper
 
+import finn.core.utils as util
 import finn.transformation.code_gen_transformation as cg_trafo
 from finn.core.datatype import DataType
 from finn.core.modelwrapper import ModelWrapper
 
 
 def test_code_gen_trafo():
-    inp = helper.make_tensor_value_info("in", TensorProto.FLOAT, [2, 2, 4, 4])
-    outp = helper.make_tensor_value_info("out", TensorProto.FLOAT, [2, 2, 2, 2])
-
-    MaxPool_batch_node = helper.make_node(
-        "StreamingMaxPool_Batch",
-        ["in"],
-        ["out"],
+    idt = wdt = odt = DataType.BIPOLAR
+    tdt = DataType.UINT32
+    mw = 8
+    mh = 8
+    pe = 4
+    simd = 4
+    wmem = mw * mh // (pe * simd)
+    assert mw * mh == wmem * pe * simd
+    nf = mh // pe
+    sf = mw // simd
+    tmem = nf
+
+    inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, [1, sf, simd])
+    outp = helper.make_tensor_value_info("outp", TensorProto.FLOAT, [1, nf, pe])
+    node_inp_list = ["inp", "weights", "thresh"]
+    FCLayer_node = helper.make_node(
+        "StreamingFCLayer_Batch",
+        node_inp_list,
+        ["outp"],
         domain="finn",
         backend="fpgadataflow",
-        code_gen_dir="hifch",
+        code_gen_dir="dummy_directory",
         executable_path="",
-        ImgDim=4,
-        PoolDim=2,
-        NumChannels=2,
+        resType="ap_resource_lut()",
+        MW=mw,
+        MH=mh,
+        SIMD=simd,
+        PE=pe,
+        WMEM=wmem,
+        TMEM=tmem,
+        inputDataType=idt.name,
+        weightDataType=wdt.name,
+        outputDataType=odt.name,
     )
-
     graph = helper.make_graph(
-        nodes=[MaxPool_batch_node],
-        name="max_pool_batch_graph",
-        inputs=[inp],
-        outputs=[outp],
+        nodes=[FCLayer_node], name="fclayer_graph", inputs=[inp], outputs=[outp]
     )
-    model = helper.make_model(graph, producer_name="finn-hls-onnx-model")
-    model = ModelWrapper(model)
 
-    # set the tensor datatypes (in this case: all to bipolar)
-    for tensor in graph.input:
-        model.set_tensor_datatype(tensor.name, DataType["BIPOLAR"])
-    for tensor in graph.output:
-        model.set_tensor_datatype(tensor.name, DataType["BIPOLAR"])
+    model = helper.make_model(graph, producer_name="fclayer-model")
+    model = ModelWrapper(model)
 
-    input_tensor = np.asarray(
-        [
-            1,
-            1,
-            1,
-            1,
-            1,
-            1,
-            1,
-            1,
-            0,
-            0,
-            0,
-            0,
-            0,
-            0,
-            0,
-            0,
-            0,
-            0,
-            0,
-            0,
-            0,
-            0,
-            0,
-            0,
-            1,
-            1,
-            1,
-            1,
-            1,
-            1,
-            1,
-            1,
-            1,
-            1,
-            1,
-            1,
-            1,
-            1,
-            1,
-            1,
-            0,
-            0,
-            0,
-            0,
-            0,
-            0,
-            0,
-            0,
-            0,
-            0,
-            0,
-            0,
-            0,
-            0,
-            0,
-            0,
-            1,
-            1,
-            1,
-            1,
-            1,
-            1,
-            1,
-            1,
-        ],
-        dtype=np.float32,
-    ).reshape(2, 2, 4, 4)
+    model.set_tensor_datatype("inp", idt)
+    model.set_tensor_datatype("outp", odt)
+    model.set_tensor_datatype("weights", wdt)
+    W = util.gen_finn_dt_tensor(wdt, (mh, mw))
+    model.set_initializer("weights", W)
+    model.set_tensor_datatype("thresh", tdt)
+    T = np.zeros((1, 1))
+    model.set_initializer("thresh", T)
 
-    input_dict = {"in": input_tensor}
     for nodes in model.graph.node:
         cg_trafo.code_gen_transformation(nodes)