Commit b3d9699a authored by nstorni's avatar nstorni
Browse files


parent d87b0574
# VAE setup and training. # Generate "Dream" of mice behaviour
The following script is the final result of the project for the Deep Learning Lecture. It uses a trained VAE and MDN-RNN to generate a "dreamed" sequence of frames.
The length of the teaching sequence can be choosen by setting the argument --teaching_duration and the length of the successive "dreamed" sequence by setting --dream_duration.
If this script is runned on the leonhard cluster the following lines will load the correct modules and install the requirements before starting the generate_dream script:
module purge
module load gcc/4.8.5 python_gpu/3.7.1
python --rnn_model_dir=weights/rnn.tar --vae_model_dir=weights/vae.tar --vae_latents=weights/video_vae_latents.tar --teaching_duration=100 --dream_duration=500 --video_name="dreamingMices"
(For running on a different system install the requirements in the requirements.txt file)
The generated video will be saved in the directory where the script is launched.
# Setup, dataset preparations and Training.
## 1. Prerequisites ## 1. Prerequisites
First ssh in leonhard and clone the project files from Gitlab (you will need to enter your credentials) to your $HOME directory: First ssh in leonhard and clone the project files from Gitlab (you will need to enter your credentials) to your $HOME directory:
...@@ -192,104 +209,3 @@ module load gcc/4.8.5 python_gpu/3.7.1 ...@@ -192,104 +209,3 @@ module load gcc/4.8.5 python_gpu/3.7.1
python $HOME/world-models/video_frame/ python $HOME/world-models/video_frame/
``` ```
# Pytorch implementation of the "WorldModels"
Paper: Ha and Schmidhuber, "World Models", 2018. For a quick summary of the paper and some additional experiments, visit the [github page](
## Prerequisites
First clone the project files from Gitlab (you will need to enter your credentials):
git clone
Navigate to the project directory and execute the following command to build the Docker image. This might take a while, but you only need to do this once.
docker build -t deep-learning:worldmodels .
To run the container, run the following command in the project directory, depending on your OS:
Windows (PowerShell): docker run -it --rm -v ${pwd}:/app deep-learning:worldmodels
Linux: docker run -it --rm -v $(pwd):/app deep-learning:worldmodels
## Running the worldmodels
To run the model, run the Docker container (see above) and execute the command inside the container.
The model is composed of three parts:
1. A Variational Auto-Encoder (VAE), whose task is to compress the input images into a compact latent representation.
2. A Mixture-Density Recurrent Network (MDN-RNN), trained to predict the latent encoding of the next frame given past latent encodings and actions.
3. A linear Controller (C), which takes both the latent encoding of the current frame, and the hidden state of the MDN-RNN given past latents and actions as input and outputs an action. It is trained to maximize the cumulated reward using the Covariance-Matrix Adaptation Evolution-Strategy ([CMA-ES]( from the `cma` python package.
In the given code, all three sections are trained separately, using the scripts ``, `` and ``.
Training scripts take as argument:
* **--logdir** : The directory in which the models will be stored. If the logdir specified already exists, it loads the old model and continues the training.
* **--noreload** : If you want to override a model in *logdir* instead of reloading it, add this option.
### 1. Data generation
Before launching the VAE and MDN-RNN training scripts, you need to generate a dataset of random rollouts and place it in the `datasets/carracing` folder.
Data generation is handled through the `data/` script, e.g.
python data/ --rollouts 1000 --rootdir datasets/carracing --threads 8
Rollouts are generated using a *brownian* random policy, instead of the *white noise* random `action_space.sample()` policy from gym, providing more consistent rollouts.
### 2. Training the VAE
The VAE is trained using the `` file, e.g.
python --logdir exp_dir
### 3. Training the MDN-RNN
The MDN-RNN is trained using the `` file, e.g.
python --logdir exp_dir
A VAE must have been trained in the same `exp_dir` for this script to work.
### 4. Training and testing the Controller
Finally, the controller is trained using CMA-ES, e.g.
python --logdir exp_dir --n-samples 4 --pop-size 4 --target-return 950 --display
You can test the obtained policy with `` e.g.
python --logdir exp_dir
### Notes
When running on a headless server, you will need to use `xvfb-run` to launch the controller training script. For instance,
xvfb-run -s "-screen 0 1400x900x24" python --logdir exp_dir --n-samples 4 --pop-size 4 --target-return 950 --display
If you do not have a display available and you launch `traincontroller` without
`xvfb-run`, the script will fail silently (but logs are available in
Be aware that `traincontroller` requires heavy gpu memory usage when launched
on gpus. To reduce the memory load, you can directly modify the maximum number
of workers by specifying the `--max-workers` argument.
If you have several GPUs available, `traincontroller` will take advantage of
all gpus specified by `CUDA_VISIBLE_DEVICES`.
## Authors
* **Corentin Tallec** - [ctallec](
* **Léonard Blier** - [leonardblier](
* **Diviyan Kalainathan** - [diviyan-kalainathan](
## License
This project is licensed under the MIT License - see the []( file for details
""" Recurrent model training """
import argparse import argparse
from functools import partial from functools import partial
from os.path import join, exists from os.path import join, exists
...@@ -18,10 +17,10 @@ from models.mdrnn import MDRNN, MDRNNCell ...@@ -18,10 +17,10 @@ from models.mdrnn import MDRNN, MDRNNCell
import cv2 import cv2
from utils.config import get_config_from_json from utils.config import get_config_from_json
parser = argparse.ArgumentParser(description='LSTM training') parser = argparse.ArgumentParser(description='Dream Generation')
parser.add_argument('--rnn_model_dir', type=str, default="/cluster/scratch/nstorni/data/rnn_D20200116T151028") parser.add_argument('--rnn_model_dir', type=str, default="weights/rnn.tar")
parser.add_argument('--vae_model_dir', type=str, default='/cluster/scratch/nstorni/experiments/shvae_temporalmask/models/vae_D20200111T092802') parser.add_argument('--vae_model_dir', type=str, default='weights/vae.tar')
parser.add_argument('--vae_latents', type=str, default='/cluster/scratch/nstorni/data/mice_tempdiff_medium/') parser.add_argument('--vae_latents', type=str, default='weights/video_vae_latents.tar')
parser.add_argument('--video_name', type=str, default='test') parser.add_argument('--video_name', type=str, default='test')
parser.add_argument('--teaching_duration', type=int, default = 100) parser.add_argument('--teaching_duration', type=int, default = 100)
parser.add_argument('--dream_duration', type=int, default = 2500) parser.add_argument('--dream_duration', type=int, default = 2500)
...@@ -36,7 +35,8 @@ out = cv2.VideoWriter(args.video_name+'.mp4',cv2.VideoWriter_fourcc(*'MP4V'), 25 ...@@ -36,7 +35,8 @@ out = cv2.VideoWriter(args.video_name+'.mp4',cv2.VideoWriter_fourcc(*'MP4V'), 25
# Reload RNN model # Reload RNN model
mdrnn = MDRNN(args.vae_latent_dim, args.rnn_hidden_dim, args.rnn_gaussians_n) mdrnn = MDRNN(args.vae_latent_dim, args.rnn_hidden_dim, args.rnn_gaussians_n)
mdrnncell = MDRNNCell(args.vae_latent_dim, args.rnn_hidden_dim, args.rnn_gaussians_n) mdrnncell = MDRNNCell(args.vae_latent_dim, args.rnn_hidden_dim, args.rnn_gaussians_n)
reload_file = join(os.path.expandvars(args.rnn_model_dir), 'checkpoint.tar') # reload_file = join(os.path.expandvars(args.rnn_model_dir), 'best.tar')
reload_file = args.rnn_model_dir
if exists(reload_file): if exists(reload_file):
rnn_state = torch.load(reload_file, map_location=torch.device('cpu')) rnn_state = torch.load(reload_file, map_location=torch.device('cpu'))
print("Loading MDRNN at epoch {} " print("Loading MDRNN at epoch {} "
...@@ -60,7 +60,8 @@ z = eps.mul(sigma).add_(mu) ...@@ -60,7 +60,8 @@ z = eps.mul(sigma).add_(mu)
# Load VAE model # Load VAE model
model = VAE(4, args.vae_latent_dim) model = VAE(4, args.vae_latent_dim)
reload_file = join(os.path.expandvars(args.vae_model_dir), 'best.tar') # reload_file = join(os.path.expandvars(args.vae_model_dir), 'best.tar')
reload_file = args.vae_model_dir
state = torch.load(reload_file, map_location=torch.device('cpu')) state = torch.load(reload_file, map_location=torch.device('cpu'))
model.load_state_dict(state['state_dict']) model.load_state_dict(state['state_dict'])
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment