Commit 94dec806 authored by Rafael Daetwyler's avatar Rafael Daetwyler
Browse files

added correct handling of sequences

parent 913804c9
......@@ -148,12 +148,13 @@ class RolloutObservationDataset(_RolloutDataset): # pylint: disable=too-few-publ
return self._transform(data['observations'][seq_index])
class LatentStateDataset(torch.utils.data.Dataset):
def __init__(self, file, train=True):
def __init__(self, file, seq_len, train=True):
self._file = file
self._seq_len = seq_len
# Reload latent
latents = torch.load(file)
print(latents.shape) # [video_number][frame_number][mu,logsigma][latent_dimension]
print('Latents shape: ',latents.shape) # [video_number][frame_number][mu,logsigma][latent_dimension]
# Compute z vector.
mu = latents[:, :, 0]
......@@ -164,21 +165,28 @@ class LatentStateDataset(torch.utils.data.Dataset):
no_vids = z.size()[0]
split_vids = (no_vids > 1)
if split_vids:
train_separation = np.ceil(no_vids*0.1)
train_separation = int(np.ceil(no_vids*0.1))
if train:
self._data = z[:-train_separation]
else:
self._data = z[-train_separation:]
else:
no_frames = z.size()[1]
train_separation = np.ceil(no_frames*0.1)
train_separation = int(np.ceil(no_frames*0.1))
if train:
self._data = z[0][:-train_separation].unsqueeze(0)
else:
self._data = z[0][-train_separation:].unsqueeze(0)
self._no_seq_tuples_per_vid = self._data.shape[1] - self._seq_len
def __len__(self):
return self._data.size
return self._data.shape[0] * self._no_seq_tuples_per_vid
def __getitem__(self, i):
return self._data[i][2:], self._data[i+1][2:]
video_no, seq_start = divmod(i, self._no_seq_tuples_per_vid)
return self._get_seqs(video_no, seq_start)
def _get_seqs(self, video_no, i):
return self._data[video_no][i:i+self._seq_len], self._data[video_no][i+1:i+self._seq_len+1]
......@@ -10,7 +10,7 @@ from torchvision import transforms
import numpy as np
from tqdm import tqdm
from utils.misc import save_checkpoint
from utils.misc import ASIZE, LSIZE, RSIZE, RED_SIZE, SIZE
from utils.misc import LSIZE, RSIZE, RED_SIZE, SIZE
from utils.learning import EarlyStopping
## WARNING : THIS SHOULD BE REPLACED WITH PYTORCH 0.5
from utils.learning import ReduceLROnPlateau
......@@ -43,6 +43,8 @@ rnn_file = join(rnn_dir, 'best.tar')
if not exists(rnn_dir):
mkdir(rnn_dir)
# override the default setting which is LSIZE = 32, RSIZE = 265
LSIZE, RSIZE = 64, 128
mdrnn = MDRNN(LSIZE, RSIZE, 5)
mdrnn.to(device)
optimizer = torch.optim.RMSprop(mdrnn.parameters(), lr=1e-3, alpha=.9)
......@@ -63,11 +65,11 @@ if exists(rnn_file) and not args.noreload:
# Data Loading
train_loader = DataLoader(
LatentStateDataset(args.latent_file),
LatentStateDataset(args.latent_file, SEQ_LEN),
batch_size=BSIZE, num_workers=8)
test_loader = DataLoader(
LatentStateDataset(args.latent_file, train=False),
LatentStateDataset(args.latent_file, SEQ_LEN, train=False),
batch_size=BSIZE, num_workers=8)
def get_loss(latent_obs, latent_next_obs):
......@@ -82,8 +84,7 @@ def get_loss(latent_obs, latent_next_obs):
:args latent_obs: (BSIZE, SEQ_LEN, LSIZE) torch tensor
:args latent_next_obs: (BSIZE, SEQ_LEN, LSIZE) torch tensor
:returns: dictionary of losses, containing the gmm, the mse, the bce and
the averaged loss.
:returns: dictionary of losses, containing the gmm and the averaged loss.
"""
latent_obs, latent_next_obs = [arr.transpose(1, 0)
for arr in [latent_obs, latent_next_obs]]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment