Commit 9e7c0330 authored by okiss's avatar okiss
Browse files

plot all ensemble

parent 7b1c42d1
......@@ -8,13 +8,14 @@ from keras.callbacks import CSVLogger
import logging
class prediction_history(tf.keras.callbacks.Callback):
def __init__(self):
def __init__(self,validation_data):
self.validation_data = validation_data
self.predhis = []
self.targets = []
self.targets = validation_data[1]
def on_epoch_end(self, epoch, logs={}):
y_pred = self.model.predict(self.validation_data[0])
self.targets.append(self.validation_data[1])
self.predhis.append(y_pred)
......@@ -116,7 +117,7 @@ class ConvNet(ABC):
ckpt_dir = config['model_dir'] + '/' + config['model'] + '_' + 'best_model.h5'
ckpt = tf.keras.callbacks.ModelCheckpoint(ckpt_dir, verbose=1, monitor='val_accuracy', save_best_only=True, mode='auto')
X_train, X_val, y_train, y_val = train_test_split(x, y, test_size=0.2, random_state=42)
prediction_ensemble = prediction_history()
prediction_ensemble = prediction_history((X_val,y_val))
hist = self.model.fit(X_train, y_train, verbose=1, batch_size=self.batch_size, validation_data=(X_val,y_val),
epochs=self.epochs, callbacks=[csv_logger, ckpt, early_stop, prediction_ensemble])
return hist, prediction_ensemble
......@@ -45,25 +45,25 @@ def run(trainX, trainY):
logging.info('Cannot start the program. Please choose one model in the config.py file')
hist, pred_ensemble = classifier.fit(trainX,trainY)
# if i == 0:
# pred = pred_ensemble.predhis
# else:
# for j, pred_epoch in enumerate(pred_ensemble.predhis):
# pred[j] = (np.array(pred[j])+np.array(pred_epoch))
if i == 0:
pred = pred_ensemble.predhis
else:
for j, pred_epoch in enumerate(pred_ensemble.predhis):
pred[j] = (np.array(pred[j])+np.array(pred_epoch))
# for j, pred_epoch in enumerate(pred):
# pred_epoch = (pred_epoch/config['ensemble']).tolist()
# loss.append(bce(pred_ensemble.targets[j],pred_epoch).numpy())
# pred_epoch = np.round(pred_epoch,0)
# accuracy.append(np.mean((np.array(pred_epoch).reshape(-1)+np.array(pred_ensemble.targets[j]).reshape(-1)-1)**2))
for j, pred_epoch in enumerate(pred):
pred_epoch = (pred_epoch/config['ensemble']).tolist()
loss.append(bce(pred_ensemble.targets,pred_epoch).numpy())
pred_epoch = np.round(pred_epoch,0)
accuracy.append(np.mean((np.array(pred_epoch).reshape(-1)+np.array(pred_ensemble.targets).reshape(-1)-1)**2))
# if config['ensemble']>1:
# config['model']+='_ensemble'
# if config['split']:
# config['model'] = config['model'] + '_cluster'
if config['ensemble']>1:
config['model']+='_ensemble'
if config['split']:
config['model'] = config['model'] + '_cluster'
# hist.history['val_loss'] = loss
# hist.history['val_accuracy'] = accuracy
hist.history['val_loss'] = loss
hist.history['val_accuracy'] = accuracy
plot_loss(hist, config['model_dir'], config['model'], val = True)
plot_acc(hist, config['model_dir'], config['model'], val = True)
save_logs(hist, config['model_dir'], config['model'], pytorch = False)
......@@ -24,11 +24,11 @@ def main():
trainX = np.transpose(trainX, (0, 2, 1))
logging.info(trainX.shape)
tune(trainX,trainY)
# tune(trainX,trainY)
# run(trainX,trainY)
# select_best_model()
# comparison_plot(n_best = 4)
comparison_plot()
logging.info("--- Runtime: %s seconds ---" % (time.time() - start_time))
logging.info('Finished Logging')
......
......@@ -68,45 +68,31 @@ def plot_loss_torch(loss, output_directory, model):
def cp_dir(source, target):
call(['cp', '-a', source, target])
def comparison_plot(n_best = 3):
results = {}
run_dir = os.getcwd()+'/results/'
#get best model in runs for all model_name
for experiment in os.listdir(run_dir):
_,name = experiment.split('_',1)
if os.path.isdir(run_dir+experiment):
try:
summary = pd.read_csv(run_dir+experiment+'/'+name+'_df_best_model.csv')
try:
acc = float(summary['best_model_val_acc'])
if not (name in results.keys()):
results[name] = acc
else:
if acc > results[name]:
results[name] = acc
except KeyError:
pass
except FileNotFoundError:
pass
def comparison_plot():
#plot n_best model
run_dir = os.getcwd()+'/results/'
plt.figure()
plt.title('Validation accuracy of the {} best models'.format(n_best))
plt.title('Comparison of the Validation accuracy' )
plt.grid(True)
plt.xlabel('epochs')
plt.ylabel('accuracy (%)')
for i in range(n_best):
best = max(results.items(), key=operator.itemgetter(1))# find best model
name = best[0]
acc = 100*pd.read_csv(os.getcwd()+'/results/best_'+name+'/'+name+'_history.csv')['val_accuracy'] #to have %
for experiment in os.listdir(run_dir):
position = experiment.find('ensemble')
if position != -1:
name_split = experiment.split('_')
name=name_split[1]
for i in range(2,len(name_split)):
name += '_'
name += name_split[i]
summary = pd.read_csv(run_dir+experiment+'/'+name+'_history.csv')
acc = summary['val_accuracy']
plt.plot(acc,'-',label=name)
results[name]=0 # to eliminate this model
plt.legend()
plt.savefig(os.getcwd()+'/results/comparison_accuracy.png')
plt.savefig(run_dir+'/comparison_accuracy.png')
def select_best_model():
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment