Commit 49d26730 authored by Rafael Dätwyler's avatar Rafael Dätwyler
Browse files

removed reward and prediction of terminality

parent 1b643a73
......@@ -53,7 +53,7 @@ class _MDRNNBase(nn.Module):
self.gaussians = gaussians
self.gmm_linear = nn.Linear(
hiddens, (2 * latents + 1) * gaussians + 2)
hiddens, (2 * latents + 1) * gaussians)
def forward(self, *inputs):
pass
......@@ -69,14 +69,11 @@ class MDRNN(_MDRNNBase):
:args latents: (SEQ_LEN, BSIZE, LSIZE) torch tensor
:returns: mu_nlat, sig_nlat, pi_nlat, rs, ds, parameters of the GMM
prediction for the next latent, gaussian prediction of the reward and
logit prediction of terminality.
:returns: mu_nlat, sig_nlat, pi_nlat, parameters of the GMM
prediction for the next latent.
- mu_nlat: (SEQ_LEN, BSIZE, N_GAUSS, LSIZE) torch tensor
- sigma_nlat: (SEQ_LEN, BSIZE, N_GAUSS, LSIZE) torch tensor
- logpi_nlat: (SEQ_LEN, BSIZE, N_GAUSS) torch tensor
- rs: (SEQ_LEN, BSIZE) torch tensor
- ds: (SEQ_LEN, BSIZE) torch tensor
"""
seq_len, bs = latents.size(0), latents.size(1)
......@@ -96,11 +93,7 @@ class MDRNN(_MDRNNBase):
pi = pi.view(seq_len, bs, self.gaussians)
logpi = f.log_softmax(pi, dim=-1)
rs = gmm_outs[:, :, -2]
ds = gmm_outs[:, :, -1]
return mus, sigmas, logpi, rs, ds
return mus, sigmas, logpi
class MDRNNCell(_MDRNNBase):
""" MDRNN model for one step forward """
......@@ -114,14 +107,11 @@ class MDRNNCell(_MDRNNBase):
:args latents: (BSIZE, LSIZE) torch tensor
:args hidden: (BSIZE, RSIZE) torch tensor
:returns: mu_nlat, sig_nlat, pi_nlat, r, d, next_hidden, parameters of
the GMM prediction for the next latent, gaussian prediction of the
reward, logit prediction of terminality and next hidden state.
:returns: mu_nlat, sig_nlat, pi_nlat, next_hidden, parameters of
the GMM prediction for the next latent and next hidden state.
- mu_nlat: (BSIZE, N_GAUSS, LSIZE) torch tensor
- sigma_nlat: (BSIZE, N_GAUSS, LSIZE) torch tensor
- logpi_nlat: (BSIZE, N_GAUSS) torch tensor
- rs: (BSIZE) torch tensor
- ds: (BSIZE) torch tensor
"""
next_hidden = self.rnn(latent, hidden)
......@@ -142,8 +132,4 @@ class MDRNNCell(_MDRNNBase):
pi = pi.view(-1, self.gaussians)
logpi = f.log_softmax(pi, dim=-1)
r = out_full[:, -2]
d = out_full[:, -1]
return mus, sigmas, logpi, r, d, next_hidden
return mus, sigmas, logpi, next_hidden
......@@ -20,13 +20,11 @@ from data.loaders import LatentStateDataset
from models.vae import VAE
from models.mdrnn import MDRNN, gmm_loss
parser = argparse.ArgumentParser("MDRNN training")
parser = argparse.ArgumentParser("LSTM training")
parser.add_argument('--logdir', type=str,
help="Where things are logged and models are loaded from.")
parser.add_argument('--noreload', action='store_true',
help="Do not reload if specified.")
parser.add_argument('--include_reward', action='store_true',
help="Add a reward modelisation term to the loss.")
parser.add_argument('--latent_file', type=str,
help="Specify the file where the latent representation from the VAE is stored.")
args = parser.parse_args()
......@@ -72,44 +70,30 @@ test_loader = DataLoader(
LatentStateDataset(args.latent_file, train=False),
batch_size=BSIZE, num_workers=8)
def get_loss(latent_obs, reward, terminal,
latent_next_obs, include_reward: bool):
def get_loss(latent_obs, latent_next_obs):
""" Compute losses.
The loss that is computed is:
(GMMLoss(latent_next_obs, GMMPredicted) + MSE(reward, predicted_reward) +
BCE(terminal, logit_terminal)) / (LSIZE + 2)
The LSIZE + 2 factor is here to counteract the fact that the GMMLoss scales
approximately linearily with LSIZE. All losses are averaged both on the
GMMLoss(latent_next_obs, GMMPredicted) / LSIZE
The LSIZE factor is here to counteract the fact that the GMMLoss scales
approximately linearily with LSIZE. The loss is averaged both on the
batch and the sequence dimensions (the two first dimensions).
:args latent_obs: (BSIZE, SEQ_LEN, LSIZE) torch tensor
:args reward: (BSIZE, SEQ_LEN) torch tensor
:args latent_next_obs: (BSIZE, SEQ_LEN, LSIZE) torch tensor
:returns: dictionary of losses, containing the gmm, the mse, the bce and
the averaged loss.
"""
latent_obs,\
reward, terminal,\
latent_next_obs = [arr.transpose(1, 0)
for arr in [latent_obs,
reward, terminal,
latent_next_obs]]
mus, sigmas, logpi, rs, ds = mdrnn(latent_obs)
latent_obs, latent_next_obs = [arr.transpose(1, 0)
for arr in [latent_obs, latent_next_obs]]
mus, sigmas, logpi = mdrnn(latent_obs)
gmm = gmm_loss(latent_next_obs, mus, sigmas, logpi)
bce = f.binary_cross_entropy_with_logits(ds, terminal)
if include_reward:
mse = f.mse_loss(rs, reward)
scale = LSIZE + 2
else:
mse = 0
scale = LSIZE + 1
loss = (gmm + bce + mse) / scale
return dict(gmm=gmm, bce=bce, mse=mse, loss=loss)
loss = gmm / LSIZE
return dict(gmm=gmm, loss=loss)
def data_pass(epoch, train, include_reward): # pylint: disable=too-many-locals
def data_pass(epoch, train): # pylint: disable=too-many-locals
""" One pass through the data """
if train:
mdrnn.train()
......@@ -118,46 +102,41 @@ def data_pass(epoch, train, include_reward): # pylint: disable=too-many-locals
mdrnn.eval()
loader = test_loader
loader.dataset.load_next_buffer()
# Comment by Rafael:
# The following line was used in the original world models to load the next batch of the dataset.
# If we use multiple videos, we might want to use a similar mechanism.
# loader.dataset.load_next_buffer()
cum_loss = 0
cum_gmm = 0
cum_bce = 0
cum_mse = 0
pbar = tqdm(total=len(loader.dataset), desc="Epoch {}".format(epoch))
for i, data in enumerate(loader):
latent_obs, action, reward, terminal, latent_next_obs = [arr.to(device) for arr in data]
latent_obs, latent_next_obs = [arr.to(device) for arr in data]
if train:
losses = get_loss(latent_obs, reward,
terminal, latent_next_obs, include_reward)
losses = get_loss(latent_obs, latent_next_obs)
optimizer.zero_grad()
losses['loss'].backward()
optimizer.step()
else:
with torch.no_grad():
losses = get_loss(latent_obs, reward,
terminal, latent_next_obs, include_reward)
losses = get_loss(latent_obs, latent_next_obs)
cum_loss += losses['loss'].item()
cum_gmm += losses['gmm'].item()
cum_bce += losses['bce'].item()
cum_mse += losses['mse'].item() if hasattr(losses['mse'], 'item') else \
losses['mse']
pbar.set_postfix_str("loss={loss:10.6f} bce={bce:10.6f} "
"gmm={gmm:10.6f} mse={mse:10.6f}".format(
loss=cum_loss / (i + 1), bce=cum_bce / (i + 1),
gmm=cum_gmm / LSIZE / (i + 1), mse=cum_mse / (i + 1)))
pbar.set_postfix_str("loss={loss:10.6f} gmm={gmm:10.6f}".format(
loss=cum_loss / (i + 1),
gmm=cum_gmm / LSIZE / (i + 1)))
pbar.update(BSIZE)
pbar.close()
return cum_loss * BSIZE / len(loader.dataset)
train = partial(data_pass, train=True, include_reward=args.include_reward)
test = partial(data_pass, train=False, include_reward=args.include_reward)
train = partial(data_pass, train=True)
test = partial(data_pass, train=False)
cur_best = None
for e in range(epochs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment