diff --git a/src/finn/custom_op/fpgadataflow/__init__.py b/src/finn/custom_op/fpgadataflow/__init__.py
index 80f8649fc2f203ce3bf6bc45d728827b9b413ee6..e0cf22bcb7cd98843cb2ab254e9c74027aed533f 100644
--- a/src/finn/custom_op/fpgadataflow/__init__.py
+++ b/src/finn/custom_op/fpgadataflow/__init__.py
@@ -41,12 +41,14 @@ class HLSCustomOp(CustomOp):
         self.code_gen_dir = util.get_by_name(onnx_node.attribute, "code_gen_dir")
         self.executable_path = ""
 
-    def code_generation(self, context):
+    def code_generation(self, model):
         node = self.onnx_node
-        if "weights" in context:
-            self.generate_weights(context)
-        if "thresh" in context:
-            self.generate_thresholds(context)
+        if node.op_type == "StreamingFCLayer_Batch":
+            self.generate_weights(model)
+            try:
+                self.generate_thresholds(model)
+            except:
+                pass
         self.global_includes()
         self.defines()
         self.read_npy_data()
diff --git a/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py b/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py
index d8d67713700505a0078f08c7542beb13d98930da..0487b9a2bfb6716156b4ba37ffcbc4c042523c4f 100644
--- a/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py
+++ b/src/finn/custom_op/fpgadataflow/streamingfclayer_batch.py
@@ -148,9 +148,9 @@ class StreamingFCLayer_Batch(HLSCustomOp):
         assert ret.shape[2] == n_thres_steps
         return ret
 
-    def generate_weights(self, context):
+    def generate_weights(self, model):
 
-        weights = context["weights"]
+        weights = model.get_initializer(self.onnx_node.input[1])
         # convert weights into hlslib-compatible format
         weight_tensor = self.get_hls_compatible_weight_tensor(weights)
         export_wdt = self.get_weight_datatype()
@@ -184,8 +184,8 @@ class StreamingFCLayer_Batch(HLSCustomOp):
         f_weights.write(weight_hls_code)
         f_weights.close()
 
-    def generate_thresholds(self, context):
-        thresholds = context["thresh"]
+    def generate_thresholds(self, model):
+        thresholds = model.get_initializer(self.onnx_node.input[2])
         threshold_tensor = self.get_hls_compatible_threshold_tensor(thresholds)
         tdt = DataType.INT32
         # use UINT32 threshold export for bipolar times bipolar
diff --git a/src/finn/transformation/fpgadataflow/__init__.py b/src/finn/transformation/fpgadataflow/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/finn/transformation/code_gen_transformation.py b/src/finn/transformation/fpgadataflow/code_gen_transformation.py
similarity index 86%
rename from src/finn/transformation/code_gen_transformation.py
rename to src/finn/transformation/fpgadataflow/code_gen_transformation.py
index 6b0ab56ee826b6e2b296a7520ba725317c99c3ca..d6fce3ddedc3ca044bd4702e91cb5007f04650ff 100644
--- a/src/finn/transformation/code_gen_transformation.py
+++ b/src/finn/transformation/fpgadataflow/code_gen_transformation.py
@@ -5,7 +5,7 @@ import finn.custom_op.registry as registry
 from finn.transformation import Transformation
 
 
-def code_gen_transformation(node, context, model):
+def code_gen_transformation(node, model):
     """Call custom implementation to generate code for single custom node
     and create folder that contains all the generated files"""
     op_type = node.op_type
@@ -21,7 +21,7 @@ def code_gen_transformation(node, context, model):
         if not code_gen_dir:
             tmp_dir = tmp.mkdtemp(prefix="code_gen_" + str(node.op_type) + "_")
             inst.tmp_dir = tmp_dir
-            inst.code_generation(context)
+            inst.code_generation(model)
             # check if directory exists
             if os.path.isdir(tmp_dir):
                 if len(os.listdir(tmp_dir)) == 0:
@@ -40,7 +40,7 @@ def code_gen_transformation(node, context, model):
                     os.rmdir(code_gen_dir)
                     tmp_dir = tmp.mkdtemp(prefix="code_gen_" + str(node.op_type) + "_")
                     inst.tmp_dir = tmp_dir
-                    inst.code_generation(context)
+                    inst.code_generation(model)
                     if os.path.isdir(tmp_dir):
                         if len(os.listdir(tmp_dir)) == 0:
                             raise Exception("Code was not generated!")
@@ -63,11 +63,6 @@ class CodeGen(Transformation):
     """Code generation for all nodes in model"""
 
     def apply(self, model):
-        W = model.get_initializer("weights")
-        T = model.get_initializer("thresh")
-        context = {}
-        context["weights"] = W
-        context["thresh"] = T
         for node in model.graph.node:
-            code_gen_transformation(node, context, model)
+            code_gen_transformation(node, model)
         return (model, False)
diff --git a/src/finn/transformation/compilation_transformation.py b/src/finn/transformation/fpgadataflow/compilation_transformation.py
similarity index 100%
rename from src/finn/transformation/compilation_transformation.py
rename to src/finn/transformation/fpgadataflow/compilation_transformation.py
diff --git a/tests/transformation/test_code_gen_trafo.py b/tests/transformation/test_code_gen_trafo.py
index 01768a36c8322b39ad1cde286fde008ab2b7c6d9..b5205c5211007c18a564b7a4c20e37de207b0708 100644
--- a/tests/transformation/test_code_gen_trafo.py
+++ b/tests/transformation/test_code_gen_trafo.py
@@ -6,21 +6,18 @@ from onnx import TensorProto, helper
 import finn.core.utils as util
 from finn.core.datatype import DataType
 from finn.core.modelwrapper import ModelWrapper
-from finn.transformation.code_gen_transformation import CodeGen
+from finn.transformation.fpgadataflow.code_gen_transformation import CodeGen
 
 
 def test_code_gen_trafo():
     idt = wdt = odt = DataType.BIPOLAR
-    tdt = DataType.UINT32
     mw = 8
     mh = 8
     pe = 4
     simd = 4
     wmem = mw * mh // (pe * simd)
-    assert mw * mh == wmem * pe * simd
     nf = mh // pe
     sf = mw // simd
-    tmem = nf
 
     inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, [1, sf, simd])
     outp = helper.make_tensor_value_info("outp", TensorProto.FLOAT, [1, nf, pe])
@@ -39,7 +36,7 @@ def test_code_gen_trafo():
         SIMD=simd,
         PE=pe,
         WMEM=wmem,
-        TMEM=tmem,
+        TMEM=0,
         inputDataType=idt.name,
         weightDataType=wdt.name,
         outputDataType=odt.name,
@@ -54,11 +51,8 @@ def test_code_gen_trafo():
     model.set_tensor_datatype("inp", idt)
     model.set_tensor_datatype("outp", odt)
     model.set_tensor_datatype("weights", wdt)
-    W = util.gen_finn_dt_tensor(wdt, (mh, mw))
+    W = util.gen_finn_dt_tensor(wdt, (mw, mh))
     model.set_initializer("weights", W)
-    model.set_tensor_datatype("thresh", tdt)
-    T = np.zeros((1, 1))
-    model.set_initializer("thresh", T)
 
     model = model.transform(CodeGen())
     for node in model.graph.node:
diff --git a/tests/transformation/test_compilation_trafo.py b/tests/transformation/test_compilation_trafo.py
index 563c4d4cca4ea1e5b13fc5723ca1f36c948c4007..9578d93d6d5bea2869e99ba442f0048a933d41c7 100644
--- a/tests/transformation/test_compilation_trafo.py
+++ b/tests/transformation/test_compilation_trafo.py
@@ -6,19 +6,17 @@ from onnx import TensorProto, helper
 import finn.core.utils as util
 from finn.core.datatype import DataType
 from finn.core.modelwrapper import ModelWrapper
-from finn.transformation.code_gen_transformation import CodeGen
-from finn.transformation.compilation_transformation import Compilation
+from finn.transformation.fpgadataflow.code_gen_transformation import CodeGen
+from finn.transformation.fpgadataflow.compilation_transformation import Compilation
 
 
 def test_compilation_trafo():
     idt = wdt = odt = DataType.BIPOLAR
-    tdt = DataType.UINT32
     mw = 8
     mh = 8
     pe = 4
     simd = 4
     wmem = mw * mh // (pe * simd)
-    assert mw * mh == wmem * pe * simd
     nf = mh // pe
     sf = mw // simd
     tmem = nf
@@ -40,7 +38,7 @@ def test_compilation_trafo():
         SIMD=simd,
         PE=pe,
         WMEM=wmem,
-        TMEM=tmem,
+        TMEM=0,
         inputDataType=idt.name,
         weightDataType=wdt.name,
         outputDataType=odt.name,
@@ -55,11 +53,8 @@ def test_compilation_trafo():
     model.set_tensor_datatype("inp", idt)
     model.set_tensor_datatype("outp", odt)
     model.set_tensor_datatype("weights", wdt)
-    W = util.gen_finn_dt_tensor(wdt, (mh, mw))
+    W = util.gen_finn_dt_tensor(wdt, (mw, mh))
     model.set_initializer("weights", W)
-    model.set_tensor_datatype("thresh", tdt)
-    T = np.zeros((1, 1))
-    model.set_initializer("thresh", T)
 
     model = model.transform(CodeGen())
     model = model.transform(Compilation())