From ccbe08d0b899b1e3a8b59c50fc5c169a327e9b69 Mon Sep 17 00:00:00 2001 From: HenniOVP <hendrikborras@web.de> Date: Wed, 22 Apr 2020 13:30:27 +0200 Subject: [PATCH] Added NodeLocalTransformation class --- src/finn/transformation/__init__.py | 46 +++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/src/finn/transformation/__init__.py b/src/finn/transformation/__init__.py index a4e0bcf33..d6cacb739 100644 --- a/src/finn/transformation/__init__.py +++ b/src/finn/transformation/__init__.py @@ -48,6 +48,7 @@ Guide to writing FINN transformations """ from abc import ABC, abstractmethod +import multiprocessing as mp class Transformation(ABC): @@ -60,3 +61,48 @@ class Transformation(ABC): @abstractmethod def apply(self, model): pass + + +class NodeLocalTransformation(Transformation): + """ + Parent class for transformations, which which can be executed localy to one node. + This class can then automatically parallelize the transformation. + Transformations sublcassing NodeLocalTransformation must fill the abstract method + applyNodeLocal(). + + * NUM_DEFAULT_WORKERS (int or None) number of parallel workers. Default is 1. None: Use all available cores. + + """ + + def __init__(self, NUM_DEFAULT_WORKERS=1): + super().__init__() + if NUM_DEFAULT_WORKERS == None: + self._num_workers = mp.cpu_count() + else: + assert NUM_DEFAULT_WORKERS > 0, "Number of workers (NUM_DEFAULT_WORKERS) to use must be larger than 0." + self._num_workers = NUM_DEFAULT_WORKERS + + @abstractmethod + def applyNodeLocal(self, node): + pass + + def apply(self, model): + # Remove old nodes from the current model + old_nodes = [] + for i in range(len(model.graph.node)): + old_nodes.append(model.graph.node.pop()) + + # Execute transformation in parallel + with mp.Pool(self._num_workers) as p: + new_nodes_and_bool = p.map(self.applyNodeLocal, old_nodes, chunksize=1) + + # extract nodes and check if the transformation needs to run again + # Note: .pop() had initially reversed the node order + run_again = False + for node, run in reversed(new_nodes_and_bool): + # Reattach new nodes to old model + model.graph.node.append(node) + if run == True: + run_again = True + + return (model, run_again) -- GitLab