test_graph.py 2.19 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import stuett
import datetime as dt 
from tests.stuett.sample_data import *

import pytest

# helper function for non-stuett node
@stuett.dat
def bypass(x):
    return x

class MyNode(stuett.core.StuettNode):
    @stuett.dat
    def __call__(self,x):
        return x + 4

    def configure(self,requests=None):
        requests = super().configure(requests)
        if 'start_time' in requests:
            requests['start_time'] += 1
        return requests

class MyMerge(stuett.core.StuettNode):
    @stuett.dat
    def __call__(self,x,y):
        return x + y

class MySource(stuett.data.DataSource):
    @stuett.dat
    def __call__(self,request=None):
        return request['start_time']

class TestConfiguration(object):
    def test_configuration(self):      
        node = MyNode()

        # create a stuett graph
        x = node({'start_time':0,'end_time':-1})
        x = bypass(x)
        x = node(x)

        # create a configuration file
        config = {}

        # configure the graph
        x_configured = stuett.core.configuration(x,config)

        #TODO: finalize test

    def test_datasource(self):
        source = MySource()
        node = MyNode()

        # create a stuett graph
        x = source()
        x = bypass(x)
        x = node(x)

        # create a configuration file
        config = {'start_time':0,'end_time':1}

        # configure the graph
        configured = stuett.core.configuration(x,config)

        x_configured = configured.compute()

        assert x_configured == 5


    def test_merging(self):
        source = MySource()
        node = MyNode()
        merge = MyMerge()

        # create a stuett graph
        import dask
        x   = source(dask_key_name='source')
        x   = bypass(x,dask_key_name='bypass')
        x   = node(x,dask_key_name='node')
        x_b = node(x,dask_key_name='branch')       #branch
        x   = merge(x_b,x,dask_key_name='merge')

        # create a configuration file
        config = {'start_time':0,'end_time':1}

        # configure the graph
        configured = stuett.core.configuration(x,config)

        x_configured = configured.compute()

        assert x_configured  == 14

        # TODO: Test default_merge

        # TODO: