Commit e8e178e9 authored by PhilFischer's avatar PhilFischer
Browse files

Added visualizations of likelihood means and of train images

parent 62a87f78
......@@ -8,11 +8,10 @@ import numpy as np
import tensorflow as tf
from datetime import datetime
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
from tensorboard.plugins.hparams import api as hp
from src.preprocessing.loading import create_dataset, load_data
from src.preprocessing.transformations import scale_image, clip_open
from src.preprocessing.transformations import scale_image
from src.preprocessing.augmentation import augment_flips_rotations
from src.metrics.visualizations import log_reconstruction, log_generation, log_style_transfer
from src.model.cvae import define_model
......@@ -51,10 +50,7 @@ print(f"*** Styles: {', '.join(styles)}")
########################################################
train, test = scale_image(train), scale_image(test)
train, test = clip_open(train), clip_open(test)
train, s_train = augment_flips_rotations(train, s_train)
train, s_train = shuffle(train, s_train, random_state=SEED)
n_styles = np.max(s_train) + 1
s_train, s_test = tf.one_hot(s_train, n_styles), tf.one_hot(s_test, n_styles)
......@@ -65,15 +61,15 @@ s_train, s_test = tf.one_hot(s_train, n_styles), tf.one_hot(s_test, n_styles)
########################################################
hparams = [
hp.HParam('architecture', hp.Discrete(['σ-VAE-1'])),
hp.HParam('architecture', hp.Discrete(['σ-VAE-2'])),
hp.HParam('likelihood', hp.Discrete(['sigma'])),
hp.HParam('log_lr', hp.RealInterval(-5.6, -3.6)),
hp.HParam('log_beta', hp.RealInterval(1.2, 1.2)),
hp.HParam('latent_dim', hp.IntInterval(12, 12)),
hp.HParam('log_lr', hp.RealInterval(-4., -4.)),
hp.HParam('log_beta', hp.RealInterval(1., 1.)),
hp.HParam('latent_dim', hp.IntInterval(64, 64)),
hp.HParam('feature_layers', hp.IntInterval(2, 2)),
hp.HParam('style_layers', hp.IntInterval(1, 1)),
hp.HParam('encoder_layers', hp.IntInterval(2, 2)),
hp.HParam('decoder_layers', hp.IntInterval(4, 5)),
hp.HParam('decoder_layers', hp.IntInterval(4, 4)),
hp.HParam('feature_filters', hp.IntInterval(8, 32)),
hp.HParam('encoder_filters', hp.IntInterval(8, 32)),
hp.HParam('decoder_filters', hp.IntInterval(8, 32)),
......@@ -81,7 +77,7 @@ hparams = [
hp.HParam('encoder_kernel_size', hp.IntInterval(3, 3)),
hp.HParam('decoder_kernel_size', hp.IntInterval(3, 3)),
hp.HParam('feature_multiple', hp.IntInterval(2, 3)),
hp.HParam('encoder_multiple', hp.IntInterval(2, 4)),
hp.HParam('encoder_multiple', hp.IntInterval(1, 3)),
hp.HParam('decoder_multiple', hp.IntInterval(2, 3))
]
metrics = [
......@@ -95,7 +91,7 @@ log_dir = 'logs'
ex_name = os.path.join('pex01', datetime.now().strftime("%y%m%d-%H%M%S"))
tuner = Tuner((train, s_train), (test, s_test), define_model, hparams, metrics, log_dir, seed=SEED)
tuner.tune(ex_name, runs=10, epochs=1000)
tuner.tune(ex_name, runs=5, epochs=1000)
if tuner.best_model is None:
print('\n*** No model gave finite results. Report omitted.')
......@@ -106,19 +102,24 @@ if tuner.best_model is None:
### INFERENCE WITH BEST MODEL
########################################################
idx = np.argmax(s_test, axis=0)
idx2 = np.argmax(s_test[np.max(idx)+1:], axis=0)
example_batch = np.concatenate([test[idx], test[np.max(idx)+1:][idx2]])
style_batch = tf.one_hot(list(range(n_styles))*2, n_styles)
def log_inference(data, s_data, model, dirname):
idx = np.argmax(s_data, axis=0)
idx2 = np.argmax(s_data[np.max(idx)+1:], axis=0)
example_batch = np.concatenate([data[idx], data[np.max(idx)+1:][idx2]])
style_batch = tf.one_hot(list(range(n_styles))*2, n_styles)
# Visualize reconstructions from latent means
log_reconstruction(model, example_batch, style_batch, os.path.join(log_dir, ex_name, dirname))
# Visualize reconstructions from latent means
log_reconstruction(tuner.best_model, example_batch, style_batch, os.path.join(log_dir, ex_name))
# Visualize style transfer
log_style_transfer(model, example_batch, tf.one_hot(range(n_styles), n_styles), os.path.join(log_dir, ex_name, dirname))
# Visualize style transfer
log_style_transfer(tuner.best_model, example_batch, tf.one_hot(range(n_styles), n_styles), os.path.join(log_dir, ex_name))
# Visualize generative sampling
log_generation(model, tf.one_hot(range(n_styles), n_styles), os.path.join(log_dir, ex_name, dirname))
# Visualize generative sampling
log_generation(tuner.best_model, tf.one_hot(range(n_styles), n_styles), os.path.join(log_dir, ex_name))
# Log inference on train and test set
log_inference(train, s_train, tuner.best_model, 'train')
log_inference(test, s_test, tuner.best_model, 'test')
########################################################
......
......@@ -6,9 +6,11 @@ from tensorflow.summary import create_file_writer
def log_reconstruction(model, example_batch, style_batch, log_dir):
examples = np.concatenate(list(example_batch), axis=-2)
mean_batch = model.reconstruct(example_batch, style_batch, sample_posterior=False)
pmeans = np.concatenate(list(mean_batch.mean()), axis=-2)
means = np.concatenate(list(mean_batch.sample()), axis=-2)
pmean_batch = model.reconstruct(example_batch, style_batch, sample_posterior=False).mean()
pmeans = np.concatenate(list(pmean_batch), axis=-2)
mean_batches = [model.reconstruct(example_batch, style_batch, sample_posterior=True).mean() for _ in range(3)]
mean_batch = np.concatenate(mean_batches, axis=-3)
means = np.concatenate(list(mean_batch), axis=-2)
sample_batches = [model.reconstruct(example_batch, style_batch, sample_posterior=True).sample() for _ in range(3)]
sample_batch = np.concatenate(sample_batches, axis=-3)
samples = np.concatenate(list(sample_batch), axis=-2)
......@@ -23,7 +25,7 @@ def log_style_transfer(model, example_batch, styles, log_dir):
batch_size = tf.shape(example_batch)[:1]
for style in tf.unstack(styles):
style_batch = tf.transpose(tf.repeat(tf.expand_dims(style, axis=-1), batch_size, axis=1))
transferred_batch = model.reconstruct(example_batch, style_batch, sample_posterior=True).sample()
transferred_batch = model.reconstruct(example_batch, style_batch, sample_posterior=True).mean()
images.append(np.concatenate(transferred_batch, axis=-2))
image = np.concatenate((examples, *images), axis=-3)
with create_file_writer(os.path.join(log_dir, 'images')).as_default():
......@@ -33,7 +35,7 @@ def log_style_transfer(model, example_batch, styles, log_dir):
def log_generation(model, style_batch, log_dir):
images = []
for style in tf.unstack(style_batch):
generated_batch = [model.generate(style).sample() for _ in range(6)]
generated_batch = [model.generate(style).mean() for _ in range(6)]
images.append(np.concatenate(generated_batch, axis=-2))
image = np.concatenate(images, axis=-3)
with create_file_writer(os.path.join(log_dir, 'images')).as_default():
......
......@@ -28,7 +28,7 @@ class Tuner():
self.shape = train[0].shape[1:]
self.n_styles = int(tf.shape(train[1])[-1])
self.train = Dataset.from_tensor_slices((train, train[0])).batch(32)
self.train = Dataset.from_tensor_slices((train, train[0])).shuffle(1000).batch(32)
self.test = Dataset.from_tensor_slices((test, test[0])).batch(32)
self.model_fn = model_fn
self.hparams = hparams
......@@ -55,12 +55,12 @@ class Tuner():
model = self.model_fn(self.shape, self.n_styles, params)
checkpoint = ModelCheckpoint(weights_dir, save_best_only=True, save_weights_only=True)
stopping = EarlyStopping(patience=100)
stopping = EarlyStopping(patience=200)
tensorboard = TensorBoard(log_dir=run_dir)
hpboard = hp.KerasCallback(run_dir, params)
print(f"\n*** Starting run and logging to {run_dir}")
model.fit(self.train, epochs=epochs, shuffle=False, verbose=2, validation_data=self.test,
model.fit(self.train, epochs=epochs, verbose=2, validation_data=self.test,
callbacks=[checkpoint, stopping, tensorboard, hpboard])
print('Validation:')
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment