graph.py 13 KB
Newer Older
matthmey's avatar
matthmey committed
1
import dask
2
from dask.core import get_dependencies, flatten
matthmey's avatar
matthmey committed
3
4
5
import numpy as np
import copy

6
7
8
9
10

class Node(object):
    def __init__(self):
        pass

matthmey's avatar
matthmey committed
11
    def configure(self, requests):
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
        """ Before a task graph is executed each node is configured.
            The request is propagated from the end to the beginning 
            of the DAG and each nodes "configure" routine is called.
            The request can be updated to reflect additional requirements,
            The return value gets passed to predecessors.

            Essentially the following question must be answered:
            What do I need to fulfil the request of my successor?

            Here, you must not configure the internal parameters of the
            Node otherwise it would not be thread-safe. You can however
            introduce a new key 'requires_request' in the request being 
            returned. This request will then be passed as an argument
            to the __call__ function.

            Best practice is to configure the Node on initialization with
            runtime independent configurations and define all runtime
            dependant configurations here.
        
        Arguments:
            requests {List} -- List of requests (i.e. dictionaries).

        
        Returns:
            dict -- The (updated) request. If updated modifications
                    must be made on a copy of the input. The return value
                    must be a dictionary. 
                    If multiple requests are input to this function they 
                    must be merged.
                    If nothing needs to be requested an empty dictionary
                    can be return. This removes all dependencies of this
                    node from the task graph.

        """
matthmey's avatar
matthmey committed
46
47
        if not isinstance(requests, list):
            raise RuntimeError("Please provide a **list** of request")
48
        if len(requests) > 1:
matthmey's avatar
matthmey committed
49
50
51
52
53
            raise RuntimeError(
                "Default configuration function cannot handle "
                "multiple requests. Please provide a custom "
                "configuration implementation"
            )
54
55
        return requests

56
57
    def __call__(self, data=None, request=None, delayed=False):
        if delayed:
matthmey's avatar
matthmey committed
58
            return dask.delayed(self.forward)(data=data, request=request)
59
60
61
62
63
        else:
            return self.forward(data=data, request=request)

    def forward(self, x, request):
        raise NotImplementedError
64
65
66
67
68
69

    def get_config(self):
        """ returns a dictionary of configurations to recreate the state
        """
        raise NotImplementedError()

matthmey's avatar
matthmey committed
70
71

class StuettNode(Node):  # TODO: define where this class should be (maybe not here)
matthmey's avatar
matthmey committed
72
73
    def __init__(self, **kwargs):
        self.config = locals().copy()
74
75
76
77
        while "kwargs" in self.config:
            if "kwargs" not in self.config["kwargs"]:
                self.config.update(self.config["kwargs"])
                break
matthmey's avatar
matthmey committed
78
            self.config.update(self.config["kwargs"])
79

matthmey's avatar
matthmey committed
80
81
82
        del self.config["kwargs"]
        del self.config["self"]

matthmey's avatar
matthmey committed
83
    def configure(self, requests):
84
85
86
87
88
89
90
91
92
93
94
95
        """ Default configure for stuett nodes
            Expects two keys per request (*start_time* and *tend*)
            If multiple requests are passed, they will be merged
            start_time = minimum of all requests' start_time
            end_time = maximum of all requests' end_time
        
        Arguments:
            request {list} -- List of requests

        Returns:
            dict -- Original request or merged requests 
        """
matthmey's avatar
matthmey committed
96
97
        if not isinstance(requests, list):
            raise RuntimeError("Please provide a list of request")
98
99
100
101

        # For time requests we just use the union of both time segments
        new_request = requests[0].copy()

matthmey's avatar
matthmey committed
102
        key_func = {"start_time": np.minimum, "end_time": np.maximum}
103
        for r in requests[1:]:
matthmey's avatar
matthmey committed
104
            for key in ["start_time", "end_time"]:
105
106
                if key in r:
                    if key in new_request:
matthmey's avatar
matthmey committed
107
                        new_request[key] = key_func[key](new_request[key], r[key])
108
                    else:
matthmey's avatar
matthmey committed
109
110
                        new_request[key] = r[key]

111
112
        return new_request

matthmey's avatar
matthmey committed
113
114

def configuration(delayed, request, keys=None, default_merge=None):
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
    """ Configures each node of the graph by propagating the request from outputs
        to inputs.
        Each node checks if it can fulfil the request and what it needs to fulfil the request.
        If a node requires additional configurations to fulfil the request it can set the
        'requires_request' flag in the returned request and this function will add the 
        return request as a a new input to the node's __call__(). See also Node.configure()
    
    Arguments:
        delayed {dask.delayed or list}  -- Delayed object or list of delayed objects
        request {dict or list}           -- request (dict), list of requests
        default_merge {callable}        -- request merge function 
    
    Keyword Arguments:
        keys {[type]} -- [description] (default: {None})
    
    Raises:
        RuntimeError: [description]
        RuntimeError: [description]
    
    Returns:
        dask.delayed or list -- Config-optimized delayed object or list of delayed objects
    """

matthmey's avatar
matthmey committed
138
    if not isinstance(delayed, list):
139
        collections = [delayed]
matthmey's avatar
matthmey committed
140
141
    else:
        collections = delayed
142
143
144

    # dsk = dask.base.collections_to_dsk(collections)
    dsk, dsk_keys = dask.base._extract_graph_and_keys(collections)
matthmey's avatar
matthmey committed
145
    dependencies, dependants = dask.core.get_deps(dsk)
146
147
148
149
150
151
152
153
154
155
156

    if keys is None:
        keys = dsk_keys

    if not isinstance(keys, (list, set)):
        keys = [keys]
    out_keys = []
    seen = set()

    work = list(set(flatten(keys)))

matthmey's avatar
matthmey committed
157
    if isinstance(request, list):
158
        if len(request) != len(work):
matthmey's avatar
matthmey committed
159
160
161
162
163
164
165
            raise RuntimeError(
                "When passing multiple request items "
                "The number of request items must be same "
                "as the number of keys"
            )

        requests = {work[i]: [request[i]] for i in range(len(request))}
166
    else:
matthmey's avatar
matthmey committed
167
        requests = {k: [request] for k in work}
168

matthmey's avatar
matthmey committed
169
    remove = {k: False for k in work}
170
171
172
173
174
175
176
177
178
179
180
181
182
    input_requests = {}
    while work:
        new_work = []
        out_keys += work
        deps = []
        for k in work:
            # if k not in requests:
            #     # there wasn't any request stored use initial config
            #     requests[k] = [config]

            # check if we have collected all dependencies so far
            # we will come back to this node another time
            # TODO: make a better check for the case when dependants[k] is a set, also: why is it a set in the first place..?
matthmey's avatar
matthmey committed
183
184
185
186
187
            if (
                k in dependants
                and len(dependants[k]) != len(requests[k])
                and not isinstance(dependants[k], set)
            ):
188
189
190
191
192
193
194
                # print(f'Waiting at {k}', dependants[k], requests[k])
                continue

            # print(f"configuring {k}",requests[k])
            # set configuration for this node k

            # If we create a delayed object from a class, `self` will be dsk[k][1]
matthmey's avatar
matthmey committed
195
196
197
            if isinstance(dsk[k], tuple) and isinstance(
                dsk[k][1], Node
            ):  # Check if we get a node of type Node class
198
199
                # current_requests = [r for r in requests[k] if r]                    # get all requests belonging to this node
                current_requests = requests[k]
matthmey's avatar
matthmey committed
200
201
202
203
204
205
                new_request = dsk[k][1].configure(
                    current_requests
                )  # Call the class configuration function
                if not isinstance(
                    new_request, list
                ):  # prepare the request return value
206
                    new_request = [new_request]
matthmey's avatar
matthmey committed
207
208
            else:  # We didn't get a Node class so there is no
                # custom configuration function: pass through
209
210
211
212
                if len(requests[k]) > 1:
                    if callable(default_merge):
                        new_request = default_merge(requests[k])
                    else:
matthmey's avatar
matthmey committed
213
214
215
216
217
                        raise RuntimeError(
                            "No valid default merger supplied. Cannot merge requests. "
                            "Either convert your function to a class Node or provide "
                            "a default merger"
                        )
218
219
                else:
                    new_request = requests[k]
matthmey's avatar
matthmey committed
220
221
222
223
224
225
226
227
228

            if (
                "requires_request" in new_request[0]
                and new_request[0]["requires_request"] == True
            ):
                del new_request[0]["requires_request"]
                input_requests[k] = copy.deepcopy(
                    new_request[0]
                )  # TODO: check if we need a deepcopy here!
229
230
231
232
233
234
235
236
237

            # update dependencies
            current_deps = get_dependencies(dsk, k, as_list=True)
            for i, d in enumerate(current_deps):
                if d in requests:
                    requests[d] += new_request
                    remove[d] = remove[d] and (not new_request[0])
                else:
                    requests[d] = new_request
matthmey's avatar
matthmey committed
238
239
240
                    remove[d] = not new_request[
                        0
                    ]  # if we received an empty dictionary flag deps for removal
241
242

                # only configure each node once in a round!
matthmey's avatar
matthmey committed
243
244
245
246
247
                if d not in new_work and d not in work:  # TODO: verify this
                    new_work.append(
                        d
                    )  # TODO: Do we need to configure dependency if we'll remove it?

248
249
250
251
252
253
254
255
256
257
        work = new_work

    # Assembling the configured new graph
    out = {k: dsk[k] for k in out_keys if not remove[k]}
    # After we have aquired all requests we can input the required_requests as a input node to the requiring node
    for k in input_requests:
        out[k] += (input_requests[k],)

    # convert to delayed object
    from dask.delayed import Delayed
matthmey's avatar
matthmey committed
258

259
    in_keys = list(flatten(keys))
matthmey's avatar
matthmey committed
260
    # print(in_keys)
261
    if len(in_keys) > 1:
matthmey's avatar
matthmey committed
262
        collection = [Delayed(key=key, dsk=out) for key in in_keys]
263
    else:
matthmey's avatar
matthmey committed
264
        collection = Delayed(key=in_keys[0], dsk=out)
matthmey's avatar
matthmey committed
265
        if isinstance(delayed, list):
266
267
268
269
270
271
            collection = [collection]

    return collection


class Freezer(Node):
matthmey's avatar
matthmey committed
272
    def __init__(self, caching=True):
273
274
275
276
277
278
279
280
281
282
283
284
285
286
        self.caching = caching

    @dask.delayed
    def __call__(self, x):
        """If caching is enabled load a cached result or stores the input data and returns it
        
        Arguments:
            x {xarray or dict} -- Either the xarray data to be passed through (and cached)
                                  or request dictionary containing information about the data
                                  to be loaded
        
        Returns:
            xarray -- Data loaded from cache or input data passed through
        """
matthmey's avatar
matthmey committed
287
288

        if isinstance(x, dict):
289
290
291
292
            if self.is_cached(x) and self.caching:
                # TODO: load from cache and return it
                pass
            elif not self.caching:
matthmey's avatar
matthmey committed
293
                raise RuntimeError(f"If caching is disabled cannot perform request {x}")
294
            else:
matthmey's avatar
matthmey committed
295
296
297
298
                raise RuntimeError(
                    f"Result is not cached but cached result is requested with {x}"
                )

299
        if self.caching:
matthmey's avatar
matthmey committed
300
301
            # TODO: store the input data
            pass
302
303
304

        return x

matthmey's avatar
matthmey committed
305
    def configure(self, requests):
306
307
308
309
310
        if self.caching:
            return [{}]
        return config_conflict(requests)


matthmey's avatar
matthmey committed
311
def optimize_freeze(dsk, keys, request_key="request"):
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
    """ Return new dask with tasks removed which are unnecessary because a later stage 
    reads from cache
    ``keys`` may be a single key or list of keys.
    Examples
    --------

    Returns
    -------
    dsk: culled dask graph
    dependencies: Dict mapping {key: [deps]}.  Useful side effect to accelerate
        other optimizations, notably fuse.
    """
    if not isinstance(keys, (list, set)):
        keys = [keys]
    out_keys = []
    seen = set()
    dependencies = dict()

matthmey's avatar
matthmey committed
330
331
332
333
334
    if request_key not in dsk:
        raise RuntimeError(
            f"Please provide a task graph which includes '{request_key}'"
        )

335
336
    request = dsk[request_key]

matthmey's avatar
matthmey committed
337
338
339
    def is_cached(task, request):
        if isinstance(task, tuple):
            if isinstance(task[0], Freezer):
340
                return task[0].is_cached(request)
matthmey's avatar
matthmey committed
341
        return False
342
343
344
345
346
347
348
349

    work = list(set(flatten(keys)))
    cached_keys = []
    while work:
        new_work = []
        out_keys += work
        deps = []
        for k in work:
matthmey's avatar
matthmey committed
350
            if is_cached(dsk[k], request):
351
352
353
                cached_keys.append(k)
            else:
                deps.append((k, get_dependencies(dsk, k, as_list=True)))
matthmey's avatar
matthmey committed
354

355
356
357
358
359
360
361
362
363
364
365
        dependencies.update(deps)
        for _, deplist in deps:
            for d in deplist:
                if d not in seen:
                    seen.add(d)
                    new_work.append(d)
        work = new_work

    out = {k: dsk[k] for k in out_keys}

    # finally we need to replace the input of the caching nodes with the request
matthmey's avatar
matthmey committed
366
    cached = {k: (out[k][0], request_key) for k in cached_keys}
367
368
369
    out.update(cached)

    return out, dependencies