Commit c1c1e58a authored by Rafael Daetwyler's avatar Rafael Daetwyler
Browse files

load config from file

parent 94dec806
{
"model" : "rnn",
"exp_name" : "test",
"learning_rate" : 0.001,
"batch_size" : 64,
"reload" : "False",
"reload_dir": "$SCRATCH/experiments/lstm/models/",
"epochs" : 200,
"input_dim": 64,
"output_dim": 128,
"seq_len": 32,
"early_stopping": "True",
"weight_decay": "True",
"logdir" : "$SCRATCH/experiments/lstm",
"dataset":"$SCRATCH/data/latent_variables/data.pt"
}
\ No newline at end of file
......@@ -3,14 +3,17 @@ import argparse
from functools import partial
from os.path import join, exists
from os import mkdir
import os
import torch
import torch.nn.functional as f
from torch.utils.data import DataLoader
from torchvision import transforms
import numpy as np
from datetime import datetime
import json
from tqdm import tqdm
from utils.misc import save_checkpoint
from utils.misc import LSIZE, RSIZE, RED_SIZE, SIZE
from utils.misc import RED_SIZE, SIZE
from utils.learning import EarlyStopping
## WARNING : THIS SHOULD BE REPLACED WITH PYTORCH 0.5
from utils.learning import ReduceLROnPlateau
......@@ -20,40 +23,40 @@ from data.loaders import LatentStateDataset
from models.vae import VAE
from models.mdrnn import MDRNN, gmm_loss
parser = argparse.ArgumentParser("LSTM training")
parser.add_argument('--logdir', type=str,
help="Where things are logged and models are loaded from.")
parser.add_argument('--noreload', action='store_true',
help="Do not reload if specified.")
parser.add_argument('--latent_file', type=str,
help="Specify the file where the latent representation from the VAE is stored.")
from utils.config import get_config_from_json
parser = argparse.ArgumentParser(description='LSTM training')
parser.add_argument('--modelconfig', type=str, help='File path of training configuration')
args = parser.parse_args()
config, config_json = get_config_from_json(args.modelconfig)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# constants
BSIZE = 16
SEQ_LEN = 32
epochs = 30
if not exists(config.logdir):
mkdir(config.logdir)
if not exists(join(config.logdir, 'models')):
mkdir(join(config.logdir, 'models'))
# Loading model
rnn_dir = join(args.logdir, 'mdrnn')
rnn_file = join(rnn_dir, 'best.tar')
# Stamp training experiment name with datetime stamp for unique naming.
datetimestamp = (datetime.now()).strftime("D%Y%m%dT%H%M%S")
lstm_dir = join(config.logdir, 'models/' + config.exp_name + "_" + datetimestamp)
if not exists(lstm_dir):
mkdir(lstm_dir)
mkdir(join(lstm_dir, 'samples'))
if not exists(rnn_dir):
mkdir(rnn_dir)
# Save configuration file
with open(join(lstm_dir,'config.json'), 'w') as outfile:
json.dump(config_json, outfile)
# override the default setting which is LSIZE = 32, RSIZE = 265
LSIZE, RSIZE = 64, 128
mdrnn = MDRNN(LSIZE, RSIZE, 5)
mdrnn = MDRNN(config.input_dim, config.output_dim, 5)
mdrnn.to(device)
optimizer = torch.optim.RMSprop(mdrnn.parameters(), lr=1e-3, alpha=.9)
optimizer = torch.optim.RMSprop(mdrnn.parameters(), lr=config.learning_rate, alpha=.9)
scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5)
earlystopping = EarlyStopping('min', patience=30)
if exists(rnn_file) and not args.noreload:
rnn_state = torch.load(rnn_file)
reload_file = join(os.path.expandvars(config.reload_dir), 'best.tar')
if config.reload == "True" and exists(reload_file):
rnn_state = torch.load(reload_file)
print("Loading MDRNN at epoch {} "
"with test error {}".format(
rnn_state["epoch"], rnn_state["precision"]))
......@@ -61,16 +64,17 @@ if exists(rnn_file) and not args.noreload:
# optimizer.load_state_dict(rnn_state["optimizer"])
# scheduler.load_state_dict(state['scheduler'])
# earlystopping.load_state_dict(state['earlystopping'])
elif config.reload == "True" and not exists(reload_file):
raise Exception('Reload file not found: {}'.format(reload_file))
# Data Loading
train_loader = DataLoader(
LatentStateDataset(args.latent_file, SEQ_LEN),
batch_size=BSIZE, num_workers=8)
LatentStateDataset(config.dataset, config.seq_len),
batch_size=config.batch_size, num_workers=8)
test_loader = DataLoader(
LatentStateDataset(args.latent_file, SEQ_LEN, train=False),
batch_size=BSIZE, num_workers=8)
LatentStateDataset(config.dataset, config.seq_len, train=False),
batch_size=config.batch_size, num_workers=8)
def get_loss(latent_obs, latent_next_obs):
""" Compute losses.
......@@ -90,7 +94,7 @@ def get_loss(latent_obs, latent_next_obs):
for arr in [latent_obs, latent_next_obs]]
mus, sigmas, logpi = mdrnn(latent_obs)
gmm = gmm_loss(latent_next_obs, mus, sigmas, logpi)
loss = gmm / LSIZE
loss = gmm / config.input_dim
return dict(gmm=gmm, loss=loss)
......@@ -103,11 +107,6 @@ def data_pass(epoch, train): # pylint: disable=too-many-locals
mdrnn.eval()
loader = test_loader
# Comment by Rafael:
# The following line was used in the original world models to load the next batch of the dataset.
# If we use multiple videos, we might want to use a similar mechanism.
# loader.dataset.load_next_buffer()
cum_loss = 0
cum_gmm = 0
......@@ -130,17 +129,17 @@ def data_pass(epoch, train): # pylint: disable=too-many-locals
pbar.set_postfix_str("loss={loss:10.6f} gmm={gmm:10.6f}".format(
loss=cum_loss / (i + 1),
gmm=cum_gmm / LSIZE / (i + 1)))
pbar.update(BSIZE)
gmm=cum_gmm / config.input_dim / (i + 1)))
pbar.update(config.batch_size)
pbar.close()
return cum_loss * BSIZE / len(loader.dataset)
return cum_loss * config.batch_size / len(loader.dataset)
train = partial(data_pass, train=True)
test = partial(data_pass, train=False)
cur_best = None
for e in range(epochs):
for e in range(config.epochs):
train(e)
test_loss = test(e)
# scheduler.step(test_loss)
......@@ -149,7 +148,8 @@ for e in range(epochs):
is_best = not cur_best or test_loss < cur_best
if is_best:
cur_best = test_loss
checkpoint_fname = join(rnn_dir, 'checkpoint.tar')
best_filename = join(lstm_dir, 'best.tar')
checkpoint_fname = join(lstm_dir, 'checkpoint.tar')
save_checkpoint({
"state_dict": mdrnn.state_dict(),
"optimizer": optimizer.state_dict(),
......
......@@ -31,9 +31,12 @@ def get_config_from_json(json_file):
config_dict = json.load(config_file)
# EasyDict allows to access dict values as attributes (works recursively).
config = EasyDict(config_dict)
config.train_dataset_dir = os.path.expandvars(config.train_dataset_dir)
config.val_dataset_dir = os.path.expandvars(config.val_dataset_dir)
config.dataset_dir = os.path.expandvars(config.dataset_dir)
if hasattr(config, 'model') and config.model == "rnn":
config.dataset = os.path.expandvars(config.dataset)
else:
config.train_dataset_dir = os.path.expandvars(config.train_dataset_dir)
config.val_dataset_dir = os.path.expandvars(config.val_dataset_dir)
config.dataset_dir = os.path.expandvars(config.dataset_dir)
config.logdir = os.path.expandvars(config.logdir)
# Make experiment name legal filename:
valid_chars = "-_%s%s" % (string.ascii_letters, string.digits)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment