Commit 9e7c0330 authored by okiss's avatar okiss
Browse files

plot all ensemble

parent 7b1c42d1
...@@ -8,13 +8,14 @@ from keras.callbacks import CSVLogger ...@@ -8,13 +8,14 @@ from keras.callbacks import CSVLogger
import logging import logging
class prediction_history(tf.keras.callbacks.Callback): class prediction_history(tf.keras.callbacks.Callback):
def __init__(self): def __init__(self,validation_data):
self.validation_data = validation_data
self.predhis = [] self.predhis = []
self.targets = [] self.targets = validation_data[1]
def on_epoch_end(self, epoch, logs={}): def on_epoch_end(self, epoch, logs={}):
y_pred = self.model.predict(self.validation_data[0]) y_pred = self.model.predict(self.validation_data[0])
self.targets.append(self.validation_data[1])
self.predhis.append(y_pred) self.predhis.append(y_pred)
...@@ -116,7 +117,7 @@ class ConvNet(ABC): ...@@ -116,7 +117,7 @@ class ConvNet(ABC):
ckpt_dir = config['model_dir'] + '/' + config['model'] + '_' + 'best_model.h5' ckpt_dir = config['model_dir'] + '/' + config['model'] + '_' + 'best_model.h5'
ckpt = tf.keras.callbacks.ModelCheckpoint(ckpt_dir, verbose=1, monitor='val_accuracy', save_best_only=True, mode='auto') ckpt = tf.keras.callbacks.ModelCheckpoint(ckpt_dir, verbose=1, monitor='val_accuracy', save_best_only=True, mode='auto')
X_train, X_val, y_train, y_val = train_test_split(x, y, test_size=0.2, random_state=42) X_train, X_val, y_train, y_val = train_test_split(x, y, test_size=0.2, random_state=42)
prediction_ensemble = prediction_history() prediction_ensemble = prediction_history((X_val,y_val))
hist = self.model.fit(X_train, y_train, verbose=1, batch_size=self.batch_size, validation_data=(X_val,y_val), hist = self.model.fit(X_train, y_train, verbose=1, batch_size=self.batch_size, validation_data=(X_val,y_val),
epochs=self.epochs, callbacks=[csv_logger, ckpt, early_stop, prediction_ensemble]) epochs=self.epochs, callbacks=[csv_logger, ckpt, early_stop, prediction_ensemble])
return hist, prediction_ensemble return hist, prediction_ensemble
...@@ -45,25 +45,25 @@ def run(trainX, trainY): ...@@ -45,25 +45,25 @@ def run(trainX, trainY):
logging.info('Cannot start the program. Please choose one model in the config.py file') logging.info('Cannot start the program. Please choose one model in the config.py file')
hist, pred_ensemble = classifier.fit(trainX,trainY) hist, pred_ensemble = classifier.fit(trainX,trainY)
# if i == 0: if i == 0:
# pred = pred_ensemble.predhis pred = pred_ensemble.predhis
# else: else:
# for j, pred_epoch in enumerate(pred_ensemble.predhis): for j, pred_epoch in enumerate(pred_ensemble.predhis):
# pred[j] = (np.array(pred[j])+np.array(pred_epoch)) pred[j] = (np.array(pred[j])+np.array(pred_epoch))
# for j, pred_epoch in enumerate(pred): for j, pred_epoch in enumerate(pred):
# pred_epoch = (pred_epoch/config['ensemble']).tolist() pred_epoch = (pred_epoch/config['ensemble']).tolist()
# loss.append(bce(pred_ensemble.targets[j],pred_epoch).numpy()) loss.append(bce(pred_ensemble.targets,pred_epoch).numpy())
# pred_epoch = np.round(pred_epoch,0) pred_epoch = np.round(pred_epoch,0)
# accuracy.append(np.mean((np.array(pred_epoch).reshape(-1)+np.array(pred_ensemble.targets[j]).reshape(-1)-1)**2)) accuracy.append(np.mean((np.array(pred_epoch).reshape(-1)+np.array(pred_ensemble.targets).reshape(-1)-1)**2))
# if config['ensemble']>1: if config['ensemble']>1:
# config['model']+='_ensemble' config['model']+='_ensemble'
# if config['split']: if config['split']:
# config['model'] = config['model'] + '_cluster' config['model'] = config['model'] + '_cluster'
# hist.history['val_loss'] = loss hist.history['val_loss'] = loss
# hist.history['val_accuracy'] = accuracy hist.history['val_accuracy'] = accuracy
plot_loss(hist, config['model_dir'], config['model'], val = True) plot_loss(hist, config['model_dir'], config['model'], val = True)
plot_acc(hist, config['model_dir'], config['model'], val = True) plot_acc(hist, config['model_dir'], config['model'], val = True)
save_logs(hist, config['model_dir'], config['model'], pytorch = False) save_logs(hist, config['model_dir'], config['model'], pytorch = False)
\ No newline at end of file
...@@ -24,11 +24,11 @@ def main(): ...@@ -24,11 +24,11 @@ def main():
trainX = np.transpose(trainX, (0, 2, 1)) trainX = np.transpose(trainX, (0, 2, 1))
logging.info(trainX.shape) logging.info(trainX.shape)
tune(trainX,trainY) # tune(trainX,trainY)
# run(trainX,trainY) # run(trainX,trainY)
# select_best_model() # select_best_model()
# comparison_plot(n_best = 4) comparison_plot()
logging.info("--- Runtime: %s seconds ---" % (time.time() - start_time)) logging.info("--- Runtime: %s seconds ---" % (time.time() - start_time))
logging.info('Finished Logging') logging.info('Finished Logging')
......
...@@ -68,45 +68,31 @@ def plot_loss_torch(loss, output_directory, model): ...@@ -68,45 +68,31 @@ def plot_loss_torch(loss, output_directory, model):
def cp_dir(source, target): def cp_dir(source, target):
call(['cp', '-a', source, target]) call(['cp', '-a', source, target])
def comparison_plot(n_best = 3): def comparison_plot():
results = {}
run_dir = os.getcwd()+'/results/'
#get best model in runs for all model_name
for experiment in os.listdir(run_dir):
_,name = experiment.split('_',1)
if os.path.isdir(run_dir+experiment):
try:
summary = pd.read_csv(run_dir+experiment+'/'+name+'_df_best_model.csv')
try:
acc = float(summary['best_model_val_acc'])
if not (name in results.keys()): run_dir = os.getcwd()+'/results/'
results[name] = acc
else:
if acc > results[name]:
results[name] = acc
except KeyError:
pass
except FileNotFoundError:
pass
#plot n_best model
plt.figure() plt.figure()
plt.title('Validation accuracy of the {} best models'.format(n_best)) plt.title('Comparison of the Validation accuracy' )
plt.grid(True) plt.grid(True)
plt.xlabel('epochs') plt.xlabel('epochs')
plt.ylabel('accuracy (%)') plt.ylabel('accuracy (%)')
for experiment in os.listdir(run_dir):
position = experiment.find('ensemble')
if position != -1:
name_split = experiment.split('_')
name=name_split[1]
for i in range(2,len(name_split)):
name += '_'
name += name_split[i]
summary = pd.read_csv(run_dir+experiment+'/'+name+'_history.csv')
acc = summary['val_accuracy']
plt.plot(acc,'-',label=name)
for i in range(n_best):
best = max(results.items(), key=operator.itemgetter(1))# find best model
name = best[0]
acc = 100*pd.read_csv(os.getcwd()+'/results/best_'+name+'/'+name+'_history.csv')['val_accuracy'] #to have %
plt.plot(acc,'-',label=name)
results[name]=0 # to eliminate this model
plt.legend() plt.legend()
plt.savefig(os.getcwd()+'/results/comparison_accuracy.png') plt.savefig(run_dir+'/comparison_accuracy.png')
def select_best_model(): def select_best_model():
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment