Commit 0cce655a authored by Rafael Daetwyler's avatar Rafael Daetwyler
Browse files

unfinished script to compare in- & output of RNN

parent 1f0be44f
import torch
from torch.utils.data import DataLoader
import numpy as np
from data.loaders import LatentStateDataset
from models.mdrnn import MDRNN
from torch.distributions.normal import Normal
dataset = '/cluster/scratch/darafael/data/latent_variables/sin.pt'
reload_file = '/cluster/scratch/darafael/experiments/lstm/models/sin1_D20200116T155627/best.tar'
seq_len = 32
batch_size = 32
no_gaussians = 5
model = MDRNN(64, 128, no_gaussians)
rnn_state = torch.load(reload_file, map_location=torch.device('cpu'))
model.load_state_dict(rnn_state["state_dict"])
model.eval()
loader = DataLoader(
LatentStateDataset(dataset, seq_len, train=False),
batch_size=batch_size, num_workers=8)
loader_iterable = iter(loader)
with torch.no_grad():
# latent: batch_size, seq_len, LSIZE
latent, next_latent = next(loader_iterable)
mus, sigmas, logpi = model(latent)
normal_dist = Normal(mus, sigmas)
# sample: batch_size, seq_len, no_gaussians, LSIZE
sample = normal_dist.sample()
print('Latent shape')
print(latent)
print('Next prediction')
print(sample)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment