From 9dfdd5659fb4e7015c66731bd2e9a7c677594c83 Mon Sep 17 00:00:00 2001 From: Yaman Umuroglu <maltanar@gmail.com> Date: Wed, 22 Apr 2020 23:24:29 +0100 Subject: [PATCH] [Transform] change nWorkers default behavior, comments --- src/finn/transformation/__init__.py | 45 +++++++++++-------- .../fpgadataflow/hlssynth_ipgen.py | 13 +++--- 2 files changed, 33 insertions(+), 25 deletions(-) diff --git a/src/finn/transformation/__init__.py b/src/finn/transformation/__init__.py index d6cacb739..e9f5fe15f 100644 --- a/src/finn/transformation/__init__.py +++ b/src/finn/transformation/__init__.py @@ -48,6 +48,7 @@ Guide to writing FINN transformations """ from abc import ABC, abstractmethod +from finn.util.basic import get_num_default_workers import multiprocessing as mp @@ -65,44 +66,50 @@ class Transformation(ABC): class NodeLocalTransformation(Transformation): """ - Parent class for transformations, which which can be executed localy to one node. + Parent class for transformations, which can be executed locally to one node + by accessing and modifying the attributes of only that node. This class can then automatically parallelize the transformation. - Transformations sublcassing NodeLocalTransformation must fill the abstract method - applyNodeLocal(). - - * NUM_DEFAULT_WORKERS (int or None) number of parallel workers. Default is 1. None: Use all available cores. - + Transformations sublcassing NodeLocalTransformation must implement the + abstract method applyNodeLocal(). + + To control the degree of parallelization, specify the num_workers argument + in the constructor, using one of the following values: + * None: use NUM_DEFAULT_WORKERS environment variable + * 0: use all available CPU cores + * (any other int>0): set number of parallel workers """ - - def __init__(self, NUM_DEFAULT_WORKERS=1): + + def __init__(self, num_workers=None): super().__init__() - if NUM_DEFAULT_WORKERS == None: - self._num_workers = mp.cpu_count() + if num_workers is None: + self._num_workers = get_num_default_workers() else: - assert NUM_DEFAULT_WORKERS > 0, "Number of workers (NUM_DEFAULT_WORKERS) to use must be larger than 0." - self._num_workers = NUM_DEFAULT_WORKERS - + self._num_workers = num_workers + assert self._num_workers >= 0, "Number of workers must be nonnegative." + if self._num_workers == 0: + self._num_workers = mp.cpu_count() + @abstractmethod def applyNodeLocal(self, node): pass - - def apply(self, model): + + def apply(self, model): # Remove old nodes from the current model old_nodes = [] for i in range(len(model.graph.node)): old_nodes.append(model.graph.node.pop()) - + # Execute transformation in parallel with mp.Pool(self._num_workers) as p: new_nodes_and_bool = p.map(self.applyNodeLocal, old_nodes, chunksize=1) - + # extract nodes and check if the transformation needs to run again # Note: .pop() had initially reversed the node order run_again = False for node, run in reversed(new_nodes_and_bool): # Reattach new nodes to old model model.graph.node.append(node) - if run == True: + if run is True: run_again = True - + return (model, run_again) diff --git a/src/finn/transformation/fpgadataflow/hlssynth_ipgen.py b/src/finn/transformation/fpgadataflow/hlssynth_ipgen.py index ac8e810e4..2a40b3c23 100644 --- a/src/finn/transformation/fpgadataflow/hlssynth_ipgen.py +++ b/src/finn/transformation/fpgadataflow/hlssynth_ipgen.py @@ -39,13 +39,14 @@ class HLSSynth_IPGen(NodeLocalTransformation): This transformation calls Vivado HLS for synthesis, so it will run for some time (several minutes) - - * NUM_DEFAULT_WORKERS (int or None) number of parallel workers. Default is 1. None: Use all available cores. + + * num_workers (int or None) number of parallel workers, see documentation in + NodeLocalTransformation for more details. """ - def __init__(self, NUM_DEFAULT_WORKERS=1): - super().__init__(NUM_DEFAULT_WORKERS=NUM_DEFAULT_WORKERS) - + def __init__(self, num_workers=None): + super().__init__(num_workers=num_workers) + def applyNodeLocal(self, node): op_type = node.op_type if node.domain == "finn": @@ -76,5 +77,5 @@ class HLSSynth_IPGen(NodeLocalTransformation): raise Exception( "Custom op_type %s is currently not supported." % op_type ) - + return (node, False) -- GitLab