From ccbe08d0b899b1e3a8b59c50fc5c169a327e9b69 Mon Sep 17 00:00:00 2001
From: HenniOVP <hendrikborras@web.de>
Date: Wed, 22 Apr 2020 13:30:27 +0200
Subject: [PATCH] Added NodeLocalTransformation class

---
 src/finn/transformation/__init__.py | 46 +++++++++++++++++++++++++++++
 1 file changed, 46 insertions(+)

diff --git a/src/finn/transformation/__init__.py b/src/finn/transformation/__init__.py
index a4e0bcf33..d6cacb739 100644
--- a/src/finn/transformation/__init__.py
+++ b/src/finn/transformation/__init__.py
@@ -48,6 +48,7 @@ Guide to writing FINN transformations
 """
 
 from abc import ABC, abstractmethod
+import multiprocessing as mp
 
 
 class Transformation(ABC):
@@ -60,3 +61,48 @@ class Transformation(ABC):
     @abstractmethod
     def apply(self, model):
         pass
+
+
+class NodeLocalTransformation(Transformation):
+    """
+    Parent class for transformations, which which can be executed localy to one node.
+    This class can then automatically parallelize the transformation.
+    Transformations sublcassing NodeLocalTransformation must fill the abstract method
+    applyNodeLocal().
+    
+    * NUM_DEFAULT_WORKERS (int or None) number of parallel workers. Default is 1. None: Use all available cores.
+    
+    """
+    
+    def __init__(self, NUM_DEFAULT_WORKERS=1):
+        super().__init__()
+        if NUM_DEFAULT_WORKERS == None:
+            self._num_workers = mp.cpu_count()
+        else:
+            assert NUM_DEFAULT_WORKERS > 0, "Number of workers (NUM_DEFAULT_WORKERS) to use must be larger than 0."
+            self._num_workers = NUM_DEFAULT_WORKERS
+    
+    @abstractmethod
+    def applyNodeLocal(self, node):
+        pass
+    
+    def apply(self, model):        
+        # Remove old nodes from the current model
+        old_nodes = []
+        for i in range(len(model.graph.node)):
+            old_nodes.append(model.graph.node.pop())
+        
+        # Execute transformation in parallel
+        with mp.Pool(self._num_workers) as p:
+            new_nodes_and_bool = p.map(self.applyNodeLocal, old_nodes, chunksize=1)
+            
+        # extract nodes and check if the transformation needs to run again
+        # Note: .pop() had initially reversed the node order
+        run_again = False
+        for node, run in reversed(new_nodes_and_bool):
+            # Reattach new nodes to old model
+            model.graph.node.append(node)
+            if run == True:
+                run_again = True
+        
+        return (model, run_again)
-- 
GitLab