management.py 52.9 KB
Newer Older
matthmey's avatar
matthmey committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
'''MIT License

Copyright (c) 2019, Swiss Federal Institute of Technology (ETH Zurich)

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.'''

matthmey's avatar
matthmey committed
23
from ..global_config import get_setting, setting_exists, set_setting
24
25
from ..core.graph import StuettNode

matthmey's avatar
matthmey committed
26
import os
27
28
29
30
31
32
33

import dask
import logging
import obspy
from obspy.clients.fdsn import Client
from obspy import UTCDateTime
from obsplus import obspy_to_array
matthmey's avatar
matthmey committed
34
from copy import deepcopy
35
import io
36

matthmey's avatar
matthmey committed
37
38
39
40
import zarr
import xarray as xr
from PIL import Image
import base64
matthmey's avatar
matthmey committed
41
import re
42

matthmey's avatar
matthmey committed
43
44
45
46
from pathlib import Path
import warnings

# TODO: revisit the following packages
47
48
import numpy as np
import pandas as pd
matthmey's avatar
matthmey committed
49
import datetime as dt
50
from pathlib import Path
matthmey's avatar
matthmey committed
51

52
53

class DataSource(StuettNode):
matthmey's avatar
matthmey committed
54
55
    def __init__(self, **kwargs):
        super().__init__(kwargs=kwargs)
56
57
58
59
60
61
62
63
64
65
66

    def __call__(self, data=None, request=None, delayed=False):
        if data is not None:
            if request is None:
                request = data
            else:
                warnings.warning(
                    "Two inputs (data, request) provided to the DataSource but it can only handle a request. Choosing request. "
                )

        # DataSource only require a request
matthmey's avatar
matthmey committed
67
        # Therefore merge permanent-config and request
68
69
70
71
        config = self.config.copy()  # TODO: do we need a deep copy?
        if request is not None:
            config.update(request)

matthmey's avatar
matthmey committed
72
        # TODO: change when rewriting for general indices
matthmey's avatar
matthmey committed
73
74
75
76
        if "start_time" in config and config["start_time"] is not None:
            config["start_time"] = pd.to_datetime(
                config["start_time"], utc=True
            ).tz_localize(
matthmey's avatar
matthmey committed
77
78
                None
            )  # TODO: change when xarray #3291 is fixed
matthmey's avatar
matthmey committed
79
80
81
82
        if "end_time" in config and config["end_time"] is not None:
            config["end_time"] = pd.to_datetime(
                config["end_time"], utc=True
            ).tz_localize(
matthmey's avatar
matthmey committed
83
                None
matthmey's avatar
matthmey committed
84
85
            )  # TODO: change when xarray #3291 is fixed

86
        if delayed:
matthmey's avatar
matthmey committed
87
            return dask.delayed(self.forward)(None, config)
88
        else:
matthmey's avatar
matthmey committed
89
            return self.forward(None, config)
90

matthmey's avatar
matthmey committed
91
    def configure(self, requests=None):
92
93
94
95
96
97
98
99
100
        """ Default configure for DataSource nodes
            Same as configure from StuettNode but adds is_source flag

        Arguments:
            request {list} -- List of requests

        Returns:
            dict -- Original request or merged requests 
        """
matthmey's avatar
matthmey committed
101
        requests = super().configure(requests)  # merging request here
matthmey's avatar
matthmey committed
102
        requests["requires_request"] = True
103
104
105

        return requests

matthmey's avatar
matthmey committed
106

107
108
class GSNDataSource(DataSource):
    def __init__(
matthmey's avatar
matthmey committed
109
110
111
112
113
114
115
        self,
        deployment=None,
        vsensor=None,
        position=None,
        start_time=None,
        end_time=None,
        **kwargs,
116
    ):
matthmey's avatar
matthmey committed
117
118
119
120
121
122
123
124
        super().__init__(
            deployment=deployment,
            position=position,
            vsensor=vsensor,
            start_time=start_time,
            end_time=end_time,
            kwargs=kwargs,
        )
125

matthmey's avatar
matthmey committed
126
    def forward(self, data=None, request=None):
127
128
        #### 1 - DEFINE VSENSOR-DEPENDENT COLUMNS ####
        colnames = pd.read_csv(
matthmey's avatar
matthmey committed
129
130
131
132
            Path(get_setting("metadata_directory")).joinpath(
                "vsensor_metadata/{:s}_{:s}.csv".format(
                    request["deployment"], request["vsensor"]
                )
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
            ),
            skiprows=0,
        )
        columns_old = colnames["colname_old"].values
        columns_new = colnames["colname_new"].values
        columns_unit = colnames["unit"].values
        if len(columns_old) != len(columns_new):
            warnings.warn(
                "WARNING: Length of 'columns_old' ({:d}) is not equal length of  'columns_new' ({:d})".format(
                    len(columns_old), len(columns_new)
                )
            )
        if len(columns_old) != len(columns_unit):
            warnings.warn(
                "WARNING: Length of 'columns_old' ({:d}) is not equal length of  'columns_new' ({:d})".format(
                    len(columns_old), len(columns_unit)
                )
            )

        unit = dict(zip(columns_new, columns_unit))
        #### 2 - DEFINE CONDITIONS AND CREATE HTTP QUERY ####

        # Set server
        server = get_setting("permasense_server")

        # Create virtual_sensor
        virtual_sensor = request["deployment"] + "_" + request["vsensor"]

        # Create query and add time as well as position selection
        query = (
            "vs[1]={:s}"
            "&time_format=iso"
            "&timeline=generation_time"
            "&field[1]=All"
            "&from={:s}"
            "&to={:s}"
            "&c_vs[1]={:s}"
            "&c_join[1]=and"
            "&c_field[1]=position"
            "&c_min[1]={:02d}"
            "&c_max[1]={:02d}"
        ).format(
            virtual_sensor,
matthmey's avatar
matthmey committed
176
177
178
179
            pd.to_datetime(request["start_time"], utc=True).strftime(
                "%d/%m/%Y+%H:%M:%S"
            ),
            pd.to_datetime(request["end_time"], utc=True).strftime("%d/%m/%Y+%H:%M:%S"),
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
            virtual_sensor,
            int(request["position"]) - 1,
            request["position"],
        )

        # query extension for images
        if request["vsensor"] == "binary__mapped":
            query = (
                query
                + "&vs[2]={:s}&field[2]=relative_file&c_join[2]=and&c_vs[2]={:s}&c_field[2]=file_complete&c_min[2]=0&c_max[2]=1&vs[3]={:s}&field[3]=generation_time&c_join[3]=and&c_vs[3]={:s}&c_field[3]=file_size&c_min[3]=2000000&c_max[3]=%2Binf&download_format=csv".format(
                    virtual_sensor, virtual_sensor, virtual_sensor, virtual_sensor
                )
            )

        # Construct url:
        url = server + "multidata?" + query
        # if self.verbose:
        #     print('The GSN http-query is:\n{:s}'.format(url))

        #### 3 - ACCESS DATA AND CREATE PANDAS DATAFRAME ####
        d = []
        d = pd.read_csv(
            url, skiprows=2
        )  # skip header lines (first 2) for import: skiprows=2
        df = pd.DataFrame(columns=columns_new)
matthmey's avatar
matthmey committed
205
        df.time = pd.to_datetime(d.generation_time, utc=True)
206
207
208
209

        # if depo in ['mh25', 'mh27old', 'mh27new', 'mh30', 'jj22', 'jj25', 'mh-nodehealth']:
        # d = d.convert_objects(convert_numeric=True)  # TODO: rewrite to remove warning
        for k in list(df):
matthmey's avatar
matthmey committed
210
            df[k] = pd.to_numeric(df[k], errors="ignore")
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232

        #        df = pd.DataFrame(columns=columns_new)
        #        df.time = timestamp2datetime(d['generation_time']/1000)
        for i in range(len(columns_old)):
            if columns_new[i] != "time":
                setattr(df, columns_new[i], getattr(d, columns_old[i]))

        df = df.sort_values(by="time")

        # Remove columns with names 'delete'
        try:
            df.drop(["delete"], axis=1, inplace=True)
        except:
            pass

        # Remove columns with only 'null'
        df = df.replace(r"null", np.nan, regex=True)
        isnull = df.isnull().all()
        [df.drop([col_name], axis=1, inplace=True) for col_name in df.columns[isnull]]

        df = df.set_index("time")
        df = df.sort_index(axis=1)
matthmey's avatar
matthmey committed
233

234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
        x = xr.DataArray(df, dims=["time", "name"], name="CSV")

        try:
            unit_coords = []
            for name in x.coords["name"].values:
                # name = re.sub(r"\[.*\]", "", name).lstrip().rstrip()
                u = unit[str(name)]
                u = re.findall(r"\[(.*)\]", u)[0]

                # name_coords.append(name)
                unit_coords.append(u)

            x = x.assign_coords({"unit": ("name", unit_coords)})
        except:
            # TODO: add a warning or test explicitly if units exist
            pass

        return x

253
254

class SeismicSource(DataSource):
matthmey's avatar
matthmey committed
255
256
    def __init__(
        self,
257
        path=None,
258
        store=None,
matthmey's avatar
matthmey committed
259
260
261
262
263
264
265
266
        station=None,
        channel=None,
        start_time=None,
        end_time=None,
        use_arclink=False,
        return_obspy=False,
        **kwargs,
    ):  # TODO: update description
267
268
269
270
        """ Seismic data source to get data from permasense
            The user can predefine the source's settings or provide them in a request
            Predefined setting should never be updated (to stay thread safe), but will be ignored
            if request contains settings
matthmey's avatar
matthmey committed
271
272
273
274
275

        Keyword Arguments:
            config {dict}       -- Configuration for the seismic source (default: {{}})
            use_arclink {bool}  -- If true, downloads the data from the arclink service (authentication required) (default: {False})
            return_obspy {bool} -- By default an xarray is returned. If true, an obspy stream will be returned (default: {False})
276
        """
matthmey's avatar
matthmey committed
277
        super().__init__(
278
            path=path,
279
            store=store,
matthmey's avatar
matthmey committed
280
281
282
283
284
285
286
287
            station=station,
            channel=channel,
            start_time=start_time,
            end_time=end_time,
            use_arclink=use_arclink,
            return_obspy=return_obspy,
            kwargs=kwargs,
        )
288

matthmey's avatar
matthmey committed
289
    def forward(self, data=None, request=None):
290
        config = request
291

matthmey's avatar
matthmey committed
292
        if config["use_arclink"]:
293
294
295
296
297
298
299
300
            try:
                arclink = get_setting("arclink")
            except KeyError as err:
                raise RuntimeError(
                    f"The following error occured \n{err}. "
                    "Please provide either the credentials to access arclink or a path to the dataset"
                )

matthmey's avatar
matthmey committed
301
302
            arclink_user = arclink["user"]
            arclink_password = arclink["password"]
matthmey's avatar
matthmey committed
303
            fdsn_client = Client(
matthmey's avatar
matthmey committed
304
305
306
307
                base_url="http://arclink.ethz.ch",
                user=arclink_user,
                password=arclink_password,
            )
matthmey's avatar
matthmey committed
308
            x = fdsn_client.get_waveforms(
matthmey's avatar
matthmey committed
309
310
311
312
313
314
315
316
                network="4D",
                station=config["station"],
                location="A",
                channel=config["channel"],
                starttime=UTCDateTime(config["start_time"]),
                endtime=UTCDateTime(config["end_time"]),
                attach_response=True,
            )
317

matthmey's avatar
matthmey committed
318
            # TODO: potentially resample
matthmey's avatar
matthmey committed
319

matthmey's avatar
matthmey committed
320
        else:  # 20180914 is last full day available in permasense_vault
321
            # logging.info('Loading seismic with fdsn')
matthmey's avatar
matthmey committed
322
            x = self.get_obspy_stream(
323
                config,
324
                config["path"],
matthmey's avatar
matthmey committed
325
326
327
328
329
330
                config["start_time"],
                config["end_time"],
                config["station"],
                config["channel"],
            )

331
332
333
334
335
        x = x.slice(UTCDateTime(config["start_time"]), UTCDateTime(config["end_time"]))

        # TODO: remove response x.remove_response(output=vel)
        # x = self.process_seismic_data(x)

matthmey's avatar
matthmey committed
336
        if not config["return_obspy"]:
matthmey's avatar
matthmey committed
337
338
            x = obspy_to_array(x)

339
            # we assume that all starttimes are equal
matthmey's avatar
matthmey committed
340
341
            starttime = x.starttime.values.reshape((-1,))[0]
            for s in x.starttime.values.reshape((-1,)):
342
343
344
345
346
                if s != starttime:
                    raise RuntimeError(
                        "Please make sure that starttime of each seimsic channel is equal"
                    )

matthmey's avatar
matthmey committed
347
            # change time coords from relative to absolute time
348
            starttime = obspy.UTCDateTime(starttime).datetime
matthmey's avatar
matthmey committed
349
350
351
            starttime = pd.to_datetime(starttime, utc=True).tz_localize(
                None
            )  # TODO: change when xarray #3291 is fixed
matthmey's avatar
matthmey committed
352
353
            timedeltas = pd.to_timedelta(x["time"].values, unit="seconds")
            xt = starttime + timedeltas
matthmey's avatar
matthmey committed
354
355
356
            x["time"] = pd.to_datetime(xt, utc=True).tz_localize(
                None
            )  # TODO: change when xarray #3291 is fixed
matthmey's avatar
matthmey committed
357
            del x.attrs["stats"]
358
359
360

        return x

361
362
363
364
365
366
367
368
    def process_seismic_data(
        self,
        stream,
        remove_response=True,
        unit="VEL",
        station_inventory=None,
        detrend=True,
        taper=False,
matthmey's avatar
matthmey committed
369
        pre_filt=(0.025, 0.05, 45.0, 49.0),
370
371
372
        water_level=60,
        apply_filter=True,
        freqmin=0.002,
matthmey's avatar
matthmey committed
373
        freqmax=50,
374
375
376
377
378
379
380
381
382
383
        resample=False,
        resample_rate=250,
        rotation_angle=None,
    ):
        # author: Samuel Weber
        if station_inventory is None:
            station_inventory = Path(get_setting("metadata_directory")).joinpath(
                "inventory_stations__MH.xml"
            )

384
        # print(station_inventory)
385
386
387
388
389
        inv = obspy.read_inventory(str(station_inventory))
        # st = stream.copy()
        st = stream
        st.attach_response(inv)

390
        if detrend:
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
            st.detrend("demean")
            st.detrend("linear")

        if taper:
            st.taper(max_percentage=0.05)

        if remove_response:
            if hasattr(st[0].stats, "response"):
                st.remove_response(
                    output=unit,
                    pre_filt=pre_filt,
                    plot=False,
                    zero_mean=False,
                    taper=False,
                    water_level=water_level,
                )
            else:
                st.remove_response(
                    output=unit,
                    inventory=inv,
                    pre_filt=pre_filt,
                    plot=False,
                    zero_mean=False,
                    taper=False,
                    water_level=water_level,
                )

418
        if detrend:
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
            st.detrend("demean")
            st.detrend("linear")

        if taper:
            st.taper(max_percentage=0.05)

        if apply_filter:
            st.filter("bandpass", freqmin=freqmin, freqmax=freqmax)

        if resample:
            st.resample(resample_rate)

        if rotation_angle is None:
            rotation_angle = np.nan
        if not np.isnan(rotation_angle):
            st.rotate("NE->RT", back_azimuth=int(rotation_angle), inventory=inv)

        return st

matthmey's avatar
matthmey committed
438
439
    def get_obspy_stream(
        self,
440
        request,
441
        path,
matthmey's avatar
matthmey committed
442
443
444
445
446
447
448
449
        start_time,
        end_time,
        station,
        channels,
        pad=False,
        verbose=False,
        fill=0,
        fill_sampling_rate=1000,
450
        old_stationname=False,
matthmey's avatar
matthmey committed
451
    ):
452
453
454
455
        """    
        Loads the microseismic data for the given timeframe into a miniseed file.

        Arguments:
matthmey's avatar
matthmey committed
456
457
            start_time {datetime} -- start timestamp of the desired obspy stream
            end_time {datetime} -- end timestamp of the desired obspy stream
458
459
460
461
462
463
464
465
466
467
468
        
        Keyword Arguments:
            pad {bool} -- If padding is true, the data will be zero padded if the data is not consistent
            fill {} -- If numpy.nan or fill value: error in the seismic stream will be filled with the value. If None no fill will be used
            verbose {bool} -- If info should be printed

        Returns:
            obspy stream -- obspy stream with up to three channels
                            the stream's channels will be sorted alphabetically
        """

matthmey's avatar
matthmey committed
469
        if not isinstance(channels, list):
470
            channels = [channels]
471
        
472
473

        # We will get the full hours seismic data and trim it to the desired length afterwards
matthmey's avatar
matthmey committed
474
475
476
477
        tbeg_hours = pd.to_datetime(start_time).replace(
            minute=0, second=0, microsecond=0
        )
        timerange = pd.date_range(start=tbeg_hours, end=end_time, freq="H")
478

matthmey's avatar
matthmey committed
479
        non_existing_files_ts = []  # keep track of nonexisting files
480
481
482
483
484
485
486
487
488
489

        # drawback of memmap files is that we need to calculate the size beforehand
        stream = obspy.Stream()

        idx = 0
        ## loop through all hours
        for i in range(len(timerange)):
            # start = time.time()
            h = timerange[i]

matthmey's avatar
matthmey committed
490
            st_list = obspy.Stream()
491

492
            datayear = timerange[i].strftime("%Y")
493
494
495
496
            if old_stationname:
                station = (
                    "MHDL" if station == "MH36" else "MHDT"
                )  # TODO: do not hardcode it
497
498
            filenames = {}
            for channel in channels:
499
500
501
502
503
504
505
506
                # filenames[channel] = Path(station).joinpath(
                #     datayear,
                #     "%s.D/" % channel,
                #     "4D.%s.A.%s.D." % (station, channel)
                #     + timerange[i].strftime("%Y%m%d_%H%M%S")
                #     + ".miniseed",
                # )
                filenames[channel] = [station,
507
                    datayear,
508
                    "%s.D" % channel,
509
                    "4D.%s.A.%s.D." % (station, channel)
matthmey's avatar
matthmey committed
510
                    + timerange[i].strftime("%Y%m%d_%H%M%S")
511
                    + ".miniseed",]
matthmey's avatar
matthmey committed
512
                # print(filenames[channel])
513

514
515
516
517
518
519
520
521
522
523
524
525
                # Load either from store or from filename
                if request['store'] is not None:
                    # get the file relative to the store
                    store = request["store"]
                    filename = '/'.join(filenames[channel])
                    st = obspy.read(io.BytesIO(store[str(filename)]))
                else:
                    if not Path(path).isdir():
                        # TODO: should this be an error or only a warning. In a period execution this could stop the whole script
                        raise IOError(
                            "Cannot find the path {}. Please provide a correct path to the permasense geophone directory".format(
                                datadir
matthmey's avatar
matthmey committed
526
527
                            )
                        )
528
529
530
                    datadir = Path(path)
                    if not datadir.joinpath(*filenames[channel]).exists():
                        non_existing_files_ts.append(timerange[i])
531

532
533
534
535
                        warnings.warn(
                            RuntimeWarning(
                                "Warning file not found for {} {}".format(
                                    timerange[i], datadir.joinpath(filenames[channel])
matthmey's avatar
matthmey committed
536
537
538
                                )
                            )
                        )
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557

                        if fill is not None:
                            # create empty stream of fill value
                            arr = (
                                np.ones(
                                    (
                                        int(
                                            np.ceil(
                                                fill_sampling_rate
                                                * dt.timedelta(hours=1).total_seconds()
                                            )
                                        ),
                                    )
                                )
                                * fill
                            )
                            st = obspy.Stream([obspy.Trace(arr)])
                        else:
                            continue
558
                    else:
559
560
                        filename = str(datadir.joinpath(*filenames[channel]))
                        st = obspy.read(filename)
561
562
563

                st_list += st

matthmey's avatar
matthmey committed
564
565
566
567
568
569
570
571
572
573
574
            stream_h = st_list.merge(method=0, fill_value=fill)
            start_time_0 = start_time if i == 0 else h
            end_time_0 = (
                end_time if i == len(timerange) - 1 else h + dt.timedelta(hours=1)
            )
            segment_h = stream_h.trim(
                obspy.UTCDateTime(start_time_0),
                obspy.UTCDateTime(end_time_0),
                pad=pad,
                fill_value=fill,
            )
575
576
577
578
            stream += segment_h

        stream = stream.merge(method=0, fill_value=fill)

matthmey's avatar
matthmey committed
579
580
581
        stream.sort(
            keys=["channel"]
        )  # TODO: change this so that the order of the input channels list is maintained
582
583
584
585

        return stream


586
class MHDSLRFilenames(DataSource):
matthmey's avatar
matthmey committed
587
    def __init__(
588
        self, base_directory=None, store=None, method="directory", start_time=None, end_time=None
matthmey's avatar
matthmey committed
589
590
591
592
593
594
595
596
597
598
599
600
    ):
        """ Fetches the DSLR images from the Matterhorn deployment, returns the image
            filename(s) corresponding to the end and start time provided in either the
            config dict or as a request to the __call__() function.          
        
        Arguments:
            StuettNode {[type]} -- [description]

        Keyword Arguments:
            base_directory {[type]} -- [description]
            method {str}     -- [description] (default: {'directory'})
        """
matthmey's avatar
matthmey committed
601
602
        super().__init__(
            base_directory=base_directory,
603
            store = store,
matthmey's avatar
matthmey committed
604
605
606
607
            method=method,
            start_time=start_time,
            end_time=end_time,
        )
608

matthmey's avatar
matthmey committed
609
    def forward(self, data=None, request=None):
matthmey's avatar
matthmey committed
610
611
612
613
        """Retrieves the images for the selected time period from the server. If only a start_time timestamp is provided, 
          the file with the corresponding date will be loaded if available. For periods (when start and end time are given) 
          all available images are indexed first to provide an efficient retrieval.
        
614
        Arguments:
matthmey's avatar
matthmey committed
615
616
617
618
619
620
621
            start_time {datetime} -- If only start_time is given the neareast available image is return. If also end_time is provided the a dataframe is returned containing image filenames from the first image after start_time until the last image before end_time.
        
        Keyword Arguments:
            end_time {datetime} -- end time of the selected period. see start_time for a description. (default: {None})
        
        Returns:
            dataframe -- Returns containing the image filenames of the selected period.
622
        """
623
        config = request
matthmey's avatar
matthmey committed
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
        methods = ["directory", "web"]
        if config["method"].lower() not in methods:
            raise RuntimeError(
                f"The {config['method']} output_format is not supported. Allowed formats are {methods}"
            )

        if config["base_directory"] is None and output_format.lower() != "web":
            raise RuntimeError("Please provide a base_directory containing the images")

        if config["method"].lower() == "web":  # TODO: implement
            raise NotImplementedError("web fetch has not been implemented yet")

        start_time = config["start_time"]
        end_time = config["end_time"]

        if start_time is None:
            raise RuntimeError("Please provide a least start time")

        # if it is not timezone aware make it
        # if start_time.tzinfo is None:
        #     start_time = start_time.replace(tzinfo=timezone.utc)

        # If there is no tmp_dir we can try to load the file directly, otherwise
        # there will be a warning later in this function and the user should
        # set a tmp_dir
        if end_time is None and (
            not setting_exists("user_dir") or not os.path.isdir(get_setting("user_dir"))
        ):
            image_filename = self.get_image_filename(start_time)
            if image_filename is not None:
                return image_filename

        # If we have already loaded the dataframe in the current session we can use it
        if setting_exists("image_list_df"):
            imglist_df = get_setting("image_list_df")
        else:
            if setting_exists("user_dir") and os.path.isdir(get_setting("user_dir")):
                # Otherwise we need to load the filename dataframe from disk
                imglist_filename = (
                    os.path.join(get_setting("user_dir"), "")
                    + "full_image_integrity.parquet"
                )

                # If it does not exist in the temporary folder of our application
                # We are going to create it
                if os.path.isfile(imglist_filename):
                    imglist_df = pd.read_parquet(
                        imglist_filename
                    )  # TODO: avoid too many different formats
                else:
                    # we are going to load the full list => no arguments
                    imglist_df = self.image_integrity(config["base_directory"])
                    imglist_df.to_parquet(imglist_filename)
            else:
                # if there is no tmp_dir we can load the image list but
                # we should warn the user that this is inefficient
                imglist_df = self.image_integrity(config["base_directory"])
                warnings.warn(
                    "No temporary directory was set. You can speed up multiple runs of your application by setting a temporary directory"
                )

            # TODO: make the index timezone aware
            imglist_df.set_index("start_time", inplace=True)
matthmey's avatar
matthmey committed
687
688
689
            imglist_df.index = pd.to_datetime(imglist_df.index, utc=True).tz_localize(
                None
            )  # TODO: change when xarray #3291 is fixed
matthmey's avatar
matthmey committed
690
691
692
693
694
695
696
            imglist_df.sort_index(inplace=True)

            set_setting("image_list_df", imglist_df)

        if end_time is None:
            if start_time < imglist_df.index[0]:
                start_time = imglist_df.index[0]
697

matthmey's avatar
matthmey committed
698
699
700
701
702
703
            return imglist_df.iloc[
                imglist_df.index.get_loc(start_time, method="nearest")
            ]
        else:
            # if end_time.tzinfo is None:
            #     end_time = end_time.replace(tzinfo=timezone.utc)
704
705
706
            if start_time > imglist_df.index[-1] or end_time < imglist_df.index[0]:
                # return empty dataframe
                return imglist_df[0:0]
matthmey's avatar
matthmey committed
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738

            if start_time < imglist_df.index[0]:
                start_time = imglist_df.index[0]
            if end_time > imglist_df.index[-1]:
                end_time = imglist_df.index[-1]

            return imglist_df.iloc[
                imglist_df.index.get_loc(
                    start_time, method="bfill"
                ) : imglist_df.index.get_loc(end_time, method="ffill")
                + 1
            ]

    def image_integrity(
        self, base_directory, start_time=None, end_time=None, delta_seconds=0
    ):
        """ Checks which images are available on the permasense server
        
        Keyword Arguments:
            start_time {[type]} -- datetime object giving the lower bound of the time range which should be checked. 
                                   If None there is no lower bound. (default: {None})
            end_time {[type]} --   datetime object giving the upper bound of the time range which should be checked.
                                   If None there is no upper bound (default: {None})
            delta_seconds {int} -- Determines the 'duration' of an image in the output dataframe.
                                   start_time  = image_time+delta_seconds
                                   end_time    = image_time-delta_seconds (default: {0})
        
        Returns:
            DataFrame -- Returns a pandas dataframe with a list containing filename relative to self.base_directory, 
                         start_time and end_time start_time and end_time can vary depending on the delta_seconds parameter
        """
        """ Checks which images are available on the permasense server
739

matthmey's avatar
matthmey committed
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
        Arguments:
            start_time:   
            end_time:   
            delta_seconds:  Determines the 'duration' of an image in the output dataframe.
                            start_time  = image_time+delta_seconds
                            end_time    = image_time-delta_seconds
        Returns:
            DataFrame -- 
        """
        if start_time is None:
            start_time = dt.datetime(
                1900, 1, 1
            )  # add random year which is before permasense installation started
        if end_time is None:
            end_time = dt.datetime.utcnow()

        tbeg_days = start_time.replace(hour=0, minute=0, second=0)
        tend_days = end_time.replace(hour=23, minute=59, second=59)

        delta_t = dt.timedelta(seconds=delta_seconds)
        num_filename_errors = 0
        images_list = []
        p = Path(base_directory)

        if not p.is_dir():
            warnings.warn(
                "Could not find the permasense image dataset. Please make sure it is available at {}".format(
                    str(p)
                )
            )

        for dir in p.glob("*/"):
            dir_date = dt.datetime.strptime(str(dir.name), "%Y-%m-%d")

            # limit the search to the explicit time range
            if dir_date < tbeg_days or dir_date > tend_days:
                continue

            for img_file in dir.glob("*"):
                # print(file.stem)
                start_time_str = img_file.stem

                try:
                    # start_time = datetime.strptime(start_time_str[:15], '%Y%m%d_%H%M%S')
                    start_time = dt.datetime.strptime(start_time_str, "%Y%m%d_%H%M%S")
                    if start_time <= start_time and start_time <= end_time:
                        images_list.append(
                            {
                                "filename": str(img_file.relative_to(base_directory)),
                                "start_time": start_time - delta_t,
                                "end_time": start_time + delta_t,
                            }
                        )
                except ValueError:
                    # try old naming convention
                    try:
                        # start_time = datetime.strptime(start_time_str[:17], '%Y-%m-%d_%H%M%S')
                        start_time = dt.datetime.strptime(
                            start_time_str, "%Y-%m-%d_%H%M%S"
                        )
                        if start_time <= start_time and start_time <= end_time:
                            images_list.append(
                                {
                                    "filename": str(
                                        img_file.relative_to(base_directory)
                                    ),
                                    "start_time": start_time - delta_t,
                                    "end_time": start_time + delta_t,
                                }
                            )
                    except ValueError:
                        num_filename_errors += 1
                        warnings.warn(
                            "Permasense data integrity, the following is not a valid image filename and will be ignored: %s"
                            % img_file
                        )
                        continue
817

matthmey's avatar
matthmey committed
818
819
820
821
822
823
824
825
826
827
828
829
        segments = pd.DataFrame(images_list)
        segments.drop_duplicates(inplace=True, subset="start_time")
        segments.start_time = pd.to_datetime(segments.start_time, utc=True)
        segments.end_time = pd.to_datetime(segments.end_time, utc=True)
        segments.sort_values("start_time")

        return segments

    def get_image_filename(self, timestamp):
        """ Checks wether an image exists for exactly the time of timestamp and returns its filename

            timestamp: datetime object for which the filename should be returned
830

matthmey's avatar
matthmey committed
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
        # Returns
            The filename if the file exists, None if there is no file
        """

        datadir = self.base_directory
        new_filename = (
            datadir
            + timestamp.strftime("%Y-%m-%d")
            + "/"
            + timestamp.strftime("%Y%m%d_%H%M%S")
            + ".JPG"
        )
        old_filename = (
            datadir
            + timestamp.strftime("%Y-%m-%d")
            + "/"
            + timestamp.strftime("%Y-%m-%d_%H%M%S")
            + ".JPG"
        )

        if os.path.isfile(new_filename):
            return new_filename
        elif os.path.isfile(old_filename):
            return old_filename
        else:
            return None

    def get_nearest_image_url(self, IMGparams, imgdate, floor=False):
        if floor:
            date_beg = imgdate - dt.timedelta(hours=4)
            date_end = imgdate
        else:
            date_beg = imgdate
            date_end = imgdate + dt.timedelta(hours=4)

        vs = []
        # predefine vs list
        field = []
        # predefine field list
        c_vs = []
        c_field = []
        c_join = []
        c_min = []
        c_max = []

        vs = vs + ["matterhorn_binary__mapped"]
        field = field + ["ALL"]
        # select only data from one sensor (position 3)
        c_vs = c_vs + ["matterhorn_binary__mapped"]
        c_field = c_field + ["position"]
        c_join = c_join + ["and"]
        c_min = c_min + ["18"]
        c_max = c_max + ["20"]

        c_vs = c_vs + ["matterhorn_binary__mapped"]
        c_field = c_field + ["file_complete"]
        c_join = c_join + ["and"]
        c_min = c_min + ["0"]
        c_max = c_max + ["1"]

        # create url which retrieves the csv data file
        url = "http://data.permasense.ch/multidata?"
        url = url + "time_format=" + "iso"
        url = url + "&from=" + date_beg.strftime("%d/%m/%Y+%H:%M:%S")
        url = url + "&to=" + date_end.strftime("%d/%m/%Y+%H:%M:%S")
        for i in range(0, len(vs), 1):
            url = url + "&vs[%d]=%s" % (i, vs[i])
            url = url + "&field[%d]=%s" % (i, field[i])

        for i in range(0, len(c_vs), 1):
            url = url + "&c_vs[%d]=%s" % (i, c_vs[i])
            url = url + "&c_field[%d]=%s" % (i, c_field[i])
            url = url + "&c_join[%d]=%s" % (i, c_join[i])
            url = url + "&c_min[%d]=%s" % (i, c_min[i])
            url = url + "&c_max[%d]=%s" % (i, c_max[i])

        url = url + "&timeline=%s" % ("generation_time")

        # print(url)
        d = pd.read_csv(url, skiprows=2)

        # print(d)

        # print(type(d['#data'].values))
        d["data"] = [s.replace("&amp;", "&") for s in d["data"].values]

        d.sort_values(by="generation_time")
        d["generation_time"] = pd.to_datetime(d["generation_time"], utc=True)

        if floor:
            data_str = d["data"].iloc[0]
            data_filename = d["relative_file"].iloc[0]
            # print(d['generation_time'].iloc[0])
            img_timestamp = d["generation_time"].iloc[0]
        else:
            data_str = d["data"].iloc[-1]
            data_filename = d["relative_file"].iloc[-1]
            # print(d['generation_time'].iloc[-1])
            img_timestamp = d["generation_time"].iloc[-1]

        file_extension = data_filename[-3:]
        base_url = "http://data.permasense.ch"
        # print(base_url + data_str)

        return base_url + data_str, img_timestamp, file_extension


class MHDSLRImages(MHDSLRFilenames):
    def __init__(
        self,
        base_directory=None,
942
        store=None,
matthmey's avatar
matthmey committed
943
944
945
946
947
948
949
        method="directory",
        output_format="xarray",
        start_time=None,
        end_time=None,
    ):
        super().__init__(
            base_directory=base_directory,
950
            store=store,
matthmey's avatar
matthmey committed
951
952
953
954
955
956
            method=method,
            start_time=start_time,
            end_time=end_time,
        )
        self.config["output_format"] = output_format

matthmey's avatar
matthmey committed
957
    def forward(self, data=None, request=None):
958
        filenames = super().forward(request=request)
matthmey's avatar
matthmey committed
959

960
        if request["output_format"] is "xarray":
matthmey's avatar
matthmey committed
961
            return self.construct_xarray(filenames)
962
        elif request["output_format"] is "base64":
matthmey's avatar
matthmey committed
963
964
965
966
            return self.construct_base64(filenames)
        else:
            output_formats = ["xarray", "base64"]
            raise RuntimeError(
967
                f"The {request['output_format']} output_format is not supported. Allowed formats are {output_formats}"
matthmey's avatar
matthmey committed
968
969
970
971
972
973
974
975
976
977
978
979
980
            )

    def construct_xarray(self, filenames):
        images = []
        times = []
        for timestamp, element in filenames.iterrows():
            filename = Path(self.config["base_directory"]).joinpath(element.filename)
            img = Image.open(filename)
            images.append(np.array(img))
            times.append(timestamp)

        images = np.array(images)
        data = xr.DataArray(
matthmey's avatar
matthmey committed
981
            images, coords={"time": times}, dims=["time", "x", "y", "c"], name="Image"
matthmey's avatar
matthmey committed
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
        )
        data.attrs["format"] = "jpg"

        return data

    def construct_base64(self, filenames):
        images = []
        times = []
        for timestamp, element in filenames.iterrows():
            filename = Path(self.config["base_directory"]).joinpath(element.filename)
            img = Image.open(filename)
            img_base64 = base64.b64encode(img.tobytes())
            images.append(img_base64)
            times.append(timestamp)

        images = np.array(images).reshape((-1, 1))
matthmey's avatar
matthmey committed
998
999
1000
        data = xr.DataArray(
            images, coords={"time": times}, dims=["time", "base64"], name="Base64Image"
        )
matthmey's avatar
matthmey committed
1001
1002
1003
1004
1005
1006
1007
1008
        data.attrs["format"] = "jpg"

        return data


class Freezer(StuettNode):
    def __init__(self, store):
        self.store = store
1009

matthmey's avatar
matthmey committed
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
    def configure(self, requests):
        """ 

        Arguments:
            request {list} -- List of requests

        Returns:
            dict -- Original, updated or merged request(s) 
        """
        requests = super().configure(requests)

        # TODO: check if data is available for requested time period

        # TODO: check how we need to update the boundaries such that we get data that fits and that is available
        # TODO: with only one start/end_time it might be inefficient for the case where [unavailable,available,unavailable] since we need to load everything
        #      one option could be to duplicate the graph by returning multiple requests...

        # TODO: make a distinction between requested start_time and freeze_output_start_time

matthmey's avatar
matthmey committed
1029
        # TODO: add node specific hash to freeze_output_start_time (there might be multiple in the graph) <- probably not necessary because we receive a copy of the request which is unique to this node
matthmey's avatar
matthmey committed
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
        # TODO: maybe the configuration method must add (and delete) the node name in the request?

        # we always require a request to crop out the right time period
        requests["requires_request"] = True

        return requests

    def to_zarr(self, x):
        x = x.to_dataset(name="frozen")
        # x = x.chunk({name: x[name].shape for name in list(x.dims)})
        # zarr_dataset = zarr.open(self.store, mode='r')
        x.to_zarr(self.store, append_dim="time")

    def open_zarr(self, requests):
        ds_zarr = xr.open_zarr(self.store)
1045
        print("read", ds_zarr)
matthmey's avatar
matthmey committed
1046

1047
    def forward(self, data=None, request=None):
matthmey's avatar
matthmey committed
1048

1049
1050
        self.to_zarr(data)
        self.open_zarr(request)
matthmey's avatar
matthmey committed
1051

1052
        return data
matthmey's avatar
matthmey committed
1053
1054
1055
1056
1057
        # TODO: check request start_time and load the data which is available, store the data which is not available
        # TODO: crop


class CsvSource(DataSource):
1058
    def __init__(self, filename=None, store=None, start_time=None, end_time=None, **kwargs):
matthmey's avatar
matthmey committed
1059
        super().__init__(
1060
            filename=filename, store=store, start_time=start_time, end_time=end_time, kwargs=kwargs
matthmey's avatar
matthmey committed
1061
1062
        )

matthmey's avatar
matthmey committed
1063
    def forward(self, data=None, request=None):
1064
1065
1066
1067
1068
1069
1070
1071
        # TODO: Implement properly
        if request['store'] is not None:
            # get the file relative to the store
            store = request["store"]
            filename = request["filename"]
            csv = pd.read_csv(io.StringIO(str(store[filename],'utf-8')))
        else:
            csv = pd.read_csv(request["filename"])
matthmey's avatar
matthmey committed
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
        csv.set_index("time", inplace=True)
        csv.index = pd.to_datetime(csv.index, utc=True).tz_localize(
            None
        )  # TODO: change when xarray #3291 is fixed

        x = xr.DataArray(csv, dims=["time", "name"], name="CSV")

        try:
            unit_coords = []
            name_coords = []
            for name in x.coords["name"].values:
                unit = re.findall(r"\[(.*)\]", name)[0]
                name = re.sub(r"\[.*\]", "", name).lstrip().rstrip()

                name_coords.append(name)
                unit_coords.append(unit)

            x.coords["name"] = name_coords
            x = x.assign_coords({"unit": ("name", unit_coords)})
        except:
            # TODO: add a warning or test explicitly if units exist
            pass

1095
1096
1097
1098
        if "start_time" not in request:
            request["start_time"] = x.coords["time"][0]
        if "end_time" not in request:
            request["end_time"] = x.coords["time"][-1]
matthmey's avatar
matthmey committed
1099

1100
        x = x.sel(time=slice(request["start_time"], request["end_time"]))
matthmey's avatar
matthmey committed
1101

1102
1103
1104
        return x


matthmey's avatar
matthmey committed
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
def to_datetime(x):
    return pd.to_datetime(x, utc=True).tz_localize(
        None
    )  # TODO: change when xarray #3291 is fixed


class BoundingBoxAnnotation(DataSource):
    def __init__(
        self,
        filename=None,
        start_time=None,
        end_time=None,
1117
        converters={"start_time": to_datetime,"end_time": to_datetime},
matthmey's avatar
matthmey committed
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
        **kwargs,
    ):
        super().__init__(
            filename=filename,
            start_time=start_time,
            end_time=end_time,
            converters=converters,
            kwargs=kwargs,
        )

matthmey's avatar
matthmey committed
1128
    def forward(self, data=None, request=None):
1129
        csv = pd.read_csv(request["filename"])
matthmey's avatar
matthmey committed
1130
1131
1132
1133
1134
1135
1136
1137

        targets = xr.DataArray(csv["__target"], dims=["index"], name="Annotation")

        for key in csv:
            if key == "__target":
                continue
            targets = targets.assign_coords({key: ("index", csv[key])})

1138
        for key in request["converters"]:
1139
            if key in csv:
1140
                converter = request["converters"][key]
1141
1142
1143
                if not callable(converter):
                    raise RuntimeError("Please provide a callable as column converter")
                targets = targets.assign_coords({key: ("index", converter(targets[key]))})
matthmey's avatar
matthmey committed
1144
1145
1146

        return targets

1147

1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
def check_overlap(data0, data1, index_dim, dims=[]):
    """ checks the overlap of two lists of xarray indexers (dicts with slices per dimension).
        Both list must be sorted by the index_dim dimension!
    
    Arguments:
        data0 {list} -- First list of xarray indexers 
        data1 {list} -- Second list of xarray indexers. Must be sorted by primarily by slice.start and secondary by slice.stop
        index_dim {string} -- Default dimension to iterate over. Indexers lists must be sorted by this dimension
    
    Keyword Arguments:
        dims {list} -- List of additional dimensions which should be checked for overlap (default: {[]})
    
    Returns:
        list -- List of indices of which items overlap. [...,[i,j],...] where item i of indexers0 overlaps with item j of indexers1
    """
1163

matthmey's avatar
matthmey committed
1164
1165
1166
1167
1168
1169
    overlap_indices = []
    # print(data0.head())
    num_overlaps = 0
    start_idx = 0
    for i in range(len(data0)):
        # data0_df = data0.iloc[i]
1170
1171
        data0_start = data0[i][index_dim].start
        data0_end = data0[i][index_dim].stop
matthmey's avatar
matthmey committed
1172
1173
        # print(data0_df['start_time'])
        ext = []
1174
        for j in range(start_idx, len(data1)):
matthmey's avatar
matthmey committed
1175
            # data1_df = data1.iloc[j]
1176
1177
            data1_start = data1[j][index_dim].start
            data1_end = data1[j][index_dim].stop
1178
            cond0 = data0_end < data1_start
matthmey's avatar
matthmey committed
1179
1180
1181
            if cond0 == True:
                break

1182
            # second condition: data0 is after data1, all items before data1 can be ignored (sorted list data0)
1183
            cond1 = data0_start > data1_end
matthmey's avatar
matthmey committed
1184

1185

1186
            if cond1:
1187
                # This only holds if data1 is sorted by both start and end index
matthmey's avatar
matthmey committed
1188
1189
1190
                start_idx = j

            if not (cond0 or cond1):
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
                # overlap on index dimension
                # check other dimensions
                overlap = True
                for dim in dims:
                    if dim == index_dim:
                        continue
                    d0_start = data0[i][dim].start
                    d0_end = data0[i][dim].stop
                    d1_start = data1[j][dim].start
                    d1_end = data1[j][dim].stop
                    if ((d0_end < d1_start) or (d0_start > d1_end)):
                        overlap = False
matthmey's avatar
matthmey committed
1203

1204
1205
                if overlap:
                    overlap_indices.append([int(i), int(j)])
matthmey's avatar
matthmey committed
1206

1207
    return overlap_indices
1208
1209


1210
1211
def get_dataset_slices(ds, dims, dataset_slice=None, stride={}):
    # thanks to xbatcher: https://github.com/rabernat/xbatcher/
1212
1213
1214
    dim_slices = []
    for dim in dims:
        if dataset_slice is None:
1215
1216
            segment_start = 0
            segment_end = ds.sizes[dim]
1217
        else:
1218
1219
1220
            segment_start = dataset_slice[dim].start
            segment_end   = dataset_slice[dim].stop
        
1221
1222
1223
        size = dims[dim]
        _stride = stride.get(dim, size)

1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
        if ds[dim].dtype == 'datetime64[ns]':
            # TODO: change when xarray #3291 is fixed
            iterator = pd.date_range(segment_start, segment_end,freq=_stride).tz_localize(None)
            segment_end = pd.to_datetime(segment_end).tz_localize(None)
        else:
            iterator = range(segment_start,segment_end,_stride)

        slices = []
        # TODO include hopsize/overlapping windows
        for start in iterator:
            end = start + size
            if end <= segment_end:
                slices.append(slice(start, end))

        dim_slices.append(slices)

    import itertools
    all_slices = []
    for slices in itertools.product(*dim_slices):
        selector = {key: slice for key, slice in zip(dims, slices)}
        all_slices.append(selector)
        
    return np.array(all_slices)
1247

1248
class SegmentedDataset(DataSource):
1249
    def __init__(
1250
1251
        self, data, label, dim="time", discard_empty = True, trim=True, dataset_slice=None, batch_dims={}, pad=False, mode = 'segments'

1252
    ):
matthmey's avatar
matthmey committed
1253

1254
1255
        """ trim ... trim the dataset to the available labels
            dataset_slice: which part of the dataset to use
1256
1257

            from xarray documentation: [...] slices are treated as inclusive of both the start and stop values, unlike normal Python indexing.
1258
        """
1259
        
matthmey's avatar
matthmey committed
1260
        # load annotation source and datasource
1261

matthmey's avatar
matthmey committed
1262
1263
1264
1265
        # define an dataset index containing all indices of the datasource (e.g. timestamps or time period) which should be in this dataset
        d = data()
        l = label()

1266
1267
1268
        # print(d['time'])
        # print(l['time'])

1269
1270
        d = d.sortby(dim)
        l = l.sortby([l["start_"+dim],l["end_"+dim]])
1271
1272

        # restrict it to the available labels
matthmey's avatar
matthmey committed
1273

1274
        # indices = check_overlap(d,l)
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
        slices = get_dataset_slices(d, batch_dims, dataset_slice=dataset_slice)

        def get_label_slices(l):
            # build label slices
            label_coords = []
            start_identifier='start_'
            for coord in l.coords:
                # print(coord)
                if coord.startswith(start_identifier):
                    label_coord = coord[len(start_identifier):]
                    if 'end_'+label_coord in l.coords:
                        label_coords += [label_coord]

            label_slices = []
            for i in range(len(l)):
                selector = {}
                for coord in label_coords:          
                    selector[coord] = slice(l[i]['start_'+coord].values,l[i]['end_'+coord].values)
                label_slices.append(selector)

            return label_coords, label_slices

        label_coords, label_slices = get_label_slices(l)

        # print(len(l),len(label_slices))
        # filter out where we do not get data for the slice
        # TODO: Is there a way to not loop through the whole dataset? 
        #       and the whole label set? but still use xarray indexing
        pos_mask = []


        label_dict = {}
        if mode == 'segments':
            overlaps = check_overlap(slices,label_slices,'time',label_coords)
            for o in overlaps:
                if discard_empty and d.sel(slices[o[0]]).size == 0:
                    continue
                # TODO: maybe this can be done faster (and cleaner)
                i = o[0]
                j = o[1]
                label = str(l[j].values)
                if i not in label_dict:
                    label_dict[i] = {'id': i, 'indexers':slices[i],'labels':[label]}
                else:
                    label_dict[i] = {'indexers':label_dict[i]['indexers'], 'labels': label_dict[i]['labels']+[label]}

        elif mode == 'points':
            for i in range(len(slices)):
                x = d.sel(slices[i])
                if x.size == 0:
                    continue
                for j in range(len(label_slices)):      # TODO: this might get really slow
                    y = x.sel(label_slices[j])
                    if y.size > 0:
                        # TODO: maybe this can be done faster (and cleaner)
                        label = str(l[j].values)
                        if i not in label_dict:
                            label_dict[i] = {'id': i, 'indexers':slices[i],'labels':[label]}
                        else:
                            label_dict[i] = {'indexers':label_dict[i]['indexers'], 'labels': label_dict[i]['labels']+[label]}

        label_list = [label_dict[key] for key in label_dict]
        print(label_list)


        from plotly.subplots import make_subplots   
        import plotly.graph_objects as go
        fig = go.Figure(layout=dict(title=dict(text="A Bar Chart"),xaxis={'type':"date"},xaxis_range=[pd.to_datetime('2017-08-01'),
                pd.to_datetime('2017-08-06')]))
        # fig = make_subplots(rows=3, cols=1, shared_xaxes=True, vertical_spacing = 0.)
        # fig.update_xaxes(range=[pd.to_datetime('2017-08-01'), pd.to_datetime('2017-08-03')])
        # fig.update_yaxes(range=[0, 2])  

        for i, sl in enumerate([slices,label_slices,label_list]):
            for item in sl:
                label = "None"
                if 'indexers' in item:
                    label = item['labels']
                    item = item['indexers']
                points = np.array([[pd.to_datetime(item['time'].start), 1],[pd.to_datetime(item['time'].stop) , 0]])
                # fig.add_shape(
                #     # Line reference to the axes
                #     go.layout.Shape(
                #         type="rect",
                #         xref="x",
                #         yref="paper",
                #         x0=pd.to_datetime(item['time'].start),
                #         y0=0,
                #         x1=pd.to_datetime(item['time'].stop),
                #         y1=1,
                #         fillcolor="LightSalmon",
                #         opacity=0.5,
                #         layer="below",
                #         line_width=0,
                #         # line=dict(
                #         #     color="LightSeaGreen",
                #         #     width=3,
                #         # ),
                # ))
                fig.add_trace(go.Scatter(x=[pd.to_datetime(item['time'].start),pd.to_datetime(item['time'].stop),pd.to_datetime(item['time'].stop),pd.to_datetime(item['time'].start)], y=[0,0,1,1],
                    fill='toself', fillcolor='darkviolet',
                    # marker={'size':0},
                    mode='lines',
                    hoveron = 'points+fills', # select where hover is active
                    line_color='darkviolet',
                    showlegend=False,
                    # line_width=0,
                    opacity=0.5,
                    text=str(label),
                    hoverinfo = 'text+x+y'))
                # fig.add_trace(go.Scatter(x=points[:,0], y=points[:,1],
                #     fill=None,
                #     mode='lines',
                #     line_color=None, line_width=0, 
                #     name="hv", line_shape='hv',
                #     showlegend=False,
                #     hovertext=str(label),
                #     hoveron = 'points+fills',
                #     ))
                # fig.add_trace(go.Scatter(
                #     x=points[:,0],
                #     y=np.zeros_like(points[:,1]),
                #     showlegend=False,
                #     hovertext=str(label),
                #     hoveron = 'points+fills',
                #     fill='tonexty', # fill area between trace0 and trace1
                #     mode='lines', line_width=0))
1402

1403
1404
        
        fig.show()
matthmey's avatar
matthmey committed
1405

1406
1407
        # TODO: get statistics to what was left out
        
matthmey's avatar
matthmey committed
1408
    def forward(self, data=None, request=None):
matthmey's avatar
matthmey committed
1409
        pass
matthmey's avatar
matthmey committed
1410

1411

matthmey's avatar
matthmey committed
1412
class PytorchDataset(DataSource):  # TODO: extends pytorch dataset
1413
    def __init__(self, source=None):
matthmey's avatar
matthmey committed
1414
        """ Creates a pytorch like dataset from a data source and a label source.
1415
1416
1417
1418
1419
            
        Arguments:
            DataSource {[type]} -- [description]
            config {dict} -- configuration for labels
        """
1420
        super().__init__(source=source)
1421

matthmey's avatar
matthmey committed
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
    def build_dataset(self):
        # load annotation source and datasource

        # define an dataset index containing all indices of the datasource (e.g. timestamps or time period) which should be in this dataset

        # go through dataset index and and check overlap of datasource indices and annotation indices

        # generate new annotation set with regards to the datasourceindices (requires to generate and empty annotation set add new labels to the it)

        # if wanted generate intermediate freeze results of datasource and annotations

        # go through all items of the datasource
        pass

matthmey's avatar
matthmey committed
1436
    def forward(self, data=None, request=None):
1437
        return x