To receive notifications about scheduled maintenance, please subscribe to the mailing-list gitlab-operations@sympa.ethz.ch. You can subscribe to the mailing-list at https://sympa.ethz.ch

test_graph.py 3.73 KB
Newer Older
1
import stuett
matthmey's avatar
matthmey committed
2
import datetime as dt
3
4
5
6
7
8
9
10
11
from tests.stuett.sample_data import *

import pytest

# helper function for non-stuett node
@stuett.dat
def bypass(x):
    return x

matthmey's avatar
matthmey committed
12

13
class MyNode(stuett.core.StuettNode):
matthmey's avatar
matthmey committed
14
15
16
17
    def __init__(self):
        super().__init__()

    def forward(self, data=None, request=None):
matthmey's avatar
matthmey committed
18
        print(data, request)
matthmey's avatar
matthmey committed
19
        return data + 4
20

matthmey's avatar
matthmey committed
21
    def configure(self, requests=None):
22
        requests = super().configure(requests)
matthmey's avatar
matthmey committed
23
24
        if "start_time" in requests:
            requests["start_time"] += 1
25
26
        return requests

matthmey's avatar
matthmey committed
27

matthmey's avatar
matthmey committed
28
29
30
class MyMerge(stuett.core.StuettNode):
    def forward(self, data, request):
        return data[0] + data[1]
31

matthmey's avatar
matthmey committed
32

33
class MySource(stuett.data.DataSource):
matthmey's avatar
matthmey committed
34
    def __init__(self, start_time=None):
matthmey's avatar
matthmey committed
35
36
        super().__init__(start_time=start_time)

matthmey's avatar
matthmey committed
37
    def forward(self, data=None, request=None):
matthmey's avatar
matthmey committed
38
39
        return request["start_time"]

40

matthmey's avatar
matthmey committed
41
42
43
44
def test_configuration():
    node = MyNode()

    # create a stuett graph
matthmey's avatar
matthmey committed
45
46
47
48
49
50
51
52
53
54
55
    x_in = node(data=5, delayed=True)  # data input
    x = bypass(x_in)
    x = node(x, delayed=True)

    print(x.compute())

    # x.visualize()
    import dask

    dsk, dsk_keys = dask.base._extract_graph_and_keys([x])
    print(dict(dsk))
matthmey's avatar
matthmey committed
56
57
58
59
60

    # create a configuration file
    config = {}

    # configure the graph
matthmey's avatar
matthmey committed
61
62
63
64
65
66
67
68
    x_configured = stuett.core.configuration(
        x, config
    )  # BUG: somehow config cuts the graph off

    import dask

    dsk, dsk_keys = dask.base._extract_graph_and_keys([x_configured])
    print(dsk)
matthmey's avatar
matthmey committed
69

matthmey's avatar
matthmey committed
70
71
72
    x = x_configured.compute()

    print(x)
matthmey's avatar
matthmey committed
73
74
    # TODO: finalize test

matthmey's avatar
matthmey committed
75
76
77
78

test_configuration()


matthmey's avatar
matthmey committed
79
80
81
def test_datasource():
    source = MySource()
    node = MyNode()
82

matthmey's avatar
matthmey committed
83
84
85
    # create a stuett graph
    x = source(delayed=True)
    x = bypass(x)
matthmey's avatar
matthmey committed
86
    x = node(x, delayed=True)
87

matthmey's avatar
matthmey committed
88
89
    # create a configuration file
    config = {"start_time": 0, "end_time": 1}
90

matthmey's avatar
matthmey committed
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
    # configure the graph
    configured = stuett.core.configuration(x, config)

    x_configured = configured.compute(
        scheduler="single-threaded", rerun_exceptions_locally=True
    )

    assert x_configured == 5


test_datasource()


def test_merging():
    source = MySource()
    node = MyNode()
    merge = MyMerge()
108

matthmey's avatar
matthmey committed
109
110
111
112
113
114
115
116
117
118
119
    # create a stuett graph
    import dask

    x = source(delayed=True)
    x = bypass(x)
    x = node(x, delayed=True)
    x_b = node(x, delayed=True)
    x = merge([x_b, x],delayed=True)

    # create a configuration file
    config = {"start_time": 0, "end_time": 1}
120

matthmey's avatar
matthmey committed
121
122
    # configure the graph
    configured = stuett.core.configuration(x, config)
123

matthmey's avatar
matthmey committed
124
    x_configured = configured.compute()
125

matthmey's avatar
matthmey committed
126
    assert x_configured == 14
127

matthmey's avatar
matthmey committed
128
    # TODO: Test default_merge
129

matthmey's avatar
matthmey committed
130
    # TODO: 
131

matthmey's avatar
matthmey committed
132
133
134
135
136
137
138
139
140
141
142
143
144
    try:
        '''
        This still fails. Configuration cannot handle "apply" nodes in the dask graph.
        The problem arises if we have a branch that needs to merge request and the 
        apply node is the one which receives the two requests. Then it requires a 
        default merge handler, which is not supplied (and does not need to) because the
        underlying function in this case supplies the merger.
        '''
        x = source(delayed=True)(dask_key_name="source")
        x = bypass(x)(dask_key_name="bypass")
        x = node(x, delayed=True)(dask_key_name="node")
        x_b = node(x, delayed=True)(dask_key_name="branch")  # branch
        x = merge([x_b, x],delayed=True)(dask_key_name="merge")
145

matthmey's avatar
matthmey committed
146
147
        # create a configuration file
        config = {"start_time": 0, "end_time": 1}
148

matthmey's avatar
matthmey committed
149

matthmey's avatar
matthmey committed
150
        import dask
151

matthmey's avatar
matthmey committed
152
        dsk, dsk_keys = dask.base._extract_graph_and_keys([x])
153

matthmey's avatar
matthmey committed
154
        print(dict(dsk))
155

matthmey's avatar
matthmey committed
156
157
        # configure the graph
        configured = stuett.core.configuration(x, config)
158

matthmey's avatar
matthmey committed
159
        x_configured = configured.compute()
160

matthmey's avatar
matthmey committed
161
162
163
164
        assert x_configured == 14
    except Exception as e:
        print(e)
        pass
165

matthmey's avatar
matthmey committed
166
test_merging()