import dask from dask.core import get_dependencies, flatten import numpy as np import copy class Node(object): def __init__(self): pass def configure(self,requests): """ Before a task graph is executed each node is configured. The request is propagated from the end to the beginning of the DAG and each nodes "configure" routine is called. The request can be updated to reflect additional requirements, The return value gets passed to predecessors. Essentially the following question must be answered: What do I need to fulfil the request of my successor? Here, you must not configure the internal parameters of the Node otherwise it would not be thread-safe. You can however introduce a new key 'requires_request' in the request being returned. This request will then be passed as an argument to the __call__ function. Best practice is to configure the Node on initialization with runtime independent configurations and define all runtime dependant configurations here. Arguments: requests {List} -- List of requests (i.e. dictionaries). Returns: dict -- The (updated) request. If updated modifications must be made on a copy of the input. The return value must be a dictionary. If multiple requests are input to this function they must be merged. If nothing needs to be requested an empty dictionary can be return. This removes all dependencies of this node from the task graph. """ if not isinstance(requests,list): raise RuntimeError('Please provide a **list** of request') if len(requests) > 1: raise RuntimeError('Default configuration function cannot handle ' 'multiple requests. Please provide a custom ' 'configuration implementation') return requests @dask.delayed def __call__(self,x,request=None): raise NotImplementedError() def get_config(self): """ returns a dictionary of configurations to recreate the state """ raise NotImplementedError() class StuettNode(Node): # TODO: define where this class should be (maybe not here) def configure(self,requests): """ Default configure for stuett nodes Expects two keys per request (*start_time* and *tend*) If multiple requests are passed, they will be merged start_time = minimum of all requests' start_time end_time = maximum of all requests' end_time Arguments: request {list} -- List of requests Returns: dict -- Original request or merged requests """ if not isinstance(requests,list): raise RuntimeError('Please provide a list of request') # For time requests we just use the union of both time segments new_request = requests[0].copy() key_func = {'start_time':np.minimum, 'end_time':np.maximum} for r in requests[1:]: for key in ['start_time', 'end_time']: if key in r: if key in new_request: new_request[key] = key_func[key](new_request[key],r[key]) else: new_request[key] = r[key] return new_request def configuration(delayed,request,keys=None,default_merge=None): """ Configures each node of the graph by propagating the request from outputs to inputs. Each node checks if it can fulfil the request and what it needs to fulfil the request. If a node requires additional configurations to fulfil the request it can set the 'requires_request' flag in the returned request and this function will add the return request as a a new input to the node's __call__(). See also Node.configure() Arguments: delayed {dask.delayed or list} -- Delayed object or list of delayed objects request {dict or list} -- request (dict), list of requests default_merge {callable} -- request merge function Keyword Arguments: keys {[type]} -- [description] (default: {None}) Raises: RuntimeError: [description] RuntimeError: [description] Returns: dask.delayed or list -- Config-optimized delayed object or list of delayed objects """ if not isinstance(delayed,list): collections = [delayed] # dsk = dask.base.collections_to_dsk(collections) dsk, dsk_keys = dask.base._extract_graph_and_keys(collections) dependencies,dependants = dask.core.get_deps(dsk) if keys is None: keys = dsk_keys print('dsk',dsk.layers) # print('keys',keys) if not isinstance(keys, (list, set)): keys = [keys] out_keys = [] seen = set() work = list(set(flatten(keys))) if isinstance(request,list): if len(request) != len(work): raise RuntimeError("When passing multiple request items " "The number of request items must be same " "as the number of keys") requests = {work[i]: [request[i]] for i in range(len(request)) } else: requests = {k: [request] for k in work } remove = {k:False for k in work} input_requests = {} while work: new_work = [] out_keys += work deps = [] for k in work: # if k not in requests: # # there wasn't any request stored use initial config # requests[k] = [config] # check if we have collected all dependencies so far # we will come back to this node another time # TODO: make a better check for the case when dependants[k] is a set, also: why is it a set in the first place..? if k in dependants and len(dependants[k]) != len(requests[k]) and not isinstance(dependants[k],set): # print(f'Waiting at {k}', dependants[k], requests[k]) continue # print(f"configuring {k}",requests[k]) # set configuration for this node k # If we create a delayed object from a class, `self` will be dsk[k][1] if isinstance(dsk[k],tuple) and isinstance(dsk[k][1],Node): # Check if we get a node of type Node class # current_requests = [r for r in requests[k] if r] # get all requests belonging to this node current_requests = requests[k] new_request = dsk[k][1].configure(current_requests) # Call the class configuration function if not isinstance(new_request,list): # prepare the request return value new_request = [new_request] else: # We didn't get a Node class so there is no # custom configuration function: pass through if len(requests[k]) > 1: if callable(default_merge): new_request = default_merge(requests[k]) else: raise RuntimeError("No valid default merger supplied. Cannot merge requests. " "Either convert your function to a class Node or provide " "a default merger") else: new_request = requests[k] if 'requires_request' in new_request[0] and new_request[0]['requires_request'] == True: del new_request[0]['requires_request'] input_requests[k] = copy.deepcopy(new_request[0]) #TODO: check if we need a deepcopy here! # update dependencies current_deps = get_dependencies(dsk, k, as_list=True) for i, d in enumerate(current_deps): if d in requests: requests[d] += new_request remove[d] = remove[d] and (not new_request[0]) else: requests[d] = new_request remove[d] = (not new_request[0]) # if we received an empty dictionary flag deps for removal # only configure each node once in a round! if d not in new_work and d not in work: # TODO: verify this new_work.append(d) # TODO: Do we need to configure dependency if we'll remove it? work = new_work # Assembling the configured new graph out = {k: dsk[k] for k in out_keys if not remove[k]} # After we have aquired all requests we can input the required_requests as a input node to the requiring node for k in input_requests: out[k] += (input_requests[k],) # convert to delayed object from dask.delayed import Delayed in_keys = list(flatten(keys)) print(in_keys) if len(in_keys) > 1: collection = [Delayed(key=key,dsk=out) for key in in_keys] else: collection = Delayed(key=in_keys[0],dsk=out) if isinstance(collection,list): collection = [collection] return collection class Freezer(Node): def __init__(self,caching=True): self.caching = caching @dask.delayed def __call__(self, x): """If caching is enabled load a cached result or stores the input data and returns it Arguments: x {xarray or dict} -- Either the xarray data to be passed through (and cached) or request dictionary containing information about the data to be loaded Returns: xarray -- Data loaded from cache or input data passed through """ if isinstance(x,dict): if self.is_cached(x) and self.caching: # TODO: load from cache and return it pass elif not self.caching: raise RuntimeError(f'If caching is disabled cannot perform request {x}') else: raise RuntimeError(f'Result is not cached but cached result is requested with {x}') if self.caching: # TODO: store the input data pass return x def configure(self,requests): if self.caching: return [{}] return config_conflict(requests) def optimize_freeze(dsk, keys, request_key='request'): """ Return new dask with tasks removed which are unnecessary because a later stage reads from cache ``keys`` may be a single key or list of keys. Examples -------- Returns ------- dsk: culled dask graph dependencies: Dict mapping {key: [deps]}. Useful side effect to accelerate other optimizations, notably fuse. """ if not isinstance(keys, (list, set)): keys = [keys] out_keys = [] seen = set() dependencies = dict() if (request_key not in dsk): raise RuntimeError(f"Please provide a task graph which includes '{request_key}'") request = dsk[request_key] def is_cached(task,request): if isinstance(task,tuple): if isinstance(task[0],Freezer): return task[0].is_cached(request) return False work = list(set(flatten(keys))) cached_keys = [] while work: new_work = [] out_keys += work deps = [] for k in work: if is_cached(dsk[k],request): cached_keys.append(k) else: deps.append((k, get_dependencies(dsk, k, as_list=True))) dependencies.update(deps) for _, deplist in deps: for d in deplist: if d not in seen: seen.add(d) new_work.append(d) work = new_work out = {k: dsk[k] for k in out_keys} # finally we need to replace the input of the caching nodes with the request cached = {k: (out[k][0],request_key) for k in cached_keys} out.update(cached) return out, dependencies