graph.py 14 KB
Newer Older
matthmey's avatar
matthmey committed
1
2
'''MIT License

matthmey's avatar
matthmey committed
3
4
Copyright (c) 2019, Swiss Federal Institute of Technology (ETH Zurich), Matthias Meyer

matthmey's avatar
matthmey committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.'''

matthmey's avatar
matthmey committed
24
import dask
25
from dask.core import get_dependencies, flatten
matthmey's avatar
matthmey committed
26
27
28
import numpy as np
import copy

29
30
31
32
33

class Node(object):
    def __init__(self):
        pass

matthmey's avatar
matthmey committed
34
    def configure(self, requests):
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
        """ Before a task graph is executed each node is configured.
            The request is propagated from the end to the beginning 
            of the DAG and each nodes "configure" routine is called.
            The request can be updated to reflect additional requirements,
            The return value gets passed to predecessors.

            Essentially the following question must be answered:
            What do I need to fulfil the request of my successor?

            Here, you must not configure the internal parameters of the
            Node otherwise it would not be thread-safe. You can however
            introduce a new key 'requires_request' in the request being 
            returned. This request will then be passed as an argument
            to the __call__ function.

            Best practice is to configure the Node on initialization with
            runtime independent configurations and define all runtime
            dependant configurations here.
        
        Arguments:
            requests {List} -- List of requests (i.e. dictionaries).

        
        Returns:
            dict -- The (updated) request. If updated modifications
                    must be made on a copy of the input. The return value
                    must be a dictionary. 
                    If multiple requests are input to this function they 
                    must be merged.
                    If nothing needs to be requested an empty dictionary
                    can be return. This removes all dependencies of this
                    node from the task graph.

        """
matthmey's avatar
matthmey committed
69
70
        if not isinstance(requests, list):
            raise RuntimeError("Please provide a **list** of request")
71
        if len(requests) > 1:
matthmey's avatar
matthmey committed
72
73
74
75
76
            raise RuntimeError(
                "Default configuration function cannot handle "
                "multiple requests. Please provide a custom "
                "configuration implementation"
            )
77
78
        return requests

79
80
    def __call__(self, data=None, request=None, delayed=False):
        if delayed:
matthmey's avatar
matthmey committed
81
            return dask.delayed(self.forward)(data, request)
82
83
84
        else:
            return self.forward(data=data, request=request)

matthmey's avatar
matthmey committed
85
    def forward(self, data, request):
86
        raise NotImplementedError
87
88
89
90
91
92

    def get_config(self):
        """ returns a dictionary of configurations to recreate the state
        """
        raise NotImplementedError()

matthmey's avatar
matthmey committed
93
94

class StuettNode(Node):  # TODO: define where this class should be (maybe not here)
matthmey's avatar
matthmey committed
95
96
    def __init__(self, **kwargs):
        self.config = locals().copy()
97
98
99
100
        while "kwargs" in self.config:
            if "kwargs" not in self.config["kwargs"]:
                self.config.update(self.config["kwargs"])
                break
matthmey's avatar
matthmey committed
101
            self.config.update(self.config["kwargs"])
102

matthmey's avatar
matthmey committed
103
104
105
        del self.config["kwargs"]
        del self.config["self"]

matthmey's avatar
matthmey committed
106
    def configure(self, requests):
107
108
109
110
111
112
113
114
115
116
117
118
        """ Default configure for stuett nodes
            Expects two keys per request (*start_time* and *tend*)
            If multiple requests are passed, they will be merged
            start_time = minimum of all requests' start_time
            end_time = maximum of all requests' end_time
        
        Arguments:
            request {list} -- List of requests

        Returns:
            dict -- Original request or merged requests 
        """
matthmey's avatar
matthmey committed
119
120
        if not isinstance(requests, list):
            raise RuntimeError("Please provide a list of request")
121
122
123
124

        # For time requests we just use the union of both time segments
        new_request = requests[0].copy()

matthmey's avatar
matthmey committed
125
        key_func = {"start_time": np.minimum, "end_time": np.maximum}
126
        for r in requests[1:]:
matthmey's avatar
matthmey committed
127
            for key in ["start_time", "end_time"]:
128
129
                if key in r:
                    if key in new_request:
matthmey's avatar
matthmey committed
130
                        new_request[key] = key_func[key](new_request[key], r[key])
131
                    else:
matthmey's avatar
matthmey committed
132
133
                        new_request[key] = r[key]

134
135
        return new_request

matthmey's avatar
matthmey committed
136
137

def configuration(delayed, request, keys=None, default_merge=None):
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
    """ Configures each node of the graph by propagating the request from outputs
        to inputs.
        Each node checks if it can fulfil the request and what it needs to fulfil the request.
        If a node requires additional configurations to fulfil the request it can set the
        'requires_request' flag in the returned request and this function will add the 
        return request as a a new input to the node's __call__(). See also Node.configure()
    
    Arguments:
        delayed {dask.delayed or list}  -- Delayed object or list of delayed objects
        request {dict or list}           -- request (dict), list of requests
        default_merge {callable}        -- request merge function 
    
    Keyword Arguments:
        keys {[type]} -- [description] (default: {None})
    
    Raises:
        RuntimeError: [description]
        RuntimeError: [description]
    
    Returns:
        dask.delayed or list -- Config-optimized delayed object or list of delayed objects
    """

matthmey's avatar
matthmey committed
161
    if not isinstance(delayed, list):
162
        collections = [delayed]
matthmey's avatar
matthmey committed
163
164
    else:
        collections = delayed
165
166
167

    # dsk = dask.base.collections_to_dsk(collections)
    dsk, dsk_keys = dask.base._extract_graph_and_keys(collections)
matthmey's avatar
matthmey committed
168
    dependencies, dependants = dask.core.get_deps(dsk)
169
170
171
172
173
174
175
176
177
178
179

    if keys is None:
        keys = dsk_keys

    if not isinstance(keys, (list, set)):
        keys = [keys]
    out_keys = []
    seen = set()

    work = list(set(flatten(keys)))

matthmey's avatar
matthmey committed
180
    if isinstance(request, list):
181
        if len(request) != len(work):
matthmey's avatar
matthmey committed
182
183
184
185
186
187
188
            raise RuntimeError(
                "When passing multiple request items "
                "The number of request items must be same "
                "as the number of keys"
            )

        requests = {work[i]: [request[i]] for i in range(len(request))}
189
    else:
matthmey's avatar
matthmey committed
190
        requests = {k: [request] for k in work}
191

matthmey's avatar
matthmey committed
192
    remove = {k: False for k in work}
193
194
195
196
197
198
199
200
201
202
203
204
205
    input_requests = {}
    while work:
        new_work = []
        out_keys += work
        deps = []
        for k in work:
            # if k not in requests:
            #     # there wasn't any request stored use initial config
            #     requests[k] = [config]

            # check if we have collected all dependencies so far
            # we will come back to this node another time
            # TODO: make a better check for the case when dependants[k] is a set, also: why is it a set in the first place..?
matthmey's avatar
matthmey committed
206
207
208
209
210
            if (
                k in dependants
                and len(dependants[k]) != len(requests[k])
                and not isinstance(dependants[k], set)
            ):
211
212
213
214
215
216
                # print(f'Waiting at {k}', dependants[k], requests[k])
                continue

            # print(f"configuring {k}",requests[k])
            # set configuration for this node k
            # If we create a delayed object from a class, `self` will be dsk[k][1]
matthmey's avatar
matthmey committed
217
            argument_is_node = None
matthmey's avatar
matthmey committed
218
219
220
221
222
            if isinstance(dsk[k], tuple):
                # check the first argument
                # # TODO: make tests for all use cases and then remove for-loop
                for ain in range(1):
                    if hasattr(dsk[k][ain], "__self__"):
matthmey's avatar
matthmey committed
223
224
225
226
                        if isinstance(dsk[k][ain].__self__, Node):
                            argument_is_node = ain
            # Check if we get a node of type Node class
            if argument_is_node is not None:
227
228
                # current_requests = [r for r in requests[k] if r]                    # get all requests belonging to this node
                current_requests = requests[k]
matthmey's avatar
matthmey committed
229
                new_request = dsk[k][argument_is_node].__self__.configure(
matthmey's avatar
matthmey committed
230
231
232
233
234
                    current_requests
                )  # Call the class configuration function
                if not isinstance(
                    new_request, list
                ):  # prepare the request return value
235
                    new_request = [new_request]
matthmey's avatar
matthmey committed
236
237
            else:  # We didn't get a Node class so there is no
                # custom configuration function: pass through
238
239
240
241
                if len(requests[k]) > 1:
                    if callable(default_merge):
                        new_request = default_merge(requests[k])
                    else:
matthmey's avatar
matthmey committed
242
243
244
245
246
                        raise RuntimeError(
                            "No valid default merger supplied. Cannot merge requests. "
                            "Either convert your function to a class Node or provide "
                            "a default merger"
                        )
247
248
                else:
                    new_request = requests[k]
matthmey's avatar
matthmey committed
249

matthmey's avatar
matthmey committed
250
251
252
            if (
                new_request[0] is not None
                and "requires_request" in new_request[0]
matthmey's avatar
matthmey committed
253
254
255
                and new_request[0]["requires_request"] == True
            ):
                del new_request[0]["requires_request"]
matthmey's avatar
matthmey committed
256
                # TODO: check if we need a deepcopy here!
matthmey's avatar
matthmey committed
257
                input_requests[k] = copy.deepcopy(new_request[0])
258
259
260
261
262
263

            # update dependencies
            current_deps = get_dependencies(dsk, k, as_list=True)
            for i, d in enumerate(current_deps):
                if d in requests:
                    requests[d] += new_request
matthmey's avatar
matthmey committed
264
                    remove[d] = remove[d] and (new_request[0] is None)
265
266
                else:
                    requests[d] = new_request
matthmey's avatar
matthmey committed
267
                    # if we received None
matthmey's avatar
matthmey committed
268
                    remove[d] = new_request[0] is None
269
270

                # only configure each node once in a round!
matthmey's avatar
matthmey committed
271
272
273
274
275
                if d not in new_work and d not in work:  # TODO: verify this
                    new_work.append(
                        d
                    )  # TODO: Do we need to configure dependency if we'll remove it?

276
277
278
279
280
        work = new_work

    # Assembling the configured new graph
    out = {k: dsk[k] for k in out_keys if not remove[k]}
    # After we have aquired all requests we can input the required_requests as a input node to the requiring node
matthmey's avatar
matthmey committed
281
    # we assume that the last argument is the request
282
    for k in input_requests:
matthmey's avatar
matthmey committed
283
284
        # Here we assume that we always receive the same tuple of (bound method, data, request)
        # If the interface changes this will break #TODO: check for all cases
matthmey's avatar
matthmey committed
285
286

        if isinstance(out[k][2], tuple):
287
288
            # TODO: find a better inversion of to_task_dask()
            if out[k][2][0] == dict:
matthmey's avatar
matthmey committed
289
                my_dict = {item[0]: item[1] for item in out[k][2][1]}
290
291
292
293
294
                my_dict.update(input_requests[k])
                out[k] = out[k][:2] + (my_dict,)
            else:
                # replace the last entry
                out[k] = out[k][:2] + (input_requests[k],)
matthmey's avatar
matthmey committed
295
296
        else:
            # replace the last entry
matthmey's avatar
matthmey committed
297
            out[k] = out[k][:2] + (input_requests[k],)
298
299

    # convert to delayed object
matthmey's avatar
matthmey committed
300
    from dask.delayed import Delayed # TODO: move somewhere else
matthmey's avatar
matthmey committed
301

302
    in_keys = list(flatten(keys))
matthmey's avatar
matthmey committed
303
    # print(in_keys)
matthmey's avatar
matthmey committed
304

305
    if len(in_keys) > 1:
matthmey's avatar
matthmey committed
306
        collection = [Delayed(key=key, dsk=out) for key in in_keys]
307
    else:
matthmey's avatar
matthmey committed
308
        collection = Delayed(key=in_keys[0], dsk=out)
matthmey's avatar
matthmey committed
309
        if isinstance(delayed, list):
310
311
312
313
314
            collection = [collection]

    return collection


matthmey's avatar
matthmey committed
315
def optimize_freeze(dsk, keys, request_key="request"):
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
    """ Return new dask with tasks removed which are unnecessary because a later stage 
    reads from cache
    ``keys`` may be a single key or list of keys.
    Examples
    --------

    Returns
    -------
    dsk: culled dask graph
    dependencies: Dict mapping {key: [deps]}.  Useful side effect to accelerate
        other optimizations, notably fuse.
    """
    if not isinstance(keys, (list, set)):
        keys = [keys]
    out_keys = []
    seen = set()
    dependencies = dict()

matthmey's avatar
matthmey committed
334
335
336
337
338
    if request_key not in dsk:
        raise RuntimeError(
            f"Please provide a task graph which includes '{request_key}'"
        )

339
340
    request = dsk[request_key]

matthmey's avatar
matthmey committed
341
342
343
    def is_cached(task, request):
        if isinstance(task, tuple):
            if isinstance(task[0], Freezer):
344
                return task[0].is_cached(request)
matthmey's avatar
matthmey committed
345
        return False
346
347
348
349
350
351
352
353

    work = list(set(flatten(keys)))
    cached_keys = []
    while work:
        new_work = []
        out_keys += work
        deps = []
        for k in work:
matthmey's avatar
matthmey committed
354
            if is_cached(dsk[k], request):
355
356
357
                cached_keys.append(k)
            else:
                deps.append((k, get_dependencies(dsk, k, as_list=True)))
matthmey's avatar
matthmey committed
358

359
360
361
362
363
364
365
366
367
368
369
        dependencies.update(deps)
        for _, deplist in deps:
            for d in deplist:
                if d not in seen:
                    seen.add(d)
                    new_work.append(d)
        work = new_work

    out = {k: dsk[k] for k in out_keys}

    # finally we need to replace the input of the caching nodes with the request
matthmey's avatar
matthmey committed
370
    cached = {k: (out[k][0], request_key) for k in cached_keys}
371
372
373
    out.update(cached)

    return out, dependencies