To receive notifications about scheduled maintenance, please subscribe to the mailing-list gitlab-operations@sympa.ethz.ch. You can subscribe to the mailing-list at https://sympa.ethz.ch

Commit 82aeec18 authored by Ard Kastrati's avatar Ard Kastrati
Browse files

Restructured the code and added the code to load the data.

parent a987f147
from config import general_params as params
from config import CNN_params as CNN_params
import numpy as np
import scipy.io as sio
import os
# import torch as torch
train_data_path = params['train_data_path']
def get_data(verbose=True):
"""
Load the data for training.
:param datapath: matlab data directory
:param variable: variable of the matlab file
:return: the data as numpy array / Tensor?
"""
train_data = torch_training_data(train_data_path, verbose)
return train_data
def torch_training_data(train_data_path, verbose=True):
"""
Extract training data from the file.
:param train_data_path: name of the file to open.
:param verbose: boolean; if true, it prints information about
the status of the program.
:return: a numpy array of shape ...?
"""
if verbose: print("Loading data... ")
if verbose: print("Extracting trials...")
trials = extract_trials()
if verbose: print(len(trials), " trials found.")
full_data = []
for i in range(5):
next_trial = load_matlab_trial(datapath=train_data_path, trial=trials[i], filename=CNN_params['filename'], variable=CNN_params['data_variable'])
full_data.append(next_trial.tolist()[0][0])
if verbose: print("data loaded.")
print(np.shape(full_data))
if verbose: print("Tensoring data..")
if verbose: print("not implemented yet")
return full_data
def load_matlab_trial(datapath, trial, filename, variable):
"""
Load the data from Matlab
:param datapath: matlab data directory
:param filename: name of the file from which the data must be loaded
:param variable: variable of the matlab file
:return: the data as numpy array
"""
data = sio.loadmat(datapath + trial + "/" + filename)[filename][variable]
return data
def extract_trials():
"""
Extracts the trials from the root directory
"""
my_list = os.listdir(train_data_path)
trials = [name for name in my_list if len(name) == 3]
return trials
\ No newline at end of file
# configuration used by the training and evaluation methods
# let's keep it here to have a clean code on other methods that we try
general_params = {}
CNN_params = {}
"""
Models:
CNN-1: First try: CNN to predict movement towards left or right (prosaccade) with 1 second data.
RNN? Transformers, Attention, etc etc
"""
# general parameters
general_params['model'] = "CNN-1"
general_params['train_data_path'] = "/Volumes/methlab/ETH_AS/preprocessed2/"
# CNN - 1
CNN_params['data_variable'] = "data"
CNN_params['filename'] = "EEGprocue"
from config import general_params as params
from IOHelper import get_data
import time
def main():
start_time = time.time()
if params['model'] == 'CNN-1':
print("Started running CNN-1. If you want to run other methods please choose another model in the config.py file.")
get_data()
else:
raise Exception('Please choose one of the following models in the config.py file')
print("--- Runtime: %s seconds ---" % (time.time() - start_time))
if __name__=='__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment