Skip to content
Snippets Groups Projects
Commit ccbe08d0 authored by HenniOVP's avatar HenniOVP
Browse files

Added NodeLocalTransformation class

parent e18785c8
No related branches found
No related tags found
No related merge requests found
......@@ -48,6 +48,7 @@ Guide to writing FINN transformations
"""
from abc import ABC, abstractmethod
import multiprocessing as mp
class Transformation(ABC):
......@@ -60,3 +61,48 @@ class Transformation(ABC):
@abstractmethod
def apply(self, model):
pass
class NodeLocalTransformation(Transformation):
"""
Parent class for transformations, which which can be executed localy to one node.
This class can then automatically parallelize the transformation.
Transformations sublcassing NodeLocalTransformation must fill the abstract method
applyNodeLocal().
* NUM_DEFAULT_WORKERS (int or None) number of parallel workers. Default is 1. None: Use all available cores.
"""
def __init__(self, NUM_DEFAULT_WORKERS=1):
super().__init__()
if NUM_DEFAULT_WORKERS == None:
self._num_workers = mp.cpu_count()
else:
assert NUM_DEFAULT_WORKERS > 0, "Number of workers (NUM_DEFAULT_WORKERS) to use must be larger than 0."
self._num_workers = NUM_DEFAULT_WORKERS
@abstractmethod
def applyNodeLocal(self, node):
pass
def apply(self, model):
# Remove old nodes from the current model
old_nodes = []
for i in range(len(model.graph.node)):
old_nodes.append(model.graph.node.pop())
# Execute transformation in parallel
with mp.Pool(self._num_workers) as p:
new_nodes_and_bool = p.map(self.applyNodeLocal, old_nodes, chunksize=1)
# extract nodes and check if the transformation needs to run again
# Note: .pop() had initially reversed the node order
run_again = False
for node, run in reversed(new_nodes_and_bool):
# Reattach new nodes to old model
model.graph.node.append(node)
if run == True:
run_again = True
return (model, run_again)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment