Commit c445dbd1 authored by Lukas Wolf's avatar Lukas Wolf
Browse files

added basenet

parent 103169b4
import tensorflow as tf
import tensorflow.keras as keras
from config import config
from keras.callbacks import CSVLogger
from utils.utils import train_val_split
class prediction_history(tf.keras.callbacks.Callback):
def __init__(self, validation_data):
self.validation_data = validation_data
self.predhis = []
self.targets = validation_data[1]
def on_epoch_end(self, epoch, logs={}):
y_pred = self.model.predict(self.validation_data[0])
class BaseNet:
def __init__(self, epochs=50, verbose=True):
self.epochs = epochs
self.verbose = verbose
if config['split']:
self.model = self._split_model()
self.model = self._build_model()
self.model.compile(loss='binary_crossentropy', optimizer=keras.optimizers.Adam(), metrics=['accuracy'])
if self.verbose:
# abstract method
def _split_model(self):
# abstract method
def _build_model(self):
def get_model(self):
return self.model
def fit(self, x, y, subjectID):
csv_logger = CSVLogger(config['batches_log'], append=True, separator=';')
ckpt_dir = config['model_dir'] + '/' + config['model'] + '_' + 'best_model.h5'
ckpt = tf.keras.callbacks.ModelCheckpoint(ckpt_dir, verbose=1, monitor='val_accuracy', save_best_only=True, mode='auto')
X_train, X_val, y_train, y_val = train_val_split(x, y, 0.2, subjectID)
prediction_ensemble = prediction_history((X_val, y_val))
if config['model'] == 'eegnet':
# early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=20)
hist =, y_train, verbose=1, validation_data=(X_val, y_val),
epochs=self.epochs, callbacks=[csv_logger, ckpt, prediction_ensemble])
hist =, y_train, verbose=2, batch_size=self.batch_size, validation_data=(X_val, y_val),
epochs=self.epochs, callbacks=[csv_logger, ckpt, prediction_ensemble])
return hist, prediction_ensemble
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment