Commit ac1b82ca authored by matthmey's avatar matthmey
Browse files

bug fixes

parent 9c624c2d
...@@ -614,6 +614,9 @@ class MHDSLRFilenames(DataSource): ...@@ -614,6 +614,9 @@ class MHDSLRFilenames(DataSource):
base_directory {[type]} -- [description] base_directory {[type]} -- [description]
method {str} -- [description] (default: {'directory'}) method {str} -- [description] (default: {'directory'})
""" """
if store is None and base_directory is not None:
store = DirectoryStore(base_directory)
super().__init__( super().__init__(
base_directory=base_directory, base_directory=base_directory,
store=store, store=store,
...@@ -710,13 +713,13 @@ class MHDSLRFilenames(DataSource): ...@@ -710,13 +713,13 @@ class MHDSLRFilenames(DataSource):
imglist_df = pd.read_csv(imglist_filename) imglist_df = pd.read_csv(imglist_filename)
else: else:
# we are going to load the full list => no arguments # we are going to load the full list => no arguments
imglist_df = self.image_integrity(config["base_directory"]) imglist_df = self.image_integrity_store(config["store"])
# imglist_df.to_parquet(imglist_filename) # imglist_df.to_parquet(imglist_filename)
imglist_df.to_csv(imglist_filename, index=False) imglist_df.to_csv(imglist_filename, index=False)
elif not success: elif not success:
# if there is no tmp_dir we can load the image list but # if there is no tmp_dir we can load the image list but
# we should warn the user that this is inefficient # we should warn the user that this is inefficient
imglist_df = self.image_integrity(config["base_directory"]) imglist_df = self.image_integrity_store(config["store"])
warnings.warn( warnings.warn(
"No temporary directory was set. You can speed up multiple runs of your application by setting a temporary directory" "No temporary directory was set. You can speed up multiple runs of your application by setting a temporary directory"
) )
...@@ -1357,7 +1360,7 @@ class SegmentedDataset(Dataset): ...@@ -1357,7 +1360,7 @@ class SegmentedDataset(Dataset):
discard_empty=True, discard_empty=True,
dataset_slice=None, dataset_slice=None,
batch_dims={}, batch_dims={},
mode="segments", segmentation_mode="segments",
): ):
self.data = data self.data = data
self.label = label self.label = label
...@@ -1365,7 +1368,7 @@ class SegmentedDataset(Dataset): ...@@ -1365,7 +1368,7 @@ class SegmentedDataset(Dataset):
self.discard_empty = discard_empty self.discard_empty = discard_empty
self.dataset_slice = dataset_slice self.dataset_slice = dataset_slice
self.batch_dims = batch_dims self.batch_dims = batch_dims
self.__mode = mode self.segmentation_mode = segmentation_mode
def compute_label_list(self): def compute_label_list(self):
""" discard_empty ... trim the dataset to the available labels """ discard_empty ... trim the dataset to the available labels
...@@ -1398,7 +1401,7 @@ class SegmentedDataset(Dataset): ...@@ -1398,7 +1401,7 @@ class SegmentedDataset(Dataset):
# and the whole label set? but still use xarray indexing # and the whole label set? but still use xarray indexing
label_dict = {} label_dict = {}
self.classes = [] self.classes = []
if self.__mode == "segments": if self.segmentation_mode == "segments":
overlaps = check_overlap( overlaps = check_overlap(
slices, label_slices, self.index_dim, requested_coords slices, label_slices, self.index_dim, requested_coords
) )
...@@ -1407,14 +1410,22 @@ class SegmentedDataset(Dataset): ...@@ -1407,14 +1410,22 @@ class SegmentedDataset(Dataset):
if self.discard_empty: if self.discard_empty:
# we need to load every single piece to check if it is empty # we need to load every single piece to check if it is empty
# TODO: loop through dims in batch_dim and check if they are correct # TODO: loop through dims in batch_dim and check if they are correct
if self.get_data(slices[o[0]]).size == 0: try:
if self.get_data(slices[o[0]]).size == 0:
continue
except Exception as e:
print(e)
continue continue
# TODO: maybe this can be done faster (and cleaner) # TODO: maybe this can be done faster (and cleaner)
i = o[0] i = o[0]
j = o[1] j = o[1]
label = str(l[j].values) label = l[j].squeeze().values
if label not in self.classes: # print(label, type(label))
self.classes.append(label) if l[j].notnull():
label = str(label)
if label not in self.classes:
self.classes.append(label)
if i not in label_dict: if i not in label_dict:
label_dict[i] = {"indexers": slices[i], "labels": [label]} label_dict[i] = {"indexers": slices[i], "labels": [label]}
elif label not in label_dict[i]["labels"]: elif label not in label_dict[i]["labels"]:
...@@ -1423,7 +1434,7 @@ class SegmentedDataset(Dataset): ...@@ -1423,7 +1434,7 @@ class SegmentedDataset(Dataset):
"labels": label_dict[i]["labels"] + [label], "labels": label_dict[i]["labels"] + [label],
} }
elif self.__mode == "points": elif self.segmentation_mode == "points":
for i in range(len(slices)): for i in range(len(slices)):
# x = d.sel(slices[i]) # x = d.sel(slices[i])
x = self.get_data(slices[i]) x = self.get_data(slices[i])
...@@ -1588,29 +1599,29 @@ class SegmentedDataset(Dataset): ...@@ -1588,29 +1599,29 @@ class SegmentedDataset(Dataset):
# return segment # return segment
class PytorchDataset(DataSource): # TODO: extends pytorch dataset # class PytorchDataset(DataSource): # TODO: extends pytorch dataset
def __init__(self, source=None): # def __init__(self, source=None):
""" Creates a pytorch like dataset from a data source and a label source. # """ Creates a pytorch like dataset from a data source and a label source.
Arguments: # Arguments:
DataSource {[type]} -- [description] # DataSource {[type]} -- [description]
config {dict} -- configuration for labels # config {dict} -- configuration for labels
""" # """
super().__init__(source=source) # super().__init__(source=source)
def build_dataset(self): # def build_dataset(self):
# load annotation source and datasource # # load annotation source and datasource
# define an dataset index containing all indices of the datasource (e.g. timestamps or time period) which should be in this dataset # # define an dataset index containing all indices of the datasource (e.g. timestamps or time period) which should be in this dataset
# go through dataset index and and check overlap of datasource indices and annotation indices # # go through dataset index and and check overlap of datasource indices and annotation indices
# generate new annotation set with regards to the datasourceindices (requires to generate and empty annotation set add new labels to the it) # # generate new annotation set with regards to the datasourceindices (requires to generate and empty annotation set add new labels to the it)
# if wanted generate intermediate freeze results of datasource and annotations # # if wanted generate intermediate freeze results of datasource and annotations
# go through all items of the datasource # # go through all items of the datasource
pass # pass
def forward(self, data=None, request=None): # def forward(self, data=None, request=None):
return x # return x
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment