To receive notifications about scheduled maintenance, please subscribe to the mailing-list gitlab-operations@sympa.ethz.ch. You can subscribe to the mailing-list at https://sympa.ethz.ch

Commit 7cb63d92 authored by matthmey's avatar matthmey
Browse files

work on datasets

parent b22376df
This diff is collapsed.
...@@ -27,7 +27,7 @@ toolz = "^0.10.0" ...@@ -27,7 +27,7 @@ toolz = "^0.10.0"
obspy = "^1.1.1" obspy = "^1.1.1"
numpy = "1.16.5" numpy = "1.16.5"
appdirs = "^1.4.3" appdirs = "^1.4.3"
obsplus = {git = "https://github.com/niosh-mining/obsplus" } obsplus = { git = "https://github.com/niosh-mining/obsplus" }
zarr = "^2.3.2" zarr = "^2.3.2"
xarray = { git = "https://github.com/niowniow/xarray.git", branch = "strided_rolling" } xarray = { git = "https://github.com/niowniow/xarray.git", branch = "strided_rolling" }
pillow = "^6.2.1" pillow = "^6.2.1"
......
...@@ -595,6 +595,7 @@ class MHDSLRFilenames(DataSource): ...@@ -595,6 +595,7 @@ class MHDSLRFilenames(DataSource):
start_time=None, start_time=None,
end_time=None, end_time=None,
force_write_to_remote=False, force_write_to_remote=False,
as_pandas=True,
): ):
""" Fetches the DSLR images from the Matterhorn deployment, returns the image """ Fetches the DSLR images from the Matterhorn deployment, returns the image
filename(s) corresponding to the end and start time provided in either the filename(s) corresponding to the end and start time provided in either the
...@@ -614,6 +615,7 @@ class MHDSLRFilenames(DataSource): ...@@ -614,6 +615,7 @@ class MHDSLRFilenames(DataSource):
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
force_write_to_remote=force_write_to_remote, force_write_to_remote=force_write_to_remote,
as_pandas=as_pandas
) )
def forward(self, data=None, request=None): def forward(self, data=None, request=None):
...@@ -722,32 +724,45 @@ class MHDSLRFilenames(DataSource): ...@@ -722,32 +724,45 @@ class MHDSLRFilenames(DataSource):
set_setting("image_list_df", imglist_df) set_setting("image_list_df", imglist_df)
output_df = None
if end_time is None: if end_time is None:
if start_time < imglist_df.index[0]: if start_time < imglist_df.index[0]:
start_time = imglist_df.index[0] start_time = imglist_df.index[0]
return imglist_df.iloc[ loc = imglist_df.index.get_loc(start_time, method="nearest")
imglist_df.index.get_loc(start_time, method="nearest") output_df = imglist_df.iloc[loc:loc+1
] ]
else: else:
# if end_time.tzinfo is None: # if end_time.tzinfo is None:
# end_time = end_time.replace(tzinfo=timezone.utc) # end_time = end_time.replace(tzinfo=timezone.utc)
if start_time > imglist_df.index[-1] or end_time < imglist_df.index[0]: if start_time > imglist_df.index[-1] or end_time < imglist_df.index[0]:
# return empty dataframe # return empty dataframe
return imglist_df[0:0] output_df = imglist_df[0:0]
if start_time < imglist_df.index[0]: if start_time < imglist_df.index[0]:
start_time = imglist_df.index[0] start_time = imglist_df.index[0]
if end_time > imglist_df.index[-1]: if end_time > imglist_df.index[-1]:
end_time = imglist_df.index[-1] end_time = imglist_df.index[-1]
return imglist_df.iloc[ output_df = imglist_df.iloc[
imglist_df.index.get_loc( imglist_df.index.get_loc(
start_time, method="bfill" start_time, method="bfill"
) : imglist_df.index.get_loc(end_time, method="ffill") ) : imglist_df.index.get_loc(end_time, method="ffill")
+ 1 + 1
] ]
if not request['as_pandas']:
output_df = output_df[['filename']] # TODO: do not get rid of end_time
output_df.index.rename('time',inplace=True)
# output = output_df.to_xarray(dims=["time"])
output = xr.Dataset.from_dataframe(output_df).to_array()
print(output)
# output = xr.DataArray(output_df['filename'], dims=["time"])
else:
output = output_df
return output
# TODO: write test for image_integrity_store # TODO: write test for image_integrity_store
def image_integrity_store( def image_integrity_store(
self, store, start_time=None, end_time=None, delta_seconds=0 self, store, start_time=None, end_time=None, delta_seconds=0
...@@ -1150,21 +1165,22 @@ class BoundingBoxAnnotation(DataSource): ...@@ -1150,21 +1165,22 @@ class BoundingBoxAnnotation(DataSource):
def __init__( def __init__(
self, self,
filename=None, filename=None,
start_time=None, store=None,
end_time=None,
converters={"start_time": to_datetime, "end_time": to_datetime}, converters={"start_time": to_datetime, "end_time": to_datetime},
**kwargs, **kwargs,
): ):
super().__init__( super().__init__(
filename=filename, filename=filename,
start_time=start_time, store=store,
end_time=end_time,
converters=converters, converters=converters,
kwargs=kwargs, kwargs=kwargs,
) )
def forward(self, data=None, request=None): def forward(self, data=None, request=None):
csv = pd.read_csv(request["filename"]) if request['store'] is not None:
csv = read_csv_with_store(request['store'],request['filename'])
else:
csv = pd.read_csv(request["filename"])
targets = xr.DataArray(csv["__target"], dims=["index"], name="Annotation") targets = xr.DataArray(csv["__target"], dims=["index"], name="Annotation")
...@@ -1291,7 +1307,31 @@ def get_dataset_slices(ds, dims, dataset_slice=None, stride={}): ...@@ -1291,7 +1307,31 @@ def get_dataset_slices(ds, dims, dataset_slice=None, stride={}):
from torch.utils.data import Dataset from torch.utils.data import Dataset
import torch import torch
# from numba import jit
# @jit(nopython=True)
def get_label_slices(l):
# build label slices
label_coords = []
start_identifier = "start_"
for coord in l.coords:
# print(coord)
if coord.startswith(start_identifier):
label_coord = coord[len(start_identifier) :]
if "end_" + label_coord in l.coords:
label_coords += [label_coord]
label_slices = []
for i in range(len(l)): # TODO: This does not scale! It is really really slow
print(i,len(l))
selector = {}
for coord in label_coords:
selector[coord] = slice(
l[i]["start_" + coord].values, l[i]["end_" + coord].values
)
label_slices.append(selector)
return label_coords, label_slices
class SegmentedDataset(Dataset): class SegmentedDataset(Dataset):
def __init__( def __init__(
self, self,
...@@ -1318,8 +1358,18 @@ class SegmentedDataset(Dataset): ...@@ -1318,8 +1358,18 @@ class SegmentedDataset(Dataset):
# TODO: this is inefficient for large datasets # TODO: this is inefficient for large datasets
# TODO: add dataset_slice to call of data # TODO: add dataset_slice to call of data
d = data()
l = label() if isinstance(data,StuettNode):
request = { "start_"+k:v.start for k, v in dataset_slice.items()}
request.update({ "end_"+k:v.stop for k, v in dataset_slice.items()})
d = data(request)
else:
d = data
if isinstance(label,StuettNode):
l = label()
else:
l = label
# print(d['time']) # print(d['time'])
# print(l['time']) # print(l['time'])
...@@ -1337,31 +1387,9 @@ class SegmentedDataset(Dataset): ...@@ -1337,31 +1387,9 @@ class SegmentedDataset(Dataset):
# indices = check_overlap(d,l) # indices = check_overlap(d,l)
slices = get_dataset_slices(d, batch_dims, dataset_slice=dataset_slice) slices = get_dataset_slices(d, batch_dims, dataset_slice=dataset_slice)
def get_label_slices(l):
# build label slices
label_coords = []
start_identifier = "start_"
for coord in l.coords:
# print(coord)
if coord.startswith(start_identifier):
label_coord = coord[len(start_identifier) :]
if "end_" + label_coord in l.coords:
label_coords += [label_coord]
label_slices = []
for i in range(len(l)):
selector = {}
for coord in label_coords:
selector[coord] = slice(
l[i]["start_" + coord].values, l[i]["end_" + coord].values
)
label_slices.append(selector)
return label_coords, label_slices
label_coords, label_slices = get_label_slices(l) label_coords, label_slices = get_label_slices(l)
print(label_slices, slices) # print(label_slices, slices)
# print(len(l),len(label_slices)) # print(len(l),len(label_slices))
# filter out where we do not get data for the slice # filter out where we do not get data for the slice
...@@ -1408,7 +1436,7 @@ class SegmentedDataset(Dataset): ...@@ -1408,7 +1436,7 @@ class SegmentedDataset(Dataset):
} }
label_list = [label_dict[key] for key in label_dict] label_list = [label_dict[key] for key in label_dict]
print(label_list) # print(label_list)
def check(): def check():
from plotly.subplots import make_subplots from plotly.subplots import make_subplots
...@@ -1509,31 +1537,32 @@ class SegmentedDataset(Dataset): ...@@ -1509,31 +1537,32 @@ class SegmentedDataset(Dataset):
# TODO: get statistics to what was left out # TODO: get statistics to what was left out
def __len__(self): def __len__(self):
raise NotImplementedError()
return len(self.label_list) return len(self.label_list)
def __getitem__(self, idx): def __getitem__(self, idx):
raise NotImplementedError()
# if torch.is_tensor(idx):
# idx = idx.tolist()
if torch.is_tensor(idx): # print(idx)
idx = idx.tolist()
print(idx)
segment = self.label_list[idx] # segment = self.label_list[idx]
print(segment) # print(segment)
indexers = segment["indexers"] # indexers = segment["indexers"]
request = { # request = {
"start_time": segment["indexers"]["time"].start, # "start_time": segment["indexers"]["time"].start,
"end_time": segment["indexers"]["time"].stop, # "end_time": segment["indexers"]["time"].stop,
} # }
data = configuration(self.data(delayed=True), request) # data = configuration(self.data(delayed=True), request)
x = data.compute() # x = data.compute()
torch.Tensor(x.values) # torch.Tensor(x.values)
print(x) # print(x)
return segment # return segment
class PytorchDataset(DataSource): # TODO: extends pytorch dataset class PytorchDataset(DataSource): # TODO: extends pytorch dataset
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment