From 3c01c64a62313d39f9e8664fa629fda48ea64b10 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Wed, 22 Apr 2020 23:33:47 +0100
Subject: [PATCH] [Transform] convert Compile to a NodeLocalTransformation

---
 .../transformation/fpgadataflow/compile.py    | 28 ++++++++++---------
 1 file changed, 15 insertions(+), 13 deletions(-)

diff --git a/src/finn/transformation/fpgadataflow/compile.py b/src/finn/transformation/fpgadataflow/compile.py
index e577c3af6..a76ab6832 100644
--- a/src/finn/transformation/fpgadataflow/compile.py
+++ b/src/finn/transformation/fpgadataflow/compile.py
@@ -28,28 +28,30 @@
 
 import finn.custom_op.registry as registry
 import finn.util.basic as util
-from finn.transformation import Transformation
+from finn.transformation import NodeLocalTransformation
 
 
-class Compile(Transformation):
+class Compile(NodeLocalTransformation):
     """For every node: compile C++ code in node attribute "code_gen_dir_npysim"
     and save path to executables in node attribute "executable_path".
     All nodes in the graph must have the fpgadataflow backend attribute.
 
     To use these executables, exec_mode must be set to "npysim" (using transformation
     SetExecMode) and the model has to be executed using execute_onnx() from
-    finn.core.onnx_exec"""
+    finn.core.onnx_exec
 
-    def __init__(self):
-        super().__init__()
+    * num_workers (int or None) number of parallel workers, see documentation in
+      NodeLocalTransformation for more details.
+    """
 
-    def apply(self, model):
-        for node in model.graph.node:
-            op_type = node.op_type
-            if node.domain == "finn":
-                backend_attribute = util.get_by_name(node.attribute, "backend")
-                if backend_attribute is None:
-                    continue
+    def __init__(self, num_workers=None):
+        super().__init__(num_workers=num_workers)
+
+    def applyNodeLocal(self, node):
+        op_type = node.op_type
+        if node.domain == "finn":
+            backend_attribute = util.get_by_name(node.attribute, "backend")
+            if backend_attribute is not None:
                 backend_value = backend_attribute.s.decode("UTF-8")
                 if backend_value == "fpgadataflow":
                     try:
@@ -74,4 +76,4 @@ class Compile(Transformation):
                         raise Exception(
                             "Custom op_type %s is currently not supported." % op_type
                         )
-        return (model, False)
+        return (node, False)
-- 
GitLab