import dask from dask.core import get_dependencies, flatten import numpy as np import copy class Node(object): def __init__(self): pass def configure(self, requests): """ Before a task graph is executed each node is configured. The request is propagated from the end to the beginning of the DAG and each nodes "configure" routine is called. The request can be updated to reflect additional requirements, The return value gets passed to predecessors. Essentially the following question must be answered: What do I need to fulfil the request of my successor? Here, you must not configure the internal parameters of the Node otherwise it would not be thread-safe. You can however introduce a new key 'requires_request' in the request being returned. This request will then be passed as an argument to the __call__ function. Best practice is to configure the Node on initialization with runtime independent configurations and define all runtime dependant configurations here. Arguments: requests {List} -- List of requests (i.e. dictionaries). Returns: dict -- The (updated) request. If updated modifications must be made on a copy of the input. The return value must be a dictionary. If multiple requests are input to this function they must be merged. If nothing needs to be requested an empty dictionary can be return. This removes all dependencies of this node from the task graph. """ if not isinstance(requests, list): raise RuntimeError("Please provide a **list** of request") if len(requests) > 1: raise RuntimeError( "Default configuration function cannot handle " "multiple requests. Please provide a custom " "configuration implementation" ) return requests def __call__(self, data=None, request=None, delayed=False): if delayed: return dask.delayed(self.forward)(data, request) else: return self.forward(data=data, request=request) def forward(self, data, request): raise NotImplementedError def get_config(self): """ returns a dictionary of configurations to recreate the state """ raise NotImplementedError() class StuettNode(Node): # TODO: define where this class should be (maybe not here) def __init__(self, **kwargs): self.config = locals().copy() while "kwargs" in self.config: if "kwargs" not in self.config["kwargs"]: self.config.update(self.config["kwargs"]) break self.config.update(self.config["kwargs"]) del self.config["kwargs"] del self.config["self"] def configure(self, requests): """ Default configure for stuett nodes Expects two keys per request (*start_time* and *tend*) If multiple requests are passed, they will be merged start_time = minimum of all requests' start_time end_time = maximum of all requests' end_time Arguments: request {list} -- List of requests Returns: dict -- Original request or merged requests """ if not isinstance(requests, list): raise RuntimeError("Please provide a list of request") # For time requests we just use the union of both time segments new_request = requests[0].copy() key_func = {"start_time": np.minimum, "end_time": np.maximum} for r in requests[1:]: for key in ["start_time", "end_time"]: if key in r: if key in new_request: new_request[key] = key_func[key](new_request[key], r[key]) else: new_request[key] = r[key] return new_request def configuration(delayed, request, keys=None, default_merge=None): """ Configures each node of the graph by propagating the request from outputs to inputs. Each node checks if it can fulfil the request and what it needs to fulfil the request. If a node requires additional configurations to fulfil the request it can set the 'requires_request' flag in the returned request and this function will add the return request as a a new input to the node's __call__(). See also Node.configure() Arguments: delayed {dask.delayed or list} -- Delayed object or list of delayed objects request {dict or list} -- request (dict), list of requests default_merge {callable} -- request merge function Keyword Arguments: keys {[type]} -- [description] (default: {None}) Raises: RuntimeError: [description] RuntimeError: [description] Returns: dask.delayed or list -- Config-optimized delayed object or list of delayed objects """ if not isinstance(delayed, list): collections = [delayed] else: collections = delayed # dsk = dask.base.collections_to_dsk(collections) dsk, dsk_keys = dask.base._extract_graph_and_keys(collections) dependencies, dependants = dask.core.get_deps(dsk) if keys is None: keys = dsk_keys if not isinstance(keys, (list, set)): keys = [keys] out_keys = [] seen = set() work = list(set(flatten(keys))) if isinstance(request, list): if len(request) != len(work): raise RuntimeError( "When passing multiple request items " "The number of request items must be same " "as the number of keys" ) requests = {work[i]: [request[i]] for i in range(len(request))} else: requests = {k: [request] for k in work} remove = {k: False for k in work} input_requests = {} while work: new_work = [] out_keys += work deps = [] for k in work: # if k not in requests: # # there wasn't any request stored use initial config # requests[k] = [config] # check if we have collected all dependencies so far # we will come back to this node another time # TODO: make a better check for the case when dependants[k] is a set, also: why is it a set in the first place..? if ( k in dependants and len(dependants[k]) != len(requests[k]) and not isinstance(dependants[k], set) ): # print(f'Waiting at {k}', dependants[k], requests[k]) continue # print(f"configuring {k}",requests[k]) # set configuration for this node k # If we create a delayed object from a class, `self` will be dsk[k][1] argument_is_node = None if isinstance(dsk[k], tuple): # check the first argument # # TODO: make tests for all use cases and then remove for-loop for ain in range(1): if hasattr(dsk[k][ain], "__self__"): if isinstance(dsk[k][ain].__self__, Node): argument_is_node = ain # Check if we get a node of type Node class if argument_is_node is not None: # current_requests = [r for r in requests[k] if r] # get all requests belonging to this node current_requests = requests[k] new_request = dsk[k][argument_is_node].__self__.configure( current_requests ) # Call the class configuration function if not isinstance( new_request, list ): # prepare the request return value new_request = [new_request] else: # We didn't get a Node class so there is no # custom configuration function: pass through if len(requests[k]) > 1: if callable(default_merge): new_request = default_merge(requests[k]) else: raise RuntimeError( "No valid default merger supplied. Cannot merge requests. " "Either convert your function to a class Node or provide " "a default merger" ) else: new_request = requests[k] if (new_request[0] is not None and "requires_request" in new_request[0] and new_request[0]["requires_request"] == True ): del new_request[0]["requires_request"] # TODO: check if we need a deepcopy here! input_requests[k] = copy.deepcopy(new_request[0]) # update dependencies current_deps = get_dependencies(dsk, k, as_list=True) for i, d in enumerate(current_deps): if d in requests: requests[d] += new_request remove[d] = remove[d] and (new_request[0] is None) else: requests[d] = new_request # if we received None remove[d] = new_request[0] is None # only configure each node once in a round! if d not in new_work and d not in work: # TODO: verify this new_work.append( d ) # TODO: Do we need to configure dependency if we'll remove it? work = new_work # Assembling the configured new graph out = {k: dsk[k] for k in out_keys if not remove[k]} # After we have aquired all requests we can input the required_requests as a input node to the requiring node # we assume that the last argument is the request for k in input_requests: # Here we assume that we always receive the same tuple of (bound method, data, request) # If the interface changes this will break #TODO: check for all cases if isinstance(out[k][2], dict): out[k][2].update(input_requests[k]) else: # replace the last entry out[k] = out[k][:2] + (input_requests[k],) # convert to delayed object from dask.delayed import Delayed in_keys = list(flatten(keys)) # print(in_keys) if len(in_keys) > 1: collection = [Delayed(key=key, dsk=out) for key in in_keys] else: collection = Delayed(key=in_keys[0], dsk=out) if isinstance(delayed, list): collection = [collection] return collection def optimize_freeze(dsk, keys, request_key="request"): """ Return new dask with tasks removed which are unnecessary because a later stage reads from cache ``keys`` may be a single key or list of keys. Examples -------- Returns ------- dsk: culled dask graph dependencies: Dict mapping {key: [deps]}. Useful side effect to accelerate other optimizations, notably fuse. """ if not isinstance(keys, (list, set)): keys = [keys] out_keys = [] seen = set() dependencies = dict() if request_key not in dsk: raise RuntimeError( f"Please provide a task graph which includes '{request_key}'" ) request = dsk[request_key] def is_cached(task, request): if isinstance(task, tuple): if isinstance(task[0], Freezer): return task[0].is_cached(request) return False work = list(set(flatten(keys))) cached_keys = [] while work: new_work = [] out_keys += work deps = [] for k in work: if is_cached(dsk[k], request): cached_keys.append(k) else: deps.append((k, get_dependencies(dsk, k, as_list=True))) dependencies.update(deps) for _, deplist in deps: for d in deplist: if d not in seen: seen.add(d) new_work.append(d) work = new_work out = {k: dsk[k] for k in out_keys} # finally we need to replace the input of the caching nodes with the request cached = {k: (out[k][0], request_key) for k in cached_keys} out.update(cached) return out, dependencies