To receive notifications about scheduled maintenance, please subscribe to the mailing-list gitlab-operations@sympa.ethz.ch. You can subscribe to the mailing-list at https://sympa.ethz.ch

Commit e51d964c authored by matthmey's avatar matthmey
Browse files

added csv

parent d7a6964d
......@@ -64,6 +64,13 @@ class Node(object):
class StuettNode(Node): # TODO: define where this class should be (maybe not here)
def __init__(self, **kwargs):
self.config = locals().copy()
while "kwargs" in self.config and self.config["kwargs"]:
self.config.update(self.config["kwargs"])
del self.config["kwargs"]
del self.config["self"]
def configure(self, requests):
""" Default configure for stuett nodes
Expects two keys per request (*start_time* and *tend*)
......
......@@ -14,6 +14,7 @@ import zarr
import xarray as xr
from PIL import Image
import base64
import re
from pathlib import Path
import warnings
......@@ -25,7 +26,8 @@ import datetime as dt
class DataSource(StuettNode):
def __init__(self):
def __init__(self, **kwargs):
super().__init__(kwargs=kwargs)
pass
def configure(self, requests=None):
......@@ -45,7 +47,16 @@ class DataSource(StuettNode):
class SeismicSource(DataSource):
def __init__(self, config={}, use_arclink=False, return_obspy=False):
def __init__(
self,
station=None,
channel=None,
start_time=None,
end_time=None,
use_arclink=False,
return_obspy=False,
**kwargs,
): # TODO: update description
""" Seismic data source to get data from permasense
The user can predefine the source's settings or provide them in a request
Predefined setting should never be updated (to stay thread safe), but will be ignored
......@@ -56,21 +67,14 @@ class SeismicSource(DataSource):
use_arclink {bool} -- If true, downloads the data from the arclink service (authentication required) (default: {False})
return_obspy {bool} -- By default an xarray is returned. If true, an obspy stream will be returned (default: {False})
"""
self.config = config
self.use_arclink = use_arclink
self.return_obspy = return_obspy
if "source" not in self.config:
self.config["source"] = None
if use_arclink:
arclink = get_setting("arclink")
arclink_user = arclink["user"]
arclink_password = arclink["password"]
self.fdsn_client = Client(
base_url="http://arclink.ethz.ch",
user=arclink_user,
password=arclink_password,
super().__init__(
station=station,
channel=channel,
start_time=start_time,
end_time=end_time,
use_arclink=use_arclink,
return_obspy=return_obspy,
kwargs=kwargs,
)
@dask.delayed
......@@ -80,9 +84,16 @@ class SeismicSource(DataSource):
if request is not None:
config.update(request)
if self.use_arclink:
# logging.info('Loading seismic with fdsn')
x = self.fdsn_client.get_waveforms(
if config["use_arclink"]:
arclink = get_setting("arclink")
arclink_user = arclink["user"]
arclink_password = arclink["password"]
fdsn_client = Client(
base_url="http://arclink.ethz.ch",
user=arclink_user,
password=arclink_password,
)
x = fdsn_client.get_waveforms(
network="4D",
station=config["station"],
location="A",
......@@ -92,8 +103,11 @@ class SeismicSource(DataSource):
attach_response=True,
)
# TODO: remove response x.remove_response(output=vel)
# TODO: slice start_time / end_time
x = x.slice(
UTCDateTime(config["start_time"]), UTCDateTime(config["end_time"])
)
# TODO: potentially resample
else: # 20180914 is last full day available in permasense_vault
# logging.info('Loading seismic with fdsn')
x = self.get_obspy_stream(
......@@ -103,15 +117,19 @@ class SeismicSource(DataSource):
config["channel"],
)
if not self.return_obspy:
if not config["return_obspy"]:
x = obspy_to_array(x)
# change time coords from relative to absolute time
starttime = obspy.UTCDateTime(x.starttime.values).datetime
starttime = pd.to_datetime(starttime, utc=True)
starttime = pd.to_datetime(starttime, utc=True).tz_localize(
None
) # TODO: change when xarray #3291 is fixed
timedeltas = pd.to_timedelta(x["time"].values, unit="seconds")
xt = starttime + timedeltas
x["time"] = pd.to_datetime(xt, utc=True)
x["time"] = pd.to_datetime(xt, utc=True).tz_localize(
None
) # TODO: change when xarray #3291 is fixed
del x.attrs["stats"]
return x
......@@ -261,8 +279,12 @@ class MHDSLRFilenames(StuettNode):
base_directory {[type]} -- [description]
method {str} -- [description] (default: {'directory'})
"""
self.config = locals().copy() # map the arguments to the config file
del self.config["self"]
super().__init__(
base_directory=base_directory,
method=method,
start_time=start_time,
end_time=end_time,
)
def __call__(self, request=None):
"""Retrieves the images for the selected time period from the server. If only a start_time timestamp is provided,
......@@ -636,7 +658,7 @@ class MHDSLRImages(MHDSLRFilenames):
images = np.array(images)
data = xr.DataArray(
images, coords={"time": times}, dims=["time", "x", "y", "c"]
images, coords={"time": times}, dims=["time", "x", "y", "c"], name="Image"
)
data.attrs["format"] = "jpg"
......@@ -653,7 +675,9 @@ class MHDSLRImages(MHDSLRFilenames):
times.append(timestamp)
images = np.array(images).reshape((-1, 1))
data = xr.DataArray(images, coords={"time": times}, dims=["time", "base64"])
data = xr.DataArray(
images, coords={"time": times}, dims=["time", "base64"], name="Base64Image"
)
data.attrs["format"] = "jpg"
return data
......@@ -682,7 +706,7 @@ class Freezer(StuettNode):
# TODO: make a distinction between requested start_time and freeze_output_start_time
# TODO: add node specific hash to freeze_output_start_time (there might be multiple in the graph) <- probably not necessart becaue we receive a copy of the request which is unique to this node
# TODO: add node specific hash to freeze_output_start_time (there might be multiple in the graph) <- probably not necessary because we receive a copy of the request which is unique to this node
# TODO: maybe the configuration method must add (and delete) the node name in the request?
# we always require a request to crop out the right time period
......@@ -698,11 +722,9 @@ class Freezer(StuettNode):
def open_zarr(self, requests):
ds_zarr = xr.open_zarr(self.store)
print(ds_zarr)
print('read',ds_zarr)
@dask.delayed
def __call__(self, x=None, requests=None):
print(x, requests)
if x is not None: # TODO: check if this is always good
if requests is None:
requests = x
......@@ -718,22 +740,188 @@ class Freezer(StuettNode):
class CsvSource(DataSource):
def __init__(self, config={}):
def __init__(self, filename=None, start_time=None, end_time=None, **kwargs):
super().__init__(
filename=filename, start_time=start_time, end_time=end_time, kwargs=kwargs
)
def __call__(self, request=None):
# TODO: This will stay but the rest will move to forward()
config = self.config.copy()
if request is not None:
config.update(request)
csv = pd.read_csv(self.config["filename"])
csv.set_index("time", inplace=True)
csv.index = pd.to_datetime(csv.index, utc=True).tz_localize(
None
) # TODO: change when xarray #3291 is fixed
x = xr.DataArray(csv, dims=["time", "name"], name="CSV")
try:
unit_coords = []
name_coords = []
for name in x.coords["name"].values:
unit = re.findall(r"\[(.*)\]", name)[0]
name = re.sub(r"\[.*\]", "", name).lstrip().rstrip()
name_coords.append(name)
unit_coords.append(unit)
x.coords["name"] = name_coords
x = x.assign_coords({"unit": ("name", unit_coords)})
except:
# TODO: add a warning or test explicitly if units exist
pass
def __call__(self, request):
if "start_time" not in config:
config["start_time"] = x.coords["time"][0]
if "end_time" not in config:
config["end_time"] = x.coords["time"][-1]
x = x.sel(time=slice(config["start_time"], config["end_time"]))
return x
class LabelSource(DataSource):
def __init__(self, config={}):
def to_datetime(x):
return pd.to_datetime(x, utc=True).tz_localize(
None
) # TODO: change when xarray #3291 is fixed
class BoundingBoxAnnotation(DataSource):
def __init__(
self,
filename=None,
start_time=None,
end_time=None,
converters={"time": to_datetime},
**kwargs,
):
super().__init__(
filename=filename,
start_time=start_time,
end_time=end_time,
converters=converters,
kwargs=kwargs,
)
def __call__(self, request=None):
config = self.config.copy()
if request is not None:
config.update(request)
csv = pd.read_csv(self.config["filename"])
targets = xr.DataArray(csv["__target"], dims=["index"], name="Annotation")
for key in csv:
if key == "__target":
continue
targets = targets.assign_coords({key: ("index", csv[key])})
for key in config["converters"]:
if key in config["converters"]:
converter = config["converters"][key]
else:
converter = lambda x: x
if not callable(converter):
raise RuntimeError("Please provide a callable as column converter")
targets = targets.assign_coords({key: ("index", converter(targets[key]))})
return targets
def check_overlap(data0,data1,sort_data0=True,sort_data1=True):
if sort_data0:
data0 = data0.sortby('time')
if sort_data1:
data1 = data1.sortby('time')
# data0['start_time'] = pd.to_datetime(data0['start_time'],utc=True)
# data0['end_time'] = pd.to_datetime(data0['end_time'],utc=True)
# data1['start_time'] = pd.to_datetime(data1['start_time'],utc=True)
# data1['end_time'] = pd.to_datetime(data1['end_time'],utc=True)
data0_start_time = list(pd.to_datetime(data0['start_time'],utc=True))
data0_end_time = list(pd.to_datetime(data0['end_time'],utc=True))
data1_start_time = list(pd.to_datetime(data1['start_time'],utc=True))
data1_end_time = list(pd.to_datetime(data1['end_time'],utc=True))
overlap_indices = []
# print(data0.head())
num_overlaps = 0
start_idx = 0
for i in range(len(data0)):
# data0_df = data0.iloc[i]
data0_start = data0_start_time[i]
data0_end = data0_end_time[i]
# print(data0_df['start_time'])
label = 'no data'
ext = []
for j in range(start_idx,len(data1)):
# data1_df = data1.iloc[j]
data1_start = data1_start_time[j]
data1_end = data1_end_time[j]
# print(type(data0_df['end_time']),type(data1_df['start_time']))
# check if data0_df is completly before data1_df, then all following items will also be non overlapping (sorted list data1)
cond0 = (data0_end < data1_start)
if cond0 == True:
break
# if data0_df['label'] != data1_df['label']:
# continue
# second condition: data0_df is after data1_df, all items before data1_df can be ignored (sorted list data0)
cond1 = (data0_start > data1_end)
if cond1:
start_idx = j
if not (cond0 or cond1):
# overlap
num_overlaps += 1
label = 'data'
overlap_indices.append([int(i),int(j)])
return overlap_indices
class LabeledDataset(DataSource):
def __init__(self,data,label,trim=True):
''' trim ... trim the dataset to the available labels
'''
# load annotation source and datasource
# define an dataset index containing all indices of the datasource (e.g. timestamps or time period) which should be in this dataset
d = data()
l = label()
print(d['time'])
print(l['time'])
d = d.sortby('time')
l = l.sortby('time')
# for
indices = check_overlap(d,l)
print(indices)
exit()
# go through dataset index and and check overlap of datasource indices and annotation indices
# generate new annotation set with regards to the datasourceindices (requires to generate and empty annotation set add new labels to the it)
# if wanted generate intermediate freeze results of datasource and annotations
# go through all items of the datasource
pass
def __call__(self, request):
return x
def __call__(self):
pass
class PytorchDataset(DataSource):
class PytorchDataset(DataSource): # TODO: extends pytorch dataset
def __init__(self, config={}):
""" Creates a pytorch like dataset from a data source and a label source.
......@@ -746,6 +934,20 @@ class PytorchDataset(DataSource):
if "source" not in self.config:
self.config["source"] = None
def build_dataset(self):
# load annotation source and datasource
# define an dataset index containing all indices of the datasource (e.g. timestamps or time period) which should be in this dataset
# go through dataset index and and check overlap of datasource indices and annotation indices
# generate new annotation set with regards to the datasourceindices (requires to generate and empty annotation set add new labels to the it)
# if wanted generate intermediate freeze results of datasource and annotations
# go through all items of the datasource
pass
def __call__(self, request):
if request is None:
raise RuntimeError("No request provided, cannot provide data")
......
user_dir/*
MHDSLR/*
!.gitkeep
\ No newline at end of file
__annotation_id,__sub_index,__target,time,x,y
5a14dcff-721f-4b7e-8cca-f10ef3e9ed3a,0,mountaineer,2017-08-04 08:12:11+00:00,81,14
5a14dcff-721f-4b7e-8cca-f10ef3e9ed3a,1,mountaineer,2017-08-04 08:12:11+00:00,82,16
7a8219fa-7581-4c0f-91a1-927167979ee9,0,no_visibility,2017-01-01 05:59:04+00:00,0,0
7a8219fa-7581-4c0f-91a1-927167979ee9,1,no_visibility,2017-01-01 06:59:04+00:00,100,100
__annotation_id,__sub_index,__target,time
5a14dcff-721f-4b7e-8cca-f10ef3e9ed3a,0,no_visibility,2017-01-01 05:59:04+00:00
5a14dcff-721f-4b7e-8cca-f10ef3e9ed3a,1,no_visibility,2017-01-01 06:01:04+00:00
7a8219fa-7581-4c0f-91a1-927167979ee9,0,snow,2017-01-01 05:59:04+00:00
7a8219fa-7581-4c0f-91a1-927167979ee9,1,snow,2017-01-01 06:01:04+00:00
75bd6584-edf8-45db-a9e6-f7c8cca52028,0,mountaineer,2017-01-21 09:59:04+00:00
7a8219fa-7581-4c0f-91a1-927167979ee9,1,mountaineer,2017-01-21 10:01:04+00:00
\ No newline at end of file
This diff is collapsed.
......@@ -11,10 +11,8 @@ import pytest
import zarr
from pathlib import Path
test_data_dir = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "..", "data", ""
)
stuett.global_config.set_setting("user_dir", test_data_dir + "user_dir/")
test_data_dir = Path(__file__).absolute().parent.joinpath("..", "data")
stuett.global_config.set_setting("user_dir", test_data_dir.joinpath("user_dir/"))
class TestSeismicSource(object):
......@@ -33,6 +31,10 @@ class TestSeismicSource(object):
x = seismic_source(config)
x = x.compute()
print(x)
return
assert x.mean() == -55.73599312780376
# with config
......@@ -46,34 +48,45 @@ class TestSeismicSource(object):
# TODO: make test to compare start_times
# class TestFreezer(object):
# @pytest.mark.slow
# def test_seismic_source_freeze(self):
# # with config
# seismic_source = stuett.data.SeismicSource(config,use_arclink=True)
@pytest.mark.slow
def test_freeze():
# with config
# seismic_source = stuett.data.SeismicSource(config,use_arclink=True)
filename = Path(test_data_dir).joinpath("timeseries", "MH30_temperature_rock_2017.csv")
node = stuett.data.CsvSource(filename)
user_dir = stuett.global_config.get_setting('user_dir')
store_name = user_dir.joinpath('frozen','test.zarr')
import shutil
shutil.rmtree(store_name,ignore_errors=True)
store = zarr.DirectoryStore(store_name)
# account_name = stuett.global_config.get_setting('azure')['account_name'] if stuett.global_config.setting_exists('azure') else "storageaccountperma8980"
# account_key = stuett.global_config.get_setting('azure')['account_key'] if stuett.global_config.setting_exists('azure') else None
# store = zarr.ABSStore(container='hackathon-on-permafrost', prefix='dataset/test.zarr', account_name=account_name, account_key=account_key, blob_service_kwargs={})
# user_dir = stuett.global_config.get_setting('user_dir')
# store_name = user_dir+'frozen/test.zarr'
# import shutil
# shutil.rmtree(store_name,ignore_errors=True)
# freezer = stuett.data.Freezer(store_name)
freezer = stuett.data.Freezer(store)
# x = freezer(seismic_source())
request = {'start_time':'2017-07-01', 'end_time':'2017-08-01'}
x = freezer(node(request))
# x_config = stuett.core.configuration(x,config)
request = {'start_time':'2017-09-01', 'end_time':'2017-10-01'}
x = freezer(node(request))
# x_config.compute()
# x = freezer()
print('final',x)
# x_config = stuett.core.configuration(x,config)
# request = {'start_time':start_time+offset, 'end_time':end_time+offset}
# x_config = stuett.core.configuration(x,request)
# x_config.compute()
# x_config.compute()
# request = {'start_time':start_time+offset, 'end_time':end_time+offset}
# x_config = stuett.core.configuration(x,request)
# x_config.compute()
test_freeze()
def test_image_filenames():
# first test without config
node = stuett.data.MHDSLRFilenames(
base_directory=os.path.join(test_data_dir, "MHDSLR")
)
node = stuett.data.MHDSLRFilenames(base_directory=test_data_dir.joinpath("MHDSLR"))
start_time = dt.datetime(2017, 8, 6, 9, 56, 12, tzinfo=dt.timezone.utc)
end_time = dt.datetime(2017, 8, 6, 10, 14, 10, tzinfo=dt.timezone.utc)
......@@ -92,6 +105,9 @@ def test_image_filenames():
data = node(config)
# test_image_filenames()
def test_mhdslrimage():
base_dir = Path(test_data_dir).joinpath("MHDSLR")
node = stuett.data.MHDSLRImages(base_directory=base_dir)
......@@ -110,6 +126,8 @@ def test_mhdslrimage():
config["output_format"] = "base64"
data = node(config)
# TODO: assert data
from PIL import Image
img = Image.open(base_dir.joinpath("2017-08-06", "20170806_095212.JPG"))
......@@ -118,5 +136,50 @@ def test_mhdslrimage():
assert data[0].values == img_base64
# test_image_filenames()
test_mhdslrimage()
# test_mhdslrimage()
def test_csv():
filename = Path(test_data_dir).joinpath(
"timeseries", "MH30_temperature_rock_2017.csv"
)
node = stuett.data.CsvSource(filename)
x = node()
length = len(x)
print(x)
# TODO: test with start and end time
def test_annotations():
filename = Path(test_data_dir).joinpath("annotations", "boundingbox_timeseries.csv")
node = stuett.data.BoundingBoxAnnotation(filename)
filename = Path(test_data_dir).joinpath("annotations", "boundingbox_images.csv")
node = stuett.data.BoundingBoxAnnotation(filename)
targets = node()
targets = targets.swap_dims({"index": "time"})
targets = targets.sortby("time")
# print(targets.sel(time=slice('2016-01-01','2016-01-04')))
# print(targets)
# test_annotations()
def test_datasets():
filename = Path(test_data_dir).joinpath("annotations", "boundingbox_timeseries.csv")
label = stuett.data.BoundingBoxAnnotation(filename)
filename = Path(test_data_dir).joinpath("timeseries", "MH30_temperature_rock_2017.csv")
data = stuett.data.CsvSource(filename)
dataset = stuett.data.LabeledDataset(label,data)
# test_datasets()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment