To receive notifications about scheduled maintenance, please subscribe to the mailing-list gitlab-operations@sympa.ethz.ch. You can subscribe to the mailing-list at https://sympa.ethz.ch

Commit 8b642f83 authored by matthmey's avatar matthmey
Browse files

working on proper configuration

parent ff3830b7
......@@ -2,23 +2,5 @@ from __future__ import absolute_import
from . import data
from . import global_config
from .convenience import *
import dask
# TODO: make it a proper decorator with arguments etc
def dat(x):
""" Helper function to tranform input callable into
dask.delayed object
From low german 'stütt dat!' which means "support it!"
Arguments:
x {callable} -- Any input callable which is supported
by dask.delayed
Returns:
dask.delayed object --
"""
return dask.delayed(x)
from pandas import to_timedelta, to_timedelta
import dask
# TODO: make it a proper decorator with arguments etc
def dat(x):
""" Helper function to tranform input callable into
dask.delayed object
From low german 'stütt dat!' which means "support it!"
Arguments:
x {callable} -- Any input callable which is supported
by dask.delayed
Returns:
dask.delayed object --
"""
return dask.delayed(x)
\ No newline at end of file
......@@ -55,11 +55,11 @@ class Node(object):
def __call__(self, data=None, request=None, delayed=False):
if delayed:
return dask.delayed(self.forward)(data=data, request=request)
return dask.delayed(self.forward)(data, request)
else:
return self.forward(data=data, request=request)
def forward(self, x, request):
def forward(self, data, request):
raise NotImplementedError
def get_config(self):
......@@ -192,12 +192,20 @@ def configuration(delayed, request, keys=None, default_merge=None):
# set configuration for this node k
# If we create a delayed object from a class, `self` will be dsk[k][1]
if isinstance(dsk[k], tuple) and isinstance(
dsk[k][1], Node
): # Check if we get a node of type Node class
print(dsk[k])
print(type(dsk[k][1]))
argument_is_node = None
if isinstance(dsk[k],tuple):
# check the first two arguments # TODO: possibly change this to only the first argument
for ain in range(2):
if hasattr(dsk[k][ain], '__self__'):
if isinstance(dsk[k][ain].__self__, Node):
argument_is_node = ain
# Check if we get a node of type Node class
if argument_is_node is not None:
# current_requests = [r for r in requests[k] if r] # get all requests belonging to this node
current_requests = requests[k]
new_request = dsk[k][1].configure(
new_request = dsk[k][argument_is_node].__self__.configure(
current_requests
) # Call the class configuration function
if not isinstance(
......@@ -223,9 +231,10 @@ def configuration(delayed, request, keys=None, default_merge=None):
and new_request[0]["requires_request"] == True
):
del new_request[0]["requires_request"]
# TODO: check if we need a deepcopy here!
input_requests[k] = copy.deepcopy(
new_request[0]
) # TODO: check if we need a deepcopy here!
)
# update dependencies
current_deps = get_dependencies(dsk, k, as_list=True)
......@@ -250,8 +259,13 @@ def configuration(delayed, request, keys=None, default_merge=None):
# Assembling the configured new graph
out = {k: dsk[k] for k in out_keys if not remove[k]}
# After we have aquired all requests we can input the required_requests as a input node to the requiring node
# we assume that the last argument is the request
for k in input_requests:
out[k] += (input_requests[k],)
if isinstance(out[k][-1],dict):
out[k][-1].update(input_requests[k])
else:
# replace the last entry
out[k] = out[k][:-1] + (input_requests[k],)
# convert to delayed object
from dask.delayed import Delayed
......@@ -268,45 +282,6 @@ def configuration(delayed, request, keys=None, default_merge=None):
return collection
class Freezer(Node):
def __init__(self, caching=True):
self.caching = caching
@dask.delayed
def __call__(self, x):
"""If caching is enabled load a cached result or stores the input data and returns it
Arguments:
x {xarray or dict} -- Either the xarray data to be passed through (and cached)
or request dictionary containing information about the data
to be loaded
Returns:
xarray -- Data loaded from cache or input data passed through
"""
if isinstance(x, dict):
if self.is_cached(x) and self.caching:
# TODO: load from cache and return it
pass
elif not self.caching:
raise RuntimeError(f"If caching is disabled cannot perform request {x}")
else:
raise RuntimeError(
f"Result is not cached but cached result is requested with {x}"
)
if self.caching:
# TODO: store the input data
pass
return x
def configure(self, requests):
if self.caching:
return [{}]
return config_conflict(requests)
def optimize_freeze(dsk, keys, request_key="request"):
""" Return new dask with tasks removed which are unnecessary because a later stage
......
from .management import *
from .processing import *
# from .collection import *
from .collection import *
import permasense
import numpy as np
import pandas as pd
import pyarrow as pa
import warnings
import datetime as dt
from .management import DataSource
from ..core import configuration
class DataCollector(DataSource):
def __init__(self, data_paths=[], granularities=[]):
"""Add and choose data path according to its granularity.
The data collector returns the data path given an index segment (index_end - index_start).
The index segment is compared against the given granularities and the mapped data path is
returned. For example, for a time series where the index is a datetime object, the timedelta
of (end_time - start_time) is compared against the given list of granularity timedeltas.
Keyword Arguments:
datapaths {list} -- a list of data paths, e.g. the leafs of a dask graph (default: {[]})
granularities {list} -- a list of sorted granularities (default: {[]})
"""
super().__init__()
self.data_paths = data_paths
self.granularities = granularities
if(len(self.data_paths) != len(self.granularities)):
raise ValueError("Each granularity is supposed to have its corresponding data manager")
if len(self.granularities) > 1 and not self.is_sorted(self.granularities):
raise ValueError('Granularities should be sorted')
def forward(self, data=None, request=None):
if(len(self.data_paths) != len(self.granularities)):
raise ValueError("Each granularity is supposed to have its corresponding data manager")
if len(self.granularities) > 1 and not self.is_sorted(self.granularities):
raise ValueError('Granularities should be sorted')
# TODO: change to generic indices or slices
granularity = request['end_time'] - request['start_time']
data_path = None
for i in range(len(self.granularities)):
print(i,granularity,'<',self.granularities[i],self.data_paths[i])
if granularity < self.granularities[i]:
data_path = self.data_paths[i]
break
if data_path is None:
raise AttributeError('No data manager can be used for this timeframe')
return data_path
def is_sorted(self, l):
"""Check whether a list is sorted
Arguments:
l {list} -- the list to be determined whether sorted
Returns:
[bool] -- if the list is sorted, return true
"""
return all(a <= b for a, b in zip(l, l[1:]))
\ No newline at end of file
......@@ -9,6 +9,7 @@ import obspy
from obspy.clients.fdsn import Client
from obspy import UTCDateTime
from obsplus import obspy_to_array
from copy import deepcopy
import zarr
import xarray as xr
......@@ -31,6 +32,7 @@ class DataSource(StuettNode):
super().__init__(kwargs=kwargs)
def __call__(self, data=None, request=None, delayed=False):
print(data,request)
if data is not None:
if request is None:
request = data
......@@ -40,15 +42,25 @@ class DataSource(StuettNode):
)
# DataSource only require a request
# Therefore merge permanent config and request
# Therefore merge permanent-config and request
config = self.config.copy() # TODO: do we need a deep copy?
if request is not None:
config.update(request)
# TODO: change when rewriting for general indices
if 'start_time' in config and config['start_time'] is not None:
config['start_time'] = pd.to_datetime(config['start_time'], utc=True).tz_localize(
None
) # TODO: change when xarray #3291 is fixed
if 'end_time' in config and config['end_time'] is not None:
config['end_time'] = pd.to_datetime(config['end_time'], utc=True).tz_localize(
None
) # TODO: change when xarray #3291 is fixed
if delayed:
return dask.delayed(self.forward)(config)
return dask.delayed(self.forward)(None,config)
else:
return self.forward(config)
return self.forward(None,config)
def configure(self, requests=None):
""" Default configure for DataSource nodes
......@@ -71,7 +83,7 @@ class GSNDataSource(DataSource):
):
super().__init__(deployment=deployment, position=position, vsensor=vsensor, start_time=start_time, end_time=end_time, kwargs=kwargs)
def forward(self, request):
def forward(self, data=None, request=None):
#### 1 - DEFINE VSENSOR-DEPENDENT COLUMNS ####
colnames = pd.read_csv(
Path(get_setting("metadata_directory")).joinpath("vsensor_metadata/{:s}_{:s}.csv".format(
......@@ -229,7 +241,7 @@ class SeismicSource(DataSource):
kwargs=kwargs,
)
def forward(self, request=None):
def forward(self, data=None, request=None):
config = request
if config["use_arclink"]:
......@@ -532,7 +544,7 @@ class MHDSLRFilenames(DataSource):
end_time=end_time,
)
def forward(self, request=None):
def forward(self, data=None, request=None):
"""Retrieves the images for the selected time period from the server. If only a start_time timestamp is provided,
the file with the corresponding date will be loaded if available. For periods (when start and end time are given)
all available images are indexed first to provide an efficient retrieval.
......@@ -877,7 +889,7 @@ class MHDSLRImages(MHDSLRFilenames):
)
self.config["output_format"] = output_format
def forward(self, request):
def forward(self, data=None, request=None):
filenames = super().forward(request=request)
if request["output_format"] is "xarray":
......@@ -983,7 +995,8 @@ class CsvSource(DataSource):
filename=filename, start_time=start_time, end_time=end_time, kwargs=kwargs
)
def forward(self, request):
def forward(self, data=None, request=None):
print(request)
csv = pd.read_csv(request["filename"])
csv.set_index("time", inplace=True)
csv.index = pd.to_datetime(csv.index, utc=True).tz_localize(
......@@ -1041,7 +1054,7 @@ class BoundingBoxAnnotation(DataSource):
kwargs=kwargs,
)
def forward(self, request):
def forward(self, data=None, request=None):
csv = pd.read_csv(request["filename"])
targets = xr.DataArray(csv["__target"], dims=["index"], name="Annotation")
......@@ -1184,7 +1197,7 @@ class LabeledDataset(DataSource):
# go through all items of the datasource
pass
def forward(self, request=None):
def forward(self, data=None, request=None):
pass
......@@ -1212,6 +1225,6 @@ class PytorchDataset(DataSource): # TODO: extends pytorch dataset
# go through all items of the datasource
pass
def forward(self, request):
def forward(self, data=None, request=None):
return x
import stuett
from pathlib import Path
import pandas as pd
test_data_dir = Path(__file__).absolute().parent.joinpath("..", "data")
stuett.global_config.set_setting("user_dir", test_data_dir.joinpath("user_dir/"))
def test_collector():
filename = Path(test_data_dir).joinpath(
"timeseries", "MH30_temperature_rock_2017.csv"
)
node = stuett.data.CsvSource(filename)
minmax_rate2 = stuett.data.MinMaxDownsampling(rate=2, dim="time")
print('creating delayed node')
x = node(delayed=True)
print('downsampled delayed node')
downsampled = minmax_rate2(x,delayed=True)
print('DataCollector node')
data_paths = [x,downsampled]
granularities = [stuett.to_timedelta(180,'s'), stuett.to_timedelta(2,'d')]
collector_node = stuett.data.DataCollector(data_paths, granularities)
print('Instatiating DataCollector node')
request = {'start_time':'2017-08-01', 'end_time':'2017-08-02'}
path = collector_node(request=request)
# print(type(path))
print('Configuration node')
import dask
dsk, dsk_keys = dask.base._extract_graph_and_keys([path])
path = stuett.core.configuration(path,request)
dsk, dsk_keys = dask.base._extract_graph_and_keys([path])
print(dsk)
print(type(path))
print('executing delayed node')
print(path.compute())
# request = {'start_time':'2017-08-01 10:01:00', 'end_time':'2017-08-01 10:02:00'}
# path = collector_node(request=request)
# path = stuett.core.configuration(path,request)
# print(path.compute())
test_collector()
\ No newline at end of file
......@@ -11,9 +11,11 @@ def bypass(x):
class MyNode(stuett.core.StuettNode):
@stuett.dat
def __call__(self, x):
return x + 4
def __init__(self):
super().__init__()
def forward(self, data=None, request=None):
return data + 4
def configure(self, requests=None):
requests = super().configure(requests)
......@@ -22,78 +24,87 @@ class MyNode(stuett.core.StuettNode):
return requests
class MyMerge(stuett.core.StuettNode):
@stuett.dat
def __call__(self, x, y):
return x + y
# class MyMerge(stuett.core.StuettNode):
# def forward(self, x, y):
# return x + y
class MySource(stuett.data.DataSource):
@stuett.dat
def __call__(self, request=None):
def __init__(self,start_time=None):
super().__init__(start_time=start_time)
def forward(self, request=None):
return request["start_time"]
class TestConfiguration(object):
def test_configuration(self):
node = MyNode()
def test_configuration():
node = MyNode()
# create a stuett graph
x = node({"start_time": 0, "end_time": -1})
x = bypass(x)
x = node(x)
# create a configuration file
config = {}
# configure the graph
x_configured = stuett.core.configuration(x, config)
# TODO: finalize test
def test_datasource():
source = MySource()
node = MyNode()
# create a stuett graph
x = node({"start_time": 0, "end_time": -1})
x = bypass(x)
x = node(x)
# create a stuett graph
x = source(delayed=True)
x = bypass(x)
x = node(x,delayed=True)
# create a configuration file
config = {}
# create a configuration file
config = {"start_time": 0, "end_time": 1}
# configure the graph
x_configured = stuett.core.configuration(x, config)
# TODO: finalize test
def test_datasource(self):
source = MySource()
node = MyNode()
# configure the graph
configured = stuett.core.configuration(x, config)
# create a stuett graph
x = source()
x = bypass(x)
x = node(x)
import dask
dsk, dsk_keys = dask.base._extract_graph_and_keys([configured])
# create a configuration file
config = {"start_time": 0, "end_time": 1}
print(dsk)
x_configured = configured.compute(scheduler='single-threaded',rerun_exceptions_locally=True)
# configure the graph
configured = stuett.core.configuration(x, config)
assert x_configured == 5
x_configured = configured.compute()
test_datasource()
assert x_configured == 5
def test_merging(self):
source = MySource()
node = MyNode()
merge = MyMerge()
# def test_merging():
# source = MySource()
# node = MyNode()
# merge = MyMerge()
# create a stuett graph
import dask
# # create a stuett graph
# import dask
x = source(dask_key_name="source")
x = bypass(x, dask_key_name="bypass")
x = node(x, dask_key_name="node")
x_b = node(x, dask_key_name="branch") # branch
x = merge(x_b, x, dask_key_name="merge")
# x = source(dask_key_name="source")
# x = bypass(x, dask_key_name="bypass")
# x = node(x, dask_key_name="node")
# x_b = node(x, dask_key_name="branch") # branch
# x = merge(x_b, x, dask_key_name="merge")
# create a configuration file
config = {"start_time": 0, "end_time": 1}
# # create a configuration file
# config = {"start_time": 0, "end_time": 1}
# configure the graph
configured = stuett.core.configuration(x, config)
# # configure the graph
# configured = stuett.core.configuration(x, config)
x_configured = configured.compute()
# x_configured = configured.compute()
assert x_configured == 14
# assert x_configured == 14
# TODO: Test default_merge
# # TODO: Test default_merge
# TODO:
# # TODO:
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment