To receive notifications about scheduled maintenance, please subscribe to the mailing-list gitlab-operations@sympa.ethz.ch. You can subscribe to the mailing-list at https://sympa.ethz.ch

Commit 0d13cb6c authored by Lukas Wolf's avatar Lukas Wolf
Browse files

eegnet fix

parent a56c05f4
......@@ -24,8 +24,10 @@ class Prediction_history:
self.model = model
#@timing_decorator
@profile
#@profile
def on_epoch_end(self):
print("Enter test on epoch end:")
get_gpu_memory()
with torch.no_grad():
y_pred = []
for batch, (x, y) in enumerate(self.dataloader):
......@@ -38,6 +40,7 @@ class Prediction_history:
# Remove batch from GPU
del x
del y
torch.cuda.empty_cache()
self.predhis.append(y_pred)
class BaseNet(nn.Module):
......@@ -147,14 +150,13 @@ class BaseNet(nn.Module):
patience +=1
else:
patience = 0
# Plot metrics and save metrics
# Plot and save metrics
plot_metrics(metrics['train_loss'], metrics['val_loss'], output_dir=config['model_dir'], metric='loss', model_number=self.model_number)
if config['task'] == 'prosaccade-clf':
plot_metrics(metrics['train_acc'], metrics['val_acc'], output_dir=config['model_dir'], metric='accuracy', model_number=self.model_number)
# Save metrics
df = pd.DataFrame.from_dict(metrics)
df.to_csv(path_or_buf=config['model_dir'] + '/metrics/' + 'metrics_model_nb_' + str(self.model_number) + '.csv')
pd.DataFrame.from_dict(metrics).to_csv(path_or_buf=config['model_dir'] + '/metrics/' + 'metrics_model_nb_' + str(self.model_number) + '.csv')
# Save model
if config['save_models']:
ckpt_dir = config['model_dir'] + '/best_models/' + config['model'] + '_nb_{}_'.format(self.model_number) + 'best_model.pth'
torch.save(self.state_dict(), ckpt_dir)
......
......@@ -5,6 +5,7 @@ from torch_models.BaseNetTorch import BaseNet
from config import config
import logging
from torch_models.Modules import Pad_Conv2d, Pad_Pool2d
from torch_models.torch_utils.utils import get_gpu_memory
class EEGNet(BaseNet):
"""
......@@ -28,22 +29,20 @@ class EEGNet(BaseNet):
super().__init__(input_shape=input_shape, epochs=epochs, model_number=model_number, batch_size=batch_size)
# Block 1: 2dconv and depthwise conv
#TODO: build only self.block1 and self.block2 for forward pass
self.padconv1 = Pad_Conv2d(kernel=(1,64))
self.padconv1 = Pad_Conv2d(kernel=(1,self.kernel_size))
self.conv1 = nn.Conv2d(
in_channels=1,
out_channels=32,
kernel_size=(1, 64)
out_channels=self.F1,
kernel_size=(1, self.kernel_size)
)
self.batchnorm1 = nn.BatchNorm2d(32, False)
self.pad_depthwise1 = Pad_Conv2d(kernel=(1,64))
self.batchnorm1 = nn.BatchNorm2d(self.F1, False)
self.depthwise_conv1 = nn.Conv2d(
in_channels=32,
out_channels=32*self.D,
groups=32,
kernel_size=(1,64)
in_channels=self.F1,
out_channels=self.F1 * self.D,
groups=self.F1,
kernel_size=(self.channels, 1)
)
self.batchnorm1_2 = nn.BatchNorm2d(32*self.D)
self.batchnorm1_2 = nn.BatchNorm2d(self.F1 * self.D)
self.activation1 = nn.ELU()
self.padpool1 = Pad_Conv2d(kernel=(1,16))
self.avgpool1 = nn.AvgPool2d(kernel_size=(1,16), stride=1)
......@@ -52,13 +51,13 @@ class EEGNet(BaseNet):
# Block 2: separable conv = depthwise + pointwise
self.pad_depthwise2 = Pad_Conv2d(kernel=(1,64))
self.depthwise_conv2 = nn.Conv2d(
in_channels=32*self.D,
out_channels=32*self.D,
groups=32,
in_channels=self.F1 * self.D,
out_channels=self.F2,
groups=self.F1*self.D,
kernel_size=(1,64)
)
self.pointwise_conv2 = nn.Conv2d( # no need for padding or pointwise
in_channels=32*self.D,
in_channels=self.F2,
out_channels=4,
kernel_size=1
)
......@@ -71,52 +70,36 @@ class EEGNet(BaseNet):
logging.info(f"Number of model parameters: {sum(p.numel() for p in self.parameters())}")
logging.info(f"Number of trainable parameters: {sum(p.numel() for p in self.parameters() if p.requires_grad)}")
if verbose and self.model_number == 0:
print(self)
def forward(self, x):
logging.info("start forward pass")
# Block 1
print(f"input tensor: {x.size()}")
x = self.padconv1(x)
print(f"tensor after padconv1: {x.size()}")
x = self.conv1(x)
print(f"tensor after conv1: {x.size()}")
x = self.batchnorm1(x)
x = self.pad_depthwise1(x)
print(f"tensor after pad depthwise: {x.size()}")
x = self.depthwise_conv1(x)
print(f"tensor after conv depthwise: {x.size()}")
x = self.batchnorm1_2(x)
x = self.activation1(x)
x = self.padpool1(x)
print(f"tensor after padpool1: {x.size()}")
x = self.avgpool1(x)
print(f"tensor after avgpool1: {x.size()}")
x = self.dropout1(x)
logging.info("forward pass: block 2")
# Block2
x = self.pad_depthwise2(x)
print(f"tensor after pad depthwise2: {x.size()}")
x = self.depthwise_conv2(x)
print(f"tensor after depthwise conv 2: {x.size()}")
x = self.pointwise_conv2(x)
print(f"tensor after pointwise conv 2: {x.size()}")
x = self.batchnorm2(x)
x = self.activation2(x)
x = self.padpool2(x)
print(f"tensor after padpool2: {x.size()}")
x = self.avgpool2(x)
print(f"tensor after avgpool2: {x.size()}")
x = self.dropout2(x)
logging.info("forward pass:reshape")
# Reshape
x = x.view(self.batch_size, -1)
print(f"tensor after view : {x.size()}")
# Output
x = self.output_layer(x)
logging.info("end forward pass")
return x
"""
......@@ -146,4 +129,4 @@ class EEGNet(BaseNet):
Return number of features passed into the output layer of the network
"""
# 4 channels before the output layer
return 4 * self.channels * self.timesamples
\ No newline at end of file
return 4 * self.timesamples
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment