preparator.py 12 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import os
import numpy as np
import h5py
import scipy.io
import pandas as pd
import re
from tqdm import tqdm


class Preparator:
    def __init__(self, load_directory='./', save_directory='./', load_file_pattern='*', save_file_name='all.npz',verbose=False):

        self.load_directory = load_directory
        self.save_directory = save_directory
        self.load_file_pattern = re.compile(load_file_pattern)
        self.save_file_name = save_file_name
        self.extract_pattern = None
        self.extract_pattern = None
        self.start_time = None
        self.length_time = None
        self.start_channel = None
        self.end_channel = None
        self.on_blocks = None
        self.off_blocks = None
        self.filters = []
        self.ignore_events = []
        self.labels = []
        self.verbose = verbose
        self.padding = True
Ard Kastrati's avatar
Ard Kastrati committed
30
31
32
33
34
        print("Preparator is initialized with: ")
        print("Directory to load data: " + self.load_directory)
        print("Directory to save data: " + self.save_directory)
        print("Looking for file that match: " + load_file_pattern)
        print("Will store the merged file with name: " + self.save_file_name)
35
36
37
38
39
40
41
42
43
44


    def extract_data_at_events(self, extract_pattern, name_start_time, start_time, name_length_time, length_time, start_channel, end_channel, padding=True):
        self.extract_pattern = extract_pattern
        self.start_time = start_time
        self.length_time = length_time
        self.start_channel = start_channel
        self.end_channel = end_channel
        self.padding = padding

Ard Kastrati's avatar
Ard Kastrati committed
45
46
47
48
49
        print("Preparator is instructed to look for events that match structure: " + str(self.extract_pattern))
        print("Time dimension -- Cut start info: " + name_start_time)
        print("Time dimension -- Cut length info: " + name_length_time)
        print("Channel dimension -- Cut start info: " + str(start_channel))
        print("Channel dimension -- Cut end info: " + str(end_channel))
50
51
52
53

    def blocks(self, on_blocks, off_blocks):
        self.on_blocks = on_blocks
        self.off_blocks = off_blocks
Ard Kastrati's avatar
Ard Kastrati committed
54
55
        print("Blocks to be used are: " + str(on_blocks))
        print("Blocks to be ignored are: " + str(off_blocks))
56
57
58

    def addFilter(self, name, f):
        self.filters.append((name, f))
Ard Kastrati's avatar
Ard Kastrati committed
59
        print('Preparator is instructed to use filter: ' + name)
60
61
62

    def addLabel(self, name, f):
        self.labels.append((name, f))
Ard Kastrati's avatar
Ard Kastrati committed
63
        print('Preparator is instructed to use label: ' + name)
64
65
66

    def ignoreEvent(self, name, f):
        self.ignore_events.append((name, f))
Ard Kastrati's avatar
Ard Kastrati committed
67
        print('Preparator is instructed to ignore the event: ' + name)
68
69

    def run(self):
Ard Kastrati's avatar
Ard Kastrati committed
70
        print("Starting collecting data.")
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
        all_EEG = []
        all_labels = []
        subj_counter = 1


        progress = tqdm(sorted(os.listdir(self.load_directory)))
        for subject in progress:

            if os.path.isdir(self.load_directory + subject):
                # if subject == 'BY2':
                #    continue
                # if subject == 'EP18':
                #    break

                cur_dir = self.load_directory + subject + '/'
                for f in sorted(os.listdir(cur_dir)):
                    if not self.load_file_pattern.match(f):
                        continue

                    progress.set_description('Loading ' + f)
                    # load the mat file
                    events = None
                    if h5py.is_hdf5(cur_dir + f):
Ard Kastrati's avatar
Ard Kastrati committed
94
                        if self.verbose: print("It is a HDF5 file. All is fine.")
95
96
97
98
                        hdf5file = h5py.File(cur_dir + f, 'r')
                        EEG = hdf5file[list(h5py.File(cur_dir + f, 'r').keys())[1]]
                        events = self._load_hdf5_events(EEG)
                    else:
Ard Kastrati's avatar
Ard Kastrati committed
99
100
101
                        # EEG = scipy.io.loadmat(cur_dir + f)['sEEG'][0]
                        # events = self._load_v5_events(EEG)
                        raise NotImplementedError("Matlab v5 files cannot be loaded. I still have to implement this.")
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117

                    events = self._ignore_events(events)
                    if self.verbose: print(events)
                    select = self._filter_blocks(events)
                    select &= self._filter_events(events)
                    trials = self._extract_events(EEG, events, select)
                    labels = self._extract_labels(events, select, subj_counter)

                    all_EEG.append(trials)
                    all_labels.append(labels)
                    subj_counter += 1

        # save the concatenated arrays
        print('Saving data...')
        EEG = np.concatenate(all_EEG, axis=0)
        labels = np.concatenate(all_labels, axis=0)
Ard Kastrati's avatar
Ard Kastrati committed
118
        print("Shapes of EEG are: ")
119
        print(EEG.shape)
Ard Kastrati's avatar
Ard Kastrati committed
120
        print("Shapes of labels are: ")
121
122
123
        print(labels.shape)
        np.savez(self.save_directory + self.save_file_name, EEG=EEG, labels=labels)

Ard Kastrati's avatar
Ard Kastrati committed
124
    # THIS IS NOT FINISHED
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
    def _load_v5_events(self, EEG):
        if self.verbose: print("Loading the events from the subject. ")
        # extract the useful event data
        events = pd.DataFrame()
        events['type'] = [''.join(map(str, el[0][0:])).strip()  for el in EEG[0]['event'][0]['type']]
        # if self.verbose: print(events)
        events['latency'] = [el[0][0] for el in EEG[0]['event'][0]['latency']]
        # if self.verbose: print(events)
        events['amplitude'] = [el[0][0] for el in EEG[0]['event'][0]['sac_amplitude']]
        # if self.verbose: print(events)
        events['start_x'] = [el[0][0] for el in EEG[0]['event'][0]['sac_startpos_x']]
        # if self.verbose: print(events)
        events['end_x'] = [el[0][0] for el in EEG[0]['event'][0]['sac_endpos_x']]
        # if self.verbose: print(events)
        events['start_y'] = [el[0][0] for el in EEG[0]['event'][0]['sac_startpos_y']]
        # if self.verbose: print(events)
        events['end_y'] = [el[0][0] for el in EEG[0]['event'][0]['sac_endpos_y']]
        # if self.verbose: print(events)
        events['duration'] = [el[0][0] for el in EEG[0]['event'][0]['duration']]
        # if self.verbose: print(events)
        events['avgpos_x'] = [el[0][0] for el in EEG[0]['event'][0]['fix_avgpos_x']]
        # if self.verbose: print(events)
        events['avgpos_y'] = [el[0][0] for el in EEG[0]['event'][0]['fix_avgpos_y']]
        # if self.verbose: print(events)
        events['endtime'] = [el[0][0] for el in EEG[0]['event'][0]['endtime']]

        if self.verbose:
            print("Events loaded are: ")
            print(events)
        return events

    def _load_hdf5_events(self, EEG):
        if self.verbose: print("Loading the events from the subject. ")
        # extract the useful event data
        events = pd.DataFrame()
        events['type'] = [''.join(map(chr, EEG[ref][:, 0])).strip() for ref in EEG['event']['type'][:, 0]]
        #if self.verbose: print(events)
        events['latency'] = [EEG[ref][0, 0] for ref in EEG['event']['latency'][:, 0]]
       # if self.verbose: print(events)
        events['amplitude'] = [EEG[ref][0, 0] for ref in EEG['event']['sac_amplitude'][:, 0]]
        #if self.verbose: print(events)
        events['start_x'] = [EEG[ref][0, 0] for ref in EEG['event']['sac_startpos_x'][:, 0]]
        #if self.verbose: print(events)
        events['end_x'] = [EEG[ref][0, 0] for ref in EEG['event']['sac_endpos_x'][:, 0]]
        #if self.verbose: print(events)
        events['start_y'] = [EEG[ref][0, 0] for ref in EEG['event']['sac_startpos_y'][:, 0]]
        #if self.verbose: print(events)
        events['end_y'] = [EEG[ref][0, 0] for ref in EEG['event']['sac_endpos_y'][:, 0]]
        #if self.verbose: print(events)
        events['duration'] = [EEG[ref][0, 0] for ref in EEG['event']['duration'][:, 0]]
        #if self.verbose: print(events)
        events['avgpos_x'] = [EEG[ref][0, 0] for ref in EEG['event']['fix_avgpos_x'][:, 0]]
        #if self.verbose: print(events)
        events['avgpos_y'] = [EEG[ref][0, 0] for ref in EEG['event']['fix_avgpos_y'][:, 0]]
        # if self.verbose: print(events)
        events['endtime'] = [EEG[ref][0, 0] for ref in EEG['event']['endtime'][:, 0]]

        if self.verbose:
            print("Events loaded are: ")
            print(events)

        return events

    def _filter_blocks(self, events):
        if self.verbose: print("Filtering the blocks: ")
        select = events['type'].apply(lambda x: True)

        if self.on_blocks is None or self.off_blocks is None:
            return select

        for i, event in enumerate(events['type']):
            if event in self.on_blocks:
                select.iloc[i] = True
            elif event in self.off_blocks:
                select.iloc[i] = False
            elif i > 0:
                select.iloc[i] = select.iloc[i-1]
        if self.verbose: print(list(zip(range(1, len(select) + 1), select)))
        return select

    def _ignore_events(self, events):
        ignore = events['type'].apply(lambda x: False)
        for name, f in self.ignore_events:
            if self.verbose: print("Applying: " + name)
            ignore |= f(events)
            if self.verbose: print(list(zip(range(1, len(ignore) + 1), ignore)))
        select = ignore.apply(lambda x: not x)
        if self.verbose: print(list(zip(range(1, len(select) + 1), select)))
        return events.loc[select]

    def _filter_events(self, events):
        if self.verbose: print("Filtering the events: ")
        select = events['type'].apply(lambda x: True)
        for i, event in enumerate(self.extract_pattern):
            select &= events['type'].shift(-i).isin(event)
        if self.verbose: print(list(zip(range(1, len(select) + 1), select)))

        for name, f in self.filters:
            if self.verbose: print("Applying filter: " + name)
            select &= f(events)
            if self.verbose: print(list(zip(range(1, len(select) + 1), select)))
        return select

    def _extract_events(self, EEG, events, select): # needs to be able to pad
        if self.verbose: print("Extracting data from the interested events: ")

        all_trials = []
        # extract the useful data
Ard Kastrati's avatar
Ard Kastrati committed
233
        if self.verbose: print(EEG['data'])
234
235
236
237
238
        data = np.array(EEG['data'], dtype='float')

        start = self.start_time(events).loc[select]
        length = events['type'].apply(lambda x: self.length_time).loc[select]
        end_block = events['latency'].shift(-len(self.extract_pattern)).loc[select]
Ard Kastrati's avatar
Ard Kastrati committed
239
240
241
        if self.verbose: print(start)
        if self.verbose: print(length)
        if self.verbose: print(end_block)
242
243
244
245

        for s, l, e in zip(start, length, end_block):
            if s + l > e and self.padding:
                #Need to pad since, the required length is bigger then the last block
Ard Kastrati's avatar
Ard Kastrati committed
246
                if self.verbose: print(str(s) + ", " + str(l) + ", " + str(e) + " that is need to pad")
247
248
249
250
                unpadded_data = data[int(s - 1):int(e - 1), (self.start_channel - 1):self.end_channel]
                x_len, y_len = unpadded_data.shape
                padding_size = int(s + l - e)
                append_data = np.pad(unpadded_data, pad_width=((0, padding_size), (0, 0)), mode='reflect')
Ard Kastrati's avatar
Ard Kastrati committed
251
                if self.verbose: print(append_data)
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
            else:
                append_data = data[int(s - 1):int(s + l - 1), (self.start_channel - 1):self.end_channel]
            all_trials.append(append_data)

        all_trials = np.array(all_trials)
        # if self.verbose: print("Extracted all this data from this participant.")
        # if self.verbose: print(all_trials)
        return all_trials

    def _extract_labels(self, events, select, subj_counter):
        if self.verbose: print("Extracting the labels for each trial.")
        if self.verbose: print("Appending the subject counter. ")
        # append subject IDs to the full list and then the labels
        nr_trials = events.loc[select].shape[0]
        labels = np.full((nr_trials, 1), subj_counter)
        if self.verbose: print(labels)

        for name, f in self.labels:
            if self.verbose: print("Appending the next label: " + name)
            labels = np.concatenate((labels, np.asarray(f(events).loc[select]).reshape(-1,1)), axis=1)
            if self.verbose: print(labels)
        return labels