Commit b3d9699a authored by nstorni's avatar nstorni
Browse files

Cleanup

parent d87b0574
# VAE setup and training.
# Generate "Dream" of mice behaviour
The following script is the final result of the project for the Deep Learning Lecture. It uses a trained VAE and MDN-RNN to generate a "dreamed" sequence of frames.
The length of the teaching sequence can be choosen by setting the argument --teaching_duration and the length of the successive "dreamed" sequence by setting --dream_duration.
If this script is runned on the leonhard cluster the following lines will load the correct modules and install the requirements before starting the generate_dream script:
```bash
$HOME/world-models/install_requirements.sh
module purge
module load gcc/4.8.5 python_gpu/3.7.1
python generate_dream.py --rnn_model_dir=weights/rnn.tar --vae_model_dir=weights/vae.tar --vae_latents=weights/video_vae_latents.tar --teaching_duration=100 --dream_duration=500 --video_name="dreamingMices"
```
(For running on a different system install the requirements in the requirements.txt file)
The generated video will be saved in the directory where the script is launched.
# Setup, dataset preparations and Training.
## 1. Prerequisites
First ssh in leonhard and clone the project files from Gitlab (you will need to enter your credentials) to your $HOME directory:
......@@ -192,104 +209,3 @@ module load gcc/4.8.5 python_gpu/3.7.1
python $HOME/world-models/video_frame/gen_tempdiff_data.py
```
# Pytorch implementation of the "WorldModels"
Paper: Ha and Schmidhuber, "World Models", 2018. https://doi.org/10.5281/zenodo.1207631. For a quick summary of the paper and some additional experiments, visit the [github page](https://ctallec.github.io/world-models/).
## Prerequisites
First clone the project files from Gitlab (you will need to enter your credentials):
```bash
git clone https://gitlab.ethz.ch/deep-learning-rodent/world-models.git
```
Navigate to the project directory and execute the following command to build the Docker image. This might take a while, but you only need to do this once.
```bash
docker build -t deep-learning:worldmodels .
```
To run the container, run the following command in the project directory, depending on your OS:
```bash
Windows (PowerShell): docker run -it --rm -v ${pwd}:/app deep-learning:worldmodels
Linux: docker run -it --rm -v $(pwd):/app deep-learning:worldmodels
```
## Running the worldmodels
To run the model, run the Docker container (see above) and execute the command inside the container.
The model is composed of three parts:
1. A Variational Auto-Encoder (VAE), whose task is to compress the input images into a compact latent representation.
2. A Mixture-Density Recurrent Network (MDN-RNN), trained to predict the latent encoding of the next frame given past latent encodings and actions.
3. A linear Controller (C), which takes both the latent encoding of the current frame, and the hidden state of the MDN-RNN given past latents and actions as input and outputs an action. It is trained to maximize the cumulated reward using the Covariance-Matrix Adaptation Evolution-Strategy ([CMA-ES](http://www.cmap.polytechnique.fr/~nikolaus.hansen/cmaartic.pdf)) from the `cma` python package.
In the given code, all three sections are trained separately, using the scripts `trainvae.py`, `trainmdrnn.py` and `traincontroller.py`.
Training scripts take as argument:
* **--logdir** : The directory in which the models will be stored. If the logdir specified already exists, it loads the old model and continues the training.
* **--noreload** : If you want to override a model in *logdir* instead of reloading it, add this option.
### 1. Data generation
Before launching the VAE and MDN-RNN training scripts, you need to generate a dataset of random rollouts and place it in the `datasets/carracing` folder.
Data generation is handled through the `data/generation_script.py` script, e.g.
```bash
python data/generation_script.py --rollouts 1000 --rootdir datasets/carracing --threads 8
```
Rollouts are generated using a *brownian* random policy, instead of the *white noise* random `action_space.sample()` policy from gym, providing more consistent rollouts.
### 2. Training the VAE
The VAE is trained using the `trainvae.py` file, e.g.
```bash
python trainvae.py --logdir exp_dir
```
### 3. Training the MDN-RNN
The MDN-RNN is trained using the `trainmdrnn.py` file, e.g.
```bash
python trainmdrnn.py --logdir exp_dir
```
A VAE must have been trained in the same `exp_dir` for this script to work.
### 4. Training and testing the Controller
Finally, the controller is trained using CMA-ES, e.g.
```bash
python traincontroller.py --logdir exp_dir --n-samples 4 --pop-size 4 --target-return 950 --display
```
You can test the obtained policy with `test_controller.py` e.g.
```bash
python test_controller.py --logdir exp_dir
```
### Notes
When running on a headless server, you will need to use `xvfb-run` to launch the controller training script. For instance,
```bash
xvfb-run -s "-screen 0 1400x900x24" python traincontroller.py --logdir exp_dir --n-samples 4 --pop-size 4 --target-return 950 --display
```
If you do not have a display available and you launch `traincontroller` without
`xvfb-run`, the script will fail silently (but logs are available in
`logdir/tmp`).
Be aware that `traincontroller` requires heavy gpu memory usage when launched
on gpus. To reduce the memory load, you can directly modify the maximum number
of workers by specifying the `--max-workers` argument.
If you have several GPUs available, `traincontroller` will take advantage of
all gpus specified by `CUDA_VISIBLE_DEVICES`.
## Authors
* **Corentin Tallec** - [ctallec](https://github.com/ctallec)
* **Léonard Blier** - [leonardblier](https://github.com/leonardblier)
* **Diviyan Kalainathan** - [diviyan-kalainathan](https://github.com/diviyan-kalainathan)
## License
This project is licensed under the MIT License - see the [LICENSE.md](LICENSE.md) file for details
""" Recurrent model training """
import argparse
from functools import partial
from os.path import join, exists
......@@ -18,10 +17,10 @@ from models.mdrnn import MDRNN, MDRNNCell
import cv2
from utils.config import get_config_from_json
parser = argparse.ArgumentParser(description='LSTM training')
parser.add_argument('--rnn_model_dir', type=str, default="/cluster/scratch/nstorni/data/rnn_D20200116T151028")
parser.add_argument('--vae_model_dir', type=str, default='/cluster/scratch/nstorni/experiments/shvae_temporalmask/models/vae_D20200111T092802')
parser.add_argument('--vae_latents', type=str, default='/cluster/scratch/nstorni/data/mice_tempdiff_medium/vae_D20200111T092802.pt')
parser = argparse.ArgumentParser(description='Dream Generation')
parser.add_argument('--rnn_model_dir', type=str, default="weights/rnn.tar")
parser.add_argument('--vae_model_dir', type=str, default='weights/vae.tar')
parser.add_argument('--vae_latents', type=str, default='weights/video_vae_latents.tar')
parser.add_argument('--video_name', type=str, default='test')
parser.add_argument('--teaching_duration', type=int, default = 100)
parser.add_argument('--dream_duration', type=int, default = 2500)
......@@ -36,7 +35,8 @@ out = cv2.VideoWriter(args.video_name+'.mp4',cv2.VideoWriter_fourcc(*'MP4V'), 25
# Reload RNN model
mdrnn = MDRNN(args.vae_latent_dim, args.rnn_hidden_dim, args.rnn_gaussians_n)
mdrnncell = MDRNNCell(args.vae_latent_dim, args.rnn_hidden_dim, args.rnn_gaussians_n)
reload_file = join(os.path.expandvars(args.rnn_model_dir), 'checkpoint.tar')
# reload_file = join(os.path.expandvars(args.rnn_model_dir), 'best.tar')
reload_file = args.rnn_model_dir
if exists(reload_file):
rnn_state = torch.load(reload_file, map_location=torch.device('cpu'))
print("Loading MDRNN at epoch {} "
......@@ -60,7 +60,8 @@ z = eps.mul(sigma).add_(mu)
# Load VAE model
model = VAE(4, args.vae_latent_dim)
reload_file = join(os.path.expandvars(args.vae_model_dir), 'best.tar')
# reload_file = join(os.path.expandvars(args.vae_model_dir), 'best.tar')
reload_file = args.vae_model_dir
state = torch.load(reload_file, map_location=torch.device('cpu'))
model.load_state_dict(state['state_dict'])
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment