To receive notifications about scheduled maintenance, please subscribe to the mailing-list gitlab-operations@sympa.ethz.ch. You can subscribe to the mailing-list at https://sympa.ethz.ch

Commit 7cb63d92 authored by matthmey's avatar matthmey
Browse files

work on datasets

parent b22376df
This diff is collapsed.
......@@ -27,7 +27,7 @@ toolz = "^0.10.0"
obspy = "^1.1.1"
numpy = "1.16.5"
appdirs = "^1.4.3"
obsplus = {git = "https://github.com/niosh-mining/obsplus" }
obsplus = { git = "https://github.com/niosh-mining/obsplus" }
zarr = "^2.3.2"
xarray = { git = "https://github.com/niowniow/xarray.git", branch = "strided_rolling" }
pillow = "^6.2.1"
......
......@@ -595,6 +595,7 @@ class MHDSLRFilenames(DataSource):
start_time=None,
end_time=None,
force_write_to_remote=False,
as_pandas=True,
):
""" Fetches the DSLR images from the Matterhorn deployment, returns the image
filename(s) corresponding to the end and start time provided in either the
......@@ -614,6 +615,7 @@ class MHDSLRFilenames(DataSource):
start_time=start_time,
end_time=end_time,
force_write_to_remote=force_write_to_remote,
as_pandas=as_pandas
)
def forward(self, data=None, request=None):
......@@ -722,32 +724,45 @@ class MHDSLRFilenames(DataSource):
set_setting("image_list_df", imglist_df)
output_df = None
if end_time is None:
if start_time < imglist_df.index[0]:
start_time = imglist_df.index[0]
return imglist_df.iloc[
imglist_df.index.get_loc(start_time, method="nearest")
loc = imglist_df.index.get_loc(start_time, method="nearest")
output_df = imglist_df.iloc[loc:loc+1
]
else:
# if end_time.tzinfo is None:
# end_time = end_time.replace(tzinfo=timezone.utc)
if start_time > imglist_df.index[-1] or end_time < imglist_df.index[0]:
# return empty dataframe
return imglist_df[0:0]
output_df = imglist_df[0:0]
if start_time < imglist_df.index[0]:
start_time = imglist_df.index[0]
if end_time > imglist_df.index[-1]:
end_time = imglist_df.index[-1]
return imglist_df.iloc[
output_df = imglist_df.iloc[
imglist_df.index.get_loc(
start_time, method="bfill"
) : imglist_df.index.get_loc(end_time, method="ffill")
+ 1
]
if not request['as_pandas']:
output_df = output_df[['filename']] # TODO: do not get rid of end_time
output_df.index.rename('time',inplace=True)
# output = output_df.to_xarray(dims=["time"])
output = xr.Dataset.from_dataframe(output_df).to_array()
print(output)
# output = xr.DataArray(output_df['filename'], dims=["time"])
else:
output = output_df
return output
# TODO: write test for image_integrity_store
def image_integrity_store(
self, store, start_time=None, end_time=None, delta_seconds=0
......@@ -1150,21 +1165,22 @@ class BoundingBoxAnnotation(DataSource):
def __init__(
self,
filename=None,
start_time=None,
end_time=None,
store=None,
converters={"start_time": to_datetime, "end_time": to_datetime},
**kwargs,
):
super().__init__(
filename=filename,
start_time=start_time,
end_time=end_time,
store=store,
converters=converters,
kwargs=kwargs,
)
def forward(self, data=None, request=None):
csv = pd.read_csv(request["filename"])
if request['store'] is not None:
csv = read_csv_with_store(request['store'],request['filename'])
else:
csv = pd.read_csv(request["filename"])
targets = xr.DataArray(csv["__target"], dims=["index"], name="Annotation")
......@@ -1291,7 +1307,31 @@ def get_dataset_slices(ds, dims, dataset_slice=None, stride={}):
from torch.utils.data import Dataset
import torch
# from numba import jit
# @jit(nopython=True)
def get_label_slices(l):
# build label slices
label_coords = []
start_identifier = "start_"
for coord in l.coords:
# print(coord)
if coord.startswith(start_identifier):
label_coord = coord[len(start_identifier) :]
if "end_" + label_coord in l.coords:
label_coords += [label_coord]
label_slices = []
for i in range(len(l)): # TODO: This does not scale! It is really really slow
print(i,len(l))
selector = {}
for coord in label_coords:
selector[coord] = slice(
l[i]["start_" + coord].values, l[i]["end_" + coord].values
)
label_slices.append(selector)
return label_coords, label_slices
class SegmentedDataset(Dataset):
def __init__(
self,
......@@ -1318,8 +1358,18 @@ class SegmentedDataset(Dataset):
# TODO: this is inefficient for large datasets
# TODO: add dataset_slice to call of data
d = data()
l = label()
if isinstance(data,StuettNode):
request = { "start_"+k:v.start for k, v in dataset_slice.items()}
request.update({ "end_"+k:v.stop for k, v in dataset_slice.items()})
d = data(request)
else:
d = data
if isinstance(label,StuettNode):
l = label()
else:
l = label
# print(d['time'])
# print(l['time'])
......@@ -1337,31 +1387,9 @@ class SegmentedDataset(Dataset):
# indices = check_overlap(d,l)
slices = get_dataset_slices(d, batch_dims, dataset_slice=dataset_slice)
def get_label_slices(l):
# build label slices
label_coords = []
start_identifier = "start_"
for coord in l.coords:
# print(coord)
if coord.startswith(start_identifier):
label_coord = coord[len(start_identifier) :]
if "end_" + label_coord in l.coords:
label_coords += [label_coord]
label_slices = []
for i in range(len(l)):
selector = {}
for coord in label_coords:
selector[coord] = slice(
l[i]["start_" + coord].values, l[i]["end_" + coord].values
)
label_slices.append(selector)
return label_coords, label_slices
label_coords, label_slices = get_label_slices(l)
print(label_slices, slices)
# print(label_slices, slices)
# print(len(l),len(label_slices))
# filter out where we do not get data for the slice
......@@ -1408,7 +1436,7 @@ class SegmentedDataset(Dataset):
}
label_list = [label_dict[key] for key in label_dict]
print(label_list)
# print(label_list)
def check():
from plotly.subplots import make_subplots
......@@ -1509,31 +1537,32 @@ class SegmentedDataset(Dataset):
# TODO: get statistics to what was left out
def __len__(self):
raise NotImplementedError()
return len(self.label_list)
def __getitem__(self, idx):
raise NotImplementedError()
# if torch.is_tensor(idx):
# idx = idx.tolist()
if torch.is_tensor(idx):
idx = idx.tolist()
print(idx)
# print(idx)
segment = self.label_list[idx]
# segment = self.label_list[idx]
print(segment)
indexers = segment["indexers"]
# print(segment)
# indexers = segment["indexers"]
request = {
"start_time": segment["indexers"]["time"].start,
"end_time": segment["indexers"]["time"].stop,
}
data = configuration(self.data(delayed=True), request)
x = data.compute()
# request = {
# "start_time": segment["indexers"]["time"].start,
# "end_time": segment["indexers"]["time"].stop,
# }
# data = configuration(self.data(delayed=True), request)
# x = data.compute()
torch.Tensor(x.values)
print(x)
# torch.Tensor(x.values)
# print(x)
return segment
# return segment
class PytorchDataset(DataSource): # TODO: extends pytorch dataset
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment