Commit 4a241f1f authored by Ard Kastrati's avatar Ard Kastrati
Browse files

Made some changes in the code to handle exceptions.

parent 66c2265c
# -*- coding: utf-8 -*-
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from IOHelper import get_data
import torch.utils.data
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
......@@ -26,10 +23,8 @@ class Net(nn.Module):
x = self.fc3(x)
return x
def run(verbose=True):
def run(trainX, trainY):
#load the data
trainX, trainY = get_data()
print(trainX.shape, trainY.shape)
dataset = torch.utils.data.TensorDataset(trainX, trainY)
trainloader = torch.utils.data.DataLoader(dataset, batch_size=2)
......@@ -67,7 +62,7 @@ def train(trainloader, net, optimizer, criterion, epoch=50):
running_loss += loss.item()
if i % 200 == 0: # print every 2000 mini-batches
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 2000))
(epoch + 1, i + 1, running_loss / 200))
running_loss = 0.0
print('Finished Training')
......@@ -12,11 +12,16 @@ def get_data(verbose=True):
:param variable: variable of the matlab file
:return: the data as numpy array / Tensor?
"""
train_x = torch_data(data_path=params['data_path'], filename=CNN_params['trainX_filename'], variable1=CNN_params['trainX_variable1'], variable2=CNN_params['trainX_variable2'], verbose=verbose, detailed_verbose=True)
train_y = torch_data(data_path=params['data_path'], filename=CNN_params['trainY_filename'], variable1=CNN_params['trainY_variable1'], variable2=CNN_params['trainY_variable2'], verbose=verbose, detailed_verbose=True)
train_x = torch_data(data_path=params['data_path'], filename=CNN_params['trainX_filename'],
variable1=CNN_params['trainX_variable1'], variable2=CNN_params['trainX_variable2'],
verbose=verbose, detailed_verbose=True)
train_y = torch_data(data_path=params['data_path'], filename=CNN_params['trainY_filename'],
variable1=CNN_params['trainY_variable1'], variable2=CNN_params['trainY_variable2'],
verbose=verbose, detailed_verbose=True)
train_y = train_y.type(torch.LongTensor)
return train_x, train_y
def torch_data(data_path, filename, variable1, variable2, verbose=True, detailed_verbose=False):
"""
Extract data from the file.
......@@ -33,19 +38,25 @@ def torch_data(data_path, filename, variable1, variable2, verbose=True, detailed
full_data = np.array([])
for i in range(len(trials)):
if detailed_verbose: print("Trying trial", trials[i])
next_trial = load_matlab_trial(datapath=data_path, trial=trials[i], filename=filename, variable1=variable1, variable2=variable2)
if full_data.size == 0:
full_data = next_trial
else:
full_data = np.concatenate((full_data, next_trial))
if verbose: print(np.shape(full_data))
try:
next_trial = load_matlab_trial(datapath=data_path, trial=trials[i], filename=filename, variable1=variable1,
variable2=variable2)
if full_data.size == 0:
full_data = next_trial
else:
full_data = np.concatenate((full_data, next_trial))
if detailed_verbose: print(np.shape(full_data))
except:
print "Trying other trials..."
if verbose: print("data loaded.")
if verbose: print("Data loaded.")
if verbose: print("Tensoring data...")
full_data_tensor = torch.from_numpy(full_data)
if verbose: print("Data in tensor form.")
return full_data_tensor
def load_matlab_trial(datapath, trial, filename, variable1, variable2):
"""
Load the data from Matlab
......@@ -54,18 +65,28 @@ def load_matlab_trial(datapath, trial, filename, variable1, variable2):
:param variable: variable of the matlab file
:return: the data as numpy array
"""
data = sio.loadmat(datapath + trial + "/" + filename)[variable1][variable2][0][0]
try:
data = sio.loadmat(datapath + trial + "/" + filename)[variable1][variable2][0][0]
except:
print "Trial " + trial + " could not be opened. "
raise Exception
if len(np.shape(data)) == 3:
data = np.swapaxes(data, 0, 2)
data = np.swapaxes(data, 1, 2)
else:
data = data - 1 # data needs to be between 0 and 1
data = data - 1 # data needs to be between 0 and 1
return data
def extract_trials():
"""
Extracts the trials from the root directory
"""
my_list = os.listdir(params['data_path'])
try:
my_list = os.listdir(params['data_path'])
except:
print "Server unreachable. Cannot list the directories. Did you (maybe) forget to connect to the server by VPN? Is your root directory set correctly in config.py file? :)"
raise Exception
trials = [name for name in my_list if len(name) == 3]
return trials
return trials
\ No newline at end of file
from config import general_params as params
from IOHelper import get_data
import time
from run import run
import CNN
import IOHelper as IO
def main():
start_time = time.time()
try:
trainX, trainY = IO.get_data(verbose=True)
except:
return
if params['model'] == 'CNN-1':
print("Started running CNN-1. If you want to run other methods please choose another model in the config.py file.")
run()
CNN.run(trainX, trainY)
else:
raise Exception('Please choose one of the following models in the config.py file')
print 'Cannot start the program. Please choose one model in the config.py file'
print("--- Runtime: %s seconds ---" % (time.time() - start_time))
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment