Commit edd7be84 authored by Silviu Nastasescu's avatar Silviu Nastasescu
Browse files

Added style loss

parent 426b3fca
......@@ -23,114 +23,124 @@ tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
SEED = 163
RANDOM = random.Random(SEED)
DATAPATH = 'data'
########################################################
### LOAD DATA
########################################################
if not os.path.exists(DATAPATH) or not os.path.isdir(DATAPATH):
print("*** Data not available at {DATAPATH}. Exiting...")
sys.exit(1)
if not os.path.exists(os.path.join(DATAPATH, 'dataset.npz')):
create_dataset(DATAPATH, size=128)
data, labels, styles = load_data(DATAPATH)
train, test, s_train, s_test = train_test_split(data, labels, test_size=0.2, stratify=labels, random_state=SEED)
assert train.shape[0] == s_train.shape[0] and test.shape[0] == s_test.shape[0]
print(f"*** Training dataset loaded with shape {train.shape}")
print(f"*** Validation dataset loaded with shape {test.shape}")
print(f"*** Styles: {', '.join(styles)}")
########################################################
### TRANSFORM AND AUGMENT DATA
########################################################
train, test = scale_image(train), scale_image(test)
train, s_train = augment_flips_rotations(train, s_train)
n_styles = np.max(s_train) + 1
s_train, s_test = tf.one_hot(s_train, n_styles), tf.one_hot(s_test, n_styles)
########################################################
### BUILD AND TUNE MODEL
########################################################
hparams = [
hp.HParam('architecture', hp.Discrete(['σ-VAE-2'])),
hp.HParam('likelihood', hp.Discrete(['sigma'])),
hp.HParam('output_activation', hp.Discrete(['linear'])),
hp.HParam('log_lr', hp.RealInterval(-4., -4.)),
hp.HParam('log_beta', hp.RealInterval(0., 0.)),
hp.HParam('latent_dim', hp.IntInterval(32, 32)),
hp.HParam('feature_layers', hp.IntInterval(2, 2)),
hp.HParam('style_layers', hp.IntInterval(1, 1)),
hp.HParam('encoder_layers', hp.IntInterval(1, 1)),
hp.HParam('decoder_layers', hp.IntInterval(4, 4)),
hp.HParam('feature_filters', hp.IntInterval(16, 32)),
hp.HParam('encoder_filters', hp.IntInterval(32, 64)),
hp.HParam('decoder_filters', hp.IntInterval(16, 32)),
hp.HParam('feature_kernel_size', hp.IntInterval(3, 3)),
hp.HParam('encoder_kernel_size', hp.IntInterval(3, 3)),
hp.HParam('decoder_kernel_size', hp.IntInterval(3, 3)),
hp.HParam('feature_multiple', hp.IntInterval(2, 2)),
hp.HParam('encoder_multiple', hp.IntInterval(2, 2)),
hp.HParam('decoder_multiple', hp.IntInterval(2, 2))
]
metrics = [
hp.Metric('epoch_loss', group='train', display_name='Train ELBO'),
hp.Metric('epoch_loss', group='validation', display_name='ELBO'),
hp.Metric('epoch_kld', group='train', display_name='Train KLD'),
hp.Metric('epoch_kld', group='validation', display_name='KLD')
]
log_dir = 'logs'
ex_name = os.path.join('pex02', datetime.now().strftime("%y%m%d-%H%M%S"))
tuner = Tuner((train, s_train), (test, s_test), define_model, hparams, metrics, log_dir, seed=SEED)
tuner.tune(ex_name, runs=1, epochs=1000)
if tuner.best_model is None:
print('\n*** No model gave finite results. Report omitted.')
sys.exit(0)
########################################################
### INFERENCE WITH BEST MODEL
########################################################
def log_inference(data, s_data, model, dirname):
idx = np.argmax(s_data, axis=0)
idx2 = np.argmax(s_data[np.max(idx)+1:], axis=0)
example_batch = np.concatenate([data[idx], data[np.max(idx)+1:][idx2]])
style_batch = tf.one_hot(list(range(n_styles))*2, n_styles)
# Visualize reconstructions from latent means
log_reconstruction(model, example_batch, style_batch, os.path.join(log_dir, ex_name, dirname))
# Visualize style transfer
log_style_transfer(model, example_batch, tf.one_hot(range(n_styles), n_styles), os.path.join(log_dir, ex_name, dirname))
# Visualize generative sampling
log_generation(model, tf.one_hot(range(n_styles), n_styles), os.path.join(log_dir, ex_name, dirname))
# Log inference on train and test set
log_inference(train, s_train, tuner.best_model, 'train')
log_inference(test, s_test, tuner.best_model, 'test')
########################################################
### PRINT BEST RESULTS
########################################################
print('\n*** Best hyperparameters found:')
for param, value in tuner.best_params.items():
print(f"{param}: {value}")
print(f"\n*** Best ELBO: {tuner.min_loss}")
print(f"*** Best BPD: {tuner.min_loss/train.shape[1]/train.shape[2]/np.log(2)}")
ORACLEPATH = os.path.join('models', 'painter_classifier_model_ResNet-50.h5')
if __name__ == '__main__':
########################################################
### LOAD DATA
########################################################
if not os.path.exists(DATAPATH) or not os.path.isdir(DATAPATH):
print(f"*** Data not available at {DATAPATH}. Exiting...")
sys.exit(1)
if not os.path.exists(ORACLEPATH):
print(f"*** Oracle not available at {ORACLEPATH}. Exiting...")
sys.exit(1)
oracle = tf.keras.models.load_model(ORACLEPATH)
#oracle.summary()
if not os.path.exists(os.path.join(DATAPATH, 'dataset.npz')):
create_dataset(DATAPATH, size=112)
data, labels, styles = load_data(DATAPATH)
train, test, s_train, s_test = train_test_split(data, labels, test_size=0.2, stratify=labels, random_state=SEED)
assert train.shape[0] == s_train.shape[0] and test.shape[0] == s_test.shape[0]
print(f"*** Training dataset loaded with shape {train.shape}")
print(f"*** Validation dataset loaded with shape {test.shape}")
print(f"*** Styles: {', '.join(styles)}")
########################################################
### TRANSFORM AND AUGMENT DATA
########################################################
train, test = scale_image(train), scale_image(test)
#train, s_train = augment_flips_rotations(train, s_train)
n_styles = np.max(s_train) + 1
s_train, s_test = tf.one_hot(s_train, n_styles), tf.one_hot(s_test, n_styles)
########################################################
### BUILD AND TUNE MODEL
########################################################
# [-4.6, 1.2, 12, 2, 1, 3, 5, 12, 12, 13, 3, 3, 3, 2, 2, 2]
# [-15921, -10503, 14.167, 14.045]
hparams = [
hp.HParam('architecture', hp.Discrete(['σ-VAE-2'])),
hp.HParam('likelihood', hp.Discrete(['sigma'])),
hp.HParam('output_activation', hp.Discrete(['linear'])),
hp.HParam('log_lr', hp.RealInterval(-4., -4.)),
hp.HParam('log_beta', hp.RealInterval(1., 1.)),
hp.HParam('log_cce', hp.RealInterval(3.5, 3.5)),
hp.HParam('latent_dim', hp.IntInterval(32, 32)),
hp.HParam('feature_layers', hp.IntInterval(2, 2)),
hp.HParam('style_layers', hp.IntInterval(1, 1)),
hp.HParam('encoder_layers', hp.IntInterval(2, 2)),
hp.HParam('decoder_layers', hp.IntInterval(4, 4)),
hp.HParam('feature_filters', hp.IntInterval(16, 16)),
hp.HParam('encoder_filters', hp.IntInterval(32, 32)),
hp.HParam('decoder_filters', hp.IntInterval(32, 32)),
hp.HParam('feature_kernel_size', hp.IntInterval(3, 3)),
hp.HParam('encoder_kernel_size', hp.IntInterval(3, 3)),
hp.HParam('decoder_kernel_size', hp.IntInterval(3, 3)),
hp.HParam('feature_multiple', hp.IntInterval(2, 2)),
hp.HParam('encoder_multiple', hp.IntInterval(2, 2)),
hp.HParam('decoder_multiple', hp.IntInterval(2, 2))
]
metrics = [
hp.Metric('epoch_loss', group='train', display_name='Train ELBO'),
hp.Metric('epoch_loss', group='validation', display_name='ELBO'),
hp.Metric('epoch_kld', group='train', display_name='Train KLD'),
hp.Metric('epoch_kld', group='validation', display_name='KLD')
]
log_dir = 'logs'
ex_name = os.path.join('pex02', datetime.now().strftime("%y%m%d-%H%M%S"))
tuner = Tuner((train, s_train), (test, s_test), define_model, hparams, metrics, log_dir, oracle, seed=SEED)
tuner.tune(ex_name, runs=1, epochs=1000)
if tuner.best_model is None:
print('\n*** No model gave finite results. Report omitted.')
sys.exit(0)
########################################################
### INFERENCE WITH BEST MODEL
########################################################
def log_inference(data, s_data, model, dirname):
idx = np.argmax(s_data, axis=0)
idx2 = np.argmax(s_data[np.max(idx)+1:], axis=0)
example_batch = np.concatenate([data[idx], data[np.max(idx)+1:][idx2]])
style_batch = tf.one_hot(list(range(n_styles))*2, n_styles)
# Visualize reconstructions from latent means
log_reconstruction(model, example_batch, style_batch, os.path.join(log_dir, ex_name, dirname))
# Visualize style transfer
log_style_transfer(model, example_batch, tf.one_hot(range(n_styles), n_styles), os.path.join(log_dir, ex_name, dirname))
# Visualize generative sampling
log_generation(model, tf.one_hot(range(n_styles), n_styles), os.path.join(log_dir, ex_name, dirname))
# Log inference on train and test set
log_inference(train, s_train, tuner.best_model, 'train')
log_inference(test, s_test, tuner.best_model, 'test')
########################################################
### PRINT BEST RESULTS
########################################################
print('\n*** Best hyperparameters found:')
for param, value in tuner.best_params.items():
print(f"{param}: {value}")
print(f"\n*** Best ELBO: {tuner.min_loss}")
print(f"*** Best BPD: {tuner.min_loss/train.shape[1]/train.shape[2]/np.log(2)}")
......@@ -2,6 +2,8 @@ import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.keras import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import UpSampling2D
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow_probability.python.layers import KLDivergenceRegularizer, IndependentNormal, DistributionLambda
from tensorflow_probability.python.distributions import Independent, Normal, Kumaraswamy, kl_divergence
......@@ -23,7 +25,7 @@ def make_kumaraswamy_distr(cdim, reinterpreted_batch_ndims=1):
class VAE(Model):
def __init__(self, shape, n_styles, latent_dim=32, activation='swish', output_activation='linear', likelihood='sigma', log_beta=0.,
def __init__(self, shape, n_styles, oracle, latent_dim=32, activation='swish', output_activation='linear', likelihood='sigma', log_beta=0., log_cce = 0.,
encoder_layers=1, encoder_filters=4, encoder_kernel_size=3, encoder_multiple=1,
decoder_layers=2, decoder_filters=4, decoder_kernel_size=3, decoder_multiple=1,
feature_layers=3, feature_filters=4, feature_kernel_size=3, feature_multiple=1,
......@@ -34,7 +36,7 @@ class VAE(Model):
self._image_features = FeatureExtractionBlock(
layers=feature_layers, filters=feature_filters, kernel_size=feature_kernel_size, multiple=feature_multiple)
feature_shape = self._image_features.compute_output_shape((None, *shape))
self._style_features = StyleExtractionBlock((feature_shape[1], feature_shape[2], 1), activation=activation, layers=style_layers)
self._style_features = StyleExtractionBlock((feature_shape[1], feature_shape[2], feature_shape[3] // 2), activation=activation, layers=style_layers)
# Define encoder and posterior
self._prior = Independent(Normal(tf.zeros(latent_dim), scale=1), reinterpreted_batch_ndims=1)
......@@ -58,6 +60,13 @@ class VAE(Model):
self._connect = DistributionLambda(make_kumaraswamy_distr(cdim, reinterpreted_batch_ndims=len(shape)))
self._decoder = Decoder(param_shape, latent_dim=latent_dim+n_styles, activation=activation, output_activation=output_activation,
layers=decoder_layers, filters=decoder_filters, kernel_size=decoder_kernel_size, multiple=decoder_multiple)
# Define oracle
self._up = UpSampling2D(size=(2, 2))
self._oracle = oracle
self._oracle.trainable = False
self._cce = CategoricalCrossentropy()
self._log_cce = 10 ** log_cce
def call(self, inputs, training=False):
# Map features
......@@ -78,10 +87,18 @@ class VAE(Model):
x = self._decoder(x, training=training)
# Sample likelihood
x = self._connect(x)
result = self._connect(x)
if self.__log_scale is not None:
self.add_metric(tf.math.exp(self.__log_scale), 'sigma')
return x
# call oracle
x = self._up(result)
x = self._oracle(x)
cce = self._log_cce * self._cce(s, x)
self.add_loss(cce)
self.add_metric(cce, name='cce')
return result
def reconstruct(self, x, s, sample_posterior=True):
xi, xs = self._image_features(x), self._style_features(s)
......@@ -104,8 +121,8 @@ class VAE(Model):
return x
def define_model(shape, n_styles, hparams):
model = VAE(shape, n_styles, **hparams)
def define_model(shape, n_styles, oracle, hparams):
model = VAE(shape, n_styles, oracle, **hparams)
lr = 10**hparams['log_lr'] if 'log_lr' in hparams else 0.0001
model.compile(optimizer=Adam(lr=lr), loss=neg_log_likelihood)
return model
......
......@@ -11,7 +11,7 @@ from tensorboard.plugins.hparams import api as hp
class Tuner():
"""A tuner object that will train models with randomly sampled parameters and log them to TensorBoard."""
def __init__(self, train, test, model_fn, hparams, metrics, log_dir, seed=0):
def __init__(self, train, test, model_fn, hparams, metrics, log_dir, oracle, seed=0):
"""Initialize tuner object.
Parameters:
......@@ -34,6 +34,7 @@ class Tuner():
self.hparams = hparams
self.metrics = metrics
self.log_dir = log_dir
self.oracle = oracle
self.random = random.Random(seed)
self.min_loss = float('inf')
......@@ -52,7 +53,7 @@ class Tuner():
"""
weights_dir = os.path.join(run_dir, 'weights/')
model = self.model_fn(self.shape, self.n_styles, params)
model = self.model_fn(self.shape, self.n_styles, self.oracle, params)
checkpoint = ModelCheckpoint(weights_dir, save_best_only=True, save_weights_only=True)
stopping = EarlyStopping(patience=200)
......@@ -64,7 +65,7 @@ class Tuner():
callbacks=[checkpoint, stopping, tensorboard, hpboard])
print('Validation:')
best_model = self.model_fn(self.shape, self.n_styles, params)
best_model = self.model_fn(self.shape, self.n_styles, self.oracle, params)
best_model.load_weights(weights_dir)
metrics = best_model.evaluate(self.test, verbose=2)
return model, metrics[0]
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment