test_graph.py 3.74 KB
Newer Older
1
import stuett
matthmey's avatar
matthmey committed
2
import datetime as dt
3
4
5
6
7
8
9
10
11
from tests.stuett.sample_data import *

import pytest

# helper function for non-stuett node
@stuett.dat
def bypass(x):
    return x

matthmey's avatar
matthmey committed
12

13
class MyNode(stuett.core.StuettNode):
matthmey's avatar
matthmey committed
14
15
16
17
    def __init__(self):
        super().__init__()

    def forward(self, data=None, request=None):
matthmey's avatar
matthmey committed
18
        print(data, request)
matthmey's avatar
matthmey committed
19
        return data + 4
20

matthmey's avatar
matthmey committed
21
    def configure(self, requests=None):
22
        requests = super().configure(requests)
matthmey's avatar
matthmey committed
23
24
        if "start_time" in requests:
            requests["start_time"] += 1
25
26
        return requests

matthmey's avatar
matthmey committed
27

matthmey's avatar
matthmey committed
28
29
30
class MyMerge(stuett.core.StuettNode):
    def forward(self, data, request):
        return data[0] + data[1]
31

matthmey's avatar
matthmey committed
32

33
class MySource(stuett.data.DataSource):
matthmey's avatar
matthmey committed
34
    def __init__(self, start_time=None):
matthmey's avatar
matthmey committed
35
36
        super().__init__(start_time=start_time)

matthmey's avatar
matthmey committed
37
    def forward(self, data=None, request=None):
matthmey's avatar
matthmey committed
38
39
        return request["start_time"]

40

matthmey's avatar
matthmey committed
41
42
43
44
def test_configuration():
    node = MyNode()

    # create a stuett graph
matthmey's avatar
matthmey committed
45
46
47
48
49
50
51
52
53
54
55
    x_in = node(data=5, delayed=True)  # data input
    x = bypass(x_in)
    x = node(x, delayed=True)

    print(x.compute())

    # x.visualize()
    import dask

    dsk, dsk_keys = dask.base._extract_graph_and_keys([x])
    print(dict(dsk))
matthmey's avatar
matthmey committed
56
57
58
59
60

    # create a configuration file
    config = {}

    # configure the graph
matthmey's avatar
matthmey committed
61
62
63
64
65
66
67
68
    x_configured = stuett.core.configuration(
        x, config
    )  # BUG: somehow config cuts the graph off

    import dask

    dsk, dsk_keys = dask.base._extract_graph_and_keys([x_configured])
    print(dsk)
matthmey's avatar
matthmey committed
69

matthmey's avatar
matthmey committed
70
71
72
    x = x_configured.compute()

    print(x)
matthmey's avatar
matthmey committed
73
74
    # TODO: finalize test

matthmey's avatar
matthmey committed
75
76
77
78

test_configuration()


matthmey's avatar
matthmey committed
79
80
81
def test_datasource():
    source = MySource()
    node = MyNode()
82

matthmey's avatar
matthmey committed
83
84
85
    # create a stuett graph
    x = source(delayed=True)
    x = bypass(x)
matthmey's avatar
matthmey committed
86
    x = node(x, delayed=True)
87

matthmey's avatar
matthmey committed
88
89
    # create a configuration file
    config = {"start_time": 0, "end_time": 1}
90

matthmey's avatar
matthmey committed
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
    # configure the graph
    configured = stuett.core.configuration(x, config)

    x_configured = configured.compute(
        scheduler="single-threaded", rerun_exceptions_locally=True
    )

    assert x_configured == 5


test_datasource()


def test_merging():
    source = MySource()
    node = MyNode()
    merge = MyMerge()
108

matthmey's avatar
matthmey committed
109
110
111
112
113
114
115
    # create a stuett graph
    import dask

    x = source(delayed=True)
    x = bypass(x)
    x = node(x, delayed=True)
    x_b = node(x, delayed=True)
matthmey's avatar
matthmey committed
116
    x = merge([x_b, x], delayed=True)
matthmey's avatar
matthmey committed
117
118
119

    # create a configuration file
    config = {"start_time": 0, "end_time": 1}
120

matthmey's avatar
matthmey committed
121
122
    # configure the graph
    configured = stuett.core.configuration(x, config)
123

matthmey's avatar
matthmey committed
124
    x_configured = configured.compute()
125

matthmey's avatar
matthmey committed
126
    assert x_configured == 14
127

matthmey's avatar
matthmey committed
128
    # TODO: Test default_merge
129

matthmey's avatar
matthmey committed
130
    # TODO:
131

matthmey's avatar
matthmey committed
132
    try:
matthmey's avatar
matthmey committed
133
        """
matthmey's avatar
matthmey committed
134
135
136
137
138
        This still fails. Configuration cannot handle "apply" nodes in the dask graph.
        The problem arises if we have a branch that needs to merge request and the 
        apply node is the one which receives the two requests. Then it requires a 
        default merge handler, which is not supplied (and does not need to) because the
        underlying function in this case supplies the merger.
matthmey's avatar
matthmey committed
139
        """
matthmey's avatar
matthmey committed
140
141
142
143
        x = source(delayed=True)(dask_key_name="source")
        x = bypass(x)(dask_key_name="bypass")
        x = node(x, delayed=True)(dask_key_name="node")
        x_b = node(x, delayed=True)(dask_key_name="branch")  # branch
matthmey's avatar
matthmey committed
144
        x = merge([x_b, x], delayed=True)(dask_key_name="merge")
145

matthmey's avatar
matthmey committed
146
147
        # create a configuration file
        config = {"start_time": 0, "end_time": 1}
148

matthmey's avatar
matthmey committed
149
        import dask
150

matthmey's avatar
matthmey committed
151
        dsk, dsk_keys = dask.base._extract_graph_and_keys([x])
152

matthmey's avatar
matthmey committed
153
        print(dict(dsk))
154

matthmey's avatar
matthmey committed
155
156
        # configure the graph
        configured = stuett.core.configuration(x, config)
157

matthmey's avatar
matthmey committed
158
        x_configured = configured.compute()
159

matthmey's avatar
matthmey committed
160
161
162
163
        assert x_configured == 14
    except Exception as e:
        print(e)
        pass
164

matthmey's avatar
matthmey committed
165
166

test_merging()