Commit 1b643a73 authored by Rafael Dätwyler's avatar Rafael Dätwyler
Browse files

copy trainmdrnn.py into new file and adapt to our purposes

parent e5821656
......@@ -6,6 +6,7 @@ from tqdm import tqdm
import torch
import torch.utils.data
import numpy as np
import csv
class _RolloutDataset(torch.utils.data.Dataset): # pylint: disable=too-few-public-methods
def __init__(self, root, transform, buffer_size=200, train=True): # pylint: disable=too-many-arguments
......@@ -145,3 +146,26 @@ class RolloutObservationDataset(_RolloutDataset): # pylint: disable=too-few-publ
def _get_data(self, data, seq_index):
return self._transform(data['observations'][seq_index])
class LatentStateDataset(torch.utils.data.Dataset):
def __init__(self, file, train=True):
self._file = file
data = np.empty(0)
with open(file, 'rb') as csvfile:
reader = csv.reader(csvfile)
i = 0
for row in reader:
data[i] = row
i += 1
# np.append(self._data, row, axis=0)
train_separation = int(data.size*0.1)
if train:
self._data = data[:-train_separation]
else:
self._data = data[-train_separation:]
def __len__(self):
return self._data.size
def __getitem__(self, i):
return self._data[i]
""" Recurrent model training """
import argparse
from functools import partial
from os.path import join, exists
from os import mkdir
import torch
import torch.nn.functional as f
from torch.utils.data import DataLoader
from torchvision import transforms
import numpy as np
from tqdm import tqdm
from utils.misc import save_checkpoint
from utils.misc import ASIZE, LSIZE, RSIZE, RED_SIZE, SIZE
from utils.learning import EarlyStopping
## WARNING : THIS SHOULD BE REPLACED WITH PYTORCH 0.5
from utils.learning import ReduceLROnPlateau
from data.loaders import RolloutSequenceDataset
from data.loaders import LatentStateDataset
from models.vae import VAE
from models.mdrnn import MDRNN, gmm_loss
parser = argparse.ArgumentParser("MDRNN training")
parser.add_argument('--logdir', type=str,
help="Where things are logged and models are loaded from.")
parser.add_argument('--noreload', action='store_true',
help="Do not reload if specified.")
parser.add_argument('--include_reward', action='store_true',
help="Add a reward modelisation term to the loss.")
parser.add_argument('--latent_file', type=str,
help="Specify the file where the latent representation from the VAE is stored.")
args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# constants
BSIZE = 16
SEQ_LEN = 32
epochs = 30
# Loading model
rnn_dir = join(args.logdir, 'mdrnn')
rnn_file = join(rnn_dir, 'best.tar')
if not exists(rnn_dir):
mkdir(rnn_dir)
mdrnn = MDRNN(LSIZE, RSIZE, 5)
mdrnn.to(device)
optimizer = torch.optim.RMSprop(mdrnn.parameters(), lr=1e-3, alpha=.9)
scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5)
earlystopping = EarlyStopping('min', patience=30)
if exists(rnn_file) and not args.noreload:
rnn_state = torch.load(rnn_file)
print("Loading MDRNN at epoch {} "
"with test error {}".format(
rnn_state["epoch"], rnn_state["precision"]))
mdrnn.load_state_dict(rnn_state["state_dict"])
# optimizer.load_state_dict(rnn_state["optimizer"])
# scheduler.load_state_dict(state['scheduler'])
# earlystopping.load_state_dict(state['earlystopping'])
# Data Loading
train_loader = DataLoader(
LatentStateDataset(args.latent_file),
batch_size=BSIZE, num_workers=8)
test_loader = DataLoader(
LatentStateDataset(args.latent_file, train=False),
batch_size=BSIZE, num_workers=8)
def get_loss(latent_obs, reward, terminal,
latent_next_obs, include_reward: bool):
""" Compute losses.
The loss that is computed is:
(GMMLoss(latent_next_obs, GMMPredicted) + MSE(reward, predicted_reward) +
BCE(terminal, logit_terminal)) / (LSIZE + 2)
The LSIZE + 2 factor is here to counteract the fact that the GMMLoss scales
approximately linearily with LSIZE. All losses are averaged both on the
batch and the sequence dimensions (the two first dimensions).
:args latent_obs: (BSIZE, SEQ_LEN, LSIZE) torch tensor
:args reward: (BSIZE, SEQ_LEN) torch tensor
:args latent_next_obs: (BSIZE, SEQ_LEN, LSIZE) torch tensor
:returns: dictionary of losses, containing the gmm, the mse, the bce and
the averaged loss.
"""
latent_obs,\
reward, terminal,\
latent_next_obs = [arr.transpose(1, 0)
for arr in [latent_obs,
reward, terminal,
latent_next_obs]]
mus, sigmas, logpi, rs, ds = mdrnn(latent_obs)
gmm = gmm_loss(latent_next_obs, mus, sigmas, logpi)
bce = f.binary_cross_entropy_with_logits(ds, terminal)
if include_reward:
mse = f.mse_loss(rs, reward)
scale = LSIZE + 2
else:
mse = 0
scale = LSIZE + 1
loss = (gmm + bce + mse) / scale
return dict(gmm=gmm, bce=bce, mse=mse, loss=loss)
def data_pass(epoch, train, include_reward): # pylint: disable=too-many-locals
""" One pass through the data """
if train:
mdrnn.train()
loader = train_loader
else:
mdrnn.eval()
loader = test_loader
loader.dataset.load_next_buffer()
cum_loss = 0
cum_gmm = 0
cum_bce = 0
cum_mse = 0
pbar = tqdm(total=len(loader.dataset), desc="Epoch {}".format(epoch))
for i, data in enumerate(loader):
latent_obs, action, reward, terminal, latent_next_obs = [arr.to(device) for arr in data]
if train:
losses = get_loss(latent_obs, reward,
terminal, latent_next_obs, include_reward)
optimizer.zero_grad()
losses['loss'].backward()
optimizer.step()
else:
with torch.no_grad():
losses = get_loss(latent_obs, reward,
terminal, latent_next_obs, include_reward)
cum_loss += losses['loss'].item()
cum_gmm += losses['gmm'].item()
cum_bce += losses['bce'].item()
cum_mse += losses['mse'].item() if hasattr(losses['mse'], 'item') else \
losses['mse']
pbar.set_postfix_str("loss={loss:10.6f} bce={bce:10.6f} "
"gmm={gmm:10.6f} mse={mse:10.6f}".format(
loss=cum_loss / (i + 1), bce=cum_bce / (i + 1),
gmm=cum_gmm / LSIZE / (i + 1), mse=cum_mse / (i + 1)))
pbar.update(BSIZE)
pbar.close()
return cum_loss * BSIZE / len(loader.dataset)
train = partial(data_pass, train=True, include_reward=args.include_reward)
test = partial(data_pass, train=False, include_reward=args.include_reward)
cur_best = None
for e in range(epochs):
train(e)
test_loss = test(e)
# scheduler.step(test_loss)
# earlystopping.step(test_loss)
is_best = not cur_best or test_loss < cur_best
if is_best:
cur_best = test_loss
checkpoint_fname = join(rnn_dir, 'checkpoint.tar')
save_checkpoint({
"state_dict": mdrnn.state_dict(),
"optimizer": optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'earlystopping': earlystopping.state_dict(),
"precision": test_loss,
"epoch": e}, is_best, checkpoint_fname,
rnn_file)
if earlystopping.stop:
print("End of Training because of early stopping at epoch {}".format(e))
break
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment