management.py 59.3 KB
Newer Older
matthmey's avatar
black    
matthmey committed
1
"""MIT License
matthmey's avatar
matthmey committed
2

matthmey's avatar
matthmey committed
3
Copyright (c) 2019, Swiss Federal Institute of Technology (ETH Zurich), Matthias Meyer
matthmey's avatar
matthmey committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
matthmey's avatar
black    
matthmey committed
21
SOFTWARE."""
matthmey's avatar
matthmey committed
22

matthmey's avatar
matthmey committed
23
from ..global_config import get_setting, setting_exists, set_setting
matthmey's avatar
matthmey committed
24
25
from ..core.graph import StuettNode, configuration
from ..convenience import read_csv_with_store, to_csv_with_store, DirectoryStore
matthmey's avatar
matthmey committed
26
from ..convenience import indexers_to_request as i2r
27

matthmey's avatar
matthmey committed
28
import os
29
30
31
32
33
34
35

import dask
import logging
import obspy
from obspy.clients.fdsn import Client
from obspy import UTCDateTime
from obsplus import obspy_to_array
matthmey's avatar
matthmey committed
36
from copy import deepcopy
37
import io
38

matthmey's avatar
matthmey committed
39
40
41
from tqdm import tqdm


matthmey's avatar
matthmey committed
42
43
44
45
import zarr
import xarray as xr
from PIL import Image
import base64
matthmey's avatar
matthmey committed
46
import re
47

matthmey's avatar
matthmey committed
48
49
50
51
from pathlib import Path
import warnings

# TODO: revisit the following packages
52
53
import numpy as np
import pandas as pd
matthmey's avatar
matthmey committed
54
import datetime as dt
55
from pathlib import Path
matthmey's avatar
matthmey committed
56

57
58

class DataSource(StuettNode):
matthmey's avatar
matthmey committed
59
60
    def __init__(self, **kwargs):
        super().__init__(kwargs=kwargs)
61
62
63
64
65
66
67
68
69
70
71

    def __call__(self, data=None, request=None, delayed=False):
        if data is not None:
            if request is None:
                request = data
            else:
                warnings.warning(
                    "Two inputs (data, request) provided to the DataSource but it can only handle a request. Choosing request. "
                )

        # DataSource only require a request
matthmey's avatar
matthmey committed
72
        # Therefore merge permanent-config and request
73
74
75
76
        config = self.config.copy()  # TODO: do we need a deep copy?
        if request is not None:
            config.update(request)

matthmey's avatar
matthmey committed
77
        # TODO: change when rewriting for general indices
matthmey's avatar
matthmey committed
78
79
80
81
        if "start_time" in config and config["start_time"] is not None:
            config["start_time"] = pd.to_datetime(
                config["start_time"], utc=True
            ).tz_localize(
matthmey's avatar
matthmey committed
82
83
                None
            )  # TODO: change when xarray #3291 is fixed
matthmey's avatar
matthmey committed
84
85
86
87
        if "end_time" in config and config["end_time"] is not None:
            config["end_time"] = pd.to_datetime(
                config["end_time"], utc=True
            ).tz_localize(
matthmey's avatar
matthmey committed
88
                None
matthmey's avatar
matthmey committed
89
90
            )  # TODO: change when xarray #3291 is fixed

91
        if delayed:
matthmey's avatar
matthmey committed
92
            return dask.delayed(self.forward)(None, config)
93
        else:
matthmey's avatar
matthmey committed
94
            return self.forward(None, config)
95

matthmey's avatar
matthmey committed
96
    def configure(self, requests=None):
97
98
99
100
101
102
103
104
105
        """ Default configure for DataSource nodes
            Same as configure from StuettNode but adds is_source flag

        Arguments:
            request {list} -- List of requests

        Returns:
            dict -- Original request or merged requests 
        """
matthmey's avatar
matthmey committed
106
        requests = super().configure(requests)  # merging request here
matthmey's avatar
matthmey committed
107
        requests["requires_request"] = True
108
109
110

        return requests

matthmey's avatar
matthmey committed
111

112
113
class GSNDataSource(DataSource):
    def __init__(
matthmey's avatar
matthmey committed
114
115
116
117
118
119
120
        self,
        deployment=None,
        vsensor=None,
        position=None,
        start_time=None,
        end_time=None,
        **kwargs,
121
    ):
matthmey's avatar
matthmey committed
122
123
124
125
126
127
128
129
        super().__init__(
            deployment=deployment,
            position=position,
            vsensor=vsensor,
            start_time=start_time,
            end_time=end_time,
            kwargs=kwargs,
        )
130

matthmey's avatar
matthmey committed
131
    def forward(self, data=None, request=None):
matthmey's avatar
matthmey committed
132
        # code from Samuel Weber
133
        #### 1 - DEFINE VSENSOR-DEPENDENT COLUMNS ####
matthmey's avatar
matthmey committed
134
        metadata = Path(get_setting("metadata_directory")).joinpath(
matthmey's avatar
black    
matthmey committed
135
136
137
            "vsensor_metadata/{:s}_{:s}.csv".format(
                request["deployment"], request["vsensor"]
            )
138
        )
matthmey's avatar
black    
matthmey committed
139
140
141
        # if not metadata.exists():

        colnames = pd.read_csv(metadata, skiprows=0,)
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
        columns_old = colnames["colname_old"].values
        columns_new = colnames["colname_new"].values
        columns_unit = colnames["unit"].values
        if len(columns_old) != len(columns_new):
            warnings.warn(
                "WARNING: Length of 'columns_old' ({:d}) is not equal length of  'columns_new' ({:d})".format(
                    len(columns_old), len(columns_new)
                )
            )
        if len(columns_old) != len(columns_unit):
            warnings.warn(
                "WARNING: Length of 'columns_old' ({:d}) is not equal length of  'columns_new' ({:d})".format(
                    len(columns_old), len(columns_unit)
                )
            )

        unit = dict(zip(columns_new, columns_unit))
        #### 2 - DEFINE CONDITIONS AND CREATE HTTP QUERY ####

        # Set server
        server = get_setting("permasense_server")

        # Create virtual_sensor
        virtual_sensor = request["deployment"] + "_" + request["vsensor"]

        # Create query and add time as well as position selection
        query = (
            "vs[1]={:s}"
            "&time_format=iso"
            "&timeline=generation_time"
            "&field[1]=All"
            "&from={:s}"
            "&to={:s}"
            "&c_vs[1]={:s}"
            "&c_join[1]=and"
            "&c_field[1]=position"
            "&c_min[1]={:02d}"
            "&c_max[1]={:02d}"
        ).format(
            virtual_sensor,
matthmey's avatar
matthmey committed
182
183
184
185
            pd.to_datetime(request["start_time"], utc=True).strftime(
                "%d/%m/%Y+%H:%M:%S"
            ),
            pd.to_datetime(request["end_time"], utc=True).strftime("%d/%m/%Y+%H:%M:%S"),
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
            virtual_sensor,
            int(request["position"]) - 1,
            request["position"],
        )

        # query extension for images
        if request["vsensor"] == "binary__mapped":
            query = (
                query
                + "&vs[2]={:s}&field[2]=relative_file&c_join[2]=and&c_vs[2]={:s}&c_field[2]=file_complete&c_min[2]=0&c_max[2]=1&vs[3]={:s}&field[3]=generation_time&c_join[3]=and&c_vs[3]={:s}&c_field[3]=file_size&c_min[3]=2000000&c_max[3]=%2Binf&download_format=csv".format(
                    virtual_sensor, virtual_sensor, virtual_sensor, virtual_sensor
                )
            )

        # Construct url:
        url = server + "multidata?" + query
        # if self.verbose:
        #     print('The GSN http-query is:\n{:s}'.format(url))

        #### 3 - ACCESS DATA AND CREATE PANDAS DATAFRAME ####
        d = []
        d = pd.read_csv(
            url, skiprows=2
        )  # skip header lines (first 2) for import: skiprows=2
        df = pd.DataFrame(columns=columns_new)
matthmey's avatar
matthmey committed
211
        df.time = pd.to_datetime(d.generation_time, utc=True)
212
213
214
215

        # if depo in ['mh25', 'mh27old', 'mh27new', 'mh30', 'jj22', 'jj25', 'mh-nodehealth']:
        # d = d.convert_objects(convert_numeric=True)  # TODO: rewrite to remove warning
        for k in list(df):
matthmey's avatar
matthmey committed
216
            df[k] = pd.to_numeric(df[k], errors="ignore")
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238

        #        df = pd.DataFrame(columns=columns_new)
        #        df.time = timestamp2datetime(d['generation_time']/1000)
        for i in range(len(columns_old)):
            if columns_new[i] != "time":
                setattr(df, columns_new[i], getattr(d, columns_old[i]))

        df = df.sort_values(by="time")

        # Remove columns with names 'delete'
        try:
            df.drop(["delete"], axis=1, inplace=True)
        except:
            pass

        # Remove columns with only 'null'
        df = df.replace(r"null", np.nan, regex=True)
        isnull = df.isnull().all()
        [df.drop([col_name], axis=1, inplace=True) for col_name in df.columns[isnull]]

        df = df.set_index("time")
        df = df.sort_index(axis=1)
matthmey's avatar
matthmey committed
239

240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
        x = xr.DataArray(df, dims=["time", "name"], name="CSV")

        try:
            unit_coords = []
            for name in x.coords["name"].values:
                # name = re.sub(r"\[.*\]", "", name).lstrip().rstrip()
                u = unit[str(name)]
                u = re.findall(r"\[(.*)\]", u)[0]

                # name_coords.append(name)
                unit_coords.append(u)

            x = x.assign_coords({"unit": ("name", unit_coords)})
        except:
            # TODO: add a warning or test explicitly if units exist
            pass

        return x

259
260

class SeismicSource(DataSource):
matthmey's avatar
matthmey committed
261
262
    def __init__(
        self,
263
        path=None,
264
        store=None,
matthmey's avatar
matthmey committed
265
266
267
268
269
270
271
272
        station=None,
        channel=None,
        start_time=None,
        end_time=None,
        use_arclink=False,
        return_obspy=False,
        **kwargs,
    ):  # TODO: update description
273
274
275
276
        """ Seismic data source to get data from permasense
            The user can predefine the source's settings or provide them in a request
            Predefined setting should never be updated (to stay thread safe), but will be ignored
            if request contains settings
matthmey's avatar
matthmey committed
277
278
279
280
281

        Keyword Arguments:
            config {dict}       -- Configuration for the seismic source (default: {{}})
            use_arclink {bool}  -- If true, downloads the data from the arclink service (authentication required) (default: {False})
            return_obspy {bool} -- By default an xarray is returned. If true, an obspy stream will be returned (default: {False})
282
        """
matthmey's avatar
matthmey committed
283
        super().__init__(
284
            path=path,
285
            store=store,
matthmey's avatar
matthmey committed
286
287
288
289
290
291
292
293
            station=station,
            channel=channel,
            start_time=start_time,
            end_time=end_time,
            use_arclink=use_arclink,
            return_obspy=return_obspy,
            kwargs=kwargs,
        )
294

matthmey's avatar
matthmey committed
295
    def forward(self, data=None, request=None):
296
        config = request
297

matthmey's avatar
matthmey committed
298
        if config["use_arclink"]:
299
300
301
302
303
304
305
306
            try:
                arclink = get_setting("arclink")
            except KeyError as err:
                raise RuntimeError(
                    f"The following error occured \n{err}. "
                    "Please provide either the credentials to access arclink or a path to the dataset"
                )

matthmey's avatar
matthmey committed
307
308
            arclink_user = arclink["user"]
            arclink_password = arclink["password"]
matthmey's avatar
matthmey committed
309
            fdsn_client = Client(
matthmey's avatar
matthmey committed
310
311
312
313
                base_url="http://arclink.ethz.ch",
                user=arclink_user,
                password=arclink_password,
            )
matthmey's avatar
matthmey committed
314
            x = fdsn_client.get_waveforms(
matthmey's avatar
matthmey committed
315
316
317
318
319
320
321
322
                network="4D",
                station=config["station"],
                location="A",
                channel=config["channel"],
                starttime=UTCDateTime(config["start_time"]),
                endtime=UTCDateTime(config["end_time"]),
                attach_response=True,
            )
323

matthmey's avatar
matthmey committed
324
            # TODO: potentially resample
matthmey's avatar
matthmey committed
325

matthmey's avatar
matthmey committed
326
        else:  # 20180914 is last full day available in permasense_vault
327
            # logging.info('Loading seismic with fdsn')
matthmey's avatar
matthmey committed
328
            x = self.get_obspy_stream(
329
                config,
330
                config["path"],
matthmey's avatar
matthmey committed
331
332
333
334
335
336
                config["start_time"],
                config["end_time"],
                config["station"],
                config["channel"],
            )

337
338
339
340
341
        x = x.slice(UTCDateTime(config["start_time"]), UTCDateTime(config["end_time"]))

        # TODO: remove response x.remove_response(output=vel)
        # x = self.process_seismic_data(x)

matthmey's avatar
matthmey committed
342
        if not config["return_obspy"]:
matthmey's avatar
matthmey committed
343
344
            x = obspy_to_array(x)

345
            # we assume that all starttimes are equal
matthmey's avatar
matthmey committed
346
347
            starttime = x.starttime.values.reshape((-1,))[0]
            for s in x.starttime.values.reshape((-1,)):
348
                if s != starttime:
matthmey's avatar
matthmey committed
349
350
351
352
                    warnings.warn("Not all starttimes of each seismic channel is equal")
                    # raise RuntimeError(
                    #     "Please make sure that starttime of each seismic channel is equal"
                    # )
353

matthmey's avatar
matthmey committed
354
            # change time coords from relative to absolute time
355
            starttime = obspy.UTCDateTime(starttime).datetime
matthmey's avatar
matthmey committed
356
357
358
            starttime = pd.to_datetime(starttime, utc=True).tz_localize(
                None
            )  # TODO: change when xarray #3291 is fixed
matthmey's avatar
matthmey committed
359
360
            timedeltas = pd.to_timedelta(x["time"].values, unit="seconds")
            xt = starttime + timedeltas
matthmey's avatar
matthmey committed
361
362
363
            x["time"] = pd.to_datetime(xt, utc=True).tz_localize(
                None
            )  # TODO: change when xarray #3291 is fixed
matthmey's avatar
matthmey committed
364
            del x.attrs["stats"]
365

matthmey's avatar
matthmey committed
366
            # x.rename({'seed_id':'channels'}) #TODO: rename seed_id to channels
367
368
        return x

369
370
371
372
373
374
375
376
    def process_seismic_data(
        self,
        stream,
        remove_response=True,
        unit="VEL",
        station_inventory=None,
        detrend=True,
        taper=False,
matthmey's avatar
matthmey committed
377
        pre_filt=(0.025, 0.05, 45.0, 49.0),
378
379
380
        water_level=60,
        apply_filter=True,
        freqmin=0.002,
matthmey's avatar
matthmey committed
381
        freqmax=50,
382
383
384
385
386
387
388
389
390
391
        resample=False,
        resample_rate=250,
        rotation_angle=None,
    ):
        # author: Samuel Weber
        if station_inventory is None:
            station_inventory = Path(get_setting("metadata_directory")).joinpath(
                "inventory_stations__MH.xml"
            )

392
        # print(station_inventory)
393
394
395
396
397
        inv = obspy.read_inventory(str(station_inventory))
        # st = stream.copy()
        st = stream
        st.attach_response(inv)

398
        if detrend:
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
            st.detrend("demean")
            st.detrend("linear")

        if taper:
            st.taper(max_percentage=0.05)

        if remove_response:
            if hasattr(st[0].stats, "response"):
                st.remove_response(
                    output=unit,
                    pre_filt=pre_filt,
                    plot=False,
                    zero_mean=False,
                    taper=False,
                    water_level=water_level,
                )
            else:
                st.remove_response(
                    output=unit,
                    inventory=inv,
                    pre_filt=pre_filt,
                    plot=False,
                    zero_mean=False,
                    taper=False,
                    water_level=water_level,
                )

426
        if detrend:
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
            st.detrend("demean")
            st.detrend("linear")

        if taper:
            st.taper(max_percentage=0.05)

        if apply_filter:
            st.filter("bandpass", freqmin=freqmin, freqmax=freqmax)

        if resample:
            st.resample(resample_rate)

        if rotation_angle is None:
            rotation_angle = np.nan
        if not np.isnan(rotation_angle):
            st.rotate("NE->RT", back_azimuth=int(rotation_angle), inventory=inv)

        return st

matthmey's avatar
matthmey committed
446
447
    def get_obspy_stream(
        self,
448
        request,
449
        path,
matthmey's avatar
matthmey committed
450
451
452
453
454
455
456
457
        start_time,
        end_time,
        station,
        channels,
        pad=False,
        verbose=False,
        fill=0,
        fill_sampling_rate=1000,
458
        old_stationname=False,
matthmey's avatar
matthmey committed
459
    ):
460
461
462
463
        """    
        Loads the microseismic data for the given timeframe into a miniseed file.

        Arguments:
matthmey's avatar
matthmey committed
464
465
            start_time {datetime} -- start timestamp of the desired obspy stream
            end_time {datetime} -- end timestamp of the desired obspy stream
466
467
468
469
470
471
472
473
474
475
476
        
        Keyword Arguments:
            pad {bool} -- If padding is true, the data will be zero padded if the data is not consistent
            fill {} -- If numpy.nan or fill value: error in the seismic stream will be filled with the value. If None no fill will be used
            verbose {bool} -- If info should be printed

        Returns:
            obspy stream -- obspy stream with up to three channels
                            the stream's channels will be sorted alphabetically
        """

matthmey's avatar
matthmey committed
477
        if not isinstance(channels, list):
478
479
480
            channels = [channels]

        # We will get the full hours seismic data and trim it to the desired length afterwards
matthmey's avatar
matthmey committed
481
482
483
484
        tbeg_hours = pd.to_datetime(start_time).replace(
            minute=0, second=0, microsecond=0
        )
        timerange = pd.date_range(start=tbeg_hours, end=end_time, freq="H")
485

matthmey's avatar
matthmey committed
486
        non_existing_files_ts = []  # keep track of nonexisting files
487
488
489
490
491
492
493
494
495
496

        # drawback of memmap files is that we need to calculate the size beforehand
        stream = obspy.Stream()

        idx = 0
        ## loop through all hours
        for i in range(len(timerange)):
            # start = time.time()
            h = timerange[i]

matthmey's avatar
matthmey committed
497
            st_list = obspy.Stream()
498

499
            datayear = timerange[i].strftime("%Y")
500
501
502
503
            if old_stationname:
                station = (
                    "MHDL" if station == "MH36" else "MHDT"
                )  # TODO: do not hardcode it
504
505
            filenames = {}
            for channel in channels:
506
507
508
509
510
511
512
                # filenames[channel] = Path(station).joinpath(
                #     datayear,
                #     "%s.D/" % channel,
                #     "4D.%s.A.%s.D." % (station, channel)
                #     + timerange[i].strftime("%Y%m%d_%H%M%S")
                #     + ".miniseed",
                # )
matthmey's avatar
black    
matthmey committed
513
514
                filenames[channel] = [
                    station,
515
                    datayear,
516
                    "%s.D" % channel,
517
                    "4D.%s.A.%s.D." % (station, channel)
matthmey's avatar
matthmey committed
518
                    + timerange[i].strftime("%Y%m%d_%H%M%S")
matthmey's avatar
black    
matthmey committed
519
520
                    + ".miniseed",
                ]
matthmey's avatar
matthmey committed
521
                # print(filenames[channel])
522

523
                # Load either from store or from filename
matthmey's avatar
black    
matthmey committed
524
                if request["store"] is not None:
525
526
                    # get the file relative to the store
                    store = request["store"]
matthmey's avatar
black    
matthmey committed
527
                    filename = "/".join(filenames[channel])
528
529
530
531
532
533
534
                    st = obspy.read(io.BytesIO(store[str(filename)]))
                else:
                    if not Path(path).isdir():
                        # TODO: should this be an error or only a warning. In a period execution this could stop the whole script
                        raise IOError(
                            "Cannot find the path {}. Please provide a correct path to the permasense geophone directory".format(
                                datadir
matthmey's avatar
matthmey committed
535
536
                            )
                        )
537
538
539
                    datadir = Path(path)
                    if not datadir.joinpath(*filenames[channel]).exists():
                        non_existing_files_ts.append(timerange[i])
540

541
542
543
544
                        warnings.warn(
                            RuntimeWarning(
                                "Warning file not found for {} {}".format(
                                    timerange[i], datadir.joinpath(filenames[channel])
matthmey's avatar
matthmey committed
545
546
547
                                )
                            )
                        )
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566

                        if fill is not None:
                            # create empty stream of fill value
                            arr = (
                                np.ones(
                                    (
                                        int(
                                            np.ceil(
                                                fill_sampling_rate
                                                * dt.timedelta(hours=1).total_seconds()
                                            )
                                        ),
                                    )
                                )
                                * fill
                            )
                            st = obspy.Stream([obspy.Trace(arr)])
                        else:
                            continue
567
                    else:
568
569
                        filename = str(datadir.joinpath(*filenames[channel]))
                        st = obspy.read(filename)
570
571
572

                st_list += st

matthmey's avatar
matthmey committed
573
574
575
576
577
578
579
580
581
582
583
            stream_h = st_list.merge(method=0, fill_value=fill)
            start_time_0 = start_time if i == 0 else h
            end_time_0 = (
                end_time if i == len(timerange) - 1 else h + dt.timedelta(hours=1)
            )
            segment_h = stream_h.trim(
                obspy.UTCDateTime(start_time_0),
                obspy.UTCDateTime(end_time_0),
                pad=pad,
                fill_value=fill,
            )
584
585
586
587
            stream += segment_h

        stream = stream.merge(method=0, fill_value=fill)

matthmey's avatar
matthmey committed
588
589
590
        stream.sort(
            keys=["channel"]
        )  # TODO: change this so that the order of the input channels list is maintained
591
592
593
594

        return stream


595
class MHDSLRFilenames(DataSource):
matthmey's avatar
matthmey committed
596
    def __init__(
matthmey's avatar
black    
matthmey committed
597
598
599
600
601
602
603
        self,
        base_directory=None,
        store=None,
        method="directory",
        start_time=None,
        end_time=None,
        force_write_to_remote=False,
matthmey's avatar
matthmey committed
604
        as_pandas=True,
matthmey's avatar
matthmey committed
605
606
607
608
609
610
611
612
613
614
615
616
    ):
        """ Fetches the DSLR images from the Matterhorn deployment, returns the image
            filename(s) corresponding to the end and start time provided in either the
            config dict or as a request to the __call__() function.          
        
        Arguments:
            StuettNode {[type]} -- [description]

        Keyword Arguments:
            base_directory {[type]} -- [description]
            method {str}     -- [description] (default: {'directory'})
        """
matthmey's avatar
matthmey committed
617
618
619
        if store is None and base_directory is not None:
            store = DirectoryStore(base_directory)

matthmey's avatar
matthmey committed
620
621
        super().__init__(
            base_directory=base_directory,
matthmey's avatar
black    
matthmey committed
622
            store=store,
matthmey's avatar
matthmey committed
623
624
625
            method=method,
            start_time=start_time,
            end_time=end_time,
matthmey's avatar
matthmey committed
626
            force_write_to_remote=force_write_to_remote,
matthmey's avatar
matthmey committed
627
            as_pandas=as_pandas,
matthmey's avatar
matthmey committed
628
        )
629

matthmey's avatar
matthmey committed
630
    def forward(self, data=None, request=None):
matthmey's avatar
matthmey committed
631
632
633
634
        """Retrieves the images for the selected time period from the server. If only a start_time timestamp is provided, 
          the file with the corresponding date will be loaded if available. For periods (when start and end time are given) 
          all available images are indexed first to provide an efficient retrieval.
        
635
        Arguments:
matthmey's avatar
matthmey committed
636
637
638
639
640
641
642
            start_time {datetime} -- If only start_time is given the neareast available image is return. If also end_time is provided the a dataframe is returned containing image filenames from the first image after start_time until the last image before end_time.
        
        Keyword Arguments:
            end_time {datetime} -- end time of the selected period. see start_time for a description. (default: {None})
        
        Returns:
            dataframe -- Returns containing the image filenames of the selected period.
643
        """
644
        config = request
matthmey's avatar
matthmey committed
645
646
647
648
649
650
        methods = ["directory", "web"]
        if config["method"].lower() not in methods:
            raise RuntimeError(
                f"The {config['method']} output_format is not supported. Allowed formats are {methods}"
            )

matthmey's avatar
black    
matthmey committed
651
652
653
        if (
            config["base_directory"] is None and config["store"] is None
        ) and output_format.lower() != "web":
matthmey's avatar
matthmey committed
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
            raise RuntimeError("Please provide a base_directory containing the images")

        if config["method"].lower() == "web":  # TODO: implement
            raise NotImplementedError("web fetch has not been implemented yet")

        start_time = config["start_time"]
        end_time = config["end_time"]

        if start_time is None:
            raise RuntimeError("Please provide a least start time")

        # if it is not timezone aware make it
        # if start_time.tzinfo is None:
        #     start_time = start_time.replace(tzinfo=timezone.utc)

        # If there is no tmp_dir we can try to load the file directly, otherwise
        # there will be a warning later in this function and the user should
        # set a tmp_dir
        if end_time is None and (
            not setting_exists("user_dir") or not os.path.isdir(get_setting("user_dir"))
        ):
            image_filename = self.get_image_filename(start_time)
            if image_filename is not None:
                return image_filename

        # If we have already loaded the dataframe in the current session we can use it
        if setting_exists("image_list_df"):
            imglist_df = get_setting("image_list_df")
        else:
matthmey's avatar
matthmey committed
683
684
685
686
687
            filename = "image_integrity.csv"
            success = False
            # first try to load it from remote via store
            if config["store"] is not None:
                if filename in config["store"]:
matthmey's avatar
black    
matthmey committed
688
                    imglist_df = read_csv_with_store(config["store"], filename)
matthmey's avatar
matthmey committed
689
                    success = True
matthmey's avatar
black    
matthmey committed
690
                elif config["force_write_to_remote"]:
matthmey's avatar
matthmey committed
691
692
693
                    # try to reload it and write to remote
                    imglist_df = self.image_integrity_store(config["store"])
                    try:
matthmey's avatar
black    
matthmey committed
694
                        to_csv_with_store(config["store"], filename, imglist_df)
matthmey's avatar
matthmey committed
695
696
697
698
699
                        success = True
                    except Exception as e:
                        print(e)

            # Otherwise we need to load the filename dataframe from disk
matthmey's avatar
black    
matthmey committed
700
701
702
703
704
705
            if (
                not success
                and setting_exists("user_dir")
                and os.path.isdir(get_setting("user_dir"))
            ):
                imglist_filename = os.path.join(get_setting("user_dir"), "") + filename
matthmey's avatar
matthmey committed
706
707
708
709

                # If it does not exist in the temporary folder of our application
                # We are going to create it
                if os.path.isfile(imglist_filename):
matthmey's avatar
matthmey committed
710
711
712
                    # imglist_df = pd.read_parquet(
                    #     imglist_filename
                    # )  # TODO: avoid too many different formats
matthmey's avatar
black    
matthmey committed
713
                    imglist_df = pd.read_csv(imglist_filename)
matthmey's avatar
matthmey committed
714
715
                else:
                    # we are going to load the full list => no arguments
matthmey's avatar
matthmey committed
716
                    imglist_df = self.image_integrity_store(config["store"])
matthmey's avatar
matthmey committed
717
                    # imglist_df.to_parquet(imglist_filename)
matthmey's avatar
black    
matthmey committed
718
                    imglist_df.to_csv(imglist_filename, index=False)
matthmey's avatar
matthmey committed
719
            elif not success:
matthmey's avatar
matthmey committed
720
721
                # if there is no tmp_dir we can load the image list but
                # we should warn the user that this is inefficient
matthmey's avatar
matthmey committed
722
                imglist_df = self.image_integrity_store(config["store"])
matthmey's avatar
matthmey committed
723
724
725
726
727
728
                warnings.warn(
                    "No temporary directory was set. You can speed up multiple runs of your application by setting a temporary directory"
                )

            # TODO: make the index timezone aware
            imglist_df.set_index("start_time", inplace=True)
matthmey's avatar
matthmey committed
729
730
731
            imglist_df.index = pd.to_datetime(imglist_df.index, utc=True).tz_localize(
                None
            )  # TODO: change when xarray #3291 is fixed
matthmey's avatar
matthmey committed
732
733
734
735
            imglist_df.sort_index(inplace=True)

            set_setting("image_list_df", imglist_df)

matthmey's avatar
matthmey committed
736
        output_df = None
matthmey's avatar
matthmey committed
737
738
739
        if end_time is None:
            if start_time < imglist_df.index[0]:
                start_time = imglist_df.index[0]
740

matthmey's avatar
matthmey committed
741
            loc = imglist_df.index.get_loc(start_time, method="nearest")
matthmey's avatar
matthmey committed
742
            output_df = imglist_df.iloc[loc : loc + 1]
matthmey's avatar
matthmey committed
743
744
745
        else:
            # if end_time.tzinfo is None:
            #     end_time = end_time.replace(tzinfo=timezone.utc)
746
747
            if start_time > imglist_df.index[-1] or end_time < imglist_df.index[0]:
                # return empty dataframe
matthmey's avatar
matthmey committed
748
                output_df = imglist_df[0:0]
matthmey's avatar
matthmey committed
749
750
751
752
753
754
755
756
757
758
759
760
            else:
                if start_time < imglist_df.index[0]:
                    start_time = imglist_df.index[0]
                if end_time > imglist_df.index[-1]:
                    end_time = imglist_df.index[-1]

                output_df = imglist_df.iloc[
                    imglist_df.index.get_loc(
                        start_time, method="bfill"
                    ) : imglist_df.index.get_loc(end_time, method="ffill")
                    + 1
                ]
matthmey's avatar
matthmey committed
761

matthmey's avatar
matthmey committed
762
763
764
        if not request["as_pandas"]:
            output_df = output_df[["filename"]]  # TODO: do not get rid of end_time
            output_df.index.rename("time", inplace=True)
matthmey's avatar
matthmey committed
765
766
            # output = output_df.to_xarray(dims=["time"])
            output = xr.Dataset.from_dataframe(output_df).to_array()
matthmey's avatar
matthmey committed
767
            # print(output)
matthmey's avatar
matthmey committed
768
769
770
771
772
            # output = xr.DataArray(output_df['filename'], dims=["time"])
        else:
            output = output_df
        return output

matthmey's avatar
matthmey committed
773
774
775
    # TODO: write test for image_integrity_store
    def image_integrity_store(
        self, store, start_time=None, end_time=None, delta_seconds=0
matthmey's avatar
matthmey committed
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
    ):
        """ Checks which images are available on the permasense server
        
        Keyword Arguments:
            start_time {[type]} -- datetime object giving the lower bound of the time range which should be checked. 
                                   If None there is no lower bound. (default: {None})
            end_time {[type]} --   datetime object giving the upper bound of the time range which should be checked.
                                   If None there is no upper bound (default: {None})
            delta_seconds {int} -- Determines the 'duration' of an image in the output dataframe.
                                   start_time  = image_time+delta_seconds
                                   end_time    = image_time-delta_seconds (default: {0})
        
        Returns:
            DataFrame -- Returns a pandas dataframe with a list containing filename relative to self.base_directory, 
                         start_time and end_time start_time and end_time can vary depending on the delta_seconds parameter
        """
        """ Checks which images are available on the permasense server
793

matthmey's avatar
matthmey committed
794
795
796
797
798
799
800
801
802
803
        Arguments:
            start_time:   
            end_time:   
            delta_seconds:  Determines the 'duration' of an image in the output dataframe.
                            start_time  = image_time+delta_seconds
                            end_time    = image_time-delta_seconds
        Returns:
            DataFrame -- 
        """
        if start_time is None:
matthmey's avatar
black    
matthmey committed
804
805
            # a random year which is before permasense installation started
            start_time = dt.datetime(1900, 1, 1)
matthmey's avatar
matthmey committed
806
807
808
809
810
811
812
813
814
815
        if end_time is None:
            end_time = dt.datetime.utcnow()

        tbeg_days = start_time.replace(hour=0, minute=0, second=0)
        tend_days = end_time.replace(hour=23, minute=59, second=59)

        delta_t = dt.timedelta(seconds=delta_seconds)
        num_filename_errors = 0
        images_list = []

matthmey's avatar
matthmey committed
816
817
818
819
820
821
822
823
        for key in store.keys():
            try:
                pathkey = Path(key)
                datekey = pathkey.parent.name
                dir_date = pd.to_datetime(str(datekey), format="%Y-%m-%d")
            except:
                # we do not care for files not matching our format
                continue
matthmey's avatar
black    
matthmey committed
824

matthmey's avatar
matthmey committed
825
826
            if pd.isnull(dir_date):
                continue
matthmey's avatar
matthmey committed
827
828
829
830
831

            # limit the search to the explicit time range
            if dir_date < tbeg_days or dir_date > tend_days:
                continue

matthmey's avatar
matthmey committed
832
833
834
835
836
837
838
839
840
841
842
843
844
845
            # print(file.stem)
            start_time_str = pathkey.stem
            try:
                _start_time = pd.to_datetime(start_time_str, format="%Y%m%d_%H%M%S")
                if start_time <= _start_time and _start_time <= end_time:
                    images_list.append(
                        {
                            "filename": str(key),
                            "start_time": _start_time - delta_t,
                            "end_time": _start_time + delta_t,
                        }
                    )
            except ValueError:
                # try old naming convention
matthmey's avatar
matthmey committed
846
                try:
matthmey's avatar
matthmey committed
847
848
849
                    start_time = pd.to_datetime(start_time_str, format="%Y%m%d_%H%M%S")

                    if start_time <= _start_time and _start_time <= end_time:
matthmey's avatar
matthmey committed
850
851
                        images_list.append(
                            {
matthmey's avatar
black    
matthmey committed
852
                                "filename": str(img_file.relative_to(base_directory)),
matthmey's avatar
matthmey committed
853
854
                                "start_time": _start_time - delta_t,
                                "end_time": _start_time + delta_t,
matthmey's avatar
matthmey committed
855
856
857
                            }
                        )
                except ValueError:
matthmey's avatar
matthmey committed
858
859
860
861
862
863
                    num_filename_errors += 1
                    warnings.warn(
                        "Permasense data integrity, the following is not a valid image filename and will be ignored: %s"
                        % img_file
                    )
                    continue
864

matthmey's avatar
matthmey committed
865
866
867
868
869
870
871
872
        segments = pd.DataFrame(images_list)
        segments.drop_duplicates(inplace=True, subset="start_time")
        segments.start_time = pd.to_datetime(segments.start_time, utc=True)
        segments.end_time = pd.to_datetime(segments.end_time, utc=True)
        segments.sort_values("start_time")

        return segments

matthmey's avatar
matthmey committed
873
874
875
876
    def image_integrity(
        self, base_directory, start_time=None, end_time=None, delta_seconds=0
    ):
        store = DirectoryStore(base_directory)
matthmey's avatar
black    
matthmey committed
877
        return self.image_integrity_store(store, start_time, end_time, delta_seconds)
matthmey's avatar
matthmey committed
878

matthmey's avatar
matthmey committed
879
880
881
882
    def get_image_filename(self, timestamp):
        """ Checks wether an image exists for exactly the time of timestamp and returns its filename

            timestamp: datetime object for which the filename should be returned
883

matthmey's avatar
matthmey committed
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
        # Returns
            The filename if the file exists, None if there is no file
        """

        datadir = self.base_directory
        new_filename = (
            datadir
            + timestamp.strftime("%Y-%m-%d")
            + "/"
            + timestamp.strftime("%Y%m%d_%H%M%S")
            + ".JPG"
        )
        old_filename = (
            datadir
            + timestamp.strftime("%Y-%m-%d")
            + "/"
            + timestamp.strftime("%Y-%m-%d_%H%M%S")
            + ".JPG"
        )

        if os.path.isfile(new_filename):
            return new_filename
        elif os.path.isfile(old_filename):
            return old_filename
        else:
            return None

    def get_nearest_image_url(self, IMGparams, imgdate, floor=False):
        if floor:
            date_beg = imgdate - dt.timedelta(hours=4)
            date_end = imgdate
        else:
            date_beg = imgdate
            date_end = imgdate + dt.timedelta(hours=4)

        vs = []
        # predefine vs list
        field = []
        # predefine field list
        c_vs = []
        c_field = []
        c_join = []
        c_min = []
        c_max = []

        vs = vs + ["matterhorn_binary__mapped"]
        field = field + ["ALL"]
        # select only data from one sensor (position 3)
        c_vs = c_vs + ["matterhorn_binary__mapped"]
        c_field = c_field + ["position"]
        c_join = c_join + ["and"]
        c_min = c_min + ["18"]
        c_max = c_max + ["20"]

        c_vs = c_vs + ["matterhorn_binary__mapped"]
        c_field = c_field + ["file_complete"]
        c_join = c_join + ["and"]
        c_min = c_min + ["0"]
        c_max = c_max + ["1"]

        # create url which retrieves the csv data file
        url = "http://data.permasense.ch/multidata?"
        url = url + "time_format=" + "iso"
        url = url + "&from=" + date_beg.strftime("%d/%m/%Y+%H:%M:%S")
        url = url + "&to=" + date_end.strftime("%d/%m/%Y+%H:%M:%S")
        for i in range(0, len(vs), 1):
            url = url + "&vs[%d]=%s" % (i, vs[i])
            url = url + "&field[%d]=%s" % (i, field[i])

        for i in range(0, len(c_vs), 1):
            url = url + "&c_vs[%d]=%s" % (i, c_vs[i])
            url = url + "&c_field[%d]=%s" % (i, c_field[i])
            url = url + "&c_join[%d]=%s" % (i, c_join[i])
            url = url + "&c_min[%d]=%s" % (i, c_min[i])
            url = url + "&c_max[%d]=%s" % (i, c_max[i])

        url = url + "&timeline=%s" % ("generation_time")

        # print(url)
        d = pd.read_csv(url, skiprows=2)

        # print(d)

        # print(type(d['#data'].values))
        d["data"] = [s.replace("&amp;", "&") for s in d["data"].values]

        d.sort_values(by="generation_time")
        d["generation_time"] = pd.to_datetime(d["generation_time"], utc=True)

        if floor:
            data_str = d["data"].iloc[0]
            data_filename = d["relative_file"].iloc[0]
            # print(d['generation_time'].iloc[0])
            img_timestamp = d["generation_time"].iloc[0]
        else:
            data_str = d["data"].iloc[-1]
            data_filename = d["relative_file"].iloc[-1]
            # print(d['generation_time'].iloc[-1])
            img_timestamp = d["generation_time"].iloc[-1]

        file_extension = data_filename[-3:]
        base_url = "http://data.permasense.ch"
        # print(base_url + data_str)

        return base_url + data_str, img_timestamp, file_extension


class MHDSLRImages(MHDSLRFilenames):
    def __init__(
        self,
        base_directory=None,
995
        store=None,
matthmey's avatar
matthmey committed
996
997
998
999
1000
        method="directory",
        output_format="xarray",
        start_time=None,
        end_time=None,
    ):
matthmey's avatar
matthmey committed
1001
1002
1003
        if store is None and base_directory is not None:
            store = DirectoryStore(base_directory)

matthmey's avatar
matthmey committed
1004
        super().__init__(
matthmey's avatar
matthmey committed
1005
            base_directory=None,
1006
            store=store,
matthmey's avatar
matthmey committed
1007
1008
1009
1010
            method=method,
            start_time=start_time,
            end_time=end_time,
        )
matthmey's avatar
matthmey committed
1011

matthmey's avatar
matthmey committed
1012
1013
        self.config["output_format"] = output_format

matthmey's avatar
matthmey committed
1014
    def forward(self, data=None, request=None):
1015
        filenames = super().forward(request=request)
matthmey's avatar
matthmey committed
1016

1017
        if request["output_format"] is "xarray":
matthmey's avatar
matthmey committed
1018
            return self.construct_xarray(filenames)
1019
        elif request["output_format"] is "base64":
matthmey's avatar
matthmey committed
1020
1021
1022
1023
            return self.construct_base64(filenames)
        else:
            output_formats = ["xarray", "base64"]
            raise RuntimeError(
1024
                f"The {request['output_format']} output_format is not supported. Allowed formats are {output_formats}"
matthmey's avatar
matthmey committed
1025
1026
1027
1028
1029
1030
            )

    def construct_xarray(self, filenames):
        images = []
        times = []
        for timestamp, element in filenames.iterrows():
matthmey's avatar
matthmey committed
1031
1032
1033
            key = element.filename
            img = Image.open(io.BytesIO(self.config["store"][key]))
            img = np.array(img.convert("RGB"))
matthmey's avatar
matthmey committed
1034
1035
1036
            images.append(np.array(img))
            times.append(timestamp)

matthmey's avatar
matthmey committed
1037
1038
1039
1040
1041
        if images:
            images = np.array(images)
        else:
            images = np.empty((0, 0, 0, 0))

matthmey's avatar
matthmey committed
1042
        data = xr.DataArray(
matthmey's avatar
matthmey committed
1043
1044
1045
1046
            images,
            coords={"time": times},
            dims=["time", "x", "y", "channels"],
            name="Image",
matthmey's avatar
matthmey committed
1047
1048
1049
1050
1051
1052
1053
1054
1055
        )
        data.attrs["format"] = "jpg"

        return data

    def construct_base64(self, filenames):
        images = []
        times = []
        for timestamp, element in filenames.iterrows():
matthmey's avatar
matthmey committed
1056
1057
1058
            key = element.filename
            img = Image.open(io.BytesIO(self.config["store"][key]))
            img = np.array(img.convert("RGB"))
matthmey's avatar
matthmey committed
1059
1060
1061
1062
1063
            img_base64 = base64.b64encode(img.tobytes())
            images.append(img_base64)
            times.append(timestamp)

        images = np.array(images).reshape((-1, 1))
matthmey's avatar
matthmey committed
1064
1065
1066
        data = xr.DataArray(
            images, coords={"time": times}, dims=["time", "base64"], name="Base64Image"
        )
matthmey's avatar
matthmey committed
1067
1068
1069
1070
1071
1072
1073
1074
        data.attrs["format"] = "jpg"

        return data


class Freezer(StuettNode):
    def __init__(self, store):
        self.store = store
1075

matthmey's avatar
matthmey committed
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
    def configure(self, requests):
        """ 

        Arguments:
            request {list} -- List of requests

        Returns:
            dict -- Original, updated or merged request(s) 
        """
        requests = super().configure(requests)

        # TODO: check if data is available for requested time period

        # TODO: check how we need to update the boundaries such that we get data that fits and that is available
        # TODO: with only one start/end_time it might be inefficient for the case where [unavailable,available,unavailable] since we need to load everything
        #      one option could be to duplicate the graph by returning multiple requests...

        # TODO: make a distinction between requested start_time and freeze_output_start_time

matthmey's avatar
matthmey committed
1095
        # TODO: add node specific hash to freeze_output_start_time (there might be multiple in the graph) <- probably not necessary because we receive a copy of the request which is unique to this node
matthmey's avatar
matthmey committed
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
        # TODO: maybe the configuration method must add (and delete) the node name in the request?

        # we always require a request to crop out the right time period
        requests["requires_request"] = True

        return requests

    def to_zarr(self, x):
        x = x.to_dataset(name="frozen")
        # x = x.chunk({name: x[name].shape for name in list(x.dims)})
        # zarr_dataset = zarr.open(self.store, mode='r')
        x.to_zarr(self.store, append_dim="time")

    def open_zarr(self, requests):
        ds_zarr = xr.open_zarr(self.store)
1111
        print("read", ds_zarr)
matthmey's avatar
matthmey committed
1112

1113
    def forward(self, data=None, request=None):
matthmey's avatar
matthmey committed
1114

1115
1116
        self.to_zarr(data)
        self.open_zarr(request)
matthmey's avatar
matthmey committed
1117

1118
        return data
matthmey's avatar
matthmey committed
1119
1120
1121
1122
1123
        # TODO: check request start_time and load the data which is available, store the data which is not available
        # TODO: crop


class CsvSource(DataSource):
matthmey's avatar
black    
matthmey committed
1124
1125
1126
    def __init__(
        self, filename=None, store=None, start_time=None, end_time=None, **kwargs
    ):
matthmey's avatar
matthmey committed
1127
        super().__init__(
matthmey's avatar
black    
matthmey committed
1128
1129
1130
1131
1132
            filename=filename,
            store=store,
            start_time=start_time,
            end_time=end_time,
            kwargs=kwargs,
matthmey's avatar
matthmey committed
1133
1134
        )

matthmey's avatar
matthmey committed
1135
    def forward(self, data=None, request=None):
1136
        # TODO: Implement properly
matthmey's avatar
black    
matthmey committed
1137
        if request["store"] is not None:
1138
1139
1140
            # get the file relative to the store
            store = request["store"]
            filename = request["filename"]
matthmey's avatar
matthmey committed
1141
            # csv = pd.read_csv(io.StringIO(str(store[filename],'utf-8')))
matthmey's avatar
black    
matthmey committed
1142
            csv = read_csv_with_store(store, filename)
1143
1144
        else:
            csv = pd.read_csv(request["filename"])
matthmey's avatar
matthmey committed
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
        csv.set_index("time", inplace=True)
        csv.index = pd.to_datetime(csv.index, utc=True).tz_localize(
            None
        )  # TODO: change when xarray #3291 is fixed

        x = xr.DataArray(csv, dims=["time", "name"], name="CSV")

        try:
            unit_coords = []
            name_coords = []
            for name in x.coords["name"].values:
                unit = re.findall(r"\[(.*)\]", name)[0]
                name = re.sub(r"\[.*\]", "", name).lstrip().rstrip()

                name_coords.append(name)
                unit_coords.append(unit)

            x.coords["name"] = name_coords
            x = x.assign_coords({"unit": ("name", unit_coords)})
        except:
            # TODO: add a warning or test explicitly if units exist
            pass

1168
1169
1170
1171
        if "start_time" not in request:
            request["start_time"] = x.coords["time"][0]
        if "end_time" not in request:
            request["end_time"] = x.coords["time"][-1]
matthmey's avatar
matthmey committed
1172

1173
        x = x.sel(time=slice(request["start_time"], request["end_time"]))
matthmey's avatar
matthmey committed
1174

1175
1176
1177
        return x


matthmey's avatar
matthmey committed
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
def to_datetime(x):
    return pd.to_datetime(x, utc=True).tz_localize(
        None
    )  # TODO: change when xarray #3291 is fixed


class BoundingBoxAnnotation(DataSource):
    def __init__(
        self,
        filename=None,
matthmey's avatar
matthmey committed
1188
        store=None,
matthmey's avatar
black    
matthmey committed
1189
        converters={"start_time": to_datetime, "end_time": to_datetime},
matthmey's avatar
matthmey committed
1190
1191
1192
        **kwargs,
    ):
        super().__init__(
matthmey's avatar
matthmey committed
1193
            filename=filename, store=store, converters=converters, kwargs=kwargs,
matthmey's avatar
matthmey committed
1194
1195
        )

matthmey's avatar
matthmey committed
1196
    def forward(self, data=None, request=None):
matthmey's avatar
matthmey committed
1197
1198
        if request["store"] is not None:
            csv = read_csv_with_store(request["store"], request["filename"])
matthmey's avatar
matthmey committed
1199
1200
        else:
            csv = pd.read_csv(request["filename"])
matthmey's avatar
matthmey committed
1201
1202
1203
1204
1205
1206
1207
1208

        targets = xr.DataArray(csv["__target"], dims=["index"], name="Annotation")

        for key in csv:
            if key == "__target":
                continue
            targets = targets.assign_coords({key: ("index", csv[key])})

1209
        for key in request["converters"]:
1210
            if key in csv:
1211
                converter = request["converters"][key]
1212
1213
                if not callable(converter):
                    raise RuntimeError("Please provide a callable as column converter")
matthmey's avatar
black    
matthmey committed
1214
1215
1216
                targets = targets.assign_coords(
                    {key: ("index", converter(targets[key]))}
                )
matthmey's avatar
matthmey committed
1217
1218
1219

        return targets

1220

1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
def check_overlap(data0, data1, index_dim, dims=[]):
    """ checks the overlap of two lists of xarray indexers (dicts with slices per dimension).
        Both list must be sorted by the index_dim dimension!
    
    Arguments:
        data0 {list} -- First list of xarray indexers 
        data1 {list} -- Second list of xarray indexers. Must be sorted by primarily by slice.start and secondary by slice.stop
        index_dim {string} -- Default dimension to iterate over. Indexers lists must be sorted by this dimension
    
    Keyword Arguments:
        dims {list} -- List of additional dimensions which should be checked for overlap (default: {[]})
    
    Returns:
        list -- List of indices of which items overlap. [...,[i,j],...] where item i of indexers0 overlaps with item j of indexers1
    """
1236

matthmey's avatar
matthmey committed
1237
1238
1239
1240
1241
1242
    overlap_indices = []
    # print(data0.head())
    num_overlaps = 0
    start_idx = 0
    for i in range(len(data0)):
        # data0_df = data0.iloc[i]
1243
1244
        data0_start = data0[i][index_dim].start
        data0_end = data0[i][index_dim].stop
matthmey's avatar
matthmey committed
1245
1246
        # print(data0_df['start_time'])
        ext = []
1247
        for j in range(start_idx, len(data1)):
matthmey's avatar
matthmey committed
1248
            # data1_df = data1.iloc[j]
1249
1250
            data1_start = data1[j][index_dim].start
            data1_end = data1[j][index_dim].stop
1251
            cond0 = data0_end < data1_start
matthmey's avatar
matthmey committed
1252
1253
1254
            if cond0 == True:
                break

1255
            # second condition: data0 is after data1, all items before data1 can be ignored (sorted list data0)
1256
            cond1 = data0_start > data1_end
matthmey's avatar
matthmey committed
1257

1258
            if cond1:
1259
                # This only holds if data1 is sorted by both start and end index
matthmey's avatar
matthmey committed
1260
1261
1262
                start_idx = j

            if not (cond0 or cond1):
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
                # overlap on index dimension
                # check other dimensions
                overlap = True
                for dim in dims:
                    if dim == index_dim:
                        continue
                    d0_start = data0[i][dim].start
                    d0_end = data0[i][dim].stop
                    d1_start = data1[j][dim].start
                    d1_end = data1[j][dim].stop
matthmey's avatar
black    
matthmey committed
1273
                    if (d0_end < d1_start) or (d0_start > d1_end):
1274
                        overlap = False
matthmey's avatar
matthmey committed
1275

1276
1277
                if overlap:
                    overlap_indices.append([int(i), int(j)])
matthmey's avatar
matthmey committed
1278

1279
    return overlap_indices
1280
1281


matthmey's avatar
matthmey committed
1282
def get_dataset_slices(dims, dataset_slice, stride={}):
1283
    # thanks to xbatcher: https://github.com/rabernat/xbatcher/
1284
1285
    dim_slices = []
    for dim in dims:
matthmey's avatar
matthmey committed
1286
1287
1288
1289
1290
1291
        # if dataset_slice is None:
        #     segment_start = 0
        #     segment_end = ds.sizes[dim]
        # else:
        segment_start = dataset_slice[dim].start
        segment_end = dataset_slice[dim].stop
matthmey's avatar
black    
matthmey committed
1292

1293
1294
1295
        size = dims[dim]
        _stride = stride.get(dim, size)