Tensorflow Probability Utilities
This notebook is a bit of a mess after the refactor.
Its code has all been moved to indl.model.tfp
and indl.model.tfp.dsae
Many of the tests have been moved to the unit tests.
import numpy as np
import tensorflow as tf
import tensorflow.keras.layers as tfkl
from tensorflow.keras import backend as K
import tensorflow_probability as tfp
tfd = tfp.distributions
tfpl = tfp.layers
tfb = tfp.bijectors
scale_shift = np.log(np.exp(1) - 1).astype(np.float32)
from indl.model.tfp.devae import *
from indl.model.tfp import *
# An example of how this would work in a variational autoencoder.
N_TIMES = 10
N_SENSORS = 8
N_SAMPLES = 2
N_HIDDEN = 5
KL_WEIGHT = 0.05
t_vec = tf.range(N_TIMES, dtype=tf.float32) / N_TIMES
sig_vec = 1 + tf.exp(-10*(t_vec - 0.5))
def make_model(prior):
input_ = tfkl.Input(shape=(LATENT_SIZE,))
# Encoder
make_latent_dist_fn, latent_params = make_mvn_dist_fn(
input_, LATENT_SIZE, offdiag=True, loc_name="latent_loc")
q_latent = tfpl.DistributionLambda(
name="q_latent",
make_distribution_fn=make_latent_dist_fn,
convert_to_tensor_fn=lambda s: s.sample(N_SAMPLES),
activity_regularizer=tfpl.KLDivergenceRegularizer(prior,
use_exact_kl=True,
weight=KL_WEIGHT)
)(latent_params)
# Decoder
y_ = q_latent[..., tf.newaxis, :] / sig_vec[:, tf.newaxis]
# broadcast-add zeros to restore timesteps
#y_ = q_latent[..., tf.newaxis, :] + tf.zeros([N_TIMES, 1])
#y_ = tf.reshape(y_, [-1, N_TIMES, LATENT_SIZE])
#y_ = tfkl.LSTM(N_HIDDEN, return_sequences=True)(y_)
#y_ = tf.reshape(y_, [N_SAMPLES, -1, N_TIMES, N_HIDDEN])
make_out_dist_fn, out_dist_params = make_mvn_dist_fn(y_, N_SENSORS, loc_name="out_loc")
p_out = tfpl.DistributionLambda(
make_distribution_fn=make_out_dist_fn, name="p_out")(out_dist_params)
# no prior on the output.
# Model
model = tf.keras.Model(inputs=input_, outputs=[q_latent, p_out])
return model
# Create a fake dataset to train the model.
LATENT_SIZE = 4
BATCH_SIZE = 6
# The latents are sampled from a distribution with known parameters.
true_dist = tfd.MultivariateNormalDiag(
loc=[-1., 1., 5, -5], # must have length == LATENT_SIZE
scale_diag=[0.5, 0.5, 0.9, 0.2]
)
# They parameterize sigmoid end points,
from indl.misc.sigfuncs import sigmoid
from functools import partial
t_vec = (np.arange(N_TIMES, dtype=np.float32) / N_TIMES)[None, :]
f_sig = partial(sigmoid, t_vec, B=10, x_offset=0.5)
# which are then mixed with a known mixing matrix
mix_mat = np.array([
[-0.3, -.28, -0.38, -0.45, -0.02, -0.12, -0.05, -0.48],
[0.27, 0.29, -0.34, 0.2, 0.41, 0.08, 0.11, 0.13],
[-0.14, 0.26, -0.28, -0.14, 0.1, -0.2, 0.4, 0.11],
[-0.05, -0.12, 0.28, 0.49, -0.12, 0.1, 0.17, 0.22]
], dtype=np.float32).T
#mix_mat = tf.convert_to_tensor(mix_mat)
def gen_ds(n_iters=1e2, latent_size=LATENT_SIZE):
iter_ix = 0
while iter_ix < n_iters:
_input = tf.ones((latent_size,), dtype=tf.float32)
latent = true_dist.sample().numpy()
_y = np.reshape(latent, [latent_size, 1])
_y = f_sig(K=_y)
_y = mix_mat @ _y
_y = _y.T
yield _input, _y
iter_ix += 1
ds = tf.data.Dataset.from_generator(gen_ds, args=[1e2], output_types=(tf.float32, tf.float32),
output_shapes=((LATENT_SIZE,), (N_TIMES, N_SENSORS)))
ds = ds.map(lambda x, y: (x, (tf.zeros(0, dtype=tf.float32), y))).batch(BATCH_SIZE)
# Train the model.
# Try playing around with the 2nd loss_weights (below) and KL_WEIGHT (above).
N_EPOCHS = 100
K.clear_session()
prior = make_mvn_prior(LATENT_SIZE, trainable_mean=True, trainable_var=True, offdiag=False)
model_ = make_model(prior)
model_.compile(optimizer='adam',
loss=[lambda _, model_latent: tfd.kl_divergence(model_latent, prior),
lambda y_true, model_out: -model_out.log_prob(y_true)],
loss_weights=[0.0, 1.0])
hist = model_.fit(ds, epochs=N_EPOCHS, verbose=2)
lat_wts = model_.get_layer("latent_loc").weights
lat_locs = np.ones((1, LATENT_SIZE)) @ lat_wts[0].numpy() + lat_wts[1].numpy()
mix_wts = model_.get_layer("out_loc").weights
model_out = lat_locs @ mix_wts[0].numpy() + mix_wts[1].numpy()
true_out = mix_mat @ true_dist.mean().numpy()
print(f"Model est lat: {lat_locs}")
print(f"Model est out: {model_out}")
print(f"prior mean: {prior.mean().numpy()}")
print(f"true lat: {true_dist.mean().numpy()}")
print(f"true out: {true_out.T}")
# test LearnableMultivariateNormalDiag
prior_factory = LearnableMultivariateNormalDiag(LATENT_SIZE)
learnable_prior = prior_factory()
sample = learnable_prior.sample((100, 64))
print(sample.shape)
print(learnable_prior.trainable_variables)
K.clear_session()
model_ = make_model(learnable_prior)
model_.compile(optimizer='adam',
loss=[lambda _, model_latent: tfd.kl_divergence(model_latent, learnable_prior),
lambda y_true, model_out: -model_out.log_prob(y_true)],
loss_weights=[0.0, 1.0])
print(learnable_prior.trainable_variables)
print([_.name for _ in model_.trainable_variables])
hist = model_.fit(ds, epochs=N_EPOCHS, verbose=2)
lat_wts = model_.get_layer("latent_loc").weights
lat_locs = np.ones((1, LATENT_SIZE)) @ lat_wts[0].numpy() + lat_wts[1].numpy()
mix_wts = model_.get_layer("out_loc").weights
model_out = lat_locs @ mix_wts[0].numpy() + mix_wts[1].numpy()
true_out = mix_mat @ true_dist.mean().numpy()
print(f"Model est lat: {lat_locs}")
print(f"Model est out: {model_out}")
print(f"prior mean: {learnable_prior.mean().numpy()}")
print(f"true lat: {true_dist.mean().numpy()}")
print(f"true out: {true_out.T}")
Latent Dynamic Factor
# Return 3 outputs, the first 2 are null
#ds_dyn = ds.map(lambda x, y: (x, (y[0], y[0], y[1])))
ds_dyn = ds.map(lambda x, y: (x, y[1]))
KL_WEIGHT = 0.001
LATENT_SIZE_DYNAMIC = 1 # Integer dimensionality of each dynamic, time-variant latent variable `z_t`.
K.clear_session()
tmp = LearnableMultivariateNormalDiagCell(3, 4)
#tmp.build((None, 10, 5))
#tmp.summary()
# test DynamicEncoder and LearnableMultivariateNormalDiagCell
K.clear_session()
dynamic_encoder = DynamicEncoder(N_HIDDEN, N_TIMES, LATENT_SIZE_DYNAMIC)
sample, dynamic_prior = dynamic_encoder.sample_dynamic_prior(
N_TIMES, samples=N_SAMPLES, batches=1)
print(sample.shape)
print("mean:", np.squeeze(dynamic_prior.mean()))
print("stddev:", np.squeeze(dynamic_prior.stddev()))
print([_.name for _ in dynamic_encoder.trainable_variables])
K.clear_session()
f_model = FactorizedAutoEncoder(N_HIDDEN, N_TIMES, LATENT_SIZE, LATENT_SIZE_DYNAMIC, N_SENSORS)
# Most of the trainable variables don't present themselves until the model pieces are called.
print([_.name for _ in f_model.static_encoder.trainable_variables])
print([_.name for _ in f_model.dynamic_encoder.trainable_variables])
print([_.name for _ in f_model.decoder.trainable_variables])
N_EPOCHS = 200
if False:
f_model.compile(optimizer='adam',
loss=lambda y_true, model_out: -model_out.log_prob(y_true))
hist = f_model.fit(ds_dyn, epochs=N_EPOCHS, verbose=2)
else:
@tf.function
def grad(model, inputs, preds):
with tf.GradientTape() as tape:
q_f = model.static_encoder(inputs)
q_z = model.dynamic_encoder(inputs)
p_full = model.decoder([tf.convert_to_tensor(q_f),
tf.convert_to_tensor(q_z)])
# Reconstruction log-likelihood: p(output|input)
recon_post_log_prob = p_full.log_prob(preds)
recon_post_log_prob = tf.reduce_sum(recon_post_log_prob,
axis=-1) # Sum over time axis
recon_post_log_prob = tf.reduce_mean(recon_post_log_prob)
# KL Divergence - analytical
# Static
static_prior = model.static_encoder.static_prior_factory()
stat_kl = tfd.kl_divergence(q_f, static_prior)
stat_kl = KL_WEIGHT * stat_kl
stat_kl = tf.reduce_mean(stat_kl)
# Dynamic
_, dynamic_prior = model.dynamic_encoder.sample_dynamic_prior(
N_TIMES, samples=1, batches=1
)
dyn_kl = tfd.kl_divergence(q_z, dynamic_prior)
dyn_kl = tf.reduce_sum(dyn_kl, axis=-1)
dyn_kl = tf.squeeze(dyn_kl)
dyn_kl = KL_WEIGHT * dyn_kl
dyn_kl = tf.reduce_mean(dyn_kl)
loss = -recon_post_log_prob + stat_kl + dyn_kl
grads = tape.gradient(loss, model.trainable_variables)
return loss, grads, (-recon_post_log_prob, stat_kl, dyn_kl)
optim = tf.keras.optimizers.Adam(learning_rate=1e-3)
for epoch_ix in range(N_EPOCHS):
for step_ix, batch in enumerate(ds_dyn):
inputs, preds = batch
loss, grads, loss_comps = grad(f_model, inputs, preds)
optim.apply_gradients(zip(grads, f_model.trainable_variables))
if (step_ix % 200) == 0:
print('.')
print(f"Epoch {epoch_ix}/{N_EPOCHS}:\tloss={loss:.3f}; "
f"Losses: {[_.numpy() for _ in loss_comps]}")
_, dyn_prior = f_model.dynamic_encoder.sample_dynamic_prior(10)
np.squeeze(dyn_prior.mean().numpy())
K.clear_session()
dynamic_prior = RNNMultivariateNormalDiag(VariationalLSTMCell(N_HIDDEN,
output_dim=LATENT_SIZE_DYNAMIC),
n_timesteps=N_TIMES, output_dim=LATENT_SIZE_DYNAMIC)
sample = dynamic_prior.sample((N_SAMPLES, BATCH_SIZE))
print(sample.shape)
print(dynamic_prior.mean())