Skip to content
Snippets Groups Projects
Commit 9dfdd565 authored by Yaman Umuroglu's avatar Yaman Umuroglu
Browse files

[Transform] change nWorkers default behavior, comments

parent 347538ae
No related branches found
No related tags found
No related merge requests found
......@@ -48,6 +48,7 @@ Guide to writing FINN transformations
"""
from abc import ABC, abstractmethod
from finn.util.basic import get_num_default_workers
import multiprocessing as mp
......@@ -65,44 +66,50 @@ class Transformation(ABC):
class NodeLocalTransformation(Transformation):
"""
Parent class for transformations, which which can be executed localy to one node.
Parent class for transformations, which can be executed locally to one node
by accessing and modifying the attributes of only that node.
This class can then automatically parallelize the transformation.
Transformations sublcassing NodeLocalTransformation must fill the abstract method
applyNodeLocal().
* NUM_DEFAULT_WORKERS (int or None) number of parallel workers. Default is 1. None: Use all available cores.
Transformations sublcassing NodeLocalTransformation must implement the
abstract method applyNodeLocal().
To control the degree of parallelization, specify the num_workers argument
in the constructor, using one of the following values:
* None: use NUM_DEFAULT_WORKERS environment variable
* 0: use all available CPU cores
* (any other int>0): set number of parallel workers
"""
def __init__(self, NUM_DEFAULT_WORKERS=1):
def __init__(self, num_workers=None):
super().__init__()
if NUM_DEFAULT_WORKERS == None:
self._num_workers = mp.cpu_count()
if num_workers is None:
self._num_workers = get_num_default_workers()
else:
assert NUM_DEFAULT_WORKERS > 0, "Number of workers (NUM_DEFAULT_WORKERS) to use must be larger than 0."
self._num_workers = NUM_DEFAULT_WORKERS
self._num_workers = num_workers
assert self._num_workers >= 0, "Number of workers must be nonnegative."
if self._num_workers == 0:
self._num_workers = mp.cpu_count()
@abstractmethod
def applyNodeLocal(self, node):
pass
def apply(self, model):
def apply(self, model):
# Remove old nodes from the current model
old_nodes = []
for i in range(len(model.graph.node)):
old_nodes.append(model.graph.node.pop())
# Execute transformation in parallel
with mp.Pool(self._num_workers) as p:
new_nodes_and_bool = p.map(self.applyNodeLocal, old_nodes, chunksize=1)
# extract nodes and check if the transformation needs to run again
# Note: .pop() had initially reversed the node order
run_again = False
for node, run in reversed(new_nodes_and_bool):
# Reattach new nodes to old model
model.graph.node.append(node)
if run == True:
if run is True:
run_again = True
return (model, run_again)
......@@ -39,13 +39,14 @@ class HLSSynth_IPGen(NodeLocalTransformation):
This transformation calls Vivado HLS for synthesis, so it will run for
some time (several minutes)
* NUM_DEFAULT_WORKERS (int or None) number of parallel workers. Default is 1. None: Use all available cores.
* num_workers (int or None) number of parallel workers, see documentation in
NodeLocalTransformation for more details.
"""
def __init__(self, NUM_DEFAULT_WORKERS=1):
super().__init__(NUM_DEFAULT_WORKERS=NUM_DEFAULT_WORKERS)
def __init__(self, num_workers=None):
super().__init__(num_workers=num_workers)
def applyNodeLocal(self, node):
op_type = node.op_type
if node.domain == "finn":
......@@ -76,5 +77,5 @@ class HLSSynth_IPGen(NodeLocalTransformation):
raise Exception(
"Custom op_type %s is currently not supported." % op_type
)
return (node, False)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment