From 9dfdd5659fb4e7015c66731bd2e9a7c677594c83 Mon Sep 17 00:00:00 2001
From: Yaman Umuroglu <maltanar@gmail.com>
Date: Wed, 22 Apr 2020 23:24:29 +0100
Subject: [PATCH] [Transform] change nWorkers default behavior, comments

---
 src/finn/transformation/__init__.py           | 45 +++++++++++--------
 .../fpgadataflow/hlssynth_ipgen.py            | 13 +++---
 2 files changed, 33 insertions(+), 25 deletions(-)

diff --git a/src/finn/transformation/__init__.py b/src/finn/transformation/__init__.py
index d6cacb739..e9f5fe15f 100644
--- a/src/finn/transformation/__init__.py
+++ b/src/finn/transformation/__init__.py
@@ -48,6 +48,7 @@ Guide to writing FINN transformations
 """
 
 from abc import ABC, abstractmethod
+from finn.util.basic import get_num_default_workers
 import multiprocessing as mp
 
 
@@ -65,44 +66,50 @@ class Transformation(ABC):
 
 class NodeLocalTransformation(Transformation):
     """
-    Parent class for transformations, which which can be executed localy to one node.
+    Parent class for transformations, which can be executed locally to one node
+    by accessing and modifying the attributes of only that node.
     This class can then automatically parallelize the transformation.
-    Transformations sublcassing NodeLocalTransformation must fill the abstract method
-    applyNodeLocal().
-    
-    * NUM_DEFAULT_WORKERS (int or None) number of parallel workers. Default is 1. None: Use all available cores.
-    
+    Transformations sublcassing NodeLocalTransformation must implement the
+    abstract method applyNodeLocal().
+
+    To control the degree of parallelization, specify the num_workers argument
+    in the constructor, using one of the following values:
+    * None: use NUM_DEFAULT_WORKERS environment variable
+    * 0: use all available CPU cores
+    * (any other int>0): set number of parallel workers
     """
-    
-    def __init__(self, NUM_DEFAULT_WORKERS=1):
+
+    def __init__(self, num_workers=None):
         super().__init__()
-        if NUM_DEFAULT_WORKERS == None:
-            self._num_workers = mp.cpu_count()
+        if num_workers is None:
+            self._num_workers = get_num_default_workers()
         else:
-            assert NUM_DEFAULT_WORKERS > 0, "Number of workers (NUM_DEFAULT_WORKERS) to use must be larger than 0."
-            self._num_workers = NUM_DEFAULT_WORKERS
-    
+            self._num_workers = num_workers
+        assert self._num_workers >= 0, "Number of workers must be nonnegative."
+        if self._num_workers == 0:
+            self._num_workers = mp.cpu_count()
+
     @abstractmethod
     def applyNodeLocal(self, node):
         pass
-    
-    def apply(self, model):        
+
+    def apply(self, model):
         # Remove old nodes from the current model
         old_nodes = []
         for i in range(len(model.graph.node)):
             old_nodes.append(model.graph.node.pop())
-        
+
         # Execute transformation in parallel
         with mp.Pool(self._num_workers) as p:
             new_nodes_and_bool = p.map(self.applyNodeLocal, old_nodes, chunksize=1)
-            
+
         # extract nodes and check if the transformation needs to run again
         # Note: .pop() had initially reversed the node order
         run_again = False
         for node, run in reversed(new_nodes_and_bool):
             # Reattach new nodes to old model
             model.graph.node.append(node)
-            if run == True:
+            if run is True:
                 run_again = True
-        
+
         return (model, run_again)
diff --git a/src/finn/transformation/fpgadataflow/hlssynth_ipgen.py b/src/finn/transformation/fpgadataflow/hlssynth_ipgen.py
index ac8e810e4..2a40b3c23 100644
--- a/src/finn/transformation/fpgadataflow/hlssynth_ipgen.py
+++ b/src/finn/transformation/fpgadataflow/hlssynth_ipgen.py
@@ -39,13 +39,14 @@ class HLSSynth_IPGen(NodeLocalTransformation):
 
     This transformation calls Vivado HLS for synthesis, so it will run for
     some time (several minutes)
-    
-    * NUM_DEFAULT_WORKERS (int or None) number of parallel workers. Default is 1. None: Use all available cores.
+
+    * num_workers (int or None) number of parallel workers, see documentation in
+      NodeLocalTransformation for more details.
     """
 
-    def __init__(self, NUM_DEFAULT_WORKERS=1):
-        super().__init__(NUM_DEFAULT_WORKERS=NUM_DEFAULT_WORKERS)
-    
+    def __init__(self, num_workers=None):
+        super().__init__(num_workers=num_workers)
+
     def applyNodeLocal(self, node):
         op_type = node.op_type
         if node.domain == "finn":
@@ -76,5 +77,5 @@ class HLSSynth_IPGen(NodeLocalTransformation):
                     raise Exception(
                         "Custom op_type %s is currently not supported." % op_type
                     )
-        
+
         return (node, False)
-- 
GitLab