To receive notifications about scheduled maintenance, please subscribe to the mailing-list gitlab-operations@sympa.ethz.ch. You can subscribe to the mailing-list at https://sympa.ethz.ch

Commit 9c624c2d authored by matthmey's avatar matthmey
Browse files

update to segmented datasets

parent 7cb63d92
......@@ -36,6 +36,7 @@ lttb = "^0.2.0"
pyarrow = "^0.15.1"
torch = "^1.3.1"
torchvision = "^0.4.2"
tqdm = "^4.39.0"
# Optional dependencies (extras)
......
......@@ -59,3 +59,9 @@ def read_csv_with_store(store, filename):
StreamReader = codecs.getreader("utf-8")
string_buffer = StreamReader(bytes_buffer)
return pd.read_csv(string_buffer)
def indexers_to_request(indexers):
request = {"start_" + k: v.start for k, v in indexers.items()}
request.update({"end_" + k: v.stop for k, v in indexers.items()})
return request
......@@ -23,6 +23,7 @@ SOFTWARE."""
from ..global_config import get_setting, setting_exists, set_setting
from ..core.graph import StuettNode, configuration
from ..convenience import read_csv_with_store, to_csv_with_store, DirectoryStore
from ..convenience import indexers_to_request as i2r
import os
......@@ -35,6 +36,9 @@ from obsplus import obspy_to_array
from copy import deepcopy
import io
from tqdm import tqdm
import zarr
import xarray as xr
from PIL import Image
......@@ -342,9 +346,10 @@ class SeismicSource(DataSource):
starttime = x.starttime.values.reshape((-1,))[0]
for s in x.starttime.values.reshape((-1,)):
if s != starttime:
raise RuntimeError(
"Please make sure that starttime of each seimsic channel is equal"
)
warnings.warn("Not all starttimes of each seismic channel is equal")
# raise RuntimeError(
# "Please make sure that starttime of each seismic channel is equal"
# )
# change time coords from relative to absolute time
starttime = obspy.UTCDateTime(starttime).datetime
......@@ -358,6 +363,7 @@ class SeismicSource(DataSource):
) # TODO: change when xarray #3291 is fixed
del x.attrs["stats"]
# x.rename({'seed_id':'channels'}) #TODO: rename seed_id to channels
return x
def process_seismic_data(
......@@ -615,7 +621,7 @@ class MHDSLRFilenames(DataSource):
start_time=start_time,
end_time=end_time,
force_write_to_remote=force_write_to_remote,
as_pandas=as_pandas
as_pandas=as_pandas,
)
def forward(self, data=None, request=None):
......@@ -730,16 +736,14 @@ class MHDSLRFilenames(DataSource):
start_time = imglist_df.index[0]
loc = imglist_df.index.get_loc(start_time, method="nearest")
output_df = imglist_df.iloc[loc:loc+1
]
output_df = imglist_df.iloc[loc : loc + 1]
else:
# if end_time.tzinfo is None:
# end_time = end_time.replace(tzinfo=timezone.utc)
if start_time > imglist_df.index[-1] or end_time < imglist_df.index[0]:
# return empty dataframe
output_df = imglist_df[0:0]
else:
if start_time < imglist_df.index[0]:
start_time = imglist_df.index[0]
if end_time > imglist_df.index[-1]:
......@@ -752,12 +756,12 @@ class MHDSLRFilenames(DataSource):
+ 1
]
if not request['as_pandas']:
output_df = output_df[['filename']] # TODO: do not get rid of end_time
output_df.index.rename('time',inplace=True)
if not request["as_pandas"]:
output_df = output_df[["filename"]] # TODO: do not get rid of end_time
output_df.index.rename("time", inplace=True)
# output = output_df.to_xarray(dims=["time"])
output = xr.Dataset.from_dataframe(output_df).to_array()
print(output)
# print(output)
# output = xr.DataArray(output_df['filename'], dims=["time"])
else:
output = output_df
......@@ -991,13 +995,17 @@ class MHDSLRImages(MHDSLRFilenames):
start_time=None,
end_time=None,
):
if store is None and base_directory is not None:
store = DirectoryStore(base_directory)
super().__init__(
base_directory=base_directory,
base_directory=None,
store=store,
method=method,
start_time=start_time,
end_time=end_time,
)
self.config["output_format"] = output_format
def forward(self, data=None, request=None):
......@@ -1017,14 +1025,22 @@ class MHDSLRImages(MHDSLRFilenames):
images = []
times = []
for timestamp, element in filenames.iterrows():
filename = Path(self.config["base_directory"]).joinpath(element.filename)
img = Image.open(filename)
key = element.filename
img = Image.open(io.BytesIO(self.config["store"][key]))
img = np.array(img.convert("RGB"))
images.append(np.array(img))
times.append(timestamp)
if images:
images = np.array(images)
else:
images = np.empty((0, 0, 0, 0))
data = xr.DataArray(
images, coords={"time": times}, dims=["time", "x", "y", "c"], name="Image"
images,
coords={"time": times},
dims=["time", "x", "y", "channels"],
name="Image",
)
data.attrs["format"] = "jpg"
......@@ -1034,8 +1050,9 @@ class MHDSLRImages(MHDSLRFilenames):
images = []
times = []
for timestamp, element in filenames.iterrows():
filename = Path(self.config["base_directory"]).joinpath(element.filename)
img = Image.open(filename)
key = element.filename
img = Image.open(io.BytesIO(self.config["store"][key]))
img = np.array(img.convert("RGB"))
img_base64 = base64.b64encode(img.tobytes())
images.append(img_base64)
times.append(timestamp)
......@@ -1170,15 +1187,12 @@ class BoundingBoxAnnotation(DataSource):
**kwargs,
):
super().__init__(
filename=filename,
store=store,
converters=converters,
kwargs=kwargs,
filename=filename, store=store, converters=converters, kwargs=kwargs,
)
def forward(self, data=None, request=None):
if request['store'] is not None:
csv = read_csv_with_store(request['store'],request['filename'])
if request["store"] is not None:
csv = read_csv_with_store(request["store"], request["filename"])
else:
csv = pd.read_csv(request["filename"])
......@@ -1262,21 +1276,21 @@ def check_overlap(data0, data1, index_dim, dims=[]):
return overlap_indices
def get_dataset_slices(ds, dims, dataset_slice=None, stride={}):
def get_dataset_slices(dims, dataset_slice, stride={}):
# thanks to xbatcher: https://github.com/rabernat/xbatcher/
dim_slices = []
for dim in dims:
if dataset_slice is None:
segment_start = 0
segment_end = ds.sizes[dim]
else:
# if dataset_slice is None:
# segment_start = 0
# segment_end = ds.sizes[dim]
# else:
segment_start = dataset_slice[dim].start
segment_end = dataset_slice[dim].stop
size = dims[dim]
_stride = stride.get(dim, size)
if ds[dim].dtype == "datetime64[ns]":
if isinstance(dims[dim], pd.Timedelta) or isinstance(dims[dim], dt.timedelta):
# TODO: change when xarray #3291 is fixed
iterator = pd.date_range(
segment_start, segment_end, freq=_stride
......@@ -1310,112 +1324,109 @@ import torch
# from numba import jit
# @jit(nopython=True)
def get_label_slices(l):
def annotations_to_slices(annotations):
# build label slices
label_coords = []
start_identifier = "start_"
for coord in l.coords:
for coord in annotations.coords:
# print(coord)
if coord.startswith(start_identifier):
label_coord = coord[len(start_identifier) :]
if "end_" + label_coord in l.coords:
if "end_" + label_coord in annotations.coords:
label_coords += [label_coord]
label_slices = []
for i in range(len(l)): # TODO: This does not scale! It is really really slow
print(i,len(l))
for i in range(len(annotations)): # TODO: see if we can still optimzied this
selector = {}
for coord in label_coords:
selector[coord] = slice(
l[i]["start_" + coord].values, l[i]["end_" + coord].values
annotations["start_" + coord].values[i],
annotations["end_" + coord].values[i],
)
label_slices.append(selector)
return label_coords, label_slices
class SegmentedDataset(Dataset):
def __init__(
self,
data,
data, # TODO: currently it should be named data source
label,
dim="time",
index_dim="time",
discard_empty=True,
trim=True,
dataset_slice=None,
batch_dims={},
pad=False,
mode="segments",
):
""" trim ... trim the dataset to the available labels
self.data = data
self.label = label
self.index_dim = index_dim
self.discard_empty = discard_empty
self.dataset_slice = dataset_slice
self.batch_dims = batch_dims
self.__mode = mode
def compute_label_list(self):
""" discard_empty ... trim the dataset to the available labels
dataset_slice: which part of the dataset to use
from xarray documentation: [...] slices are treated as inclusive of both the start and stop values, unlike normal Python indexing.
"""
if self.dataset_slice is None or not isinstance(self.dataset_slice, dict):
raise RuntimeError("No dataset_slice requested")
# load annotation source and datasource
# define an dataset index containing all indices of the datasource (e.g. timestamps or time period) which should be in this dataset
# TODO: this is inefficient for large datasets
# TODO: add dataset_slice to call of data
if isinstance(data,StuettNode):
request = { "start_"+k:v.start for k, v in dataset_slice.items()}
request.update({ "end_"+k:v.stop for k, v in dataset_slice.items()})
d = data(request)
if isinstance(self.label, StuettNode):
l = self.label()
else:
d = data
if isinstance(label,StuettNode):
l = label()
else:
l = label
# print(d['time'])
# print(l['time'])
if dataset_slice is None or not isinstance(dataset_slice, dict):
raise RuntimeError("No dataset_slice requested")
requested_coords = dataset_slice.keys()
l = self.label
l = l.sortby([l["start_" + self.index_dim], l["end_" + self.index_dim]])
d = d.sortby(dim)
l = l.sortby([l["start_" + dim], l["end_" + dim]])
requested_coords = self.dataset_slice.keys()
# restrict it to the available labels
slices = get_dataset_slices(self.batch_dims, self.dataset_slice)
label_coords, label_slices = annotations_to_slices(l)
# indices = check_overlap(d,l)
slices = get_dataset_slices(d, batch_dims, dataset_slice=dataset_slice)
label_coords, label_slices = get_label_slices(l)
# print(label_slices, slices)
# print(len(l),len(label_slices))
# filter out where we do not get data for the slice
# TODO: Is there a way to not loop through the whole dataset?
# and the whole label set? but still use xarray indexing
label_dict = {}
if mode == "segments":
overlaps = check_overlap(slices, label_slices, "time", requested_coords)
for o in overlaps:
if discard_empty and d.sel(slices[o[0]]).size == 0:
self.classes = []
if self.__mode == "segments":
overlaps = check_overlap(
slices, label_slices, self.index_dim, requested_coords
)
for o in tqdm(overlaps):
# if discard_empty and d.sel(slices[o[0]]).size == 0:
if self.discard_empty:
# we need to load every single piece to check if it is empty
# TODO: loop through dims in batch_dim and check if they are correct
if self.get_data(slices[o[0]]).size == 0:
continue
# TODO: maybe this can be done faster (and cleaner)
i = o[0]
j = o[1]
label = str(l[j].values)
if label not in self.classes:
self.classes.append(label)
if i not in label_dict:
label_dict[i] = {"id": i, "indexers": slices[i], "labels": [label]}
else:
label_dict[i] = {"indexers": slices[i], "labels": [label]}
elif label not in label_dict[i]["labels"]:
label_dict[i] = {
"indexers": label_dict[i]["indexers"],
"labels": label_dict[i]["labels"] + [label],
}
elif mode == "points":
elif self.__mode == "points":
for i in range(len(slices)):
x = d.sel(slices[i])
# x = d.sel(slices[i])
x = self.get_data(slices[i])
if x.size == 0:
continue
for j in range(len(label_slices)): # TODO: this might get really slow
......@@ -1423,21 +1434,26 @@ class SegmentedDataset(Dataset):
if y.size > 0:
# TODO: maybe this can be done faster (and cleaner)
label = str(l[j].values)
if label not in self.classes:
self.classes.append(label)
if i not in label_dict:
label_dict[i] = {
"id": i,
"indexers": slices[i],
"labels": [label],
}
else:
elif label not in label_dict[i]["labels"]:
label_dict[i] = {
"indexers": label_dict[i]["indexers"],
"labels": label_dict[i]["labels"] + [label],
}
label_list = [label_dict[key] for key in label_dict]
label_list = pd.DataFrame([label_dict[key] for key in label_dict])
self.classes = {class_name: i for i, class_name in enumerate(self.classes)}
# print(label_list)
return label_list
def check():
from plotly.subplots import make_subplots
import plotly.graph_objects as go
......@@ -1529,13 +1545,20 @@ class SegmentedDataset(Dataset):
fig.show()
self.label_list = label_list
self.data = data
# TODO: verify that the segments have the same length?
# TODO: get statistics to what was left out
def get_data(self, indexers):
# TODO: works only for datasources, not for chains
if isinstance(self.data, StuettNode):
request = i2r(indexers)
d = self.data(request)
return d.sel(indexers)
else: # expecting a xarray
d = self.data
return d.sel(indexers)
def __len__(self):
raise NotImplementedError()
return len(self.label_list)
......
......@@ -194,8 +194,6 @@ def test_mhdslrimage():
config["output_format"] = "base64"
data = node(config)
# TODO: assert data
from PIL import Image
img = Image.open(base_dir.joinpath("2017-08-06", "20170806_095212.JPG"))
......@@ -203,6 +201,18 @@ def test_mhdslrimage():
assert data[0].values == img_base64
# Check a period where there is no image
start_time = dt.datetime(2017, 8, 6, 9, 55, 12, tzinfo=dt.timezone.utc)
end_time = dt.datetime(2017, 8, 6, 10, 10, 10, tzinfo=dt.timezone.utc)
config = {
"start_time": start_time,
"end_time": end_time,
}
data = node(config)
# print(data)
assert data.shape == (0, 0, 0, 0)
# test_mhdslrimage()
......@@ -230,7 +240,7 @@ def test_csv():
# TODO: test with start and end time
test_csv()
# test_csv()
def test_annotations():
......@@ -272,7 +282,7 @@ def test_datasets():
batch_dims={"time": pd.to_timedelta(24, "m")},
)
x = dataset[0]
# x = dataset[0]
test_datasets()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment