Commit 4847e09a authored by zpgeng's avatar zpgeng
Browse files

Merge branch 'master' of https://gitlab.ethz.ch/kard/dl-project

Oriel's commit
parents 507a80a3 1c6212c1
......@@ -4,6 +4,8 @@ import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
from utils.utils import *
class Net(nn.Module):
def __init__(self):
......@@ -45,10 +47,13 @@ def run(trainX, trainY):
# Newly added lines below
save_logs(hist, pytorch=True)
save_model_param(pytorch=True)
plot_loss_torch(loss)
def train(trainloader, net, optimizer, criterion, epoch=50):
loss=[]
for epoch in range(2): # loop over the dataset multiple times
running_loss = 0.0
loss_values=[]
for i, data in enumerate(trainloader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data
......@@ -61,12 +66,16 @@ def train(trainloader, net, optimizer, criterion, epoch=50):
loss = criterion(outputs, labels.squeeze(1))
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
run_loss = loss.item()
running_loss+=run_loss
loss_values.append(run_loss)
if i % 200 == 0: # print every 200 mini-batches
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 200))
running_loss = 0.0
los=np.mean(loss_values)
loss.append(los)
return loss
print('Finished Training')
......@@ -33,15 +33,30 @@ def plot_loss(hist,name,val=False):
plt.plot(epochs,hist.history['val_loss'],'g-',label='validation')
plt.legend()
plt.xlabel('epochs')
plt.ylabel('Binary Cross Entropie')
plt.ylabel('Binary Cross Entropy')
plt.savefig('loss_'+name+'.png')
plt.show()
def plot_loss_torch(loss,name='CNN'):
epochs=np.arange(len(loss))
plt.figure()
plt.title(name+ ' loss')
plt.plot(epochs,loss,'b-',label='training')
plt.legend()
plt.xlabel('epochs')
plt.ylabel('Binary Cross Entropy')
plt.savefig('loss_'+name+'.png')
plt.show()
# Save the logs (newly added without debugging)
# haven't write about the pytorch=True part
def save_logs(hist, output_directory=config['model_dir'], pytorch=False):
if pytorch:
# Maybe Oriel can work on this
df_metrics={'Loss':hist}
df_metrics=pd.DataFrame(df_metrics)
df_metrics.to_csv(output_directory + config['model'] + '_df_metrics.csv', index=False)
else:
hist_df = pd.DataFrame(hist.history)
......@@ -70,4 +85,4 @@ def save_model_param(output_directory=config['model_dir'], pytorch=False):
torch.save(net.state_dict(), output_directory + config['model'] + '_model.pth')
else:
classifier.save(output_directory + config['model'] + '_model.h5')
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment