management.py 27.7 KB
Newer Older
matthmey's avatar
matthmey committed
1
from ..global_config import get_setting, setting_exists, set_setting
2
3
from ..core.graph import StuettNode

matthmey's avatar
matthmey committed
4
import os
5
6
7
8
9
10
11
12

import dask
import logging
import obspy
from obspy.clients.fdsn import Client
from obspy import UTCDateTime
from obsplus import obspy_to_array

matthmey's avatar
matthmey committed
13
14
15
16
import zarr
import xarray as xr
from PIL import Image
import base64
17

matthmey's avatar
matthmey committed
18
19
20
21
from pathlib import Path
import warnings

# TODO: revisit the following packages
22
23
import numpy as np
import pandas as pd
matthmey's avatar
matthmey committed
24
25
import datetime as dt

26
27
28
29
30

class DataSource(StuettNode):
    def __init__(self):
        pass

matthmey's avatar
matthmey committed
31
    def configure(self, requests=None):
32
33
34
35
36
37
38
39
40
41
        """ Default configure for DataSource nodes
            Same as configure from StuettNode but adds is_source flag

        Arguments:
            request {list} -- List of requests

        Returns:
            dict -- Original request or merged requests 
        """
        requests = super().configure(requests)
matthmey's avatar
matthmey committed
42
        requests["requires_request"] = True
43
44
45
46
47

        return requests


class SeismicSource(DataSource):
matthmey's avatar
matthmey committed
48
    def __init__(self, config={}, use_arclink=False, return_obspy=False):
49
50
51
52
        """ Seismic data source to get data from permasense
            The user can predefine the source's settings or provide them in a request
            Predefined setting should never be updated (to stay thread safe), but will be ignored
            if request contains settings
matthmey's avatar
matthmey committed
53
54
55
56
57

        Keyword Arguments:
            config {dict}       -- Configuration for the seismic source (default: {{}})
            use_arclink {bool}  -- If true, downloads the data from the arclink service (authentication required) (default: {False})
            return_obspy {bool} -- By default an xarray is returned. If true, an obspy stream will be returned (default: {False})
58
59
60
        """
        self.config = config
        self.use_arclink = use_arclink
matthmey's avatar
matthmey committed
61
        self.return_obspy = return_obspy
62

matthmey's avatar
matthmey committed
63
64
        if "source" not in self.config:
            self.config["source"] = None
65
66

        if use_arclink:
matthmey's avatar
matthmey committed
67
68
69
70
71
72
73
74
            arclink = get_setting("arclink")
            arclink_user = arclink["user"]
            arclink_password = arclink["password"]
            self.fdsn_client = Client(
                base_url="http://arclink.ethz.ch",
                user=arclink_user,
                password=arclink_password,
            )
75
76

    @dask.delayed
matthmey's avatar
matthmey committed
77
78
    def __call__(self, request=None):

79
80
81
82
83
        config = self.config.copy()
        if request is not None:
            config.update(request)

        if self.use_arclink:
84
            # logging.info('Loading seismic with fdsn')
matthmey's avatar
matthmey committed
85
86
87
88
89
90
91
92
93
94
95
96
97
            x = self.fdsn_client.get_waveforms(
                network="4D",
                station=config["station"],
                location="A",
                channel=config["channel"],
                starttime=UTCDateTime(config["start_time"]),
                endtime=UTCDateTime(config["end_time"]),
                attach_response=True,
            )
            # TODO: remove response x.remove_response(output=vel)
            # TODO: slice start_time / end_time
            # TODO: potentially resample
        else:  # 20180914 is last full day available in permasense_vault
98
            # logging.info('Loading seismic with fdsn')
matthmey's avatar
matthmey committed
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
            x = self.get_obspy_stream(
                config["start_time"],
                config["end_time"],
                config["station"],
                config["channel"],
            )

        if not self.return_obspy:
            x = obspy_to_array(x)

            # change time coords from relative to absolute time
            starttime = obspy.UTCDateTime(x.starttime.values).datetime
            starttime = pd.to_datetime(starttime, utc=True)
            timedeltas = pd.to_timedelta(x["time"].values, unit="seconds")
            xt = starttime + timedeltas
            x["time"] = pd.to_datetime(xt, utc=True)
            del x.attrs["stats"]
116
117
118

        return x

matthmey's avatar
matthmey committed
119
120
121
122
123
124
125
126
127
128
129
    def get_obspy_stream(
        self,
        start_time,
        end_time,
        station,
        channels,
        pad=False,
        verbose=False,
        fill=0,
        fill_sampling_rate=1000,
    ):
130
131
132
133
        """    
        Loads the microseismic data for the given timeframe into a miniseed file.

        Arguments:
matthmey's avatar
matthmey committed
134
135
            start_time {datetime} -- start timestamp of the desired obspy stream
            end_time {datetime} -- end timestamp of the desired obspy stream
136
137
138
139
140
141
142
143
144
145
146
        
        Keyword Arguments:
            pad {bool} -- If padding is true, the data will be zero padded if the data is not consistent
            fill {} -- If numpy.nan or fill value: error in the seismic stream will be filled with the value. If None no fill will be used
            verbose {bool} -- If info should be printed

        Returns:
            obspy stream -- obspy stream with up to three channels
                            the stream's channels will be sorted alphabetically
        """

matthmey's avatar
matthmey committed
147
        if not isinstance(channels, list):
148
            channels = [channels]
matthmey's avatar
matthmey committed
149
150
151
152
        datadir = (
            os.path.join(get_setting("permasense_vault_dir"), "")
            + "geophones/binaries/PS/%s/" % station
        )
153
154
155

        if not os.path.isdir(datadir):
            # TODO: should this be an error or only a warning. In a period execution this could stop the whole script
matthmey's avatar
matthmey committed
156
157
158
159
160
            raise IOError(
                "Cannot find the path {}. Please provide a correct path to the permasense_vault directory".format(
                    datadir
                )
            )
161
162

        # We will get the full hours seismic data and trim it to the desired length afterwards
matthmey's avatar
matthmey committed
163
164
165
166
        tbeg_hours = pd.to_datetime(start_time).replace(
            minute=0, second=0, microsecond=0
        )
        timerange = pd.date_range(start=tbeg_hours, end=end_time, freq="H")
167

matthmey's avatar
matthmey committed
168
        non_existing_files_ts = []  # keep track of nonexisting files
169
170
171
172
173
174
175
176
177
178

        # drawback of memmap files is that we need to calculate the size beforehand
        stream = obspy.Stream()

        idx = 0
        ## loop through all hours
        for i in range(len(timerange)):
            # start = time.time()
            h = timerange[i]

matthmey's avatar
matthmey committed
179
            st_list = obspy.Stream()
180

matthmey's avatar
matthmey committed
181
182
            datayear = timerange[i].strftime("%Y/")
            sn = "MHDL" if station == "MH36" else "MHDT"  # TODO: do not hardcode it
183
184
            filenames = {}
            for channel in channels:
matthmey's avatar
matthmey committed
185
186
187
188
189
190
191
192
193
                filenames[channel] = (
                    datadir
                    + datayear
                    + "%s.D/PS.%s.A.%s.D." % (channel, sn, channel)
                    + timerange[i].strftime("%Y%m%d_%H%M%S")
                    + ".miniseed"
                )
                # print(filenames[channel])
                if not os.path.isfile(filenames[channel]):
194
195
                    non_existing_files_ts.append(timerange[i])

matthmey's avatar
matthmey committed
196
197
198
199
200
201
202
                    warnings.warn(
                        RuntimeWarning(
                            "Warning file not found for {} {}".format(
                                timerange[i], filenames[channel]
                            )
                        )
                    )
203
204
205

                    if fill is not None:
                        # create empty stream of fill value
matthmey's avatar
matthmey committed
206
207
208
209
210
211
212
213
214
215
216
217
218
                        arr = (
                            np.ones(
                                (
                                    int(
                                        np.ceil(
                                            fill_sampling_rate
                                            * dt.timedelta(hours=1).total_seconds()
                                        )
                                    ),
                                )
                            )
                            * fill
                        )
219
220
221
222
223
224
225
226
                        st = obspy.Stream([obspy.Trace(arr)])
                    else:
                        continue
                else:
                    st = obspy.read(filenames[channel])

                st_list += st

matthmey's avatar
matthmey committed
227
228
229
230
231
232
233
234
235
236
237
            stream_h = st_list.merge(method=0, fill_value=fill)
            start_time_0 = start_time if i == 0 else h
            end_time_0 = (
                end_time if i == len(timerange) - 1 else h + dt.timedelta(hours=1)
            )
            segment_h = stream_h.trim(
                obspy.UTCDateTime(start_time_0),
                obspy.UTCDateTime(end_time_0),
                pad=pad,
                fill_value=fill,
            )
238
239
240
241
            stream += segment_h

        stream = stream.merge(method=0, fill_value=fill)

matthmey's avatar
matthmey committed
242
243
244
        stream.sort(
            keys=["channel"]
        )  # TODO: change this so that the order of the input channels list is maintained
245
246
247
248

        return stream


matthmey's avatar
matthmey committed
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
class MHDSLRFilenames(StuettNode):
    def __init__(
        self, base_directory=None, method="directory", start_time=None, end_time=None
    ):
        """ Fetches the DSLR images from the Matterhorn deployment, returns the image
            filename(s) corresponding to the end and start time provided in either the
            config dict or as a request to the __call__() function.          
        
        Arguments:
            StuettNode {[type]} -- [description]

        Keyword Arguments:
            base_directory {[type]} -- [description]
            method {str}     -- [description] (default: {'directory'})
        """
        self.config = locals().copy()  # map the arguments to the config file
        del self.config["self"]
266

matthmey's avatar
matthmey committed
267
268
269
270
271
    def __call__(self, request=None):
        """Retrieves the images for the selected time period from the server. If only a start_time timestamp is provided, 
          the file with the corresponding date will be loaded if available. For periods (when start and end time are given) 
          all available images are indexed first to provide an efficient retrieval.
        
272
        Arguments:
matthmey's avatar
matthmey committed
273
274
275
276
277
278
279
            start_time {datetime} -- If only start_time is given the neareast available image is return. If also end_time is provided the a dataframe is returned containing image filenames from the first image after start_time until the last image before end_time.
        
        Keyword Arguments:
            end_time {datetime} -- end time of the selected period. see start_time for a description. (default: {None})
        
        Returns:
            dataframe -- Returns containing the image filenames of the selected period.
280
        """
matthmey's avatar
matthmey committed
281
282
283
        config = self.config.copy()  # TODO: do we need a deep copy?
        if request is not None:
            config.update(request)
284

matthmey's avatar
matthmey committed
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
        methods = ["directory", "web"]
        if config["method"].lower() not in methods:
            raise RuntimeError(
                f"The {config['method']} output_format is not supported. Allowed formats are {methods}"
            )

        if config["base_directory"] is None and output_format.lower() != "web":
            raise RuntimeError("Please provide a base_directory containing the images")

        if config["method"].lower() == "web":  # TODO: implement
            raise NotImplementedError("web fetch has not been implemented yet")

        start_time = config["start_time"]
        end_time = config["end_time"]

        if start_time is None:
            raise RuntimeError("Please provide a least start time")

        # if it is not timezone aware make it
        # if start_time.tzinfo is None:
        #     start_time = start_time.replace(tzinfo=timezone.utc)

        # If there is no tmp_dir we can try to load the file directly, otherwise
        # there will be a warning later in this function and the user should
        # set a tmp_dir
        if end_time is None and (
            not setting_exists("user_dir") or not os.path.isdir(get_setting("user_dir"))
        ):
            image_filename = self.get_image_filename(start_time)
            if image_filename is not None:
                return image_filename

        # If we have already loaded the dataframe in the current session we can use it
        if setting_exists("image_list_df"):
            imglist_df = get_setting("image_list_df")
        else:
            if setting_exists("user_dir") and os.path.isdir(get_setting("user_dir")):
                # Otherwise we need to load the filename dataframe from disk
                imglist_filename = (
                    os.path.join(get_setting("user_dir"), "")
                    + "full_image_integrity.parquet"
                )

                # If it does not exist in the temporary folder of our application
                # We are going to create it
                if os.path.isfile(imglist_filename):
                    imglist_df = pd.read_parquet(
                        imglist_filename
                    )  # TODO: avoid too many different formats
                else:
                    # we are going to load the full list => no arguments
                    imglist_df = self.image_integrity(config["base_directory"])
                    imglist_df.to_parquet(imglist_filename)
            else:
                # if there is no tmp_dir we can load the image list but
                # we should warn the user that this is inefficient
                imglist_df = self.image_integrity(config["base_directory"])
                warnings.warn(
                    "No temporary directory was set. You can speed up multiple runs of your application by setting a temporary directory"
                )

            # TODO: make the index timezone aware
            imglist_df.set_index("start_time", inplace=True)
            imglist_df.index = pd.to_datetime(imglist_df.index, utc=True)
            imglist_df.sort_index(inplace=True)

            set_setting("image_list_df", imglist_df)

        if end_time is None:
            if start_time < imglist_df.index[0]:
                start_time = imglist_df.index[0]
            return imglist_df.iloc[
                imglist_df.index.get_loc(start_time, method="nearest")
            ]
        else:
            # if end_time.tzinfo is None:
            #     end_time = end_time.replace(tzinfo=timezone.utc)

            if start_time < imglist_df.index[0]:
                start_time = imglist_df.index[0]
            if end_time > imglist_df.index[-1]:
                end_time = imglist_df.index[-1]

            return imglist_df.iloc[
                imglist_df.index.get_loc(
                    start_time, method="bfill"
                ) : imglist_df.index.get_loc(end_time, method="ffill")
                + 1
            ]

    def image_integrity(
        self, base_directory, start_time=None, end_time=None, delta_seconds=0
    ):
        """ Checks which images are available on the permasense server
        
        Keyword Arguments:
            start_time {[type]} -- datetime object giving the lower bound of the time range which should be checked. 
                                   If None there is no lower bound. (default: {None})
            end_time {[type]} --   datetime object giving the upper bound of the time range which should be checked.
                                   If None there is no upper bound (default: {None})
            delta_seconds {int} -- Determines the 'duration' of an image in the output dataframe.
                                   start_time  = image_time+delta_seconds
                                   end_time    = image_time-delta_seconds (default: {0})
        
        Returns:
            DataFrame -- Returns a pandas dataframe with a list containing filename relative to self.base_directory, 
                         start_time and end_time start_time and end_time can vary depending on the delta_seconds parameter
        """
        """ Checks which images are available on the permasense server
394

matthmey's avatar
matthmey committed
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
        Arguments:
            start_time:   
            end_time:   
            delta_seconds:  Determines the 'duration' of an image in the output dataframe.
                            start_time  = image_time+delta_seconds
                            end_time    = image_time-delta_seconds
        Returns:
            DataFrame -- 
        """
        if start_time is None:
            start_time = dt.datetime(
                1900, 1, 1
            )  # add random year which is before permasense installation started
        if end_time is None:
            end_time = dt.datetime.utcnow()

        tbeg_days = start_time.replace(hour=0, minute=0, second=0)
        tend_days = end_time.replace(hour=23, minute=59, second=59)

        delta_t = dt.timedelta(seconds=delta_seconds)
        num_filename_errors = 0
        images_list = []
        # p = Path(permasense_vault_dir + 'gsn-binaries/matterhorn/5015/camera1/')
        p = Path(base_directory)

        if not p.is_dir():
            warnings.warn(
                "Could not find the permasense image dataset. Please make sure it is available at {}".format(
                    str(p)
                )
            )

        for dir in p.glob("*/"):
            dir_date = dt.datetime.strptime(str(dir.name), "%Y-%m-%d")

            # limit the search to the explicit time range
            if dir_date < tbeg_days or dir_date > tend_days:
                continue

            for img_file in dir.glob("*"):
                # print(file.stem)
                start_time_str = img_file.stem

                try:
                    # start_time = datetime.strptime(start_time_str[:15], '%Y%m%d_%H%M%S')
                    start_time = dt.datetime.strptime(start_time_str, "%Y%m%d_%H%M%S")
                    if start_time <= start_time and start_time <= end_time:
                        images_list.append(
                            {
                                "filename": str(img_file.relative_to(base_directory)),
                                "start_time": start_time - delta_t,
                                "end_time": start_time + delta_t,
                            }
                        )
                except ValueError:
                    # try old naming convention
                    try:
                        # start_time = datetime.strptime(start_time_str[:17], '%Y-%m-%d_%H%M%S')
                        start_time = dt.datetime.strptime(
                            start_time_str, "%Y-%m-%d_%H%M%S"
                        )
                        if start_time <= start_time and start_time <= end_time:
                            images_list.append(
                                {
                                    "filename": str(
                                        img_file.relative_to(base_directory)
                                    ),
                                    "start_time": start_time - delta_t,
                                    "end_time": start_time + delta_t,
                                }
                            )
                    except ValueError:
                        num_filename_errors += 1
                        warnings.warn(
                            "Permasense data integrity, the following is not a valid image filename and will be ignored: %s"
                            % img_file
                        )
                        continue
473

matthmey's avatar
matthmey committed
474
475
476
477
478
479
480
481
482
483
484
485
        segments = pd.DataFrame(images_list)
        segments.drop_duplicates(inplace=True, subset="start_time")
        segments.start_time = pd.to_datetime(segments.start_time, utc=True)
        segments.end_time = pd.to_datetime(segments.end_time, utc=True)
        segments.sort_values("start_time")

        return segments

    def get_image_filename(self, timestamp):
        """ Checks wether an image exists for exactly the time of timestamp and returns its filename

            timestamp: datetime object for which the filename should be returned
486

matthmey's avatar
matthmey committed
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
        # Returns
            The filename if the file exists, None if there is no file
        """

        datadir = self.base_directory
        new_filename = (
            datadir
            + timestamp.strftime("%Y-%m-%d")
            + "/"
            + timestamp.strftime("%Y%m%d_%H%M%S")
            + ".JPG"
        )
        old_filename = (
            datadir
            + timestamp.strftime("%Y-%m-%d")
            + "/"
            + timestamp.strftime("%Y-%m-%d_%H%M%S")
            + ".JPG"
        )

        if os.path.isfile(new_filename):
            return new_filename
        elif os.path.isfile(old_filename):
            return old_filename
        else:
            return None

    def get_nearest_image_url(self, IMGparams, imgdate, floor=False):
        if floor:
            date_beg = imgdate - dt.timedelta(hours=4)
            date_end = imgdate
        else:
            date_beg = imgdate
            date_end = imgdate + dt.timedelta(hours=4)

        vs = []
        # predefine vs list
        field = []
        # predefine field list
        c_vs = []
        c_field = []
        c_join = []
        c_min = []
        c_max = []

        vs = vs + ["matterhorn_binary__mapped"]
        field = field + ["ALL"]
        # select only data from one sensor (position 3)
        c_vs = c_vs + ["matterhorn_binary__mapped"]
        c_field = c_field + ["position"]
        c_join = c_join + ["and"]
        c_min = c_min + ["18"]
        c_max = c_max + ["20"]

        c_vs = c_vs + ["matterhorn_binary__mapped"]
        c_field = c_field + ["file_complete"]
        c_join = c_join + ["and"]
        c_min = c_min + ["0"]
        c_max = c_max + ["1"]

        # create url which retrieves the csv data file
        url = "http://data.permasense.ch/multidata?"
        url = url + "time_format=" + "iso"
        url = url + "&from=" + date_beg.strftime("%d/%m/%Y+%H:%M:%S")
        url = url + "&to=" + date_end.strftime("%d/%m/%Y+%H:%M:%S")
        for i in range(0, len(vs), 1):
            url = url + "&vs[%d]=%s" % (i, vs[i])
            url = url + "&field[%d]=%s" % (i, field[i])

        for i in range(0, len(c_vs), 1):
            url = url + "&c_vs[%d]=%s" % (i, c_vs[i])
            url = url + "&c_field[%d]=%s" % (i, c_field[i])
            url = url + "&c_join[%d]=%s" % (i, c_join[i])
            url = url + "&c_min[%d]=%s" % (i, c_min[i])
            url = url + "&c_max[%d]=%s" % (i, c_max[i])

        url = url + "&timeline=%s" % ("generation_time")

        # print(url)
        d = pd.read_csv(url, skiprows=2)

        # print(d)

        # print(type(d['#data'].values))
        d["data"] = [s.replace("&amp;", "&") for s in d["data"].values]

        d.sort_values(by="generation_time")
        d["generation_time"] = pd.to_datetime(d["generation_time"], utc=True)

        if floor:
            data_str = d["data"].iloc[0]
            data_filename = d["relative_file"].iloc[0]
            # print(d['generation_time'].iloc[0])
            img_timestamp = d["generation_time"].iloc[0]
        else:
            data_str = d["data"].iloc[-1]
            data_filename = d["relative_file"].iloc[-1]
            # print(d['generation_time'].iloc[-1])
            img_timestamp = d["generation_time"].iloc[-1]

        file_extension = data_filename[-3:]
        base_url = "http://data.permasense.ch"
        # print(base_url + data_str)

        return base_url + data_str, img_timestamp, file_extension


class MHDSLRImages(MHDSLRFilenames):
    def __init__(
        self,
        base_directory=None,
        method="directory",
        output_format="xarray",
        start_time=None,
        end_time=None,
    ):
        super().__init__(
            base_directory=base_directory,
            method=method,
            start_time=start_time,
            end_time=end_time,
        )
        self.config["output_format"] = output_format

    def __call__(self, request):
        config = self.config.copy()  # TODO: do we need a deep copy?
        if request is not None:
            config.update(request)
615

matthmey's avatar
matthmey committed
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
        filenames = super().__call__(request)

        if config["output_format"] is "xarray":
            return self.construct_xarray(filenames)
        elif config["output_format"] is "base64":
            return self.construct_base64(filenames)
        else:
            output_formats = ["xarray", "base64"]
            raise RuntimeError(
                f"The {config['output_format']} output_format is not supported. Allowed formats are {output_formats}"
            )

    def construct_xarray(self, filenames):
        images = []
        times = []
        for timestamp, element in filenames.iterrows():
            filename = Path(self.config["base_directory"]).joinpath(element.filename)
            img = Image.open(filename)
            images.append(np.array(img))
            times.append(timestamp)

        images = np.array(images)
        data = xr.DataArray(
            images, coords={"time": times}, dims=["time", "x", "y", "c"]
        )
        data.attrs["format"] = "jpg"

        return data

    def construct_base64(self, filenames):
        images = []
        times = []
        for timestamp, element in filenames.iterrows():
            filename = Path(self.config["base_directory"]).joinpath(element.filename)
            img = Image.open(filename)
            img_base64 = base64.b64encode(img.tobytes())
            images.append(img_base64)
            times.append(timestamp)

        images = np.array(images).reshape((-1, 1))
        data = xr.DataArray(images, coords={"time": times}, dims=["time", "base64"])
        data.attrs["format"] = "jpg"

        return data


class Freezer(StuettNode):
    def __init__(self, store):
        self.store = store
665

matthmey's avatar
matthmey committed
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
    def configure(self, requests):
        """ 

        Arguments:
            request {list} -- List of requests

        Returns:
            dict -- Original, updated or merged request(s) 
        """
        requests = super().configure(requests)

        # TODO: check if data is available for requested time period

        # TODO: check how we need to update the boundaries such that we get data that fits and that is available
        # TODO: with only one start/end_time it might be inefficient for the case where [unavailable,available,unavailable] since we need to load everything
        #      one option could be to duplicate the graph by returning multiple requests...

        # TODO: make a distinction between requested start_time and freeze_output_start_time

        # TODO: add node specific hash to freeze_output_start_time (there might be multiple in the graph) <- probably not necessart becaue we receive a copy of the request which is unique to this node
        # TODO: maybe the configuration method must add (and delete) the node name in the request?

        # we always require a request to crop out the right time period
        requests["requires_request"] = True

        return requests

    def to_zarr(self, x):
        x = x.to_dataset(name="frozen")
        # x = x.chunk({name: x[name].shape for name in list(x.dims)})
        # zarr_dataset = zarr.open(self.store, mode='r')
        x.to_zarr(self.store, append_dim="time")

    def open_zarr(self, requests):
        ds_zarr = xr.open_zarr(self.store)
        print(ds_zarr)

    @dask.delayed
    def __call__(self, x=None, requests=None):
        print(x, requests)
        if x is not None:  # TODO: check if this is always good
            if requests is None:
                requests = x
        else:
            raise RuntimeError("No input provided")

        self.to_zarr(x)
        self.open_zarr(requests)

        return x
        # TODO: check request start_time and load the data which is available, store the data which is not available
        # TODO: crop


class CsvSource(DataSource):
    def __init__(self, config={}):
        pass

    def __call__(self, request):
725
726
727
        return x


matthmey's avatar
matthmey committed
728
729
730
731
732
733
734
class LabelSource(DataSource):
    def __init__(self, config={}):
        pass

    def __call__(self, request):
        return x

735
736

class PytorchDataset(DataSource):
matthmey's avatar
matthmey committed
737
738
    def __init__(self, config={}):
        """ Creates a pytorch like dataset from a data source and a label source.
739
740
741
742
743
744
745
            
        Arguments:
            DataSource {[type]} -- [description]
            config {dict} -- configuration for labels
        """
        self.config = config

matthmey's avatar
matthmey committed
746
747
        if "source" not in self.config:
            self.config["source"] = None
748

matthmey's avatar
matthmey committed
749
    def __call__(self, request):
750
        if request is None:
matthmey's avatar
matthmey committed
751
            raise RuntimeError("No request provided, cannot provide data")
752
753

        return x