Commit a56264ce authored by nstorni's avatar nstorni
Browse files

Training with json config file. VAE configurability.

parent a4c3ff07
# VAE setup and training.
## 1. Prerequisites
First clone the project files from Gitlab (you will need to enter your credentials):
```bash
git clone https://gitlab.ethz.ch/deep-learning-rodent/world-models.git
```
Then install the requirements (they will be installed for the python_gpu/3.7.1 module only):
```bash
$HOME/world-models/install_requirements.sh
```
(To learn more about module use on the cluster https://scicomp.ethz.ch/wiki/Getting_started_with_clusters)
## 2. Mini Mice dataset preparation
Prepare folder structure in $SCRATCH directory:
```bash
mkdir $SCRATCH/data
mkdir $SCRATCH/data/mini_mice_dataset
mkdir $SCRATCH/data/mini_mice_dataset/video
mkdir $SCRATCH/data/mini_mice_dataset/train
mkdir $SCRATCH/data/mini_mice_dataset/train/nolabel
mkdir $SCRATCH/data/mini_mice_dataset/val
mkdir $SCRATCH/data/mini_mice_dataset/val/nolabel
mkdir $SCRATCH/data/mini_mice_dataset/test
mkdir $SCRATCH/data/mini_mice_dataset/test/nolabel
```
This structure is required for the dataset loader to load each picture (it expects a folder containing subfolders for each class, in our case we have only one class "nolabel").
Load one video to $SCRATCH/data/mini_mice_dataset/video using winSCP or similar.
E.g copy from local (ON YOUR LOCAL MACHINE) directory video.avi to the target directory in the cloud.:
```bash
scp OFT_1.mp4 USERNAME@login.leonhard.ethz.ch:$SCRATCH/data/mini_mice_dataset/video
```
Extract frames from videos:
```bash
$HOME/world-models/video_frame/video2frames.sh
```
Move some images from train to test folder (not really random, just taking all the file ending with 1 in the filename).
```bash
cd $SCRATCH/data/mini_mice_dataset
cp train/nolabel/*1.png val/nolabel
```
## 3. VAE training setup
The configuration for training the VAE are done through a json file where all hyperparameters and dataset directory paths are specified. A template for a training job is available in the config directory (template_config.json), copy it to your local directory and modify it for you training specifications. This config file will then be copied automatically to the output directory with the trained models.
```bash
cd $HOME
mkdir trainin_configs
cp world-models/template_config.json training_configs/train_config1.json
```
Modify train_config1.json for your training run, then you can test the configuration by starting the training on the local leonhard instance:
```bash
$HOME/world-models/train_vae.sh training_configs/train_config1.json
```
If there are no errors interrupt the training with CTRL+C, you can now submit it to the cluster.
```bash
bsub -W 12:00 -n 4 -R "rusage[ngpus_excl_p=1,mem=4096 ]" "$HOME/world-models/train_vae.sh training_configs/train_config1.json"
```
## 4. Monitoring the job with tensorboard
You can monitor the training process with tensorboard, launch tensorboard:
```bash
$HOME/world-models/start_tensorboard.sh
```
Forward the port used by tensorboard to view the training process in the browser.
```bash
ssh nstorni@login.leonhard.ethz.ch -L 6000: login.leonhard.ethz.ch:6008
```
The training job will save the models, samples and logs in the $SCRATCH folder.
# Pytorch implementation of the "WorldModels"
Paper: Ha and Schmidhuber, "World Models", 2018. https://doi.org/10.5281/zenodo.1207631. For a quick summary of the paper and some additional experiments, visit the [github page](https://ctallec.github.io/world-models/).
......
#!/bin/bash
echo "Loading modules"
module purge
module load gcc/4.8.5 python_gpu/3.7.1
echo "Installing requirements"
cd $HOME
pip install --user -r world-models/requirements.txt
pip install --user --upgrade torch
"""
Variational encoder model, used as a visual model
for our model of the world.
......@@ -7,39 +6,46 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
class Decoder(nn.Module):
class CustomDecoder(nn.Module):
""" VAE decoder """
def __init__(self, img_channels, latent_size):
super(Decoder, self).__init__()
super(CustomDecoder, self).__init__()
self.latent_size = latent_size
self.img_channels = img_channels
self.fc1 = nn.Linear(latent_size, 29*18*512)
self.devonv1 = nn.ConvTranspose2d(512, 256, 5, stride=2)
self.devonv2 = nn.ConvTranspose2d(256, 128, 5, stride=2)
self.deconv3 = nn.ConvTranspose2d(128, 64, 5, stride=2)
self.deconv4 = nn.ConvTranspose2d(64, 32, 6, stride=2)
self.fc1 = nn.Linear(latent_size, 27*16*512)
self.deconv1 = nn.ConvTranspose2d(512, 256, 4, stride=2)
self.deconv2 = nn.ConvTranspose2d(256, 128, 4, stride=2)
self.deconv3 = nn.ConvTranspose2d(128, 64, 4, stride=2)
self.deconv4 = nn.ConvTranspose2d(64, 32, 4, stride=2)
self.deconv5 = nn.ConvTranspose2d(32, img_channels, 6, stride=2)
def forward(self, x): # pylint: disable=arguments-differ
x = F.relu(self.fc1(x))
x = x.view(-1, 512, 18, 29)
# print('Flat before restructuring {}'.format(x.size()))
x = x.view(-1, 512, 16, 27)
# print('Before deconv {}'.format(x.size()))
#x = x.unsqueeze(-1).unsqueeze(-1)
x = F.relu(self.deconv1(x))
# print('Deconv1 {}'.format(x.size()))
x = F.relu(self.deconv2(x))
# print('Deconv2 {}'.format(x.size()))
x = F.relu(self.deconv3(x))
# print('Deconv3 {}'.format(x.size()))
x = F.relu(self.deconv4(x))
# print('Deconv4 {}'.format(x.size()))
reconstruction = F.sigmoid(self.deconv5(x))
# print('Reconstruction {}'.format(reconstruction.size()))
return reconstruction
class Encoder(nn.Module): # pylint: disable=too-many-instance-attributes
class CustomEncoder(nn.Module): # pylint: disable=too-many-instance-attributes
""" VAE encoder """
def __init__(self, img_channels, latent_size):
super(Encoder, self).__init__()
super(CustomEncoder, self).__init__()
self.latent_size = latent_size
#self.img_size = img_size
self.img_channels = img_channels
# 928 x 576
self.conv1 = nn.Conv2d(img_channels, 32, 4, stride=2)
# 464 x 288
......@@ -51,17 +57,124 @@ class Encoder(nn.Module): # pylint: disable=too-many-instance-attributes
# 58 x 36
self.conv5 = nn.Conv2d(256, 512, 4, stride=2)
# 29 x 18
self.fc_mu = nn.Linear(29*18*512, latent_size)
self.fc_logsigma = nn.Linear(29*18*512, latent_size)
self.fc_mu = nn.Linear(27*16*512, latent_size)
self.fc_logsigma = nn.Linear(27*16*512, latent_size)
def forward(self, x): # pylint: disable=arguments-differ
# print('Encoder input {}'.format(x.size()))
x = F.relu(self.conv1(x))
# print('Conv1 {}'.format(x.size()))
x = F.relu(self.conv2(x))
# print('Conv2 {}'.format(x.size()))
x = F.relu(self.conv3(x))
# print('Conv3 {}'.format(x.size()))
x = F.relu(self.conv4(x))
# print('Conv4 {}'.format(x.size()))
x = F.relu(self.conv5(x))
# print('Conv5 {}'.format(x.size()))
x = x.view(x.size(0), -1)
# print('Squashed {}'.format(x.size()))
mu = self.fc_mu(x)
logsigma = self.fc_logsigma(x)
return mu, logsigma
class SquareDecoder(nn.Module):
""" VAE decoder """
def __init__(self, img_channels, latent_size):
super(SquareDecoder, self).__init__()
self.latent_size = latent_size
self.img_channels = img_channels
self.fc1 = nn.Linear(latent_size, 1024)
self.deconv1 = nn.ConvTranspose2d(1024, 128, 5, stride=2)
self.deconv2 = nn.ConvTranspose2d(128, 64, 5, stride=2)
self.deconv3 = nn.ConvTranspose2d(64, 32, 6, stride=2)
self.deconv4 = nn.ConvTranspose2d(32, img_channels, 6, stride=2)
def forward(self, x): # pylint: disable=arguments-differ
x = F.relu(self.fc1(x))
x = x.unsqueeze(-1).unsqueeze(-1)
x = F.relu(self.deconv1(x))
x = F.relu(self.deconv2(x))
x = F.relu(self.deconv3(x))
reconstruction = F.sigmoid(self.deconv4(x))
return reconstruction
class SquareEncoder(nn.Module): # pylint: disable=too-many-instance-attributes
""" VAE encoder """
def __init__(self, img_channels, latent_size):
super(SquareEncoder, self).__init__()
self.latent_size = latent_size
#self.img_size = img_size
self.img_channels = img_channels
self.conv1 = nn.Conv2d(img_channels, 32, 4, stride=2)
self.conv2 = nn.Conv2d(32, 64, 4, stride=2)
self.conv3 = nn.Conv2d(64, 128, 4, stride=2)
self.conv4 = nn.Conv2d(128, 256, 4, stride=2)
self.fc_mu = nn.Linear(2*2*256, latent_size)
self.fc_logsigma = nn.Linear(2*2*256, latent_size)
def forward(self, x): # pylint: disable=arguments-differ
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = F.relu(self.conv4(x))
x = x.view(x.size(0), -1)
mu = self.fc_mu(x)
logsigma = self.fc_logsigma(x)
return mu, logsigma
class Decoder(nn.Module):
""" VAE decoder """
def __init__(self, img_channels, latent_size):
super(Decoder, self).__init__()
self.latent_size = latent_size
self.img_channels = img_channels
self.fc1 = nn.Linear(latent_size, 1024)
self.deconv1 = nn.ConvTranspose2d(1024, 128, 5, stride=2)
self.deconv2 = nn.ConvTranspose2d(128, 64, 5, stride=2)
self.deconv3 = nn.ConvTranspose2d(64, 32, 6, stride=2)
self.deconv4 = nn.ConvTranspose2d(32, img_channels, 6, stride=2)
def forward(self, x): # pylint: disable=arguments-differ
x = F.relu(self.fc1(x))
x = x.unsqueeze(-1).unsqueeze(-1)
x = F.relu(self.deconv1(x))
x = F.relu(self.deconv2(x))
x = F.relu(self.deconv3(x))
reconstruction = F.sigmoid(self.deconv4(x))
return reconstruction
class Encoder(nn.Module): # pylint: disable=too-many-instance-attributes
""" VAE encoder """
def __init__(self, img_channels, latent_size):
super(Encoder, self).__init__()
self.latent_size = latent_size
#self.img_size = img_size
self.img_channels = img_channels
self.conv1 = nn.Conv2d(img_channels, 32, 4, stride=2)
self.conv2 = nn.Conv2d(32, 64, 4, stride=2)
self.conv3 = nn.Conv2d(64, 128, 4, stride=2)
self.conv4 = nn.Conv2d(128, 256, 4, stride=2)
self.fc_mu = nn.Linear(2*2*256, latent_size)
self.fc_logsigma = nn.Linear(2*2*256, latent_size)
def forward(self, x): # pylint: disable=arguments-differ
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = F.relu(self.conv4(x))
x = x.view(x.size(0), -1)
mu = self.fc_mu(x)
......@@ -69,12 +182,19 @@ class Encoder(nn.Module): # pylint: disable=too-many-instance-attributes
return mu, logsigma
def factory(classname):
cls = globals()[classname]
return cls()
class VAE(nn.Module):
""" Variational Autoencoder """
def __init__(self, img_channels, latent_size):
def __init__(self, type, img_channels, latent_size):
super(VAE, self).__init__()
self.encoder = Encoder(img_channels, latent_size)
self.decoder = Decoder(img_channels, latent_size)
encoder = globals()[type + "Encoder"]
decoder = globals()[type + "Decoder"]
self.encoder = encoder(img_channels, latent_size)
self.decoder = decoder(img_channels, latent_size)
def forward(self, x): # pylint: disable=arguments-differ
mu, logsigma = self.encoder(x)
......@@ -84,3 +204,5 @@ class VAE(nn.Module):
recon_x = self.decoder(z)
return recon_x, mu, logsigma
#!/bin/bash
# Make bash file executable chmod u+x filename.sh
echo "Loading modules"
module purge
module load gcc/4.8.5 python_cpu/3.7.1
echo "Starting tensorboard"
tensorboard --logdir $SCRATCH/vae_training_outputs/tensorboard_logs
\ No newline at end of file
{
"exp_name" : "VAE training",
"learning_rate" : 0.0003,
"batch_size" : 32,
"noreload" : "True",
"reload_dir": "",
"epochs" : 20,
"samples" : "True",
"vae_type": "Square",
"input_dim": [64,64],
"latent_dim": 10,
"early_stopping": "True",
"weight_decay": "True",
"logdir" : "$SCRATCH/vae_training_outputs",
"train_dataset_dir": "$SCRATCH/data/mini_mice_dataset/train",
"val_dataset_dir":"$SCRATCH/data/mini_mice_dataset/val",
"output_dir": ""
}
\ No newline at end of file
#!/bin/bash
echo "Loading modules"
module purge
module load gcc/4.8.5 python_gpu/3.7.1
cd $HOME
#pip install --user -r world-models/requirements.txt
#pip install --user --upgrade torch
echo $1
# python world-models/trainvae.py --modelconfig $HOME/world-models/template_config.json
python world-models/trainvae.py --modelconfig $1
\ No newline at end of file
......@@ -2,6 +2,7 @@
import argparse
from os.path import join, exists
from os import mkdir
import os
import torch
import torch.utils.data
......@@ -11,6 +12,8 @@ from torchvision import transforms, datasets
from torchvision.utils import save_image
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import json
from models.vae import VAE
......@@ -21,19 +24,14 @@ from utils.learning import EarlyStopping
from utils.learning import ReduceLROnPlateau
from data.loaders import RolloutObservationDataset
from utils.config import get_config_from_json
parser = argparse.ArgumentParser(description='VAE Trainer')
parser.add_argument('--batch-size', type=int, default=32, metavar='N',
help='input batch size for training (default: 32)')
parser.add_argument('--epochs', type=int, default=1000, metavar='N',
help='number of epochs to train (default: 1000)')
parser.add_argument('--logdir', type=str, help='Directory where results are logged')
parser.add_argument('--noreload', action='store_true',
help='Best model is not reloaded if specified')
parser.add_argument('--nosamples', action='store_true',
help='Does not save samples during training if specified')
parser.add_argument('--modelconfig', type=str, help='File path of training configuration')
args = parser.parse_args()
config, config_json = get_config_from_json(args.modelconfig)
args = parser.parse_args()
cuda = torch.cuda.is_available()
......@@ -43,37 +41,33 @@ torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if cuda else "cpu")
INPUT_WIDTH = config.input_dim[0]
INPUT_HEIGHT = config.input_dim[1]
transform_train = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((RED_SIZE, RED_SIZE)),
transforms.Resize((INPUT_HEIGHT, INPUT_WIDTH)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
])
transform_test = transforms.Compose([
# transforms.ToPILImage(),
transforms.Resize((RED_SIZE, RED_SIZE)),
transforms.Resize((INPUT_HEIGHT, INPUT_WIDTH)),
transforms.ToTensor(),
])
train_dir = '../data/hymenoptera_data/train'
test_dir = '../data/hymenoptera_data/val'
dataset_train = datasets.ImageFolder(train_dir, transform_test)
dataset_test = datasets.ImageFolder(test_dir, transform_test)
dataset_train = datasets.ImageFolder(config.train_dataset_dir, transform_test)
dataset_test = datasets.ImageFolder(config.val_dataset_dir, transform_test)
# dataset_train = RolloutObservationDataset('datasets/carracing',
# transform_train, train=True)
# dataset_test = RolloutObservationDataset('datasets/carracing',
# transform_test, train=False)
train_loader = torch.utils.data.DataLoader(
dataset_train, batch_size=args.batch_size, shuffle=True, num_workers=2)
dataset_train, batch_size=config.batch_size, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(
dataset_test, batch_size=args.batch_size, shuffle=True, num_workers=2)
dataset_test, batch_size=config.batch_size, shuffle=True, num_workers=2)
model = VAE(config.vae_type,3, config.latent_dim).to(device)
model = VAE(3, LSIZE).to(device)
optimizer = optim.Adam(model.parameters())
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5)
earlystopping = EarlyStopping('min', patience=3000)
......@@ -99,10 +93,15 @@ def train(epoch):
# data = data.to(device)
batch_idx = 0
for inputs, labels in train_loader:
batch_idx += len(labels)
# writer.add_image('input', inputs[0], batch_idx * len(inputs))
batch_idx += 1
data = inputs.to(device)
optimizer.zero_grad()
recon_batch, mu, logvar = model(data)
# writer.add_image('reconstruction', recon_batch[0], batch_idx * len(data))
# writer.add_images('train',torch.stack((inputs[0],recon_batch[0]),0), (epoch-1)*len(train_loader) + batch_idx*len(data))
loss = loss_function(recon_batch, data, mu, logvar)
loss.backward()
train_loss += loss.item()
......@@ -112,6 +111,8 @@ def train(epoch):
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader),
loss.item() / len(data)))
writer.add_images('train',torch.stack((inputs[0],(recon_batch[0]).cpu()),0), (epoch-1)*len(dataset_train) + batch_idx*len(data))
writer.add_scalar('Loss/train', loss.item() / len(data), (epoch-1)*len(dataset_train) + batch_idx*len(data))
print('====> Epoch: {} Average loss: {:.4f}'.format(
epoch, train_loss / len(train_loader.dataset)))
......@@ -131,22 +132,36 @@ def test(epoch):
test_loss /= len(test_loader.dataset)
print('====> Test set loss: {:.4f}'.format(test_loss))
# writer.add_scalar('Loss/test', test_loss, epoch)
writer.add_scalar('Loss/test', test_loss, epoch)
return test_loss
# Create folders for storing models and training logs.
if not exists(config.logdir):
mkdir(config.logdir)
if not exists(join(config.logdir, 'models')):
mkdir(join(config.logdir, 'models'))
# check vae dir exists, if not, create it
vae_dir = join(args.logdir, 'vae')
runs_dir = join(config.logdir, 'tensorboard_logs' )
if not exists(runs_dir):
mkdir(runs_dir)
# Stamp training experiment name with datetime stamp for unique naming.
datetimestamp = (datetime.now()).strftime("D%Y_%m_%dT%H_%M_%S")
vae_dir = join(config.logdir, 'models/vae_' + datetimestamp)
if not exists(vae_dir):
mkdir(vae_dir)
mkdir(join(vae_dir, 'samples'))
mkdir(join(vae_dir, 'runs'))
# writer = SummaryWriter(log_dir=join(vae_dir, 'runs'))
# Save configuration file
with open(join(vae_dir,'config.json'), 'w') as outfile:
json.dump(config_json, outfile)
writer = SummaryWriter(runs_dir+"/"+datetimestamp)
reload_file = join(vae_dir, 'best.tar')
if not args.noreload and exists(reload_file):
reload_file = join(config.reload_dir, 'best.tar')
if not config.noreload and exists(reload_file):
state = torch.load(reload_file)
print("Reloading model at epoch {}"
", with test error {}".format(
......@@ -160,7 +175,7 @@ if not args.noreload and exists(reload_file):
cur_best = None
for epoch in range(1, args.epochs + 1):
for epoch in range(1, config.epochs + 1):
train(epoch)
test_loss = test(epoch)
scheduler.step(test_loss)
......@@ -184,11 +199,11 @@ for epoch in range(1, args.epochs + 1):
if not args.nosamples:
if config.samples:
with torch.no_grad():
sample = torch.randn(RED_SIZE, LSIZE).to(device)
sample = torch.randn(1, config.latent_dim).to(device)
sample = model.decoder(sample).cpu()
save_image(sample.view(64, 3, RED_SIZE, RED_SIZE),
save_image(sample.view(1, 3, INPUT_HEIGHT, INPUT_WIDTH),
join(vae_dir, 'samples/sample_' + str(epoch) + '.png'))
if earlystopping.stop:
......
import json
from easydict import EasyDict
import os
def get_config_from_json(json_file):
"""
Get the config from a json file
:param json_file: the path of the config file
:return: config(namespace), config(dictionary)
"""
print(json_file)
# parse the configurations from the config json file provided
with open(json_file, 'r') as config_file:
try:
config_dict = json.load(config_file)
# EasyDict allows to access dict values as attributes (works recursively).
config = EasyDict(config_dict)
config.train_dataset_dir = os.path.expandvars(config.train_dataset_dir)
config.val_dataset_dir = os.path.expandvars(config.val_dataset_dir)
config.logdir = os.path.expandvars(config.logdir)
return config, config_dict
except ValueError:
print("INVALID JSON file format.. Please provide a good json file")
exit(-1)
\ No newline at end of file
from PIL import Image
import sys
import os
import math
import numpy as np
###########################################################################################
# script to generate moving mnist video dataset (frame by frame) as described in
# [1] arXiv:1502.04681 - Unsupervised Learning of Video Representations Using LSTMs
# Srivastava et al
# by Tencia Lee
# saves in hdf5, npz, or jpg (individual frames) format
###########################################################################################
# helper functions
def arr_from_img(im,shift=0):
w,h=im.size
arr=im.getdata()
c = np.product(arr.size) / (w*h)
return np.asarray(arr, dtype=np.float32).reshape((h,w,c)).transpose(2,1,0) / 255. - shift
def get_picture_array(X, index, shift=0):
ch, w, h = X.shape[1], X.shape[2], X.shape[3]
ret = ((X[index]+shift)*255.).reshape(ch,w,h).transpose(2,1,0).clip(0,255).astype(np.uint8)
if ch == 1:
ret=ret.reshape(h,w)
return ret
# loads mnist from web on demand
def load_dataset():
if sys.version_info[0] == 2:
from urllib import urlretrieve
else: