Skip to content

Tensorflow Probability Utilities

This notebook is a bit of a mess after the refactor. Its code has all been moved to indl.model.tfp and indl.model.tfp.dsae

Many of the tests have been moved to the unit tests.

import numpy as np
import tensorflow as tf
import tensorflow.keras.layers as tfkl
from tensorflow.keras import backend as K
import tensorflow_probability as tfp
tfd = tfp.distributions
tfpl = tfp.layers
tfb = tfp.bijectors
scale_shift = np.log(np.exp(1) - 1).astype(np.float32)
from indl.model.tfp.devae import *
from indl.model.tfp import *
# An example of how this would work in a variational autoencoder.
N_TIMES = 10
N_SENSORS = 8
N_SAMPLES = 2
N_HIDDEN = 5
KL_WEIGHT = 0.05

t_vec = tf.range(N_TIMES, dtype=tf.float32) / N_TIMES
sig_vec = 1 + tf.exp(-10*(t_vec - 0.5))


def make_model(prior):
    input_ = tfkl.Input(shape=(LATENT_SIZE,))

    # Encoder
    make_latent_dist_fn, latent_params = make_mvn_dist_fn(
        input_, LATENT_SIZE, offdiag=True, loc_name="latent_loc")
    q_latent = tfpl.DistributionLambda(
        name="q_latent",
        make_distribution_fn=make_latent_dist_fn,
        convert_to_tensor_fn=lambda s: s.sample(N_SAMPLES),
        activity_regularizer=tfpl.KLDivergenceRegularizer(prior,
                                                          use_exact_kl=True,
                                                          weight=KL_WEIGHT)
    )(latent_params)

    # Decoder
    y_ = q_latent[..., tf.newaxis, :] / sig_vec[:, tf.newaxis]
    # broadcast-add zeros to restore timesteps
    #y_ = q_latent[..., tf.newaxis, :] + tf.zeros([N_TIMES, 1])
    #y_ = tf.reshape(y_, [-1, N_TIMES, LATENT_SIZE])
    #y_ = tfkl.LSTM(N_HIDDEN, return_sequences=True)(y_)
    #y_ = tf.reshape(y_, [N_SAMPLES, -1, N_TIMES, N_HIDDEN])
    make_out_dist_fn, out_dist_params = make_mvn_dist_fn(y_, N_SENSORS, loc_name="out_loc")
    p_out = tfpl.DistributionLambda(
        make_distribution_fn=make_out_dist_fn, name="p_out")(out_dist_params)
    # no prior on the output.

    # Model
    model = tf.keras.Model(inputs=input_, outputs=[q_latent, p_out])
    return model
# Create a fake dataset to train the model.
LATENT_SIZE = 4
BATCH_SIZE = 6
# The latents are sampled from a distribution with known parameters.
true_dist = tfd.MultivariateNormalDiag(
    loc=[-1., 1., 5, -5],  # must have length == LATENT_SIZE
    scale_diag=[0.5, 0.5, 0.9, 0.2]
)
# They parameterize sigmoid end points,
from indl.misc.sigfuncs import sigmoid
from functools import partial
t_vec = (np.arange(N_TIMES, dtype=np.float32) / N_TIMES)[None, :]
f_sig = partial(sigmoid, t_vec, B=10, x_offset=0.5)
# which are then mixed with a known mixing matrix
mix_mat = np.array([
    [-0.3, -.28, -0.38, -0.45, -0.02, -0.12, -0.05, -0.48],
    [0.27, 0.29, -0.34, 0.2, 0.41, 0.08, 0.11, 0.13],
    [-0.14, 0.26, -0.28, -0.14, 0.1, -0.2, 0.4, 0.11],
    [-0.05, -0.12, 0.28, 0.49, -0.12, 0.1, 0.17, 0.22]
], dtype=np.float32).T
#mix_mat = tf.convert_to_tensor(mix_mat)


def gen_ds(n_iters=1e2, latent_size=LATENT_SIZE):
    iter_ix = 0
    while iter_ix < n_iters:
        _input = tf.ones((latent_size,), dtype=tf.float32)
        latent = true_dist.sample().numpy()
        _y = np.reshape(latent, [latent_size, 1])
        _y = f_sig(K=_y)
        _y = mix_mat @ _y
        _y = _y.T
        yield _input, _y
        iter_ix += 1


ds = tf.data.Dataset.from_generator(gen_ds, args=[1e2], output_types=(tf.float32, tf.float32),
                                    output_shapes=((LATENT_SIZE,), (N_TIMES, N_SENSORS)))
ds = ds.map(lambda x, y: (x, (tf.zeros(0, dtype=tf.float32), y))).batch(BATCH_SIZE)
# Train the model.
# Try playing around with the 2nd loss_weights (below) and KL_WEIGHT (above).
N_EPOCHS = 100

K.clear_session()
prior = make_mvn_prior(LATENT_SIZE, trainable_mean=True, trainable_var=True, offdiag=False)
model_ = make_model(prior)

model_.compile(optimizer='adam',
               loss=[lambda _, model_latent: tfd.kl_divergence(model_latent, prior),
                     lambda y_true, model_out: -model_out.log_prob(y_true)],
               loss_weights=[0.0, 1.0])
hist = model_.fit(ds, epochs=N_EPOCHS, verbose=2)
Epoch 1/100
17/17 - 0s - loss: 91.9389 - q_latent_loss: 4.6523 - p_out_loss: 90.5613
Epoch 2/100
17/17 - 0s - loss: 11531.2471 - q_latent_loss: 4.3536 - p_out_loss: 11529.9590
Epoch 3/100
17/17 - 0s - loss: 743.5315 - q_latent_loss: 4.1376 - p_out_loss: 742.3065
Epoch 4/100
17/17 - 0s - loss: 137048.2969 - q_latent_loss: 3.9767 - p_out_loss: 137047.1250
Epoch 5/100
17/17 - 0s - loss: 113.2025 - q_latent_loss: 3.8272 - p_out_loss: 112.0695
Epoch 6/100
17/17 - 0s - loss: 74.9116 - q_latent_loss: 3.7352 - p_out_loss: 73.8058
Epoch 7/100
17/17 - 0s - loss: 79.9207 - q_latent_loss: 3.6573 - p_out_loss: 78.8380
Epoch 8/100
17/17 - 0s - loss: 372.0323 - q_latent_loss: 3.5850 - p_out_loss: 370.9711
Epoch 9/100
17/17 - 0s - loss: 60.3690 - q_latent_loss: 3.5167 - p_out_loss: 59.3279
Epoch 10/100
17/17 - 0s - loss: 7418.3252 - q_latent_loss: 3.4512 - p_out_loss: 7417.3027
Epoch 11/100
17/17 - 0s - loss: 3513.5659 - q_latent_loss: 3.3824 - p_out_loss: 3512.5649
Epoch 12/100
17/17 - 0s - loss: 53.7796 - q_latent_loss: 3.3220 - p_out_loss: 52.7962
Epoch 13/100
17/17 - 0s - loss: 20.4849 - q_latent_loss: 3.2672 - p_out_loss: 19.5177
Epoch 14/100
17/17 - 0s - loss: 1562.3738 - q_latent_loss: 3.2157 - p_out_loss: 1561.4218
Epoch 15/100
17/17 - 0s - loss: 963.6710 - q_latent_loss: 3.1652 - p_out_loss: 962.7339
Epoch 16/100
17/17 - 0s - loss: 2838.6292 - q_latent_loss: 3.1165 - p_out_loss: 2837.7068
Epoch 17/100
17/17 - 0s - loss: 35.5000 - q_latent_loss: 3.0708 - p_out_loss: 34.5910
Epoch 18/100
17/17 - 0s - loss: 706.3328 - q_latent_loss: 3.0287 - p_out_loss: 705.4363
Epoch 19/100
17/17 - 0s - loss: 52.2510 - q_latent_loss: 2.9887 - p_out_loss: 51.3663
Epoch 20/100
17/17 - 0s - loss: 7950.6792 - q_latent_loss: 2.9512 - p_out_loss: 7949.8062
Epoch 21/100
17/17 - 0s - loss: 190.5290 - q_latent_loss: 2.9112 - p_out_loss: 189.6672
Epoch 22/100
17/17 - 0s - loss: 27.1404 - q_latent_loss: 2.8756 - p_out_loss: 26.2891
Epoch 23/100
17/17 - 0s - loss: 222.1122 - q_latent_loss: 2.8434 - p_out_loss: 221.2705
Epoch 24/100
17/17 - 0s - loss: 26.6817 - q_latent_loss: 2.8132 - p_out_loss: 25.8489
Epoch 25/100
17/17 - 0s - loss: 301.0465 - q_latent_loss: 2.7844 - p_out_loss: 300.2223
Epoch 26/100
17/17 - 0s - loss: 17.3472 - q_latent_loss: 2.7568 - p_out_loss: 16.5312
Epoch 27/100
17/17 - 0s - loss: 47.1069 - q_latent_loss: 2.7306 - p_out_loss: 46.2986
Epoch 28/100
17/17 - 0s - loss: 34.0318 - q_latent_loss: 2.7057 - p_out_loss: 33.2309
Epoch 29/100
17/17 - 0s - loss: 25.4790 - q_latent_loss: 2.6819 - p_out_loss: 24.6851
Epoch 30/100
17/17 - 0s - loss: 35.2237 - q_latent_loss: 2.6592 - p_out_loss: 34.4365
Epoch 31/100
17/17 - 0s - loss: 17.7497 - q_latent_loss: 2.6375 - p_out_loss: 16.9690
Epoch 32/100
17/17 - 0s - loss: 19.0096 - q_latent_loss: 2.6168 - p_out_loss: 18.2350
Epoch 33/100
17/17 - 0s - loss: 17.3889 - q_latent_loss: 2.5969 - p_out_loss: 16.6202
Epoch 34/100
17/17 - 0s - loss: 3164.2163 - q_latent_loss: 2.5783 - p_out_loss: 3163.4539
Epoch 35/100
17/17 - 0s - loss: 318.1325 - q_latent_loss: 2.5614 - p_out_loss: 317.3743
Epoch 36/100
17/17 - 0s - loss: 496.7764 - q_latent_loss: 2.5442 - p_out_loss: 496.0233
Epoch 37/100
17/17 - 0s - loss: 43.8554 - q_latent_loss: 2.5274 - p_out_loss: 43.1073
Epoch 38/100
17/17 - 0s - loss: 127.4449 - q_latent_loss: 2.5114 - p_out_loss: 126.7015
Epoch 39/100
17/17 - 0s - loss: 34.3881 - q_latent_loss: 2.4962 - p_out_loss: 33.6492
Epoch 40/100
17/17 - 0s - loss: 92.4194 - q_latent_loss: 2.4816 - p_out_loss: 91.6848
Epoch 41/100
17/17 - 0s - loss: 2711.8994 - q_latent_loss: 2.4669 - p_out_loss: 2711.1697
Epoch 42/100
17/17 - 0s - loss: 32.7170 - q_latent_loss: 2.4484 - p_out_loss: 31.9922
Epoch 43/100
17/17 - 0s - loss: 368.3795 - q_latent_loss: 2.4338 - p_out_loss: 367.6590
Epoch 44/100
17/17 - 0s - loss: 185.1732 - q_latent_loss: 2.4207 - p_out_loss: 184.4566
Epoch 45/100
17/17 - 0s - loss: 1496.3325 - q_latent_loss: 2.4090 - p_out_loss: 1495.6195
Epoch 46/100
17/17 - 0s - loss: 40.2950 - q_latent_loss: 2.3979 - p_out_loss: 39.5852
Epoch 47/100
17/17 - 0s - loss: 3341.2693 - q_latent_loss: 2.3869 - p_out_loss: 3340.5627
Epoch 48/100
17/17 - 0s - loss: 317.7818 - q_latent_loss: 2.3705 - p_out_loss: 317.0801
Epoch 49/100
17/17 - 0s - loss: 129.0085 - q_latent_loss: 2.3577 - p_out_loss: 128.3107
Epoch 50/100
17/17 - 0s - loss: 677.0181 - q_latent_loss: 2.3475 - p_out_loss: 676.3232
Epoch 51/100
17/17 - 0s - loss: 44.6301 - q_latent_loss: 2.3379 - p_out_loss: 43.9381
Epoch 52/100
17/17 - 0s - loss: 13.6527 - q_latent_loss: 2.3290 - p_out_loss: 12.9633
Epoch 53/100
17/17 - 0s - loss: 2174.0337 - q_latent_loss: 2.3206 - p_out_loss: 2173.3469
Epoch 54/100
17/17 - 0s - loss: 85.2265 - q_latent_loss: 2.3118 - p_out_loss: 84.5422
Epoch 55/100
17/17 - 0s - loss: 13.3743 - q_latent_loss: 2.3037 - p_out_loss: 12.6924
Epoch 56/100
17/17 - 0s - loss: 46996.4805 - q_latent_loss: 2.2929 - p_out_loss: 46995.8008
Epoch 57/100
17/17 - 0s - loss: 23.0730 - q_latent_loss: 2.2527 - p_out_loss: 22.4061
Epoch 58/100
17/17 - 0s - loss: 11.3692 - q_latent_loss: 2.2361 - p_out_loss: 10.7074
Epoch 59/100
17/17 - 0s - loss: 25.0329 - q_latent_loss: 2.2285 - p_out_loss: 24.3733
Epoch 60/100
17/17 - 0s - loss: 17.7706 - q_latent_loss: 2.2227 - p_out_loss: 17.1126
Epoch 61/100
17/17 - 0s - loss: 10.4981 - q_latent_loss: 2.2176 - p_out_loss: 9.8417
Epoch 62/100
17/17 - 0s - loss: 32.7279 - q_latent_loss: 2.2127 - p_out_loss: 32.0729
Epoch 63/100
17/17 - 0s - loss: 30.3548 - q_latent_loss: 2.2081 - p_out_loss: 29.7012
Epoch 64/100
17/17 - 0s - loss: 30.8280 - q_latent_loss: 2.2037 - p_out_loss: 30.1757
Epoch 65/100
17/17 - 0s - loss: 41.9700 - q_latent_loss: 2.1996 - p_out_loss: 41.3189
Epoch 66/100
17/17 - 0s - loss: 54.7435 - q_latent_loss: 2.1956 - p_out_loss: 54.0936
Epoch 67/100
17/17 - 0s - loss: 43.0750 - q_latent_loss: 2.1918 - p_out_loss: 42.4262
Epoch 68/100
17/17 - 0s - loss: 243.1561 - q_latent_loss: 2.1882 - p_out_loss: 242.5083
Epoch 69/100
17/17 - 0s - loss: 22.0350 - q_latent_loss: 2.1842 - p_out_loss: 21.3885
Epoch 70/100
17/17 - 0s - loss: 99.8952 - q_latent_loss: 2.1805 - p_out_loss: 99.2498
Epoch 71/100
17/17 - 0s - loss: 31.7489 - q_latent_loss: 2.1772 - p_out_loss: 31.1044
Epoch 72/100
17/17 - 0s - loss: 10.3156 - q_latent_loss: 2.1741 - p_out_loss: 9.6721
Epoch 73/100
17/17 - 0s - loss: 10.8083 - q_latent_loss: 2.1710 - p_out_loss: 10.1656
Epoch 74/100
17/17 - 0s - loss: 149.1215 - q_latent_loss: 2.1682 - p_out_loss: 148.4797
Epoch 75/100
17/17 - 0s - loss: 11.4474 - q_latent_loss: 2.1655 - p_out_loss: 10.8064
Epoch 76/100
17/17 - 0s - loss: 11.3532 - q_latent_loss: 2.1629 - p_out_loss: 10.7130
Epoch 77/100
17/17 - 0s - loss: 9.8263 - q_latent_loss: 2.1603 - p_out_loss: 9.1868
Epoch 78/100
17/17 - 0s - loss: 12.8481 - q_latent_loss: 2.1578 - p_out_loss: 12.2094
Epoch 79/100
17/17 - 0s - loss: 10.0819 - q_latent_loss: 2.1554 - p_out_loss: 9.4439
Epoch 80/100
17/17 - 0s - loss: 9.9589 - q_latent_loss: 2.1531 - p_out_loss: 9.3216
Epoch 81/100
17/17 - 0s - loss: 148.1879 - q_latent_loss: 2.1508 - p_out_loss: 147.5513
Epoch 82/100
17/17 - 0s - loss: 13.6919 - q_latent_loss: 2.1483 - p_out_loss: 13.0560
Epoch 83/100
17/17 - 0s - loss: 11.4130 - q_latent_loss: 2.1462 - p_out_loss: 10.7777
Epoch 84/100
17/17 - 0s - loss: 62.1876 - q_latent_loss: 2.1442 - p_out_loss: 61.5529
Epoch 85/100
17/17 - 0s - loss: 37.1363 - q_latent_loss: 2.1424 - p_out_loss: 36.5021
Epoch 86/100
17/17 - 0s - loss: 9.0239 - q_latent_loss: 2.1406 - p_out_loss: 8.3903
Epoch 87/100
17/17 - 0s - loss: 9.6573 - q_latent_loss: 2.1389 - p_out_loss: 9.0242
Epoch 88/100
17/17 - 0s - loss: 16.5087 - q_latent_loss: 2.1371 - p_out_loss: 15.8761
Epoch 89/100
17/17 - 0s - loss: 9.0933 - q_latent_loss: 2.1355 - p_out_loss: 8.4612
Epoch 90/100
17/17 - 0s - loss: 11.1045 - q_latent_loss: 2.1339 - p_out_loss: 10.4729
Epoch 91/100
17/17 - 0s - loss: 14.6627 - q_latent_loss: 2.1323 - p_out_loss: 14.0315
Epoch 92/100
17/17 - 0s - loss: 12.7246 - q_latent_loss: 2.1308 - p_out_loss: 12.0939
Epoch 93/100
17/17 - 0s - loss: 18.2482 - q_latent_loss: 2.1294 - p_out_loss: 17.6179
Epoch 94/100
17/17 - 0s - loss: 12.0889 - q_latent_loss: 2.1280 - p_out_loss: 11.4590
Epoch 95/100
17/17 - 0s - loss: 35.7015 - q_latent_loss: 2.1267 - p_out_loss: 35.0720
Epoch 96/100
17/17 - 0s - loss: 156.0881 - q_latent_loss: 2.1253 - p_out_loss: 155.4590
Epoch 97/100
17/17 - 0s - loss: 8.8568 - q_latent_loss: 2.1240 - p_out_loss: 8.2280
Epoch 98/100
17/17 - 0s - loss: 10.9563 - q_latent_loss: 2.1228 - p_out_loss: 10.3279
Epoch 99/100
17/17 - 0s - loss: 10.1586 - q_latent_loss: 2.1217 - p_out_loss: 9.5306
Epoch 100/100
17/17 - 0s - loss: 14.0194 - q_latent_loss: 2.1206 - p_out_loss: 13.3917

lat_wts = model_.get_layer("latent_loc").weights
lat_locs = np.ones((1, LATENT_SIZE)) @ lat_wts[0].numpy() + lat_wts[1].numpy()
mix_wts = model_.get_layer("out_loc").weights
model_out = lat_locs @ mix_wts[0].numpy() + mix_wts[1].numpy()
true_out = mix_mat @ true_dist.mean().numpy()

print(f"Model est lat: {lat_locs}")
print(f"Model est out: {model_out}")
print(f"prior mean: {prior.mean().numpy()}")
print(f"true lat: {true_dist.mean().numpy()}")
print(f"true out: {true_out.T}")
Model est lat: [[ 0.51195845 -1.02728739  0.67845761 -0.11073773]]
Model est out: [[ 0.10048061  1.36010552  0.30864524 -0.09840383  1.17217551 -0.67099368
   0.95406085  0.03414997]]
prior mean: [ 0.5117957  -0.8991166   0.66152537 -0.11197621]
true lat: [-1.  1.  5. -5.]
true out: [ 0.12000006  2.4699998  -2.76       -2.5         1.53       -1.3
  1.31        0.05999994]

# test LearnableMultivariateNormalDiag
prior_factory = LearnableMultivariateNormalDiag(LATENT_SIZE)
learnable_prior = prior_factory()
sample = learnable_prior.sample((100, 64))
print(sample.shape)
print(learnable_prior.trainable_variables)
(100, 64, 4)
(<tf.Variable 'learnable_multivariate_normal_diag_2/mean:0' shape=(4,) dtype=float32, numpy=array([ 0.16748714, -0.1799583 ,  0.0387747 ,  0.11378615], dtype=float32)>, <tf.Variable 'learnable_multivariate_normal_diag_2/transformed_scale:0' shape=(4,) dtype=float32, numpy=array([-0.11407143,  0.06062925,  0.02439827, -0.01735771], dtype=float32)>)

K.clear_session()
model_ = make_model(learnable_prior)

model_.compile(optimizer='adam',
               loss=[lambda _, model_latent: tfd.kl_divergence(model_latent, learnable_prior),
                     lambda y_true, model_out: -model_out.log_prob(y_true)],
               loss_weights=[0.0, 1.0])
print(learnable_prior.trainable_variables)
print([_.name for _ in model_.trainable_variables])

hist = model_.fit(ds, epochs=N_EPOCHS, verbose=2)

lat_wts = model_.get_layer("latent_loc").weights
lat_locs = np.ones((1, LATENT_SIZE)) @ lat_wts[0].numpy() + lat_wts[1].numpy()
mix_wts = model_.get_layer("out_loc").weights
model_out = lat_locs @ mix_wts[0].numpy() + mix_wts[1].numpy()
true_out = mix_mat @ true_dist.mean().numpy()

print(f"Model est lat: {lat_locs}")
print(f"Model est out: {model_out}")
print(f"prior mean: {learnable_prior.mean().numpy()}")
print(f"true lat: {true_dist.mean().numpy()}")
print(f"true out: {true_out.T}")
(<tf.Variable 'learnable_multivariate_normal_diag_2/mean:0' shape=(4,) dtype=float32, numpy=array([ 0.16748714, -0.1799583 ,  0.0387747 ,  0.11378615], dtype=float32)>, <tf.Variable 'learnable_multivariate_normal_diag_2/transformed_scale:0' shape=(4,) dtype=float32, numpy=array([-0.11407143,  0.06062925,  0.02439827, -0.01735771], dtype=float32)>)
['dense/kernel:0', 'dense/bias:0', 'latent_loc/kernel:0', 'latent_loc/bias:0', 'dense_1/kernel:0', 'dense_1/bias:0', 'out_loc/kernel:0', 'out_loc/bias:0', 'learnable_multivariate_normal_diag_2/mean:0', 'learnable_multivariate_normal_diag_2/transformed_scale:0']
Epoch 1/100
17/17 - 0s - loss: 266.6295 - q_latent_loss: 6.5640 - p_out_loss: 264.6859
Epoch 2/100
17/17 - 0s - loss: 2032.0966 - q_latent_loss: 6.2853 - p_out_loss: 2030.2358
Epoch 3/100
17/17 - 0s - loss: 83.5799 - q_latent_loss: 6.0718 - p_out_loss: 81.7824
Epoch 4/100
17/17 - 0s - loss: 82.7522 - q_latent_loss: 5.8942 - p_out_loss: 81.0072
Epoch 5/100
17/17 - 0s - loss: 59.2224 - q_latent_loss: 5.7269 - p_out_loss: 57.5269
Epoch 6/100
17/17 - 0s - loss: 38.8948 - q_latent_loss: 5.5706 - p_out_loss: 37.2456
Epoch 7/100
17/17 - 0s - loss: 47.8537 - q_latent_loss: 5.4227 - p_out_loss: 46.2483
Epoch 8/100
17/17 - 0s - loss: 60.9186 - q_latent_loss: 5.2828 - p_out_loss: 59.3546
Epoch 9/100
17/17 - 0s - loss: 80.7008 - q_latent_loss: 5.1479 - p_out_loss: 79.1768
Epoch 10/100
17/17 - 0s - loss: 29.5548 - q_latent_loss: 5.0204 - p_out_loss: 28.0686
Epoch 11/100
17/17 - 0s - loss: 100.5337 - q_latent_loss: 4.9013 - p_out_loss: 99.0827
Epoch 12/100
17/17 - 0s - loss: 208.5356 - q_latent_loss: 4.7856 - p_out_loss: 207.1189
Epoch 13/100
17/17 - 0s - loss: 47.4895 - q_latent_loss: 4.6692 - p_out_loss: 46.1072
Epoch 14/100
17/17 - 0s - loss: 51.8070 - q_latent_loss: 4.5624 - p_out_loss: 50.4563
Epoch 15/100
17/17 - 0s - loss: 49.2825 - q_latent_loss: 4.4640 - p_out_loss: 47.9610
Epoch 16/100
17/17 - 0s - loss: 63.7341 - q_latent_loss: 4.3716 - p_out_loss: 62.4399
Epoch 17/100
17/17 - 0s - loss: 35.2299 - q_latent_loss: 4.2837 - p_out_loss: 33.9617
Epoch 18/100
17/17 - 0s - loss: 45.8432 - q_latent_loss: 4.2006 - p_out_loss: 44.5997
Epoch 19/100
17/17 - 0s - loss: 25.7876 - q_latent_loss: 4.1215 - p_out_loss: 24.5675
Epoch 20/100
17/17 - 0s - loss: 268.8558 - q_latent_loss: 4.0396 - p_out_loss: 267.6599
Epoch 21/100
17/17 - 0s - loss: 45.2869 - q_latent_loss: 3.9513 - p_out_loss: 44.1171
Epoch 22/100
17/17 - 0s - loss: 30.1766 - q_latent_loss: 3.8787 - p_out_loss: 29.0284
Epoch 23/100
17/17 - 0s - loss: 32.2969 - q_latent_loss: 3.8108 - p_out_loss: 31.1688
Epoch 24/100
17/17 - 0s - loss: 64.0437 - q_latent_loss: 3.7457 - p_out_loss: 62.9348
Epoch 25/100
17/17 - 0s - loss: 39.9464 - q_latent_loss: 3.6825 - p_out_loss: 38.8562
Epoch 26/100
17/17 - 0s - loss: 33.5094 - q_latent_loss: 3.6220 - p_out_loss: 32.4372
Epoch 27/100
17/17 - 0s - loss: 31.4306 - q_latent_loss: 3.5643 - p_out_loss: 30.3755
Epoch 28/100
17/17 - 0s - loss: 27.8061 - q_latent_loss: 3.5087 - p_out_loss: 26.7674
Epoch 29/100
17/17 - 0s - loss: 65.8272 - q_latent_loss: 3.4540 - p_out_loss: 64.8047
Epoch 30/100
17/17 - 0s - loss: 25.9475 - q_latent_loss: 3.4009 - p_out_loss: 24.9408
Epoch 31/100
17/17 - 0s - loss: 30.2780 - q_latent_loss: 3.3515 - p_out_loss: 29.2859
Epoch 32/100
17/17 - 0s - loss: 21.8850 - q_latent_loss: 3.3044 - p_out_loss: 20.9068
Epoch 33/100
17/17 - 0s - loss: 36.6851 - q_latent_loss: 3.2587 - p_out_loss: 35.7204
Epoch 34/100
17/17 - 0s - loss: 25.5569 - q_latent_loss: 3.2120 - p_out_loss: 24.6061
Epoch 35/100
17/17 - 0s - loss: 24.6902 - q_latent_loss: 3.1682 - p_out_loss: 23.7523
Epoch 36/100
17/17 - 0s - loss: 116.0450 - q_latent_loss: 3.1269 - p_out_loss: 115.1194
Epoch 37/100
17/17 - 0s - loss: 20.0418 - q_latent_loss: 3.0908 - p_out_loss: 19.1268
Epoch 38/100
17/17 - 0s - loss: 56.1398 - q_latent_loss: 3.0497 - p_out_loss: 55.2370
Epoch 39/100
17/17 - 0s - loss: 27.5171 - q_latent_loss: 3.0067 - p_out_loss: 26.6270
Epoch 40/100
17/17 - 0s - loss: 20.7006 - q_latent_loss: 2.9684 - p_out_loss: 19.8219
Epoch 41/100
17/17 - 0s - loss: 26.9046 - q_latent_loss: 2.9328 - p_out_loss: 26.0364
Epoch 42/100
17/17 - 0s - loss: 18.7693 - q_latent_loss: 2.8996 - p_out_loss: 17.9110
Epoch 43/100
17/17 - 0s - loss: 22.3650 - q_latent_loss: 2.8670 - p_out_loss: 21.5163
Epoch 44/100
17/17 - 0s - loss: 32.9155 - q_latent_loss: 2.8352 - p_out_loss: 32.0763
Epoch 45/100
17/17 - 0s - loss: 19.9130 - q_latent_loss: 2.8037 - p_out_loss: 19.0830
Epoch 46/100
17/17 - 0s - loss: 19.9001 - q_latent_loss: 2.7740 - p_out_loss: 19.0789
Epoch 47/100
17/17 - 0s - loss: 25.4838 - q_latent_loss: 2.7436 - p_out_loss: 24.6716
Epoch 48/100
17/17 - 0s - loss: 23.9622 - q_latent_loss: 2.7135 - p_out_loss: 23.1589
Epoch 49/100
17/17 - 0s - loss: 20.7703 - q_latent_loss: 2.6849 - p_out_loss: 19.9756
Epoch 50/100
17/17 - 0s - loss: 19.6302 - q_latent_loss: 2.6576 - p_out_loss: 18.8435
Epoch 51/100
17/17 - 0s - loss: 18.7125 - q_latent_loss: 2.6321 - p_out_loss: 17.9334
Epoch 52/100
17/17 - 0s - loss: 21.4065 - q_latent_loss: 2.6073 - p_out_loss: 20.6347
Epoch 53/100
17/17 - 0s - loss: 37.3685 - q_latent_loss: 2.5831 - p_out_loss: 36.6039
Epoch 54/100
17/17 - 0s - loss: 15.8975 - q_latent_loss: 2.5606 - p_out_loss: 15.1395
Epoch 55/100
17/17 - 0s - loss: 15.6574 - q_latent_loss: 2.5387 - p_out_loss: 14.9059
Epoch 56/100
17/17 - 0s - loss: 28.7901 - q_latent_loss: 2.5174 - p_out_loss: 28.0449
Epoch 57/100
17/17 - 0s - loss: 99.3240 - q_latent_loss: 2.4972 - p_out_loss: 98.5848
Epoch 58/100
17/17 - 0s - loss: 19.6783 - q_latent_loss: 2.4761 - p_out_loss: 18.9453
Epoch 59/100
17/17 - 0s - loss: 18.9958 - q_latent_loss: 2.4563 - p_out_loss: 18.2688
Epoch 60/100
17/17 - 0s - loss: 21.3663 - q_latent_loss: 2.4364 - p_out_loss: 20.6451
Epoch 61/100
17/17 - 0s - loss: 26.8008 - q_latent_loss: 2.4179 - p_out_loss: 26.0850
Epoch 62/100
17/17 - 0s - loss: 13.9355 - q_latent_loss: 2.3984 - p_out_loss: 13.2256
Epoch 63/100
17/17 - 0s - loss: 14.0786 - q_latent_loss: 2.3803 - p_out_loss: 13.3740
Epoch 64/100
17/17 - 0s - loss: 20.6991 - q_latent_loss: 2.3634 - p_out_loss: 19.9995
Epoch 65/100
17/17 - 0s - loss: 33.9438 - q_latent_loss: 2.3476 - p_out_loss: 33.2488
Epoch 66/100
17/17 - 0s - loss: 19.5023 - q_latent_loss: 2.3325 - p_out_loss: 18.8118
Epoch 67/100
17/17 - 0s - loss: 16.1214 - q_latent_loss: 2.3179 - p_out_loss: 15.4353
Epoch 68/100
17/17 - 0s - loss: 33.3983 - q_latent_loss: 2.3044 - p_out_loss: 32.7162
Epoch 69/100
17/17 - 0s - loss: 14.1833 - q_latent_loss: 2.2933 - p_out_loss: 13.5045
Epoch 70/100
17/17 - 0s - loss: 33.0913 - q_latent_loss: 2.2802 - p_out_loss: 32.4163
Epoch 71/100
17/17 - 0s - loss: 15.5565 - q_latent_loss: 2.2661 - p_out_loss: 14.8857
Epoch 72/100
17/17 - 0s - loss: 23.7552 - q_latent_loss: 2.2522 - p_out_loss: 23.0885
Epoch 73/100
17/17 - 0s - loss: 15.8186 - q_latent_loss: 2.2402 - p_out_loss: 15.1554
Epoch 74/100
17/17 - 0s - loss: 15.8109 - q_latent_loss: 2.2277 - p_out_loss: 15.1514
Epoch 75/100
17/17 - 0s - loss: 23.2216 - q_latent_loss: 2.2153 - p_out_loss: 22.5659
Epoch 76/100
17/17 - 0s - loss: 17.1244 - q_latent_loss: 2.2021 - p_out_loss: 16.4725
Epoch 77/100
17/17 - 0s - loss: 20.2818 - q_latent_loss: 2.1875 - p_out_loss: 19.6343
Epoch 78/100
17/17 - 0s - loss: 20.1146 - q_latent_loss: 2.1751 - p_out_loss: 19.4708
Epoch 79/100
17/17 - 0s - loss: 12.0626 - q_latent_loss: 2.1647 - p_out_loss: 11.4218
Epoch 80/100
17/17 - 0s - loss: 17.5959 - q_latent_loss: 2.1547 - p_out_loss: 16.9580
Epoch 81/100
17/17 - 0s - loss: 46.9798 - q_latent_loss: 2.1428 - p_out_loss: 46.3454
Epoch 82/100
17/17 - 0s - loss: 20.0080 - q_latent_loss: 2.1244 - p_out_loss: 19.3792
Epoch 83/100
17/17 - 0s - loss: 11.8902 - q_latent_loss: 2.1120 - p_out_loss: 11.2651
Epoch 84/100
17/17 - 0s - loss: 17.5359 - q_latent_loss: 2.1015 - p_out_loss: 16.9139
Epoch 85/100
17/17 - 0s - loss: 14.8084 - q_latent_loss: 2.0917 - p_out_loss: 14.1892
Epoch 86/100
17/17 - 0s - loss: 9.6016 - q_latent_loss: 2.0823 - p_out_loss: 8.9852
Epoch 87/100
17/17 - 0s - loss: 13.4432 - q_latent_loss: 2.0736 - p_out_loss: 12.8294
Epoch 88/100
17/17 - 0s - loss: 16.5978 - q_latent_loss: 2.0652 - p_out_loss: 15.9865
Epoch 89/100
17/17 - 0s - loss: 21.3484 - q_latent_loss: 2.0557 - p_out_loss: 20.7399
Epoch 90/100
17/17 - 0s - loss: 11.3400 - q_latent_loss: 2.0439 - p_out_loss: 10.7350
Epoch 91/100
17/17 - 0s - loss: 14.2551 - q_latent_loss: 2.0345 - p_out_loss: 13.6529
Epoch 92/100
17/17 - 0s - loss: 14.2384 - q_latent_loss: 2.0268 - p_out_loss: 13.6385
Epoch 93/100
17/17 - 0s - loss: 16.3489 - q_latent_loss: 2.0194 - p_out_loss: 15.7511
Epoch 94/100
17/17 - 0s - loss: 14.2265 - q_latent_loss: 2.0118 - p_out_loss: 13.6310
Epoch 95/100
17/17 - 0s - loss: 11.5992 - q_latent_loss: 2.0041 - p_out_loss: 11.0060
Epoch 96/100
17/17 - 0s - loss: 11.7333 - q_latent_loss: 1.9971 - p_out_loss: 11.1421
Epoch 97/100
17/17 - 0s - loss: 12.1329 - q_latent_loss: 1.9904 - p_out_loss: 11.5438
Epoch 98/100
17/17 - 0s - loss: 13.7211 - q_latent_loss: 1.9832 - p_out_loss: 13.1341
Epoch 99/100
17/17 - 0s - loss: 37.2112 - q_latent_loss: 1.9745 - p_out_loss: 36.6267
Epoch 100/100
17/17 - 0s - loss: 9.3972 - q_latent_loss: 1.9630 - p_out_loss: 8.8161
Model est lat: [[ 2.51292503 -1.26424221 -1.11180196 -0.02588509]]
Model est out: [[-0.27425705  2.39304429 -2.49596335  0.29950965  1.37276651 -0.43098222
  -0.45473986  0.07839915]]
prior mean: [ 1.3514248  -1.2266612  -0.8751619  -0.03475915]
true lat: [-1.  1.  5. -5.]
true out: [ 0.12000006  2.4699998  -2.76       -2.5         1.53       -1.3
  1.31        0.05999994]

Latent Dynamic Factor

# Return 3 outputs, the first 2 are null
#ds_dyn = ds.map(lambda x, y: (x, (y[0], y[0], y[1])))
ds_dyn = ds.map(lambda x, y: (x, y[1]))
KL_WEIGHT = 0.001
LATENT_SIZE_DYNAMIC = 1  # Integer dimensionality of each dynamic, time-variant latent variable `z_t`.
K.clear_session()
tmp = LearnableMultivariateNormalDiagCell(3, 4)
#tmp.build((None, 10, 5))
#tmp.summary()
# test DynamicEncoder and LearnableMultivariateNormalDiagCell
K.clear_session()
dynamic_encoder = DynamicEncoder(N_HIDDEN, N_TIMES, LATENT_SIZE_DYNAMIC)
sample, dynamic_prior = dynamic_encoder.sample_dynamic_prior(
    N_TIMES, samples=N_SAMPLES, batches=1)
print(sample.shape)
print("mean:", np.squeeze(dynamic_prior.mean()))
print("stddev:", np.squeeze(dynamic_prior.stddev()))
print([_.name for _ in dynamic_encoder.trainable_variables])
(2, 1, 10, 1)
mean: [[ 0.          0.42388976  0.45631832  0.365768    0.20130846  0.37873474
   0.31262326  0.26073667  0.15399611  0.14049806]
 [ 0.          0.08804662  0.0361465  -0.03267653 -0.08733355  0.19941618
   0.30335566  0.3730844   0.2744042   0.17948757]]
stddev: [[1.00001    0.929902   0.9554563  0.99371654 1.0142238  0.97018814
  1.0006421  1.0033575  1.0105829  1.0065393 ]
 [1.00001    0.9952911  1.0008274  1.0000954  0.99874425 0.98230976
  0.9771384  0.9683764  1.000174   1.0049998 ]]
['learnable_multivariate_normal_diag_cell/mvndiagcell_lstm/kernel:0', 'learnable_multivariate_normal_diag_cell/mvndiagcell_lstm/recurrent_kernel:0', 'learnable_multivariate_normal_diag_cell/mvndiagcell_lstm/bias:0', 'learnable_multivariate_normal_diag_cell/mvndiagcell_loc/kernel:0', 'learnable_multivariate_normal_diag_cell/mvndiagcell_loc/bias:0', 'learnable_multivariate_normal_diag_cell/mvndiagcell_scale/kernel:0', 'learnable_multivariate_normal_diag_cell/mvndiagcell_scale/bias:0']

K.clear_session()
f_model = FactorizedAutoEncoder(N_HIDDEN, N_TIMES, LATENT_SIZE, LATENT_SIZE_DYNAMIC, N_SENSORS)
# Most of the trainable variables don't present themselves until the model pieces are called.
print([_.name for _ in f_model.static_encoder.trainable_variables])
print([_.name for _ in f_model.dynamic_encoder.trainable_variables])
print([_.name for _ in f_model.decoder.trainable_variables])
[]
['learnable_multivariate_normal_diag/mean:0', 'learnable_multivariate_normal_diag/untransformed_stddev:0']
[]
[]

N_EPOCHS = 200
if False:
    f_model.compile(optimizer='adam',
                    loss=lambda y_true, model_out: -model_out.log_prob(y_true))
    hist = f_model.fit(ds_dyn, epochs=N_EPOCHS, verbose=2)
else:
    @tf.function
    def grad(model, inputs, preds):
        with tf.GradientTape() as tape:
            q_f = model.static_encoder(inputs)
            q_z = model.dynamic_encoder(inputs)
            p_full = model.decoder([tf.convert_to_tensor(q_f),
                                    tf.convert_to_tensor(q_z)])

            # Reconstruction log-likelihood: p(output|input)
            recon_post_log_prob = p_full.log_prob(preds)
            recon_post_log_prob = tf.reduce_sum(recon_post_log_prob,
                                                axis=-1)  # Sum over time axis
            recon_post_log_prob = tf.reduce_mean(recon_post_log_prob)

            # KL Divergence - analytical
            # Static
            static_prior = model.static_encoder.static_prior_factory()
            stat_kl = tfd.kl_divergence(q_f, static_prior)
            stat_kl = KL_WEIGHT * stat_kl
            stat_kl = tf.reduce_mean(stat_kl)

            # Dynamic
            _, dynamic_prior = model.dynamic_encoder.sample_dynamic_prior(
                N_TIMES, samples=1, batches=1
            )
            dyn_kl = tfd.kl_divergence(q_z, dynamic_prior)
            dyn_kl = tf.reduce_sum(dyn_kl, axis=-1)
            dyn_kl = tf.squeeze(dyn_kl)
            dyn_kl = KL_WEIGHT * dyn_kl
            dyn_kl = tf.reduce_mean(dyn_kl)

            loss = -recon_post_log_prob + stat_kl + dyn_kl

        grads = tape.gradient(loss, model.trainable_variables)
        return loss, grads, (-recon_post_log_prob, stat_kl, dyn_kl)

    optim = tf.keras.optimizers.Adam(learning_rate=1e-3)
    for epoch_ix in range(N_EPOCHS):
        for step_ix, batch in enumerate(ds_dyn):
            inputs, preds = batch
            loss, grads, loss_comps = grad(f_model, inputs, preds)
            optim.apply_gradients(zip(grads, f_model.trainable_variables))
            if (step_ix % 200) == 0:
                print('.')
        print(f"Epoch {epoch_ix}/{N_EPOCHS}:\tloss={loss:.3f}; "
              f"Losses: {[_.numpy() for _ in loss_comps]}")
.
Epoch 0/200:    loss=3777.438; Losses: [3777.3972, 0.0028260916, 0.03819199]
.
Epoch 1/200:    loss=908.620; Losses: [908.5834, 0.002681077, 0.034407526]
.
Epoch 2/200:    loss=682.733; Losses: [682.7001, 0.0025911012, 0.030505255]
.
Epoch 3/200:    loss=317.985; Losses: [317.9555, 0.0024803109, 0.027142774]
.
Epoch 4/200:    loss=737.753; Losses: [737.7305, 0.0024128703, 0.019867169]
.
Epoch 5/200:    loss=293.983; Losses: [293.9585, 0.0023704215, 0.022514882]
.
Epoch 6/200:    loss=412.075; Losses: [412.0592, 0.002335041, 0.013460781]
.
Epoch 7/200:    loss=306.618; Losses: [306.60327, 0.002284743, 0.012861049]
.
Epoch 8/200:    loss=215.916; Losses: [215.9029, 0.0022479186, 0.0111077]
.
Epoch 9/200:    loss=223.136; Losses: [223.12366, 0.0022122823, 0.010420884]
.
Epoch 10/200:   loss=300.156; Losses: [300.14383, 0.0021807426, 0.009667382]
.
Epoch 11/200:   loss=179.766; Losses: [179.75607, 0.0021481065, 0.007676568]
.
Epoch 12/200:   loss=215.981; Losses: [215.97124, 0.0021149626, 0.008112778]
.
Epoch 13/200:   loss=208.054; Losses: [208.04565, 0.0020882282, 0.006530414]
.
Epoch 14/200:   loss=207.146; Losses: [207.13806, 0.0020594192, 0.0063194335]
.
Epoch 15/200:   loss=261.805; Losses: [261.7982, 0.0020306753, 0.0045144106]
.
Epoch 16/200:   loss=178.393; Losses: [178.38637, 0.0020005947, 0.0045015276]
.
Epoch 17/200:   loss=284.872; Losses: [284.86642, 0.0019730518, 0.003997615]
.
Epoch 18/200:   loss=212.054; Losses: [212.04866, 0.0019469936, 0.0034680453]
.
Epoch 19/200:   loss=145.862; Losses: [145.85641, 0.001922644, 0.004011639]
.
Epoch 20/200:   loss=182.153; Losses: [182.14563, 0.0018981744, 0.0059485184]
.
Epoch 21/200:   loss=154.575; Losses: [154.56941, 0.0018728755, 0.0038126395]
.
Epoch 22/200:   loss=214.752; Losses: [214.74734, 0.0018469194, 0.003135754]
.
Epoch 23/200:   loss=179.347; Losses: [179.34248, 0.001825613, 0.0029100403]
.
Epoch 24/200:   loss=354.274; Losses: [354.26898, 0.0018005224, 0.0028155171]
.
Epoch 25/200:   loss=181.006; Losses: [181.0021, 0.0017763268, 0.0022955306]
.
Epoch 26/200:   loss=142.006; Losses: [142.00166, 0.0017520584, 0.0022001117]
.
Epoch 27/200:   loss=158.932; Losses: [158.92723, 0.0017273662, 0.0026816982]
.
Epoch 28/200:   loss=175.159; Losses: [175.15494, 0.0017028436, 0.0019087334]
.
Epoch 29/200:   loss=167.915; Losses: [167.91084, 0.0016797789, 0.002798946]
.
Epoch 30/200:   loss=152.785; Losses: [152.78006, 0.001658167, 0.0029456303]
.
Epoch 31/200:   loss=158.407; Losses: [158.40344, 0.0016350556, 0.0018697139]
.
Epoch 32/200:   loss=151.065; Losses: [151.06094, 0.0016126116, 0.00282249]
.
Epoch 33/200:   loss=181.075; Losses: [181.07072, 0.0015892526, 0.0025259533]
.
Epoch 34/200:   loss=157.210; Losses: [157.20605, 0.0015682159, 0.0026225548]
.
Epoch 35/200:   loss=151.200; Losses: [151.19623, 0.0015464819, 0.0017979483]
.
Epoch 36/200:   loss=157.111; Losses: [157.10796, 0.0015237778, 0.0016808716]
.
Epoch 37/200:   loss=160.398; Losses: [160.39449, 0.0015015532, 0.0018031715]
.
Epoch 38/200:   loss=133.418; Losses: [133.41472, 0.0014805306, 0.0017712188]
.
Epoch 39/200:   loss=161.662; Losses: [161.65845, 0.0014592485, 0.0018676165]
.
Epoch 40/200:   loss=181.789; Losses: [181.78532, 0.0014379938, 0.0023018767]
.
Epoch 41/200:   loss=183.568; Losses: [183.56451, 0.00142015, 0.001671126]
.
Epoch 42/200:   loss=211.679; Losses: [211.67618, 0.0014024411, 0.001817507]
.
Epoch 43/200:   loss=235.384; Losses: [235.38095, 0.0013794828, 0.0018946148]
.
Epoch 44/200:   loss=139.088; Losses: [139.08447, 0.0013571468, 0.0019089653]
.
Epoch 45/200:   loss=160.254; Losses: [160.25092, 0.0013359843, 0.0013933791]
.
Epoch 46/200:   loss=142.206; Losses: [142.20291, 0.0013167453, 0.0014015753]
.
Epoch 47/200:   loss=138.814; Losses: [138.8115, 0.0012925405, 0.0014298331]
.
Epoch 48/200:   loss=130.677; Losses: [130.67343, 0.0012679367, 0.0021604125]
.
Epoch 49/200:   loss=145.296; Losses: [145.29306, 0.0012432866, 0.0013332999]
.
Epoch 50/200:   loss=133.338; Losses: [133.33516, 0.0012180805, 0.0012685797]
.
Epoch 51/200:   loss=138.212; Losses: [138.20908, 0.0011976506, 0.0017989981]
.
Epoch 52/200:   loss=136.139; Losses: [136.13644, 0.0011768966, 0.001470595]
.
Epoch 53/200:   loss=141.839; Losses: [141.83646, 0.0011539405, 0.0017275128]
.
Epoch 54/200:   loss=159.297; Losses: [159.29402, 0.0011311076, 0.0017414299]
.
Epoch 55/200:   loss=125.919; Losses: [125.91669, 0.0011091225, 0.0013509693]
.
Epoch 56/200:   loss=135.710; Losses: [135.7079, 0.0010871431, 0.0010571101]
.
Epoch 57/200:   loss=124.548; Losses: [124.545654, 0.0010660599, 0.0011921946]
.
Epoch 58/200:   loss=128.607; Losses: [128.60472, 0.0010461143, 0.0010578154]
.
Epoch 59/200:   loss=200.879; Losses: [200.87674, 0.0010245068, 0.0014237395]
.
Epoch 60/200:   loss=181.939; Losses: [181.93698, 0.0010024481, 0.0010916217]
.
Epoch 61/200:   loss=161.069; Losses: [161.06737, 0.0009815091, 0.0009930782]
.
Epoch 62/200:   loss=129.661; Losses: [129.65881, 0.0009596758, 0.001090972]
.
Epoch 63/200:   loss=342.733; Losses: [342.73068, 0.000939325, 0.001181667]
.
Epoch 64/200:   loss=160.802; Losses: [160.8003, 0.00091621274, 0.0012079075]
.
Epoch 65/200:   loss=123.200; Losses: [123.19836, 0.00089694076, 0.0009983401]
.
Epoch 66/200:   loss=134.465; Losses: [134.46295, 0.00087896077, 0.0009002821]
.
Epoch 67/200:   loss=205.839; Losses: [205.83714, 0.0008604548, 0.001061962]
.
Epoch 68/200:   loss=144.191; Losses: [144.18906, 0.0008421927, 0.0011492949]
.
Epoch 69/200:   loss=164.397; Losses: [164.39539, 0.0008238552, 0.0010998722]
.
Epoch 70/200:   loss=131.024; Losses: [131.02272, 0.00080561615, 0.000867295]
.
Epoch 71/200:   loss=130.408; Losses: [130.40652, 0.0007878286, 0.0009589862]
.
Epoch 72/200:   loss=120.511; Losses: [120.509705, 0.0007217523, 0.0008722847]
.
Epoch 73/200:   loss=122.388; Losses: [122.38655, 0.00069110864, 0.00080618635]
.
Epoch 74/200:   loss=123.200; Losses: [123.198654, 0.0006731994, 0.0008244566]
.
Epoch 75/200:   loss=117.884; Losses: [117.88217, 0.00065816526, 0.00086460914]
.
Epoch 76/200:   loss=123.508; Losses: [123.50694, 0.0006448629, 0.0008477152]
.
Epoch 77/200:   loss=121.749; Losses: [121.74744, 0.0006321136, 0.0008150753]
.
Epoch 78/200:   loss=145.549; Losses: [145.5473, 0.00061951997, 0.00082959904]
.
Epoch 79/200:   loss=135.341; Losses: [135.33992, 0.00060778105, 0.0007312283]
.
Epoch 80/200:   loss=131.476; Losses: [131.47452, 0.0005965105, 0.0008391714]
.
Epoch 81/200:   loss=123.978; Losses: [123.976944, 0.00058624754, 0.0008059401]
.
Epoch 82/200:   loss=136.084; Losses: [136.08298, 0.0005766748, 0.0007252015]
.
Epoch 83/200:   loss=137.815; Losses: [137.81375, 0.00056776253, 0.0009108439]
.
Epoch 84/200:   loss=116.955; Losses: [116.95401, 0.0005592232, 0.0008040637]
.
Epoch 85/200:   loss=131.525; Losses: [131.52376, 0.00055153086, 0.0007604256]
.
Epoch 86/200:   loss=135.716; Losses: [135.71432, 0.00054452394, 0.0006871256]
.
Epoch 87/200:   loss=191.940; Losses: [191.93927, 0.0005378095, 0.0006448448]
.
Epoch 88/200:   loss=170.746; Losses: [170.74509, 0.00053181086, 0.0006591724]
.
Epoch 89/200:   loss=121.373; Losses: [121.37161, 0.00052640436, 0.00079500565]
.
Epoch 90/200:   loss=126.909; Losses: [126.90761, 0.0005213883, 0.0006021685]
.
Epoch 91/200:   loss=122.121; Losses: [122.120255, 0.0005157513, 0.0006873938]
.
Epoch 92/200:   loss=129.156; Losses: [129.15442, 0.0005111307, 0.00064897194]
.
Epoch 93/200:   loss=113.183; Losses: [113.18219, 0.0005073488, 0.000602577]
.
Epoch 94/200:   loss=146.389; Losses: [146.38794, 0.00050397916, 0.0005585851]
.
Epoch 95/200:   loss=130.446; Losses: [130.44531, 0.0005009744, 0.00056175387]
.
Epoch 96/200:   loss=118.884; Losses: [118.88327, 0.0004950377, 0.0005437781]
.
Epoch 97/200:   loss=114.319; Losses: [114.3183, 0.0004938163, 0.0006134198]
.
Epoch 98/200:   loss=134.747; Losses: [134.74594, 0.0004912615, 0.0005468172]
.
Epoch 99/200:   loss=127.224; Losses: [127.22336, 0.0004886875, 0.00050384714]
.
Epoch 100/200:  loss=123.007; Losses: [123.00618, 0.0004862357, 0.00056520273]
.
Epoch 101/200:  loss=130.374; Losses: [130.3726, 0.0004841056, 0.0005738659]
.
Epoch 102/200:  loss=120.120; Losses: [120.11898, 0.00048118114, 0.00054365897]
.
Epoch 103/200:  loss=107.167; Losses: [107.16606, 0.00047940924, 0.0004886348]
.
Epoch 104/200:  loss=112.270; Losses: [112.268585, 0.00047771176, 0.0004745159]
.
Epoch 105/200:  loss=125.537; Losses: [125.53565, 0.0004758754, 0.00047247266]
.
Epoch 106/200:  loss=109.324; Losses: [109.32308, 0.00047426036, 0.00046340926]
.
Epoch 107/200:  loss=113.328; Losses: [113.327286, 0.0004729222, 0.00046490182]
.
Epoch 108/200:  loss=117.106; Losses: [117.10452, 0.000471713, 0.00062898267]
.
Epoch 109/200:  loss=122.371; Losses: [122.37039, 0.0004705812, 0.00051386596]
.
Epoch 110/200:  loss=119.422; Losses: [119.42122, 0.0004696009, 0.0005573735]
.
Epoch 111/200:  loss=131.784; Losses: [131.78348, 0.000468354, 0.00041559048]
.
Epoch 112/200:  loss=124.476; Losses: [124.475006, 0.00046699354, 0.00041292777]
.
Epoch 113/200:  loss=104.487; Losses: [104.486534, 0.00046530017, 0.00042556314]
.
Epoch 114/200:  loss=119.418; Losses: [119.41684, 0.0004641684, 0.00039730535]
.
Epoch 115/200:  loss=117.776; Losses: [117.77547, 0.00046339296, 0.00056288636]
.
Epoch 116/200:  loss=112.189; Losses: [112.18817, 0.0004628609, 0.00045002214]
.
Epoch 117/200:  loss=116.317; Losses: [116.31613, 0.0004616853, 0.00036934114]
.
Epoch 118/200:  loss=159.105; Losses: [159.10422, 0.00046110825, 0.0003716088]
.
Epoch 119/200:  loss=116.958; Losses: [116.95712, 0.00045984008, 0.0003543413]
.
Epoch 120/200:  loss=108.100; Losses: [108.09944, 0.00045923653, 0.00045118056]
.
Epoch 121/200:  loss=107.565; Losses: [107.56447, 0.00045848632, 0.00033903003]
.
Epoch 122/200:  loss=117.631; Losses: [117.62992, 0.0004574218, 0.00033954927]
.
Epoch 123/200:  loss=116.075; Losses: [116.07385, 0.0004564249, 0.00036835673]
.
Epoch 124/200:  loss=106.798; Losses: [106.79701, 0.0004566427, 0.00032678706]
.
Epoch 125/200:  loss=113.363; Losses: [113.36252, 0.0004562596, 0.00042438688]
.
Epoch 126/200:  loss=118.104; Losses: [118.10279, 0.00045581322, 0.00032525152]
.
Epoch 127/200:  loss=113.516; Losses: [113.51486, 0.00045547614, 0.0005056638]
.
Epoch 128/200:  loss=117.624; Losses: [117.62353, 0.00045549794, 0.00034517745]
.
Epoch 129/200:  loss=112.273; Losses: [112.272316, 0.00045538746, 0.00029629772]
.
Epoch 130/200:  loss=113.328; Losses: [113.32768, 0.0004549961, 0.0003071945]
.
Epoch 131/200:  loss=111.756; Losses: [111.75513, 0.00045427072, 0.00032724047]
.
Epoch 132/200:  loss=107.796; Losses: [107.795494, 0.00045377907, 0.00027464304]
.
Epoch 133/200:  loss=150.595; Losses: [150.59428, 0.00045307074, 0.0003010978]
.
Epoch 134/200:  loss=120.134; Losses: [120.13356, 0.00045292114, 0.00029552256]
.
Epoch 135/200:  loss=120.130; Losses: [120.12947, 0.00045320712, 0.0002585037]
.
Epoch 136/200:  loss=117.070; Losses: [117.06926, 0.0004533619, 0.00026611943]
.
Epoch 137/200:  loss=111.006; Losses: [111.00518, 0.00045333267, 0.0002824646]
.
Epoch 138/200:  loss=115.901; Losses: [115.90064, 0.00045347284, 0.00030576388]
.
Epoch 139/200:  loss=111.147; Losses: [111.146286, 0.000453111, 0.0002558468]
.
Epoch 140/200:  loss=103.128; Losses: [103.12687, 0.0004522237, 0.00038727812]
.
Epoch 141/200:  loss=115.025; Losses: [115.024475, 0.00045187408, 0.0002339398]
.
Epoch 142/200:  loss=117.170; Losses: [117.16969, 0.0004520767, 0.00023050333]
.
Epoch 143/200:  loss=105.315; Losses: [105.31448, 0.0004517807, 0.00026579303]
.
Epoch 144/200:  loss=114.114; Losses: [114.113625, 0.00045176974, 0.0002442702]
.
Epoch 145/200:  loss=108.179; Losses: [108.178154, 0.00045183185, 0.00022483827]
.
Epoch 146/200:  loss=119.408; Losses: [119.407074, 0.0004509559, 0.00021656642]
.
Epoch 147/200:  loss=116.504; Losses: [116.50336, 0.000450821, 0.00021879328]
.
Epoch 148/200:  loss=108.465; Losses: [108.464005, 0.0004506137, 0.00031175395]
.
Epoch 149/200:  loss=99.284; Losses: [99.28293, 0.00045024417, 0.00022024836]
.
Epoch 150/200:  loss=105.143; Losses: [105.14198, 0.00044986696, 0.00024956765]
.
Epoch 151/200:  loss=107.015; Losses: [107.01389, 0.0004503439, 0.00019209863]
.
Epoch 152/200:  loss=109.961; Losses: [109.96058, 0.00045052002, 0.00036505712]
.
Epoch 153/200:  loss=110.943; Losses: [110.94235, 0.00045058262, 0.00019362733]
.
Epoch 154/200:  loss=105.146; Losses: [105.14586, 0.0004505078, 0.00018350873]
.
Epoch 155/200:  loss=153.239; Losses: [153.23862, 0.00045094165, 0.00018688777]
.
Epoch 156/200:  loss=97.193; Losses: [97.192276, 0.0004497521, 0.0002643012]
.
Epoch 157/200:  loss=116.076; Losses: [116.075356, 0.00044927179, 0.00018770616]
.
Epoch 158/200:  loss=99.644; Losses: [99.64349, 0.00044897772, 0.00019957994]
.
Epoch 159/200:  loss=101.686; Losses: [101.68573, 0.00044913022, 0.0001649716]
.
Epoch 160/200:  loss=114.998; Losses: [114.99737, 0.00044872603, 0.0001598363]
.
Epoch 161/200:  loss=126.449; Losses: [126.44795, 0.00044798924, 0.00017636445]
.
Epoch 162/200:  loss=99.323; Losses: [99.32204, 0.00044971833, 0.0001718228]
.
Epoch 163/200:  loss=118.403; Losses: [118.402115, 0.0004499098, 0.00016231032]
.
Epoch 164/200:  loss=101.217; Losses: [101.21654, 0.00044922353, 0.00015090306]
.
Epoch 165/200:  loss=132.002; Losses: [132.0016, 0.0004491811, 0.00018679435]
.
Epoch 166/200:  loss=103.262; Losses: [103.26103, 0.00044870118, 0.00014671378]
.
Epoch 167/200:  loss=98.593; Losses: [98.592026, 0.0004482735, 0.00017167021]
.
Epoch 168/200:  loss=102.641; Losses: [102.64062, 0.00044823167, 0.0001966377]
.
Epoch 169/200:  loss=110.199; Losses: [110.19867, 0.00044768318, 0.00014072815]
.
Epoch 170/200:  loss=98.456; Losses: [98.45533, 0.00044640992, 0.00013351238]
.
Epoch 171/200:  loss=107.700; Losses: [107.698944, 0.00044608887, 0.000128763]
.
Epoch 172/200:  loss=110.314; Losses: [110.31317, 0.0004454155, 0.00015472547]
.
Epoch 173/200:  loss=101.824; Losses: [101.82384, 0.00044520054, 0.00012376827]
.
Epoch 174/200:  loss=100.615; Losses: [100.61448, 0.00044518447, 0.00012106997]
.
Epoch 175/200:  loss=99.010; Losses: [99.00989, 0.00044499052, 0.0001265088]
.
Epoch 176/200:  loss=104.999; Losses: [104.99841, 0.0004448745, 0.00017745573]
.
Epoch 177/200:  loss=98.012; Losses: [98.0118, 0.0004448767, 0.00016406355]
.
Epoch 178/200:  loss=99.610; Losses: [99.60982, 0.0004446858, 0.00011498472]
.
Epoch 179/200:  loss=108.899; Losses: [108.89821, 0.00044491427, 0.00010847509]
.
Epoch 180/200:  loss=118.136; Losses: [118.13521, 0.00044519745, 0.00010489453]
.
Epoch 181/200:  loss=98.810; Losses: [98.80894, 0.000445671, 0.00012609128]
.
Epoch 182/200:  loss=96.406; Losses: [96.40544, 0.0004462903, 0.00010188887]
.
Epoch 183/200:  loss=98.501; Losses: [98.50068, 0.00044781342, 9.952572e-05]
.
Epoch 184/200:  loss=97.149; Losses: [97.14829, 0.00044934792, 0.00010874969]
.
Epoch 185/200:  loss=100.637; Losses: [100.63678, 0.00044996737, 9.560937e-05]
.
Epoch 186/200:  loss=97.202; Losses: [97.201294, 0.00045056257, 9.1939364e-05]
.
Epoch 187/200:  loss=99.839; Losses: [99.83876, 0.0004511009, 8.99044e-05]
.
Epoch 188/200:  loss=118.457; Losses: [118.45691, 0.00045174357, 9.007633e-05]
.
Epoch 189/200:  loss=102.419; Losses: [102.41877, 0.00045339097, 9.094182e-05]
.
Epoch 190/200:  loss=110.813; Losses: [110.812386, 0.00045527803, 8.5699685e-05]
.
Epoch 191/200:  loss=99.384; Losses: [99.38332, 0.000456504, 8.375079e-05]
.
Epoch 192/200:  loss=103.581; Losses: [103.58075, 0.0004564843, 9.893996e-05]
.
Epoch 193/200:  loss=97.621; Losses: [97.62078, 0.00045904273, 8.179152e-05]
.
Epoch 194/200:  loss=94.909; Losses: [94.90842, 0.00046064917, 7.898855e-05]
.
Epoch 195/200:  loss=100.840; Losses: [100.83898, 0.0004623049, 8.4420986e-05]
.
Epoch 196/200:  loss=98.157; Losses: [98.15616, 0.00046370056, 8.1935854e-05]
.
Epoch 197/200:  loss=95.283; Losses: [95.28199, 0.00046543585, 7.4171934e-05]
.
Epoch 198/200:  loss=98.005; Losses: [98.00419, 0.0004669357, 7.385877e-05]
.
Epoch 199/200:  loss=103.609; Losses: [103.608406, 0.00046879202, 7.1431816e-05]

_, dyn_prior = f_model.dynamic_encoder.sample_dynamic_prior(10)
np.squeeze(dyn_prior.mean().numpy())
array([-1.6796231, -2.5568666, -2.4644861, -2.4013069, -2.3772488,
       -2.3707619, -2.377105 , -2.332031 , -2.3327737, -2.360692 ],
      dtype=float32)
K.clear_session()
dynamic_prior = RNNMultivariateNormalDiag(VariationalLSTMCell(N_HIDDEN,
                                                              output_dim=LATENT_SIZE_DYNAMIC),
                                          n_timesteps=N_TIMES, output_dim=LATENT_SIZE_DYNAMIC)
sample = dynamic_prior.sample((N_SAMPLES, BATCH_SIZE))
print(sample.shape)
print(dynamic_prior.mean())
(2, 6, 10, 1)
tf.Tensor(
[[ 0.        ]
 [ 0.04328619]
 [ 0.08498121]
 [-0.17377347]
 [-0.09743058]
 [-0.30255282]
 [-0.22110605]
 [-0.36379734]
 [-0.30933833]
 [-0.13590682]], shape=(10, 1), dtype=float32)