Commit 7a1d3c82 authored by nstorni's avatar nstorni
Browse files

Merge branch 'development' into 'master'

Development into master

See merge request !2
parents 6b8022df 205cc04e
......@@ -7,3 +7,6 @@ Dockerfile
frames/
dock_config.json
.mypy_cache
exp_*
lsf.*
lsd.*
# Pytorch implementation of the "WorldModels"
# Generate "Dream" of mice behaviour
The following script is the final result of the project for the Deep Learning Lecture. It uses a trained VAE and MDN-RNN to generate a "dreamed" sequence of frames.
The length of the teaching sequence can be choosen by setting the argument --teaching_duration and the length of the successive "dreamed" sequence by setting --dream_duration.
Paper: Ha and Schmidhuber, "World Models", 2018. https://doi.org/10.5281/zenodo.1207631. For a quick summary of the paper and some additional experiments, visit the [github page](https://ctallec.github.io/world-models/).
If this script is runned on the leonhard cluster the following lines will load the correct modules and install the requirements before starting the generate_dream script:
```bash
$HOME/world-models/install_requirements.sh
module purge
module load gcc/4.8.5 python_gpu/3.7.1
python generate_dream.py --rnn_model_dir=weights/rnn.tar --vae_model_dir=weights/vae.tar --vae_latents=weights/video_vae_latents.tar --teaching_duration=100 --dream_duration=500 --video_name="dreamingMices"
```
(For running on a different system install the requirements in the requirements.txt file)
The generated video will be saved in the directory where the script is launched.
## Prerequisites
First clone the project files from Gitlab (you will need to enter your credentials):
# Setup, dataset preparations and Training.
## 1. Prerequisites
First ssh in leonhard and clone the project files from Gitlab (you will need to enter your credentials) to your $HOME directory:
```bash
cd $HOME
git clone https://gitlab.ethz.ch/deep-learning-rodent/world-models.git
```
Then install the requirements (they will be installed for the python_gpu/3.7.1 module only):
```bash
$HOME/world-models/install_requirements.sh
```
(To learn more about module use on the cluster https://scicomp.ethz.ch/wiki/Getting_started_with_clusters)
Navigate to the project directory and execute the following command to build the Docker image. This might take a while, but you only need to do this once.
## 2. Mini Mice dataset preparation
Prepare folder structure in $SCRATCH directory:
```bash
DATASET_NAME="mini_mice_dataset"
mkdir $SCRATCH/data
mkdir $SCRATCH/data/$DATASET_NAME
mkdir $SCRATCH/data/$DATASET_NAME/video
mkdir $SCRATCH/data/$DATASET_NAME/train
mkdir $SCRATCH/data/$DATASET_NAME/train/nolabel
mkdir $SCRATCH/data/$DATASET_NAME/val
mkdir $SCRATCH/data/$DATASET_NAME/val/nolabel
mkdir $SCRATCH/data/$DATASET_NAME/test
mkdir $SCRATCH/data/$DATASET_NAME/test/nolabel
```
This structure is required for the dataset loader to load each picture (it expects a folder containing subfolders for each class, in our case we have only one class "nolabel").
Load one video to $SCRATCH/data/$DATASET_NAME/video using winSCP or similar.
E.g copy from local (ON YOUR LOCAL MACHINE) directory video.avi to the target directory in the cloud.:
```bash
docker build -t deep-learning:worldmodels .
scp *.mp4 nstorni@login.leonhard.ethz.ch:$SCRATCH/data/mini_mice_dataset/video
```
To run the container, run the following command in the project directory, depending on your OS:
Extract frames from videos (takes a long time, for many videos submit as a job to the cluster, command next):
```bash
cd $SCRATCH/data/$DATASET_NAME
$HOME/world-models/video_frame/video2frames.sh
cd $HOME
```
As it is quite slow at extracting the frames, if you want ot extract multiple video and don't won't to leave the ssh connection open you can submit a job to the cluster that will extract the frames:
```bash
Windows (PowerShell): docker run -it --rm -v ${pwd}:/app deep-learning:worldmodels
Linux: docker run -it --rm -v $(pwd):/app deep-learning:worldmodels
bsub -o job_logs -n 1 -R "rusage[mem=4096]" "cd $SCRATCH/data/mini_mice_dataset; pwd; $HOME/world-models/video_frame/video2frames.sh"
```
You can check the job status by tiping:
```bash
bbjobs
```
## Running the worldmodels
Compute normalisation statistics (mean and std).
```bash
cd $SCRATCH/data/$DATASET_NAME
module purge
module load gcc/4.8.5 python_cpu/3.7.1
python $HOME/world-models/utils/dataset_statistics.py --imagetype=RGB
```
To run the model, run the Docker container (see above) and execute the command inside the container.
In the cluster:
bsub -o $HOME/job_logs -n 4 -R "rusage[mem=4096 ]" "python $HOME/world-models/utils/dataset_statistics.py --imagetype=RGBA"
The model is composed of three parts:
1. A Variational Auto-Encoder (VAE), whose task is to compress the input images into a compact latent representation.
2. A Mixture-Density Recurrent Network (MDN-RNN), trained to predict the latent encoding of the next frame given past latent encodings and actions.
3. A linear Controller (C), which takes both the latent encoding of the current frame, and the hidden state of the MDN-RNN given past latents and actions as input and outputs an action. It is trained to maximize the cumulated reward using the Covariance-Matrix Adaptation Evolution-Strategy ([CMA-ES](http://www.cmap.polytechnique.fr/~nikolaus.hansen/cmaartic.pdf)) from the `cma` python package.
Move some images from train to test folder (not really random, just taking all the file ending with 1 in the filename) should be about 10% of the train set.
In the given code, all three sections are trained separately, using the scripts `trainvae.py`, `trainmdrnn.py` and `traincontroller.py`.
```bash
cd $SCRATCH/data/$DATASET_NAME
mv train/nolabel/*1.png val/nolabel
```
Training scripts take as argument:
* **--logdir** : The directory in which the models will be stored. If the logdir specified already exists, it loads the old model and continues the training.
* **--noreload** : If you want to override a model in *logdir* instead of reloading it, add this option.
### 1. Data generation
Before launching the VAE and MDN-RNN training scripts, you need to generate a dataset of random rollouts and place it in the `datasets/carracing` folder.
## 3. VAE training setup
The configuration for training the VAE are done through a json file where all hyperparameters and dataset directory paths are specified. A template for a training job is available in the config directory (template_config.json), copy it to your local directory and modify it for you training specifications. This config file will then be copied automatically to the output directory with the trained models.
Data generation is handled through the `data/generation_script.py` script, e.g.
```bash
python data/generation_script.py --rollouts 1000 --rootdir datasets/carracing --threads 8
cd $HOME
mkdir training_configs
cp world-models/template_config.json training_configs/train_config1.json
```
Modify train_config1.json for your training run, then you can test the configuration by starting the training on the local leonhard instance:
Rollouts are generated using a *brownian* random policy, instead of the *white noise* random `action_space.sample()` policy from gym, providing more consistent rollouts.
```bash
$HOME/world-models/start_training.sh --modeldir train_lwtvae.py --modelconfigdir $HOME/training_configs/train_config1.json
```
If there are no errors interrupt the training with CTRL+C, you can now submit it to the cluster.
### 2. Training the VAE
The VAE is trained using the `trainvae.py` file, e.g.
```bash
python trainvae.py --logdir exp_dir
bsub -o job_logs -n 4 -R "rusage[ngpus_excl_p=1,mem=4096 ]" "$HOME/world-models/start_training.sh --modeldir trainvae.py --modelconfigdir $HOME/training_configs/train_config1.json"
```
### 3. Training the MDN-RNN
The MDN-RNN is trained using the `trainmdrnn.py` file, e.g.
## 4. Monitoring the job with tensorboard
You can monitor the training process with tensorboard, launch tensorboard:
```bash
python trainmdrnn.py --logdir exp_dir
$HOME/world-models/utils/start_tensorboard.sh mini_mice_dataset
```
A VAE must have been trained in the same `exp_dir` for this script to work.
### 4. Training and testing the Controller
Finally, the controller is trained using CMA-ES, e.g.
Forward the port used by tensorboard to view the training process in the browser, do this in a separate bash shell (you will have to change to the port that tensorboard outputs).
```bash
python traincontroller.py --logdir exp_dir --n-samples 4 --pop-size 4 --target-return 950 --display
ssh -L 6993:localhost:6993 nstorni@login.leonhard.ethz.ch
```
You can test the obtained policy with `test_controller.py` e.g.
You can access tensorboard in your browser at localhost:6993
The training job will save the models, samples and logs in the $SCRATCH/experiments folder.
You can delete failed runs from the experiments directory with the following script, the first argument is the experiments subdirectory name and the second argument is the run name (read it from the tensorboard) :
```bash
python test_controller.py --logdir exp_dir
$HOME/world-models/utils/delete_failed_runs.sh shvae_temporalmask lwtvae_l3_D20200112T183921
```
### Notes
When running on a headless server, you will need to use `xvfb-run` to launch the controller training script. For instance,
## 5. Generate Moving MNIST dataset
You can generate a custom toy dataset with a MNIST digits moving on a black frame that bounce on the borders with the following script:
```bash
xvfb-run -s "-screen 0 1400x900x24" python traincontroller.py --logdir exp_dir --n-samples 4 --pop-size 4 --target-return 950 --display
$HOME/world-models/utils/generate_moving_mnist.sh --num_videos 10 --num_videos_val 300 --num_frames 10 --digits_dim 40 --frame_dim 256 --num_digits 1
```
If you do not have a display available and you launch `traincontroller` without
`xvfb-run`, the script will fail silently (but logs are available in
`logdir/tmp`).
--num_videos: specifies how many video sequences
--num_frames: how many frames per sequence
--digits_dim: how large are the digits
--frame_dim: how large is the image
--num_digits: how many digits are on each frame
Be aware that `traincontroller` requires heavy gpu memory usage when launched
on gpus. To reduce the memory load, you can directly modify the maximum number
of workers by specifying the `--max-workers` argument.
The script will create a folder named: DATASET_NAME=$CUSTOM_NAME"movingMNIST_""digits"$NUM_DIGITS"ddim"$DIGITS_DIM"fdim"$FRAME_DIM"v"$NUM_VIDEOS"f"$NUM_FRAMES
in $SCRATCH/data.
The script already computes the normalisation statistics and puts them in the dataset folder.
If you have several GPUs available, `traincontroller` will take advantage of
all gpus specified by `CUDA_VISIBLE_DEVICES`.
For large datasets run the script in the cluster:
## Authors
```bash
bsub -o $HOME/job_logs -n 4 -R "rusage[mem=4096]" "$HOME/world-models/utils/generate_moving_mnist.sh --num_videos 10000 --num_videos_val 300 --num_frames 10 --digits_dim 40 --frame_dim 256"
```
* **Corentin Tallec** - [ctallec](https://github.com/ctallec)
* **Léonard Blier** - [leonardblier](https://github.com/leonardblier)
* **Diviyan Kalainathan** - [diviyan-kalainathan](https://github.com/diviyan-kalainathan)
## 6. Running the RNN
Once you have the latent variables from the VAE, you can use them to train the RNN.
First, create a folder for the RNN:
```bash
mkdir $SCRATCH/data/lstm
```
Afterwards, copy the dataset (containing the latent variables) into this folder. If you need to upload them from your local machine, use the following command:
```bash
USERNAME="your-username"
scp data.pt $USERNAME@login.leonhard.ethz.ch:/cluster/scratch/$USERNAME/data/latent_variables
```
Then, you can copy the config template for the RNN into the configs folder:
```bash
cd $HOME
cp world-models/template_config_rnn.json training_configs/rnn_config1.json
```
And edit the config file to fit your parameters. Then you can do a test run to see if the script works:
```bash
$HOME/world-models/start_training.sh --modeldir trainlstm.py --modelconfigdir $HOME/training_configs/rnn_config1.json
```
If you don't get any errors, you can abort the execution and send the job to be executed on the cluster:
```bash
bsub -o job_logs -n 4 -R "rusage[ngpus_excl_p=1,mem=4096]" "$HOME/world-models/start_training.sh --modeldir trainlstm.py --modelconfigdir $HOME/training_configs/rnn_config1.json"
```
## License
## 6. Generate interframe differences dataset
```bash
DATASET_NAME="interdiff_dataset"
mkdir $SCRATCH/data
mkdir $SCRATCH/data/$DATASET_NAME
mkdir $SCRATCH/data/$DATASET_NAME/video
mkdir $SCRATCH/data/$DATASET_NAME/train
mkdir $SCRATCH/data/$DATASET_NAME/train/nolabel
mkdir $SCRATCH/data/$DATASET_NAME/val
mkdir $SCRATCH/data/$DATASET_NAME/val/nolabel
mkdir $SCRATCH/data/$DATASET_NAME/test
mkdir $SCRATCH/data/$DATASET_NAME/test/nolabel
```
This structure is required for the dataset loader to load each picture (it expects a folder containing subfolders for each class, in our case we have only one class "nolabel").
Load one video to $SCRATCH/data/$DATASET_NAME/video using winSCP or similar.
E.g copy from local (ON YOUR LOCAL MACHINE) directory video.avi to the target directory in the cloud.:
```bash
USERNAME="your-username"
scp *.mp4 $USERNAME@login.leonhard.ethz.ch:$SCRATCH/data/$DATASET_NAME/video
```
Generate interframe differences:
```bash
cd $SCRATCH/data/$DATASET_NAME
module purge
module load gcc/4.8.5 python_gpu/3.7.1
python $HOME/world-models/video_frame/gen_tempdiff_data.py
```
This project is licensed under the MIT License - see the [LICENSE.md](LICENSE.md) file for details
import cv2
import numpy as np
import glob
img_array = []
for filename in glob.glob('movingmnistdata/*.jpg'):
img = cv2.imread(filename)
height, width, layers = img.shape
size = (width,height)
img_array.append(img)
out = cv2.VideoWriter('test2.avi',cv2.VideoWriter_fourcc(*'DIVX'), 15, size)
for i in range(len(img_array)):
out.write(img_array[i])
out.release()
\ No newline at end of file
......@@ -6,6 +6,7 @@ from tqdm import tqdm
import torch
import torch.utils.data
import numpy as np
import csv
class _RolloutDataset(torch.utils.data.Dataset): # pylint: disable=too-few-public-methods
def __init__(self, root, transform, buffer_size=200, train=True): # pylint: disable=too-many-arguments
......@@ -145,3 +146,47 @@ class RolloutObservationDataset(_RolloutDataset): # pylint: disable=too-few-publ
def _get_data(self, data, seq_index):
return self._transform(data['observations'][seq_index])
class LatentStateDataset(torch.utils.data.Dataset):
def __init__(self, file, seq_len, train=True):
self._file = file
self._seq_len = seq_len
# Reload latent
latents = torch.load(file)
print('Latents shape: ',latents.shape) # [video_number][frame_number][mu,logsigma][latent_dimension]
# Compute z vector.
mu = latents[:, :, 0]
logsigma = latents[:, :, 1]
sigma = logsigma.exp()
eps = torch.randn_like(sigma)
z = eps.mul(sigma).add_(mu)
no_vids = z.size()[0]
split_vids = (no_vids > 1)
if split_vids:
train_separation = int(np.ceil(no_vids*0.1))
if train:
self._data = z[:-train_separation]
else:
self._data = z[-train_separation:]
else:
no_frames = z.size()[1]
train_separation = int(np.ceil(no_frames*0.1))
if train:
self._data = z[0][:-train_separation].unsqueeze(0)
else:
self._data = z[0][-train_separation:].unsqueeze(0)
self._no_seq_tuples_per_vid = self._data.shape[1] - self._seq_len
def __len__(self):
return self._data.shape[0] * self._no_seq_tuples_per_vid
def __getitem__(self, i):
video_no, seq_start = divmod(i, self._no_seq_tuples_per_vid)
return self._get_seqs(video_no, seq_start)
def _get_seqs(self, video_no, i):
return self._data[video_no][i:i+self._seq_len], self._data[video_no][i+1:i+self._seq_len+1]
import math
import os
import sys
import numpy as np
from PIL import Image
###########################################################################################
# script to generate moving mnist video dataset (frame by frame) as described in
# [1] arXiv:1502.04681 - Unsupervised Learning of Video Representations Using LSTMs
# Srivastava et al
# by Tencia Lee
# saves in hdf5, npz, or jpg (individual frames) format
###########################################################################################
# helper functions
def arr_from_img(im, mean=0, std=1):
'''
Args:
im: Image
shift: Mean to subtract
std: Standard Deviation to subtract
Returns:
Image in np.float32 format, in width height channel format. With values in range 0,1
Shift means subtract by certain value. Could be used for mean subtraction.
'''
width, height = im.size
arr = im.getdata()
c = int(np.product(arr.size) / (width * height))
return (np.asarray(arr, dtype=np.float32).reshape((height, width, c)).transpose(2, 1, 0) / 255. - mean) / std
def get_image_from_array(X, index, mean=0, std=1):
'''
Args:
X: Dataset of shape N x C x W x H
index: Index of image we want to fetch
mean: Mean to add
std: Standard Deviation to add
Returns:
Image with dimensions H x W x C or H x W if it's a single channel image
'''
ch, w, h = X.shape[1], X.shape[2], X.shape[3]
ret = (((X[index] + mean) * 255.) * std).reshape(ch, w, h).transpose(2, 1, 0).clip(0, 255).astype(np.uint8)
if ch == 1:
ret = ret.reshape(h, w)
return ret
# loads mnist from web on demand
def load_dataset(training=True):
if sys.version_info[0] == 2:
from urllib import urlretrieve
else:
from urllib.request import urlretrieve
def download(filename, source='http://yann.lecun.com/exdb/mnist/'):
print("Downloading %s" % filename)
urlretrieve(source + filename, filename)
import gzip
def load_mnist_images(filename):
if not os.path.exists(filename):
download(filename)
with gzip.open(filename, 'rb') as f:
data = np.frombuffer(f.read(), np.uint8, offset=16)
data = data.reshape(-1, 1, 28, 28).transpose(0, 1, 3, 2)
return data / np.float32(255)
if training:
return load_mnist_images('train-images-idx3-ubyte.gz')
return load_mnist_images('t10k-images-idx3-ubyte.gz')
def generate_moving_mnist(training, shape=(64, 64), num_frames=30, num_images=100, original_size=28, nums_per_image=2):
'''
Args:
training: Boolean, used to decide if downloading/generating train set or test set
shape: Shape we want for our moving images (new_width and new_height)
num_frames: Number of frames in a particular movement/animation/gif
num_images: Number of movement/animations/gif to generate
original_size: Real size of the images (eg: MNIST is 28x28)
nums_per_image: Digits per movement/animation/gif.
Returns:
Dataset of np.uint8 type with dimensions num_frames * num_images x 1 x new_width x new_height
'''
mnist = load_dataset(training)
width, height = shape
# Get how many pixels can we move around a single image
lims = (x_lim, y_lim) = width - original_size, height - original_size
# Create a dataset of shape of num_frames * num_images x 1 x new_width x new_height
# Eg : 3000000 x 1 x 64 x 64
dataset = np.empty((num_images, num_frames, width, height), dtype=np.uint8)
for img_idx in range(num_images):
# Randomly generate direction, speed and velocity for both images
direcs = np.pi * (np.random.rand(nums_per_image) * 2 - 1)
speeds = np.random.randint(5, size=nums_per_image) + 2
veloc = np.asarray([(speed * math.cos(direc), speed * math.sin(direc)) for direc, speed in zip(direcs, speeds)])
# Get a list containing two PIL images randomly sampled from the database
mnist_images = [Image.fromarray(get_image_from_array(mnist, r, mean=0)).resize((original_size, original_size),
Image.ANTIALIAS) \
for r in np.random.randint(0, mnist.shape[0], nums_per_image)]
# Generate tuples of (x,y) i.e initial positions for nums_per_image (default : 2)
positions = np.asarray([(np.random.rand() * x_lim, np.random.rand() * y_lim) for _ in range(nums_per_image)])
# Generate new frames for the entire num_framesgth
for frame_idx in range(num_frames):
canvases = [Image.new('L', (width, height)) for _ in range(nums_per_image)]
canvas = np.zeros((1, width, height), dtype=np.float32)
# In canv (i.e Image object) place the image at the respective positions
# Super impose both images on the canvas (i.e empty np array)
for i, canv in enumerate(canvases):
canv.paste(mnist_images[i], tuple(positions[i].astype(int)))
canvas += arr_from_img(canv, mean=0)
# Get the next position by adding velocity
next_pos = positions + veloc
# Iterate over velocity and see if we hit the wall
# If we do then change the (change direction)
for i, pos in enumerate(next_pos):
for j, coord in enumerate(pos):
if coord < -2 or coord > lims[j] + 2:
veloc[i] = list(list(veloc[i][:j]) + [-1 * veloc[i][j]] + list(veloc[i][j + 1:]))
# Make the permanent change to position by adding updated velocity
positions = positions + veloc
# Add the canvas to the dataset array
dataset[img_idx, frame_idx] = (canvas * 255).clip(0, 255).astype(np.uint8)
return dataset
def main(training, dest, filetype='npz', frame_size=64, num_frames=30, num_images=100, original_size=28,
nums_per_image=2):
dat = generate_moving_mnist(training, shape=(frame_size, frame_size), num_frames=num_frames, num_images=num_images, \
original_size=original_size, nums_per_image=nums_per_image)
n = num_images * num_frames
if filetype == 'npz':
np.savez(dest, dat)
elif filetype == 'jpg':
print(dat.shape[0])
print(dat.shape[1])
for v in range(dat.shape[0]):
for f in range(dat.shape[1]):
# Image.fromarray(get_image_from_array(dat[v], f, mean=0)).save(os.path.join(dest, '{}_{}.jpg'.format(v,f)))
Image.fromarray(dat[v][f]).save(os.path.join(dest, '{}_{}.jpg'.format(v,f)))
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Command line options')
parser.add_argument('--dest', type=str, dest='dest', default='movingmnistdata')
parser.add_argument('--filetype', type=str, dest='filetype', default="npz")
parser.add_argument('--training', type=bool, dest='training', default=True)
parser.add_argument('--frame_size', type=int, dest='frame_size', default=64)
parser.add_argument('--num_frames', type=int, dest='num_frames', default=30) # length of each sequence
parser.add_argument('--num_images', type=int, dest='num_images', default=20000) # number of sequences to generate
parser.add_argument('--original_size', type=int, dest='original_size',
default=28) # size of mnist digit within frame
parser.add_argument('--nums_per_image', type=int, dest='nums_per_image',
default=2) # number of digits in each frame
args = parser.parse_args(sys.argv[1:])
main(**{k: v for (k, v) in vars(args).items() if v is not None})
\ No newline at end of file
import torch
from torch.utils.data import DataLoader
import numpy as np
from data.loaders import LatentStateDataset
from models.mdrnn import MDRNN
from torch.distributions.normal import Normal
dataset = '/cluster/scratch/darafael/data/latent_variables/sin.pt'
reload_file = '/cluster/scratch/darafael/experiments/lstm/models/sin1_D20200116T155627/best.tar'
seq_len = 32
batch_size = 32
no_gaussians = 5
model = MDRNN(64, 128, no_gaussians)
rnn_state = torch.load(reload_file, map_location=torch.device('cpu'))
model.load_state_dict(rnn_state["state_dict"])
model.eval()
loader = DataLoader(
LatentStateDataset(dataset, seq_len, train=False),
batch_size=batch_size, num_workers=8)
loader_iterable = iter(loader)
with torch.no_grad():
# latent: batch_size, seq_len, LSIZE
latent, next_latent = next(loader_iterable)
mus, sigmas, logpi = model(latent)
normal_dist = Normal(mus, sigmas)
# sample: batch_size, seq_len, no_gaussians, LSIZE
sample = normal_dist.sample()
print('Latent shape')
print(latent)
print('Next prediction')
print(sample)
import argparse
from functools import partial
from os.path import join, exists
from os import mkdir
import os
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
import numpy as np
from datetime import datetime
import json
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter
from models.lwtvae import VAE
from models.mdrnn import MDRNN, MDRNNCell
import cv2
from utils.config import get_config_from_json
parser = argparse.ArgumentParser(description='Dream Generation')
parser.add_argument('--rnn_model_dir', type=str, default="weights/rnn.tar")
parser.add_argument('--vae_model_dir', type=str, default='weights/vae.tar')
parser.add_argument('--vae_latents', type=str, default='weights/video_vae_latents.tar')
parser.add_argument('--video_name', type=str, default='test')
parser.add_argument('--teaching_duration', type=int, default = 100)
parser.add_argument('--dream_duration', type=int, default = 2500)
parser.add_argument('--vae_latent_dim', type=int, default = 64)
parser.add_argument('--rnn_hidden_dim', type=int, default = 128)
parser.add_argument('--rnn_gaussians_n', type=int, default = 5)
parser.add_argument('--video_n', type=int, default = 1)
args = parser.parse_args()
out = cv2.VideoWriter(args.video_name+'.mp4',cv2.VideoWriter_fourcc(*'MP4V'), 25, (256,256))
# Reload RNN model
mdrnn = MDRNN(args.vae_latent_dim, args.rnn_hidden_dim, args.rnn_gaussians_n)
mdrnncell = MDRNNCell(args.vae_latent_dim, args.rnn_hidden_dim, args.rnn_gaussians_n)
# reload_file = join(os.path.expandvars(args.rnn_model_dir), 'best.tar')
reload_file = args.rnn_model_dir
if exists(reload_file):
rnn_state = torch.load(reload_file, map_location=torch.device('cpu'))
print("Loading MDRNN at epoch {} "
"with test error {}".format(
rnn_state["epoch"], rnn_state["precision"]))
mdrnn.load_state_dict(rnn_state["state_dict"])
elif config.reload == "True" and not exists(reload_file):
raise Exception('Reload file not found: {}'.format(reload_file))
rnn_state_dict = {k.strip('_l0'): v for k, v in mdrnn.state_dict().items()}
mdrnncell.load_state_dict(rnn_state_dict)
# Load VAE latents
latents = torch.load(args.vae_latents)