IOHelper.py 4.68 KB
Newer Older
1
from config import config
2
3
4
import numpy as np
import scipy.io as sio
import os
5
import pickle
Ard Kastrati's avatar
Ard Kastrati committed
6
import h5py
Ard Kastrati's avatar
Ard Kastrati committed
7
import logging
8

Ard Kastrati's avatar
Ard Kastrati committed
9
def get_mat_data(data_dir, verbose=True):
10
11
    with h5py.File(data_dir + config['trainX_file'], 'r') as f:
        X = f[config['trainX_variable']][:]
Ard Kastrati's avatar
Ard Kastrati committed
12
        if verbose:
Ard Kastrati's avatar
Ard Kastrati committed
13
14
            logging.info("X training loaded.")
            logging.info(X.shape)
Ard Kastrati's avatar
Ard Kastrati committed
15

16
17
    with h5py.File(data_dir + config['trainY_file'], 'r') as f:
        y = f[config['trainY_variable']][:]
Ard Kastrati's avatar
Ard Kastrati committed
18
        if verbose:
Ard Kastrati's avatar
Ard Kastrati committed
19
20
            logging.info("y training loaded.")
            logging.info(y.shape)
Ard Kastrati's avatar
Ard Kastrati committed
21

Ard Kastrati's avatar
Ard Kastrati committed
22
    if verbose: logging.info("Setting the shapes")
Ard Kastrati's avatar
Ard Kastrati committed
23
24
    X = np.transpose(X, (2, 1, 0))
    y = np.transpose(y, (1, 0))
25
    if config['downsampled']: X = np.transpose(X, (0, 2, 1))
Ard Kastrati's avatar
Ard Kastrati committed
26
    if verbose:
Ard Kastrati's avatar
Ard Kastrati committed
27
28
        logging.info(X.shape)
        logging.info(y.shape)
Ard Kastrati's avatar
Ard Kastrati committed
29
30
31
    return X, y

def get_pickle_data(data_dir, verbose=True):
32
    pkl_file_x = open(data_dir + 'x.pkl', 'rb')
33
34
35
    x = pickle.load(pkl_file_x)
    pkl_file_x.close()
    if verbose:
Ard Kastrati's avatar
Ard Kastrati committed
36
37
        logging.info("X training loaded.")
        logging.info(x.shape)
38

39
    pkl_file_y = open(data_dir + 'y.pkl', 'rb')
40
41
42
    y = pickle.load(pkl_file_y)
    pkl_file_y.close()
    if verbose:
Ard Kastrati's avatar
Ard Kastrati committed
43
44
        logging.info("Y training loaded.")
        logging.info(y.shape)
45
46
47
48
49
50
51
52
53
54
55
56
57
58
    return x, y

def store(x, y, clip=True):
    if clip:
        x = x[:10000]
        y = y[:10000]
    output_x = open('x_clip.pkl', 'wb')
    pickle.dump(x, output_x)
    output_x.close()

    output_y = open('y_clip.pkl', 'wb')
    pickle.dump(y, output_y)
    output_y.close()

Ard Kastrati's avatar
Ard Kastrati committed
59
def collect_data(verbose=True):
60
61
62
63
    """
        Load the data for training.
        :param datapath: matlab data directory
        :param variable: variable of the matlab file
64
        :return: the data as numpy array
65
    """
Ard Kastrati's avatar
Ard Kastrati committed
66
    train_x = collect_trial_data(data_path=config['data_dir'], filename=config['cnn']['trainX_filename'],
67
                         variable1=config['cnn']['trainX_variable1'], variable2=config['cnn']['trainX_variable2'],
68
                         verbose=verbose, detailed_verbose=True)
Ard Kastrati's avatar
Ard Kastrati committed
69
    train_y = collect_trial_data(data_path=config['data_dir'], filename=config['cnn']['trainY_filename'],
70
                         variable1=config['cnn']['trainY_variable1'], variable2=config['cnn']['trainY_variable2'],
71
                         verbose=verbose, detailed_verbose=True)
Ard Kastrati's avatar
Ard Kastrati committed
72
    return train_x, train_y
73

74

Ard Kastrati's avatar
Ard Kastrati committed
75
def collect_trial_data(data_path, filename, variable1, variable2, verbose=True, detailed_verbose=False):
76
    """
Ard Kastrati's avatar
Ard Kastrati committed
77
78
    Extract data from the file.
    :param data_path: name of the file to open.
Ard Kastrati's avatar
Ard Kastrati committed
79
    :param verbose: boolean; if true, it logging.infos information about
80
    the status of the program.
81
    :return: a numpy array of shape ...?0
82
    """
Ard Kastrati's avatar
Ard Kastrati committed
83

Ard Kastrati's avatar
Ard Kastrati committed
84
85
    if verbose: logging.info("Loading data... ")
    if verbose: logging.info("Extracting trials...")
86
    trials = extract_trials()
Ard Kastrati's avatar
Ard Kastrati committed
87
    if verbose: logging.info(len(trials), " trials found.")
88

Ard Kastrati's avatar
Ard Kastrati committed
89
    full_data = np.array([])
90
    for i in range(20):
Ard Kastrati's avatar
Ard Kastrati committed
91
        if detailed_verbose: logging.info("Trying trial", trials[i])
92
93
94
95
96
97
98
        try:
            next_trial = load_matlab_trial(datapath=data_path, trial=trials[i], filename=filename, variable1=variable1,
                                       variable2=variable2)
            if full_data.size == 0:
                full_data = next_trial
            else:
                full_data = np.concatenate((full_data, next_trial))
Ard Kastrati's avatar
Ard Kastrati committed
99
            if detailed_verbose: logging.info(np.shape(full_data))
100
        except:
Ard Kastrati's avatar
Ard Kastrati committed
101
            logging.info("Trying other trials...")
102

Ard Kastrati's avatar
Ard Kastrati committed
103
    if verbose: logging.info("Data loaded.")
104
    return full_data
105

106

Ard Kastrati's avatar
Ard Kastrati committed
107
def load_matlab_trial(datapath, trial, filename, variable1, variable2):
108
109
110
111
112
113
114
    """
    Load the data from Matlab
    :param datapath: matlab data directory
    :param filename: name of the file from which the data must be loaded
    :param variable: variable of the matlab file
    :return: the data as numpy array
    """
115
116
117
    try:
        data = sio.loadmat(datapath + trial + "/" + filename)[variable1][variable2][0][0]
    except:
Ard Kastrati's avatar
Ard Kastrati committed
118
        logging.info("Trial " + trial + " could not be opened. ")
119
120
        raise Exception

Ard Kastrati's avatar
Ard Kastrati committed
121
122
123
124
    if len(np.shape(data)) == 3:
        data = np.swapaxes(data, 0, 2)
        data = np.swapaxes(data, 1, 2)
    else:
125
        data = data - 1  # y needs to be 0 or 1 (instead of 1 and 2)
126
127
    return data

128

129
130
131
132
def extract_trials():
    """
        Extracts the trials from the root directory
    """
133
    try:
Ard Kastrati's avatar
Ard Kastrati committed
134
        my_list = os.listdir(config['data_dir'])
135
    except:
Ard Kastrati's avatar
Ard Kastrati committed
136
        logging.info("Server unreachable. Cannot list the directories. Did you (maybe) forget to connect to the server by VPN? Is your root directory set correctly in config.py file? :)")
137
        raise Exception
138
    trials = [name for name in my_list if len(name) == 3]
139
    return trials