Commit 859ea82a authored by nstorni's avatar nstorni
Browse files

Merge branch 'lstm-adaptation' into 'development'

Lstm adaptation into development

See merge request !1
parents f28e4f85 eb1ae055
......@@ -133,6 +133,32 @@ For large datasets run the script in the cluster:
bsub -o $HOME/job_logs -n 4 -R "rusage[mem=4096]" "$HOME/world-models/utils/generate_moving_mnist.sh --num_videos 10000 --num_videos_val 300 --num_frames 10 --digits_dim 40 --frame_dim 256"
```
## 6. Running the RNN
Once you have the latent variables from the VAE, you can use them to train the RNN.
First, create a folder for the RNN:
```bash
mkdir $SCRATCH/data/lstm
```
Afterwards, copy the dataset (containing the latent variables) into this folder. If you need to upload them from your local machine, use the following command:
```bash
USERNAME="your-username"
scp data.pt $USERNAME@login.leonhard.ethz.ch:/cluster/scratch/$USERNAME/data/latent_variables
```
Then, you can copy the config template for the RNN into the configs folder:
```bash
cd $HOME
cp world-models/template_config_rnn.json training_configs/rnn_config1.json
```
And edit the config file to fit your parameters. Then you can do a test run to see if the script works:
```bash
$HOME/world-models/start_training.sh --modeldir trainlstm.py --modelconfigdir $HOME/training_configs/rnn_config1.json
```
If you don't get any errors, you can abort the execution and send the job to be executed on the cluster:
```bash
bsub -o job_logs -n 4 -R "rusage[ngpus_excl_p=1,mem=4096]" "$HOME/world-models/start_training.sh --modeldir trainlstm.py --modelconfigdir $HOME/training_configs/rnn_config1.json"
```
## 6. Generate interframe differences dataset
```bash
......
......@@ -6,6 +6,7 @@ from tqdm import tqdm
import torch
import torch.utils.data
import numpy as np
import csv
class _RolloutDataset(torch.utils.data.Dataset): # pylint: disable=too-few-public-methods
def __init__(self, root, transform, buffer_size=200, train=True): # pylint: disable=too-many-arguments
......@@ -145,3 +146,47 @@ class RolloutObservationDataset(_RolloutDataset): # pylint: disable=too-few-publ
def _get_data(self, data, seq_index):
return self._transform(data['observations'][seq_index])
class LatentStateDataset(torch.utils.data.Dataset):
def __init__(self, file, seq_len, train=True):
self._file = file
self._seq_len = seq_len
# Reload latent
latents = torch.load(file)
print('Latents shape: ',latents.shape) # [video_number][frame_number][mu,logsigma][latent_dimension]
# Compute z vector.
mu = latents[:, :, 0]
logsigma = latents[:, :, 1]
sigma = logsigma.exp()
eps = torch.randn_like(sigma)
z = eps.mul(sigma).add_(mu)
no_vids = z.size()[0]
split_vids = (no_vids > 1)
if split_vids:
train_separation = int(np.ceil(no_vids*0.1))
if train:
self._data = z[:-train_separation]
else:
self._data = z[-train_separation:]
else:
no_frames = z.size()[1]
train_separation = int(np.ceil(no_frames*0.1))
if train:
self._data = z[0][:-train_separation].unsqueeze(0)
else:
self._data = z[0][-train_separation:].unsqueeze(0)
self._no_seq_tuples_per_vid = self._data.shape[1] - self._seq_len
def __len__(self):
return self._data.shape[0] * self._no_seq_tuples_per_vid
def __getitem__(self, i):
video_no, seq_start = divmod(i, self._no_seq_tuples_per_vid)
return self._get_seqs(video_no, seq_start)
def _get_seqs(self, video_no, i):
return self._data[video_no][i:i+self._seq_len], self._data[video_no][i+1:i+self._seq_len+1]
import torch
from torch.utils.data import DataLoader
import numpy as np
from data.loaders import LatentStateDataset
from models.mdrnn import MDRNN
from torch.distributions.normal import Normal
dataset = '/cluster/scratch/darafael/data/latent_variables/sin.pt'
reload_file = '/cluster/scratch/darafael/experiments/lstm/models/sin1_D20200116T155627/best.tar'
seq_len = 32
batch_size = 32
no_gaussians = 5
model = MDRNN(64, 128, no_gaussians)
rnn_state = torch.load(reload_file, map_location=torch.device('cpu'))
model.load_state_dict(rnn_state["state_dict"])
model.eval()
loader = DataLoader(
LatentStateDataset(dataset, seq_len, train=False),
batch_size=batch_size, num_workers=8)
loader_iterable = iter(loader)
with torch.no_grad():
# latent: batch_size, seq_len, LSIZE
latent, next_latent = next(loader_iterable)
mus, sigmas, logpi = model(latent)
normal_dist = Normal(mus, sigmas)
# sample: batch_size, seq_len, no_gaussians, LSIZE
sample = normal_dist.sample()
print('Latent shape')
print(latent)
print('Next prediction')
print(sample)
......@@ -53,7 +53,7 @@ class _MDRNNBase(nn.Module):
self.gaussians = gaussians
self.gmm_linear = nn.Linear(
hiddens, (2 * latents + 1) * gaussians + 2)
hiddens, (2 * latents + 1) * gaussians)
def forward(self, *inputs):
pass
......@@ -69,14 +69,11 @@ class MDRNN(_MDRNNBase):
:args latents: (SEQ_LEN, BSIZE, LSIZE) torch tensor
:returns: mu_nlat, sig_nlat, pi_nlat, rs, ds, parameters of the GMM
prediction for the next latent, gaussian prediction of the reward and
logit prediction of terminality.
:returns: mu_nlat, sig_nlat, pi_nlat, parameters of the GMM
prediction for the next latent.
- mu_nlat: (SEQ_LEN, BSIZE, N_GAUSS, LSIZE) torch tensor
- sigma_nlat: (SEQ_LEN, BSIZE, N_GAUSS, LSIZE) torch tensor
- logpi_nlat: (SEQ_LEN, BSIZE, N_GAUSS) torch tensor
- rs: (SEQ_LEN, BSIZE) torch tensor
- ds: (SEQ_LEN, BSIZE) torch tensor
"""
seq_len, bs = latents.size(0), latents.size(1)
......@@ -96,11 +93,7 @@ class MDRNN(_MDRNNBase):
pi = pi.view(seq_len, bs, self.gaussians)
logpi = f.log_softmax(pi, dim=-1)
rs = gmm_outs[:, :, -2]
ds = gmm_outs[:, :, -1]
return mus, sigmas, logpi, rs, ds
return mus, sigmas, logpi
class MDRNNCell(_MDRNNBase):
""" MDRNN model for one step forward """
......@@ -114,14 +107,11 @@ class MDRNNCell(_MDRNNBase):
:args latents: (BSIZE, LSIZE) torch tensor
:args hidden: (BSIZE, RSIZE) torch tensor
:returns: mu_nlat, sig_nlat, pi_nlat, r, d, next_hidden, parameters of
the GMM prediction for the next latent, gaussian prediction of the
reward, logit prediction of terminality and next hidden state.
:returns: mu_nlat, sig_nlat, pi_nlat, next_hidden, parameters of
the GMM prediction for the next latent and next hidden state.
- mu_nlat: (BSIZE, N_GAUSS, LSIZE) torch tensor
- sigma_nlat: (BSIZE, N_GAUSS, LSIZE) torch tensor
- logpi_nlat: (BSIZE, N_GAUSS) torch tensor
- rs: (BSIZE) torch tensor
- ds: (BSIZE) torch tensor
"""
next_hidden = self.rnn(latent, hidden)
......@@ -142,8 +132,4 @@ class MDRNNCell(_MDRNNBase):
pi = pi.view(-1, self.gaussians)
logpi = f.log_softmax(pi, dim=-1)
r = out_full[:, -2]
d = out_full[:, -1]
return mus, sigmas, logpi, r, d, next_hidden
return mus, sigmas, logpi, next_hidden
{
"model" : "rnn",
"exp_name" : "test",
"learning_rate" : 0.001,
"batch_size" : 64,
"reload" : "False",
"reload_dir": "$SCRATCH/experiments/lstm/models/",
"epochs" : 30,
"input_dim": 64,
"output_dim": 128,
"seq_len": 32,
"early_stopping": "True",
"weight_decay": "True",
"logdir" : "$SCRATCH/experiments/lstm",
"dataset":"$SCRATCH/data/latent_variables/data.pt"
}
\ No newline at end of file
""" Recurrent model training """
import argparse
from functools import partial
from os.path import join, exists
from os import mkdir
import os
import torch
import torch.nn.functional as f
from torch.utils.data import DataLoader
from torchvision import transforms
import numpy as np
from datetime import datetime
import json
from tqdm import tqdm
from utils.misc import save_checkpoint
from utils.misc import RED_SIZE, SIZE
from utils.learning import EarlyStopping
## WARNING : THIS SHOULD BE REPLACED WITH PYTORCH 0.5
from utils.learning import ReduceLROnPlateau
from data.loaders import RolloutSequenceDataset
from data.loaders import LatentStateDataset
from models.vae import VAE
from models.mdrnn import MDRNN, gmm_loss
from utils.config import get_config_from_json
parser = argparse.ArgumentParser(description='LSTM training')
parser.add_argument('--modelconfig', type=str, help='File path of training configuration')
args = parser.parse_args()
config, config_json = get_config_from_json(args.modelconfig)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if not exists(config.logdir):
mkdir(config.logdir)
if not exists(join(config.logdir, 'models')):
mkdir(join(config.logdir, 'models'))
# Stamp training experiment name with datetime stamp for unique naming.
datetimestamp = (datetime.now()).strftime("D%Y%m%dT%H%M%S")
lstm_dir = join(config.logdir, 'models/' + config.exp_name + "_" + datetimestamp)
if not exists(lstm_dir):
mkdir(lstm_dir)
mkdir(join(lstm_dir, 'samples'))
# Save configuration file
with open(join(lstm_dir,'config.json'), 'w') as outfile:
json.dump(config_json, outfile)
mdrnn = MDRNN(config.input_dim, config.output_dim, 5)
mdrnn.to(device)
optimizer = torch.optim.RMSprop(mdrnn.parameters(), lr=config.learning_rate, alpha=.9)
scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5)
earlystopping = EarlyStopping('min', patience=30)
reload_file = join(os.path.expandvars(config.reload_dir), 'best.tar')
if config.reload == "True" and exists(reload_file):
rnn_state = torch.load(reload_file)
print("Loading MDRNN at epoch {} "
"with test error {}".format(
rnn_state["epoch"], rnn_state["precision"]))
mdrnn.load_state_dict(rnn_state["state_dict"])
# optimizer.load_state_dict(rnn_state["optimizer"])
# scheduler.load_state_dict(state['scheduler'])
# earlystopping.load_state_dict(state['earlystopping'])
elif config.reload == "True" and not exists(reload_file):
raise Exception('Reload file not found: {}'.format(reload_file))
# Data Loading
train_loader = DataLoader(
LatentStateDataset(config.dataset, config.seq_len),
batch_size=config.batch_size, num_workers=8)
test_loader = DataLoader(
LatentStateDataset(config.dataset, config.seq_len, train=False),
batch_size=config.batch_size, num_workers=8)
def get_loss(latent_obs, latent_next_obs):
""" Compute losses.
The loss that is computed is:
GMMLoss(latent_next_obs, GMMPredicted) / LSIZE
The LSIZE factor is here to counteract the fact that the GMMLoss scales
approximately linearily with LSIZE. The loss is averaged both on the
batch and the sequence dimensions (the two first dimensions).
:args latent_obs: (BSIZE, SEQ_LEN, LSIZE) torch tensor
:args latent_next_obs: (BSIZE, SEQ_LEN, LSIZE) torch tensor
:returns: dictionary of losses, containing the gmm and the averaged loss.
"""
latent_obs, latent_next_obs = [arr.transpose(1, 0)
for arr in [latent_obs, latent_next_obs]]
mus, sigmas, logpi = mdrnn(latent_obs)
gmm = gmm_loss(latent_next_obs, mus, sigmas, logpi)
loss = gmm / config.input_dim
return dict(gmm=gmm, loss=loss)
def data_pass(epoch, train): # pylint: disable=too-many-locals
""" One pass through the data """
if train:
mdrnn.train()
loader = train_loader
else:
mdrnn.eval()
loader = test_loader
cum_loss = 0
cum_gmm = 0
pbar = tqdm(total=len(loader.dataset), desc="Epoch {}".format(epoch))
for i, data in enumerate(loader):
latent_obs, latent_next_obs = [arr.to(device) for arr in data]
if train:
losses = get_loss(latent_obs, latent_next_obs)
optimizer.zero_grad()
losses['loss'].backward()
optimizer.step()
else:
with torch.no_grad():
losses = get_loss(latent_obs, latent_next_obs)
cum_loss += losses['loss'].item()
cum_gmm += losses['gmm'].item()
pbar.set_postfix_str("loss={loss:10.6f} gmm={gmm:10.6f}".format(
loss=cum_loss / (i + 1),
gmm=cum_gmm / config.input_dim / (i + 1)))
pbar.update(config.batch_size)
pbar.close()
return cum_loss * config.batch_size / len(loader.dataset)
train = partial(data_pass, train=True)
test = partial(data_pass, train=False)
cur_best = None
for e in range(config.epochs):
train(e)
test_loss = test(e)
# scheduler.step(test_loss)
# earlystopping.step(test_loss)
is_best = not cur_best or test_loss < cur_best
if is_best:
cur_best = test_loss
best_filename = join(lstm_dir, 'best.tar')
checkpoint_fname = join(lstm_dir, 'checkpoint.tar')
save_checkpoint({
"state_dict": mdrnn.state_dict(),
"optimizer": optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'earlystopping': earlystopping.state_dict(),
"precision": test_loss,
"epoch": e}, is_best, checkpoint_fname,
best_filename)
if earlystopping.stop:
print("End of Training because of early stopping at epoch {}".format(e))
break
......@@ -31,9 +31,12 @@ def get_config_from_json(json_file):
config_dict = json.load(config_file)
# EasyDict allows to access dict values as attributes (works recursively).
config = EasyDict(config_dict)
config.train_dataset_dir = os.path.expandvars(config.train_dataset_dir)
config.val_dataset_dir = os.path.expandvars(config.val_dataset_dir)
config.dataset_dir = os.path.expandvars(config.dataset_dir)
if hasattr(config, 'model') and config.model == "rnn":
config.dataset = os.path.expandvars(config.dataset)
else:
config.train_dataset_dir = os.path.expandvars(config.train_dataset_dir)
config.val_dataset_dir = os.path.expandvars(config.val_dataset_dir)
config.dataset_dir = os.path.expandvars(config.dataset_dir)
config.logdir = os.path.expandvars(config.logdir)
# Make experiment name legal filename:
valid_chars = "-_%s%s" % (string.ascii_letters, string.digits)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment