Commit b22376df authored by matthmey's avatar matthmey
Browse files

black

parent d397d46b
'''MIT License
"""MIT License
Copyright (c) 2019, Swiss Federal Institute of Technology (ETH Zurich), Matthias Meyer
......@@ -19,13 +19,13 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.'''
SOFTWARE."""
from pandas import to_datetime, to_timedelta
from zarr import DirectoryStore, ABSStore
import dask
import pandas as pd
import pandas as pd
import io, codecs
# TODO: make it a proper decorator with arguments etc
......@@ -46,15 +46,16 @@ def dat(x):
return dask.delayed(x)
def to_csv_with_store(store,filename,dataframe):
StreamWriter = codecs.getwriter('utf-8')
bytes_buffer = io.BytesIO()
def to_csv_with_store(store, filename, dataframe):
StreamWriter = codecs.getwriter("utf-8")
bytes_buffer = io.BytesIO()
string_buffer = StreamWriter(bytes_buffer)
dataframe.to_csv(string_buffer,index=False)
dataframe.to_csv(string_buffer, index=False)
store[filename] = bytes_buffer.getvalue()
def read_csv_with_store(store,filename):
def read_csv_with_store(store, filename):
bytes_buffer = io.BytesIO(store[str(filename)])
StreamReader = codecs.getreader('utf-8')
StreamReader = codecs.getreader("utf-8")
string_buffer = StreamReader(bytes_buffer)
return pd.read_csv(string_buffer)
'''MIT License
"""MIT License
Copyright (c) 2019, Swiss Federal Institute of Technology (ETH Zurich), Matthias Meyer
......@@ -19,7 +19,7 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.'''
SOFTWARE."""
import dask
from dask.core import get_dependencies, flatten
......@@ -297,7 +297,7 @@ def configuration(delayed, request, keys=None, default_merge=None):
out[k] = out[k][:2] + (input_requests[k],)
# convert to delayed object
from dask.delayed import Delayed # TODO: move somewhere else
from dask.delayed import Delayed # TODO: move somewhere else
in_keys = list(flatten(keys))
# print(in_keys)
......
'''MIT License
"""MIT License
Copyright (c) 2019, Swiss Federal Institute of Technology (ETH Zurich), Matthias Meyer
......@@ -19,7 +19,7 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.'''
SOFTWARE."""
import numpy as np
import pandas as pd
......
'''MIT License
"""MIT License
Copyright (c) 2019, Swiss Federal Institute of Technology (ETH Zurich), Matthias Meyer
......@@ -18,7 +18,7 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.'''
SOFTWARE."""
from ..global_config import get_setting, setting_exists, set_setting
from ..core.graph import StuettNode, configuration
......@@ -128,14 +128,13 @@ class GSNDataSource(DataSource):
# code from Samuel Weber
#### 1 - DEFINE VSENSOR-DEPENDENT COLUMNS ####
metadata = Path(get_setting("metadata_directory")).joinpath(
"vsensor_metadata/{:s}_{:s}.csv".format(
request["deployment"], request["vsensor"]
))
# if not metadata.exists():
colnames = pd.read_csv(metadata,
skiprows=0,
"vsensor_metadata/{:s}_{:s}.csv".format(
request["deployment"], request["vsensor"]
)
)
# if not metadata.exists():
colnames = pd.read_csv(metadata, skiprows=0,)
columns_old = colnames["colname_old"].values
columns_new = colnames["colname_new"].values
columns_unit = colnames["unit"].values
......@@ -471,7 +470,6 @@ class SeismicSource(DataSource):
if not isinstance(channels, list):
channels = [channels]
# We will get the full hours seismic data and trim it to the desired length afterwards
tbeg_hours = pd.to_datetime(start_time).replace(
......@@ -506,19 +504,21 @@ class SeismicSource(DataSource):
# + timerange[i].strftime("%Y%m%d_%H%M%S")
# + ".miniseed",
# )
filenames[channel] = [station,
filenames[channel] = [
station,
datayear,
"%s.D" % channel,
"4D.%s.A.%s.D." % (station, channel)
+ timerange[i].strftime("%Y%m%d_%H%M%S")
+ ".miniseed",]
+ ".miniseed",
]
# print(filenames[channel])
# Load either from store or from filename
if request['store'] is not None:
if request["store"] is not None:
# get the file relative to the store
store = request["store"]
filename = '/'.join(filenames[channel])
filename = "/".join(filenames[channel])
st = obspy.read(io.BytesIO(store[str(filename)]))
else:
if not Path(path).isdir():
......@@ -588,7 +588,13 @@ class SeismicSource(DataSource):
class MHDSLRFilenames(DataSource):
def __init__(
self, base_directory=None, store=None, method="directory", start_time=None, end_time=None, force_write_to_remote=False,
self,
base_directory=None,
store=None,
method="directory",
start_time=None,
end_time=None,
force_write_to_remote=False,
):
""" Fetches the DSLR images from the Matterhorn deployment, returns the image
filename(s) corresponding to the end and start time provided in either the
......@@ -603,7 +609,7 @@ class MHDSLRFilenames(DataSource):
"""
super().__init__(
base_directory=base_directory,
store = store,
store=store,
method=method,
start_time=start_time,
end_time=end_time,
......@@ -631,7 +637,9 @@ class MHDSLRFilenames(DataSource):
f"The {config['method']} output_format is not supported. Allowed formats are {methods}"
)
if (config["base_directory"] is None and config['store'] is None) and output_format.lower() != "web":
if (
config["base_directory"] is None and config["store"] is None
) and output_format.lower() != "web":
raise RuntimeError("Please provide a base_directory containing the images")
if config["method"].lower() == "web": # TODO: implement
......@@ -666,25 +674,24 @@ class MHDSLRFilenames(DataSource):
# first try to load it from remote via store
if config["store"] is not None:
if filename in config["store"]:
imglist_df = read_csv_with_store(config["store"],filename)
imglist_df = read_csv_with_store(config["store"], filename)
success = True
elif config['force_write_to_remote']:
elif config["force_write_to_remote"]:
# try to reload it and write to remote
imglist_df = self.image_integrity_store(config["store"])
try:
to_csv_with_store(config["store"],filename,imglist_df)
to_csv_with_store(config["store"], filename, imglist_df)
success = True
except Exception as e:
print(e)
# Otherwise we need to load the filename dataframe from disk
if not success and setting_exists("user_dir") and os.path.isdir(get_setting("user_dir")):
imglist_filename = (
os.path.join(get_setting("user_dir"), "")
+ filename
)
if (
not success
and setting_exists("user_dir")
and os.path.isdir(get_setting("user_dir"))
):
imglist_filename = os.path.join(get_setting("user_dir"), "") + filename
# If it does not exist in the temporary folder of our application
# We are going to create it
......@@ -692,14 +699,12 @@ class MHDSLRFilenames(DataSource):
# imglist_df = pd.read_parquet(
# imglist_filename
# ) # TODO: avoid too many different formats
imglist_df = pd.read_csv(
imglist_filename
)
imglist_df = pd.read_csv(imglist_filename)
else:
# we are going to load the full list => no arguments
imglist_df = self.image_integrity(config["base_directory"])
# imglist_df.to_parquet(imglist_filename)
imglist_df.to_csv(imglist_filename,index=False)
imglist_df.to_csv(imglist_filename, index=False)
elif not success:
# if there is no tmp_dir we can load the image list but
# we should warn the user that this is inefficient
......@@ -774,10 +779,8 @@ class MHDSLRFilenames(DataSource):
DataFrame --
"""
if start_time is None:
# a random year which is before permasense installation started
start_time = dt.datetime(
1900, 1, 1
)
# a random year which is before permasense installation started
start_time = dt.datetime(1900, 1, 1)
if end_time is None:
end_time = dt.datetime.utcnow()
......@@ -796,14 +799,13 @@ class MHDSLRFilenames(DataSource):
except:
# we do not care for files not matching our format
continue
if pd.isnull(dir_date):
continue
# limit the search to the explicit time range
if dir_date < tbeg_days or dir_date > tend_days:
continue
# print(file.stem)
start_time_str = pathkey.stem
......@@ -825,9 +827,7 @@ class MHDSLRFilenames(DataSource):
if start_time <= _start_time and _start_time <= end_time:
images_list.append(
{
"filename": str(
img_file.relative_to(base_directory)
),
"filename": str(img_file.relative_to(base_directory)),
"start_time": _start_time - delta_t,
"end_time": _start_time + delta_t,
}
......@@ -852,7 +852,7 @@ class MHDSLRFilenames(DataSource):
self, base_directory, start_time=None, end_time=None, delta_seconds=0
):
store = DirectoryStore(base_directory)
return self.image_integrity_store(store,start_time,end_time,delta_seconds)
return self.image_integrity_store(store, start_time, end_time, delta_seconds)
def get_image_filename(self, timestamp):
""" Checks wether an image exists for exactly the time of timestamp and returns its filename
......@@ -1086,19 +1086,25 @@ class Freezer(StuettNode):
class CsvSource(DataSource):
def __init__(self, filename=None, store=None, start_time=None, end_time=None, **kwargs):
def __init__(
self, filename=None, store=None, start_time=None, end_time=None, **kwargs
):
super().__init__(
filename=filename, store=store, start_time=start_time, end_time=end_time, kwargs=kwargs
filename=filename,
store=store,
start_time=start_time,
end_time=end_time,
kwargs=kwargs,
)
def forward(self, data=None, request=None):
# TODO: Implement properly
if request['store'] is not None:
if request["store"] is not None:
# get the file relative to the store
store = request["store"]
filename = request["filename"]
# csv = pd.read_csv(io.StringIO(str(store[filename],'utf-8')))
csv = read_csv_with_store(store,filename)
csv = read_csv_with_store(store, filename)
else:
csv = pd.read_csv(request["filename"])
csv.set_index("time", inplace=True)
......@@ -1146,7 +1152,7 @@ class BoundingBoxAnnotation(DataSource):
filename=None,
start_time=None,
end_time=None,
converters={"start_time": to_datetime,"end_time": to_datetime},
converters={"start_time": to_datetime, "end_time": to_datetime},
**kwargs,
):
super().__init__(
......@@ -1172,7 +1178,9 @@ class BoundingBoxAnnotation(DataSource):
converter = request["converters"][key]
if not callable(converter):
raise RuntimeError("Please provide a callable as column converter")
targets = targets.assign_coords({key: ("index", converter(targets[key]))})
targets = targets.assign_coords(
{key: ("index", converter(targets[key]))}
)
return targets
......@@ -1214,7 +1222,6 @@ def check_overlap(data0, data1, index_dim, dims=[]):
# second condition: data0 is after data1, all items before data1 can be ignored (sorted list data0)
cond1 = data0_start > data1_end
if cond1:
# This only holds if data1 is sorted by both start and end index
start_idx = j
......@@ -1230,7 +1237,7 @@ def check_overlap(data0, data1, index_dim, dims=[]):
d0_end = data0[i][dim].stop
d1_start = data1[j][dim].start
d1_end = data1[j][dim].stop
if ((d0_end < d1_start) or (d0_start > d1_end)):
if (d0_end < d1_start) or (d0_start > d1_end):
overlap = False
if overlap:
......@@ -1248,17 +1255,19 @@ def get_dataset_slices(ds, dims, dataset_slice=None, stride={}):
segment_end = ds.sizes[dim]
else:
segment_start = dataset_slice[dim].start
segment_end = dataset_slice[dim].stop
segment_end = dataset_slice[dim].stop
size = dims[dim]
_stride = stride.get(dim, size)
if ds[dim].dtype == 'datetime64[ns]':
if ds[dim].dtype == "datetime64[ns]":
# TODO: change when xarray #3291 is fixed
iterator = pd.date_range(segment_start, segment_end,freq=_stride).tz_localize(None)
iterator = pd.date_range(
segment_start, segment_end, freq=_stride
).tz_localize(None)
segment_end = pd.to_datetime(segment_end).tz_localize(None)
else:
iterator = range(segment_start,segment_end,_stride)
iterator = range(segment_start, segment_end, _stride)
slices = []
# TODO include hopsize/overlapping windows
......@@ -1270,19 +1279,31 @@ def get_dataset_slices(ds, dims, dataset_slice=None, stride={}):
dim_slices.append(slices)
import itertools
all_slices = []
for slices in itertools.product(*dim_slices):
selector = {key: slice for key, slice in zip(dims, slices)}
all_slices.append(selector)
return np.array(all_slices)
from torch.utils.data import Dataset
import torch
class SegmentedDataset(Dataset):
def __init__(
self, data, label, dim="time", discard_empty = True, trim=True, dataset_slice=None, batch_dims={}, pad=False, mode = 'segments'
self,
data,
label,
dim="time",
discard_empty=True,
trim=True,
dataset_slice=None,
batch_dims={},
pad=False,
mode="segments",
):
""" trim ... trim the dataset to the available labels
......@@ -1290,26 +1311,26 @@ class SegmentedDataset(Dataset):
from xarray documentation: [...] slices are treated as inclusive of both the start and stop values, unlike normal Python indexing.
"""
# load annotation source and datasource
# define an dataset index containing all indices of the datasource (e.g. timestamps or time period) which should be in this dataset
# TODO: this is inefficient for large datasets
# TODO: add dataset_slice to call of data
d = data()
d = data()
l = label()
# print(d['time'])
# print(l['time'])
if dataset_slice is None or not isinstance(dataset_slice,dict):
raise RuntimeError('No dataset_slice requested')
if dataset_slice is None or not isinstance(dataset_slice, dict):
raise RuntimeError("No dataset_slice requested")
requested_coords = dataset_slice.keys()
d = d.sortby(dim)
l = l.sortby([l["start_"+dim],l["end_"+dim]])
l = l.sortby([l["start_" + dim], l["end_" + dim]])
# restrict it to the available labels
......@@ -1319,19 +1340,21 @@ class SegmentedDataset(Dataset):
def get_label_slices(l):
# build label slices
label_coords = []
start_identifier='start_'
start_identifier = "start_"
for coord in l.coords:
# print(coord)
if coord.startswith(start_identifier):
label_coord = coord[len(start_identifier):]
if 'end_'+label_coord in l.coords:
label_coord = coord[len(start_identifier) :]
if "end_" + label_coord in l.coords:
label_coords += [label_coord]
label_slices = []
for i in range(len(l)):
selector = {}
for coord in label_coords:
selector[coord] = slice(l[i]['start_'+coord].values,l[i]['end_'+coord].values)
for coord in label_coords:
selector[coord] = slice(
l[i]["start_" + coord].values, l[i]["end_" + coord].values
)
label_slices.append(selector)
return label_coords, label_slices
......@@ -1342,11 +1365,11 @@ class SegmentedDataset(Dataset):
# print(len(l),len(label_slices))
# filter out where we do not get data for the slice
# TODO: Is there a way to not loop through the whole dataset?
# TODO: Is there a way to not loop through the whole dataset?
# and the whole label set? but still use xarray indexing
label_dict = {}
if mode == 'segments':
overlaps = check_overlap(slices,label_slices,'time',requested_coords)
if mode == "segments":
overlaps = check_overlap(slices, label_slices, "time", requested_coords)
for o in overlaps:
if discard_empty and d.sel(slices[o[0]]).size == 0:
continue
......@@ -1355,44 +1378,68 @@ class SegmentedDataset(Dataset):
j = o[1]
label = str(l[j].values)
if i not in label_dict:
label_dict[i] = {'id': i, 'indexers':slices[i],'labels':[label]}
label_dict[i] = {"id": i, "indexers": slices[i], "labels": [label]}
else:
label_dict[i] = {'indexers':label_dict[i]['indexers'], 'labels': label_dict[i]['labels']+[label]}
label_dict[i] = {
"indexers": label_dict[i]["indexers"],
"labels": label_dict[i]["labels"] + [label],
}
elif mode == 'points':
elif mode == "points":
for i in range(len(slices)):
x = d.sel(slices[i])
if x.size == 0:
continue
for j in range(len(label_slices)): # TODO: this might get really slow
for j in range(len(label_slices)): # TODO: this might get really slow
y = x.sel(label_slices[j])
if y.size > 0:
# TODO: maybe this can be done faster (and cleaner)
label = str(l[j].values)
if i not in label_dict:
label_dict[i] = {'id': i, 'indexers':slices[i],'labels':[label]}
label_dict[i] = {
"id": i,
"indexers": slices[i],
"labels": [label],
}
else:
label_dict[i] = {'indexers':label_dict[i]['indexers'], 'labels': label_dict[i]['labels']+[label]}
label_dict[i] = {
"indexers": label_dict[i]["indexers"],
"labels": label_dict[i]["labels"] + [label],
}
label_list = [label_dict[key] for key in label_dict]
print(label_list)
def check():
from plotly.subplots import make_subplots
from plotly.subplots import make_subplots
import plotly.graph_objects as go
fig = go.Figure(layout=dict(title=dict(text="A Bar Chart"),xaxis={'type':"date"},xaxis_range=[pd.to_datetime('2017-08-01'),
pd.to_datetime('2017-08-06')]))
fig = go.Figure(
layout=dict(
title=dict(text="A Bar Chart"),
xaxis={"type": "date"},
xaxis_range=[
pd.to_datetime("2017-08-01"),
pd.to_datetime("2017-08-06"),
],
)
)
# fig = make_subplots(rows=3, cols=1, shared_xaxes=True, vertical_spacing = 0.)
# fig.update_xaxes(range=[pd.to_datetime('2017-08-01'), pd.to_datetime('2017-08-03')])
# fig.update_yaxes(range=[0, 2])
# fig.update_yaxes(range=[0, 2])
for i, sl in enumerate([slices,label_slices,label_list]):
for i, sl in enumerate([slices, label_slices, label_list]):
for item in sl:
label = "None"
if 'indexers' in item:
label = item['labels']
item = item['indexers']
points = np.array([[pd.to_datetime(item['time'].start), 1],[pd.to_datetime(item['time'].stop) , 0]])
if "indexers" in item:
label = item["labels"]
item = item["indexers"]
points = np.array(
[
[pd.to_datetime(item["time"].start), 1],
[pd.to_datetime(item["time"].stop), 0],
]
)
# fig.add_shape(
# # Line reference to the axes
# go.layout.Shape(
......@@ -1412,21 +1459,32 @@ class SegmentedDataset(Dataset):
# # width=3,
# # ),
# ))
fig.add_trace(go.Scatter(x=[pd.to_datetime(item['time'].start),pd.to_datetime(item['time'].stop),pd.to_datetime(item['time'].stop),pd.to_datetime(item['time'].start)], y=[0,0,1,1],
fill='toself', fillcolor='darkviolet',
# marker={'size':0},
mode='lines',
hoveron = 'points+fills', # select where hover is active
line_color='darkviolet',
showlegend=False,
# line_width=0,
opacity=0.5,
text=str(label),
hoverinfo = 'text+x+y'))
fig.add_trace(
go.Scatter(
x=[
pd.to_datetime(item["time"].start),
pd.to_datetime(item["time"].stop),
pd.to_datetime(item["time"].stop),
pd.to_datetime(item["time"].start),
],
y=[0, 0, 1, 1],
fill="toself",
fillcolor="darkviolet",
# marker={'size':0},
mode="lines",
hoveron="points+fills", # select where hover is active
line_color="darkviolet",
showlegend=False,
# line_width=0,
opacity=0.5,
text=str(label),
hoverinfo="text+x+y",
)
)
# fig.add_trace(go.Scatter(x=points[:,0], y=points[:,1],
# fill=None,
# mode='lines',
# line_color=None, line_width=0,
# line_color=None, line_width=0,
# name="hv", line_shape='hv',
# showlegend=False,
# hovertext=str(label),
......@@ -1441,7 +1499,6 @@ class SegmentedDataset(Dataset):
# fill='tonexty', # fill area between trace0 and trace1
# mode='lines', line_width=0))