Commit e2f10638 authored by Lukas Wolf's avatar Lukas Wolf
Browse files

print trainable parameters to config

parent 9ef2a006
......@@ -4,6 +4,8 @@ from config import config
from keras.callbacks import CSVLogger
from sklearn.model_selection import train_test_split
from utils.utils import train_val_split
import numpy as np
import logging
class prediction_history(tf.keras.callbacks.Callback):
......@@ -43,6 +45,9 @@ class BaseNet:
if self.verbose:
self.model.summary()
logging.info(f"Number of trainable parameters: {np.sum([np.prod(v.get_shape()) for v in self.model.trainable_weights])}")
logging.info(f"Number of non-trainable parameters: {np.sum([np.prod(v.get_shape()) for v in self.model.non_trainable_weights])}")
# abstract method
def _split_model(self):
pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment