Commit 1d029b9c authored by nstorni's avatar nstorni
Browse files

Dream generation script.

parent 859ea82a
""" Recurrent model training """
import argparse
from functools import partial
from os.path import join, exists
from os import mkdir
import os
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
import numpy as np
from datetime import datetime
import json
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter
from models.lwtvae import VAE
from models.mdrnn import MDRNN, MDRNNCell
import cv2
from utils.config import get_config_from_json
parser = argparse.ArgumentParser(description='LSTM training')
parser.add_argument('--rnn_model_dir', type=str, default="/cluster/scratch/nstorni/data/rnn_D20200116T151028")
parser.add_argument('--vae_model_dir', type=str, default='/cluster/scratch/nstorni/experiments/shvae_temporalmask/models/vae_D20200111T092802')
parser.add_argument('--vae_latents', type=str, default='/cluster/scratch/nstorni/data/mice_tempdiff_medium/vae_D20200111T092802.pt')
parser.add_argument('--video_name', type=str, default='test')
parser.add_argument('--teaching_duration', type=int, default = 100)
parser.add_argument('--dream_duration', type=int, default = 2500)
parser.add_argument('--vae_latent_dim', type=int, default = 64)
parser.add_argument('--rnn_hidden_dim', type=int, default = 128)
parser.add_argument('--rnn_gaussians_n', type=int, default = 5)
parser.add_argument('--video_n', type=int, default = 1)
args = parser.parse_args()
out = cv2.VideoWriter(args.video_name+'.mp4',cv2.VideoWriter_fourcc(*'MP4V'), 25, (256,256))
# Reload RNN model
mdrnn = MDRNN(args.vae_latent_dim, args.rnn_hidden_dim, args.rnn_gaussians_n)
mdrnncell = MDRNNCell(args.vae_latent_dim, args.rnn_hidden_dim, args.rnn_gaussians_n)
reload_file = join(os.path.expandvars(args.rnn_model_dir), 'checkpoint.tar')
if exists(reload_file):
rnn_state = torch.load(reload_file, map_location=torch.device('cpu'))
print("Loading MDRNN at epoch {} "
"with test error {}".format(
rnn_state["epoch"], rnn_state["precision"]))
mdrnn.load_state_dict(rnn_state["state_dict"])
elif config.reload == "True" and not exists(reload_file):
raise Exception('Reload file not found: {}'.format(reload_file))
rnn_state_dict = {k.strip('_l0'): v for k, v in mdrnn.state_dict().items()}
mdrnncell.load_state_dict(rnn_state_dict)
# Load VAE latents
latents = torch.load(args.vae_latents)
# Sample latents
mu = latents[:,:,0]
logsigma = latents[:,:,1]
sigma = logsigma.exp()
eps = torch.randn_like(sigma)
z = eps.mul(sigma).add_(mu)
# Load VAE model
model = VAE(4, args.vae_latent_dim)
reload_file = join(os.path.expandvars(args.vae_model_dir), 'best.tar')
state = torch.load(reload_file, map_location=torch.device('cpu'))
model.load_state_dict(state['state_dict'])
# Initialize hidden state
hstate = 2 * [torch.zeros(1, args.rnn_hidden_dim)]
with torch.no_grad():
for i in range(args.dream_duration):
# While teaching se latent.
if i < args.teaching_duration:
lstate = z[args.video_n][i].view(1,args.vae_latent_dim)
# Predict future distribution with MDRNN
mu, sigma, pi, n_h = mdrnncell( lstate,hstate)
# Sample Z
pi = pi.squeeze()
mixt = Categorical(torch.exp(pi)).sample().item()
recon_z = mu[:, mixt, :] + sigma[:, mixt, :] * torch.randn_like(mu[:, mixt, :])
# Set RNN state
hstate = n_h
lstate = recon_z
# Reconstruct frame from predicted latent with VAE.
recon_img = model.decoder(recon_z)
# Output to video.
out.write(np.uint8((250*recon_img[0,0:3]).permute(1, 2, 0).numpy()))
if i % 100: print("Frame ", i)
out.release()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment