test_graph.py 2.2 KB
Newer Older
1
import stuett
matthmey's avatar
matthmey committed
2
import datetime as dt
3
4
5
6
7
8
9
10
11
from tests.stuett.sample_data import *

import pytest

# helper function for non-stuett node
@stuett.dat
def bypass(x):
    return x

matthmey's avatar
matthmey committed
12

13
14
class MyNode(stuett.core.StuettNode):
    @stuett.dat
matthmey's avatar
matthmey committed
15
    def __call__(self, x):
16
17
        return x + 4

matthmey's avatar
matthmey committed
18
    def configure(self, requests=None):
19
        requests = super().configure(requests)
matthmey's avatar
matthmey committed
20
21
        if "start_time" in requests:
            requests["start_time"] += 1
22
23
        return requests

matthmey's avatar
matthmey committed
24

25
26
class MyMerge(stuett.core.StuettNode):
    @stuett.dat
matthmey's avatar
matthmey committed
27
    def __call__(self, x, y):
28
29
        return x + y

matthmey's avatar
matthmey committed
30

31
32
class MySource(stuett.data.DataSource):
    @stuett.dat
matthmey's avatar
matthmey committed
33
34
35
    def __call__(self, request=None):
        return request["start_time"]

36
37

class TestConfiguration(object):
matthmey's avatar
matthmey committed
38
    def test_configuration(self):
39
40
41
        node = MyNode()

        # create a stuett graph
matthmey's avatar
matthmey committed
42
        x = node({"start_time": 0, "end_time": -1})
43
44
45
46
47
48
49
        x = bypass(x)
        x = node(x)

        # create a configuration file
        config = {}

        # configure the graph
matthmey's avatar
matthmey committed
50
        x_configured = stuett.core.configuration(x, config)
51

matthmey's avatar
matthmey committed
52
        # TODO: finalize test
53
54
55
56
57
58
59
60
61
62
63

    def test_datasource(self):
        source = MySource()
        node = MyNode()

        # create a stuett graph
        x = source()
        x = bypass(x)
        x = node(x)

        # create a configuration file
matthmey's avatar
matthmey committed
64
        config = {"start_time": 0, "end_time": 1}
65
66

        # configure the graph
matthmey's avatar
matthmey committed
67
        configured = stuett.core.configuration(x, config)
68
69
70
71
72
73
74
75
76
77
78
79

        x_configured = configured.compute()

        assert x_configured == 5

    def test_merging(self):
        source = MySource()
        node = MyNode()
        merge = MyMerge()

        # create a stuett graph
        import dask
matthmey's avatar
matthmey committed
80
81
82
83
84
85

        x = source(dask_key_name="source")
        x = bypass(x, dask_key_name="bypass")
        x = node(x, dask_key_name="node")
        x_b = node(x, dask_key_name="branch")  # branch
        x = merge(x_b, x, dask_key_name="merge")
86
87

        # create a configuration file
matthmey's avatar
matthmey committed
88
        config = {"start_time": 0, "end_time": 1}
89
90

        # configure the graph
matthmey's avatar
matthmey committed
91
        configured = stuett.core.configuration(x, config)
92
93
94

        x_configured = configured.compute()

matthmey's avatar
matthmey committed
95
        assert x_configured == 14
96
97
98

        # TODO: Test default_merge

matthmey's avatar
matthmey committed
99
        # TODO: