To receive notifications about scheduled maintenance, please subscribe to the mailing-list gitlab-operations@sympa.ethz.ch. You can subscribe to the mailing-list at https://sympa.ethz.ch

Commit 9c624c2d authored by matthmey's avatar matthmey
Browse files

update to segmented datasets

parent 7cb63d92
...@@ -36,6 +36,7 @@ lttb = "^0.2.0" ...@@ -36,6 +36,7 @@ lttb = "^0.2.0"
pyarrow = "^0.15.1" pyarrow = "^0.15.1"
torch = "^1.3.1" torch = "^1.3.1"
torchvision = "^0.4.2" torchvision = "^0.4.2"
tqdm = "^4.39.0"
# Optional dependencies (extras) # Optional dependencies (extras)
......
...@@ -59,3 +59,9 @@ def read_csv_with_store(store, filename): ...@@ -59,3 +59,9 @@ def read_csv_with_store(store, filename):
StreamReader = codecs.getreader("utf-8") StreamReader = codecs.getreader("utf-8")
string_buffer = StreamReader(bytes_buffer) string_buffer = StreamReader(bytes_buffer)
return pd.read_csv(string_buffer) return pd.read_csv(string_buffer)
def indexers_to_request(indexers):
request = {"start_" + k: v.start for k, v in indexers.items()}
request.update({"end_" + k: v.stop for k, v in indexers.items()})
return request
...@@ -23,6 +23,7 @@ SOFTWARE.""" ...@@ -23,6 +23,7 @@ SOFTWARE."""
from ..global_config import get_setting, setting_exists, set_setting from ..global_config import get_setting, setting_exists, set_setting
from ..core.graph import StuettNode, configuration from ..core.graph import StuettNode, configuration
from ..convenience import read_csv_with_store, to_csv_with_store, DirectoryStore from ..convenience import read_csv_with_store, to_csv_with_store, DirectoryStore
from ..convenience import indexers_to_request as i2r
import os import os
...@@ -35,6 +36,9 @@ from obsplus import obspy_to_array ...@@ -35,6 +36,9 @@ from obsplus import obspy_to_array
from copy import deepcopy from copy import deepcopy
import io import io
from tqdm import tqdm
import zarr import zarr
import xarray as xr import xarray as xr
from PIL import Image from PIL import Image
...@@ -342,9 +346,10 @@ class SeismicSource(DataSource): ...@@ -342,9 +346,10 @@ class SeismicSource(DataSource):
starttime = x.starttime.values.reshape((-1,))[0] starttime = x.starttime.values.reshape((-1,))[0]
for s in x.starttime.values.reshape((-1,)): for s in x.starttime.values.reshape((-1,)):
if s != starttime: if s != starttime:
raise RuntimeError( warnings.warn("Not all starttimes of each seismic channel is equal")
"Please make sure that starttime of each seimsic channel is equal" # raise RuntimeError(
) # "Please make sure that starttime of each seismic channel is equal"
# )
# change time coords from relative to absolute time # change time coords from relative to absolute time
starttime = obspy.UTCDateTime(starttime).datetime starttime = obspy.UTCDateTime(starttime).datetime
...@@ -358,6 +363,7 @@ class SeismicSource(DataSource): ...@@ -358,6 +363,7 @@ class SeismicSource(DataSource):
) # TODO: change when xarray #3291 is fixed ) # TODO: change when xarray #3291 is fixed
del x.attrs["stats"] del x.attrs["stats"]
# x.rename({'seed_id':'channels'}) #TODO: rename seed_id to channels
return x return x
def process_seismic_data( def process_seismic_data(
...@@ -615,7 +621,7 @@ class MHDSLRFilenames(DataSource): ...@@ -615,7 +621,7 @@ class MHDSLRFilenames(DataSource):
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
force_write_to_remote=force_write_to_remote, force_write_to_remote=force_write_to_remote,
as_pandas=as_pandas as_pandas=as_pandas,
) )
def forward(self, data=None, request=None): def forward(self, data=None, request=None):
...@@ -730,16 +736,14 @@ class MHDSLRFilenames(DataSource): ...@@ -730,16 +736,14 @@ class MHDSLRFilenames(DataSource):
start_time = imglist_df.index[0] start_time = imglist_df.index[0]
loc = imglist_df.index.get_loc(start_time, method="nearest") loc = imglist_df.index.get_loc(start_time, method="nearest")
output_df = imglist_df.iloc[loc:loc+1 output_df = imglist_df.iloc[loc : loc + 1]
]
else: else:
# if end_time.tzinfo is None: # if end_time.tzinfo is None:
# end_time = end_time.replace(tzinfo=timezone.utc) # end_time = end_time.replace(tzinfo=timezone.utc)
if start_time > imglist_df.index[-1] or end_time < imglist_df.index[0]: if start_time > imglist_df.index[-1] or end_time < imglist_df.index[0]:
# return empty dataframe # return empty dataframe
output_df = imglist_df[0:0] output_df = imglist_df[0:0]
else:
if start_time < imglist_df.index[0]: if start_time < imglist_df.index[0]:
start_time = imglist_df.index[0] start_time = imglist_df.index[0]
if end_time > imglist_df.index[-1]: if end_time > imglist_df.index[-1]:
...@@ -752,12 +756,12 @@ class MHDSLRFilenames(DataSource): ...@@ -752,12 +756,12 @@ class MHDSLRFilenames(DataSource):
+ 1 + 1
] ]
if not request['as_pandas']: if not request["as_pandas"]:
output_df = output_df[['filename']] # TODO: do not get rid of end_time output_df = output_df[["filename"]] # TODO: do not get rid of end_time
output_df.index.rename('time',inplace=True) output_df.index.rename("time", inplace=True)
# output = output_df.to_xarray(dims=["time"]) # output = output_df.to_xarray(dims=["time"])
output = xr.Dataset.from_dataframe(output_df).to_array() output = xr.Dataset.from_dataframe(output_df).to_array()
print(output) # print(output)
# output = xr.DataArray(output_df['filename'], dims=["time"]) # output = xr.DataArray(output_df['filename'], dims=["time"])
else: else:
output = output_df output = output_df
...@@ -991,13 +995,17 @@ class MHDSLRImages(MHDSLRFilenames): ...@@ -991,13 +995,17 @@ class MHDSLRImages(MHDSLRFilenames):
start_time=None, start_time=None,
end_time=None, end_time=None,
): ):
if store is None and base_directory is not None:
store = DirectoryStore(base_directory)
super().__init__( super().__init__(
base_directory=base_directory, base_directory=None,
store=store, store=store,
method=method, method=method,
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
) )
self.config["output_format"] = output_format self.config["output_format"] = output_format
def forward(self, data=None, request=None): def forward(self, data=None, request=None):
...@@ -1017,14 +1025,22 @@ class MHDSLRImages(MHDSLRFilenames): ...@@ -1017,14 +1025,22 @@ class MHDSLRImages(MHDSLRFilenames):
images = [] images = []
times = [] times = []
for timestamp, element in filenames.iterrows(): for timestamp, element in filenames.iterrows():
filename = Path(self.config["base_directory"]).joinpath(element.filename) key = element.filename
img = Image.open(filename) img = Image.open(io.BytesIO(self.config["store"][key]))
img = np.array(img.convert("RGB"))
images.append(np.array(img)) images.append(np.array(img))
times.append(timestamp) times.append(timestamp)
if images:
images = np.array(images) images = np.array(images)
else:
images = np.empty((0, 0, 0, 0))
data = xr.DataArray( data = xr.DataArray(
images, coords={"time": times}, dims=["time", "x", "y", "c"], name="Image" images,
coords={"time": times},
dims=["time", "x", "y", "channels"],
name="Image",
) )
data.attrs["format"] = "jpg" data.attrs["format"] = "jpg"
...@@ -1034,8 +1050,9 @@ class MHDSLRImages(MHDSLRFilenames): ...@@ -1034,8 +1050,9 @@ class MHDSLRImages(MHDSLRFilenames):
images = [] images = []
times = [] times = []
for timestamp, element in filenames.iterrows(): for timestamp, element in filenames.iterrows():
filename = Path(self.config["base_directory"]).joinpath(element.filename) key = element.filename
img = Image.open(filename) img = Image.open(io.BytesIO(self.config["store"][key]))
img = np.array(img.convert("RGB"))
img_base64 = base64.b64encode(img.tobytes()) img_base64 = base64.b64encode(img.tobytes())
images.append(img_base64) images.append(img_base64)
times.append(timestamp) times.append(timestamp)
...@@ -1170,15 +1187,12 @@ class BoundingBoxAnnotation(DataSource): ...@@ -1170,15 +1187,12 @@ class BoundingBoxAnnotation(DataSource):
**kwargs, **kwargs,
): ):
super().__init__( super().__init__(
filename=filename, filename=filename, store=store, converters=converters, kwargs=kwargs,
store=store,
converters=converters,
kwargs=kwargs,
) )
def forward(self, data=None, request=None): def forward(self, data=None, request=None):
if request['store'] is not None: if request["store"] is not None:
csv = read_csv_with_store(request['store'],request['filename']) csv = read_csv_with_store(request["store"], request["filename"])
else: else:
csv = pd.read_csv(request["filename"]) csv = pd.read_csv(request["filename"])
...@@ -1262,21 +1276,21 @@ def check_overlap(data0, data1, index_dim, dims=[]): ...@@ -1262,21 +1276,21 @@ def check_overlap(data0, data1, index_dim, dims=[]):
return overlap_indices return overlap_indices
def get_dataset_slices(ds, dims, dataset_slice=None, stride={}): def get_dataset_slices(dims, dataset_slice, stride={}):
# thanks to xbatcher: https://github.com/rabernat/xbatcher/ # thanks to xbatcher: https://github.com/rabernat/xbatcher/
dim_slices = [] dim_slices = []
for dim in dims: for dim in dims:
if dataset_slice is None: # if dataset_slice is None:
segment_start = 0 # segment_start = 0
segment_end = ds.sizes[dim] # segment_end = ds.sizes[dim]
else: # else:
segment_start = dataset_slice[dim].start segment_start = dataset_slice[dim].start
segment_end = dataset_slice[dim].stop segment_end = dataset_slice[dim].stop
size = dims[dim] size = dims[dim]
_stride = stride.get(dim, size) _stride = stride.get(dim, size)
if ds[dim].dtype == "datetime64[ns]": if isinstance(dims[dim], pd.Timedelta) or isinstance(dims[dim], dt.timedelta):
# TODO: change when xarray #3291 is fixed # TODO: change when xarray #3291 is fixed
iterator = pd.date_range( iterator = pd.date_range(
segment_start, segment_end, freq=_stride segment_start, segment_end, freq=_stride
...@@ -1310,112 +1324,109 @@ import torch ...@@ -1310,112 +1324,109 @@ import torch
# from numba import jit # from numba import jit
# @jit(nopython=True) # @jit(nopython=True)
def get_label_slices(l): def annotations_to_slices(annotations):
# build label slices # build label slices
label_coords = [] label_coords = []
start_identifier = "start_" start_identifier = "start_"
for coord in l.coords: for coord in annotations.coords:
# print(coord) # print(coord)
if coord.startswith(start_identifier): if coord.startswith(start_identifier):
label_coord = coord[len(start_identifier) :] label_coord = coord[len(start_identifier) :]
if "end_" + label_coord in l.coords: if "end_" + label_coord in annotations.coords:
label_coords += [label_coord] label_coords += [label_coord]
label_slices = [] label_slices = []
for i in range(len(l)): # TODO: This does not scale! It is really really slow for i in range(len(annotations)): # TODO: see if we can still optimzied this
print(i,len(l))
selector = {} selector = {}
for coord in label_coords: for coord in label_coords:
selector[coord] = slice( selector[coord] = slice(
l[i]["start_" + coord].values, l[i]["end_" + coord].values annotations["start_" + coord].values[i],
annotations["end_" + coord].values[i],
) )
label_slices.append(selector) label_slices.append(selector)
return label_coords, label_slices return label_coords, label_slices
class SegmentedDataset(Dataset): class SegmentedDataset(Dataset):
def __init__( def __init__(
self, self,
data, data, # TODO: currently it should be named data source
label, label,
dim="time", index_dim="time",
discard_empty=True, discard_empty=True,
trim=True,
dataset_slice=None, dataset_slice=None,
batch_dims={}, batch_dims={},
pad=False,
mode="segments", mode="segments",
): ):
self.data = data
""" trim ... trim the dataset to the available labels self.label = label
self.index_dim = index_dim
self.discard_empty = discard_empty
self.dataset_slice = dataset_slice
self.batch_dims = batch_dims
self.__mode = mode
def compute_label_list(self):
""" discard_empty ... trim the dataset to the available labels
dataset_slice: which part of the dataset to use dataset_slice: which part of the dataset to use
from xarray documentation: [...] slices are treated as inclusive of both the start and stop values, unlike normal Python indexing. from xarray documentation: [...] slices are treated as inclusive of both the start and stop values, unlike normal Python indexing.
""" """
if self.dataset_slice is None or not isinstance(self.dataset_slice, dict):
raise RuntimeError("No dataset_slice requested")
# load annotation source and datasource # load annotation source and datasource
# define an dataset index containing all indices of the datasource (e.g. timestamps or time period) which should be in this dataset # define an dataset index containing all indices of the datasource (e.g. timestamps or time period) which should be in this dataset
# TODO: this is inefficient for large datasets # TODO: this is inefficient for large datasets
# TODO: add dataset_slice to call of data
if isinstance(data,StuettNode): if isinstance(self.label, StuettNode):
request = { "start_"+k:v.start for k, v in dataset_slice.items()} l = self.label()
request.update({ "end_"+k:v.stop for k, v in dataset_slice.items()})
d = data(request)
else: else:
d = data l = self.label
l = l.sortby([l["start_" + self.index_dim], l["end_" + self.index_dim]])
if isinstance(label,StuettNode):
l = label()
else:
l = label
# print(d['time'])
# print(l['time'])
if dataset_slice is None or not isinstance(dataset_slice, dict):
raise RuntimeError("No dataset_slice requested")
requested_coords = dataset_slice.keys()
d = d.sortby(dim) requested_coords = self.dataset_slice.keys()
l = l.sortby([l["start_" + dim], l["end_" + dim]])
# restrict it to the available labels slices = get_dataset_slices(self.batch_dims, self.dataset_slice)
label_coords, label_slices = annotations_to_slices(l)
# indices = check_overlap(d,l)
slices = get_dataset_slices(d, batch_dims, dataset_slice=dataset_slice)
label_coords, label_slices = get_label_slices(l)
# print(label_slices, slices)
# print(len(l),len(label_slices))
# filter out where we do not get data for the slice # filter out where we do not get data for the slice
# TODO: Is there a way to not loop through the whole dataset? # TODO: Is there a way to not loop through the whole dataset?
# and the whole label set? but still use xarray indexing # and the whole label set? but still use xarray indexing
label_dict = {} label_dict = {}
if mode == "segments": self.classes = []
overlaps = check_overlap(slices, label_slices, "time", requested_coords) if self.__mode == "segments":
for o in overlaps: overlaps = check_overlap(
if discard_empty and d.sel(slices[o[0]]).size == 0: slices, label_slices, self.index_dim, requested_coords
)
for o in tqdm(overlaps):
# if discard_empty and d.sel(slices[o[0]]).size == 0:
if self.discard_empty:
# we need to load every single piece to check if it is empty
# TODO: loop through dims in batch_dim and check if they are correct
if self.get_data(slices[o[0]]).size == 0:
continue continue
# TODO: maybe this can be done faster (and cleaner) # TODO: maybe this can be done faster (and cleaner)
i = o[0] i = o[0]
j = o[1] j = o[1]
label = str(l[j].values) label = str(l[j].values)
if label not in self.classes:
self.classes.append(label)
if i not in label_dict: if i not in label_dict:
label_dict[i] = {"id": i, "indexers": slices[i], "labels": [label]} label_dict[i] = {"indexers": slices[i], "labels": [label]}
else: elif label not in label_dict[i]["labels"]:
label_dict[i] = { label_dict[i] = {
"indexers": label_dict[i]["indexers"], "indexers": label_dict[i]["indexers"],
"labels": label_dict[i]["labels"] + [label], "labels": label_dict[i]["labels"] + [label],
} }
elif mode == "points": elif self.__mode == "points":
for i in range(len(slices)): for i in range(len(slices)):
x = d.sel(slices[i]) # x = d.sel(slices[i])
x = self.get_data(slices[i])
if x.size == 0: if x.size == 0:
continue continue
for j in range(len(label_slices)): # TODO: this might get really slow for j in range(len(label_slices)): # TODO: this might get really slow
...@@ -1423,21 +1434,26 @@ class SegmentedDataset(Dataset): ...@@ -1423,21 +1434,26 @@ class SegmentedDataset(Dataset):
if y.size > 0: if y.size > 0:
# TODO: maybe this can be done faster (and cleaner) # TODO: maybe this can be done faster (and cleaner)
label = str(l[j].values) label = str(l[j].values)
if label not in self.classes:
self.classes.append(label)
if i not in label_dict: if i not in label_dict:
label_dict[i] = { label_dict[i] = {
"id": i,
"indexers": slices[i], "indexers": slices[i],
"labels": [label], "labels": [label],
} }
else: elif label not in label_dict[i]["labels"]:
label_dict[i] = { label_dict[i] = {
"indexers": label_dict[i]["indexers"], "indexers": label_dict[i]["indexers"],
"labels": label_dict[i]["labels"] + [label], "labels": label_dict[i]["labels"] + [label],
} }
label_list = [label_dict[key] for key in label_dict] label_list = pd.DataFrame([label_dict[key] for key in label_dict])
self.classes = {class_name: i for i, class_name in enumerate(self.classes)}
# print(label_list) # print(label_list)
return label_list
def check(): def check():
from plotly.subplots import make_subplots from plotly.subplots import make_subplots
import plotly.graph_objects as go import plotly.graph_objects as go
...@@ -1529,13 +1545,20 @@ class SegmentedDataset(Dataset): ...@@ -1529,13 +1545,20 @@ class SegmentedDataset(Dataset):
fig.show() fig.show()
self.label_list = label_list
self.data = data
# TODO: verify that the segments have the same length? # TODO: verify that the segments have the same length?
# TODO: get statistics to what was left out # TODO: get statistics to what was left out
def get_data(self, indexers):
# TODO: works only for datasources, not for chains
if isinstance(self.data, StuettNode):
request = i2r(indexers)
d = self.data(request)
return d.sel(indexers)
else: # expecting a xarray
d = self.data
return d.sel(indexers)
def __len__(self): def __len__(self):
raise NotImplementedError() raise NotImplementedError()
return len(self.label_list) return len(self.label_list)
......
...@@ -194,8 +194,6 @@ def test_mhdslrimage(): ...@@ -194,8 +194,6 @@ def test_mhdslrimage():
config["output_format"] = "base64" config["output_format"] = "base64"
data = node(config) data = node(config)
# TODO: assert data
from PIL import Image from PIL import Image
img = Image.open(base_dir.joinpath("2017-08-06", "20170806_095212.JPG")) img = Image.open(base_dir.joinpath("2017-08-06", "20170806_095212.JPG"))
...@@ -203,6 +201,18 @@ def test_mhdslrimage(): ...@@ -203,6 +201,18 @@ def test_mhdslrimage():
assert data[0].values == img_base64 assert data[0].values == img_base64
# Check a period where there is no image
start_time = dt.datetime(2017, 8, 6, 9, 55, 12, tzinfo=dt.timezone.utc)
end_time = dt.datetime(2017, 8, 6, 10, 10, 10, tzinfo=dt.timezone.utc)
config = {
"start_time": start_time,
"end_time": end_time,
}
data = node(config)
# print(data)
assert data.shape == (0, 0, 0, 0)
# test_mhdslrimage() <