Commit 913804c9 authored by Rafael Dätwyler's avatar Rafael Dätwyler
Browse files

changed loader to match VAE output

parent a7eca285
......@@ -150,22 +150,35 @@ class RolloutObservationDataset(_RolloutDataset): # pylint: disable=too-few-publ
class LatentStateDataset(torch.utils.data.Dataset):
def __init__(self, file, train=True):
self._file = file
data = np.empty(0)
with open(file, 'rb') as csvfile:
reader = csv.reader(csvfile)
i = 0
for row in reader:
data[i] = row
i += 1
# np.append(self._data, row, axis=0)
train_separation = int(data.size*0.1)
if train:
self._data = data[:-train_separation]
# Reload latent
latents = torch.load(file)
print(latents.shape) # [video_number][frame_number][mu,logsigma][latent_dimension]
# Compute z vector.
mu = latents[:, :, 0]
logsigma = latents[:, :, 1]
sigma = logsigma.exp()
eps = torch.randn_like(sigma)
z = eps.mul(sigma).add_(mu)
no_vids = z.size()[0]
split_vids = (no_vids > 1)
if split_vids:
train_separation = np.ceil(no_vids*0.1)
if train:
self._data = z[:-train_separation]
else:
self._data = z[-train_separation:]
else:
self._data = data[-train_separation:]
no_frames = z.size()[1]
train_separation = np.ceil(no_frames*0.1)
if train:
self._data = z[0][:-train_separation].unsqueeze(0)
else:
self._data = z[0][-train_separation:].unsqueeze(0)
def __len__(self):
return self._data.size
def __getitem__(self, i):
return self._data[i]
return self._data[i][2:], self._data[i+1][2:]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment