Custom Recurrent Layers
A bit of a mess. Also check out tfp_utils, lfads_utils, and lfads_complex_cell.
Test GenerativeRNN
N_IN = 8
N_UNITS = 24
N_OUT_TIMESTEPS = 115
cell = tfkl.GRUCell # LSTMCell or GRUCell
# Test regular RNN with zeros input
reg_rnn_layer = tfkl.RNN(cell(N_UNITS), return_state=True, return_sequences=True)
in_ = tf.zeros((1, 115, 16))
x_ = reg_rnn_layer(in_)
print(K.any(x_[0])) # Just to remind myself that input zeros and state zeros will yield output zeros.
# Test placeholder tensor with no timesteps
K.clear_session()
gen_rnn_layer = GenerativeRNN(cell(N_UNITS), return_sequences=True, return_state=True,
timesteps=N_OUT_TIMESTEPS)
in_ = tfkl.Input(shape=(N_IN,))
x_, cell_state_ = gen_rnn_layer(in_)
print("Test placeholder tensor")
model = tf.keras.Model(inputs=in_, outputs=x_)
model.summary()
# Test placeholder tensor with no timesteps as initial state
K.clear_session()
gen_rnn_layer = GenerativeRNN(cell(N_UNITS), return_sequences=True, return_state=True,
timesteps=N_OUT_TIMESTEPS)
in_ = tfkl.Input(shape=(N_UNITS,))
x_, cell_state_ = gen_rnn_layer(None, initial_state=in_)
print("Test placeholder tensor")
model = tf.keras.Model(inputs=in_, outputs=x_)
model.summary()
# Test None input --> uses zeros
K.clear_session()
gen_rnn_layer = GenerativeRNN(cell(N_UNITS), return_sequences=True, return_state=True,
timesteps=N_OUT_TIMESTEPS)
print("Test None input")
x_, cell_state_ = gen_rnn_layer()
print(x_.shape, cell_state_.shape)
print(K.any(x_), K.any(cell_state_)) # <- any non-zero values?
# Test random input
K.clear_session()
gen_rnn_layer = GenerativeRNN(cell(N_UNITS), return_sequences=True, return_state=True,
timesteps=N_OUT_TIMESTEPS)
in_ = tf.random.uniform((1, 8, N_UNITS), minval=-1.0, maxval=1.0)
print("Test zeros input")
x_, cell_state_ = gen_rnn_layer(in_)
print(x_.shape, cell_state_.shape)
print(K.any(x_), K.any(cell_state_)) # <- any non-zero values?
# Test random states
K.clear_session()
gen_rnn_layer = GenerativeRNN(cell(N_UNITS), return_sequences=True, return_state=True,
timesteps=N_OUT_TIMESTEPS)
print(gen_rnn_layer.compute_output_shape())
init_states = [tf.random.uniform((1, N_UNITS), minval=-1.0, maxval=1.0) for _ in range(1)]
x_, cell_states_ = gen_rnn_layer(initial_state=init_states)
print(x_.shape, cell_state_.shape)
print(K.any(x_), K.any(cell_state_)) # <- any non-zero values?
# Test masking
K.clear_session()
tmp = tf.range(N_OUT_TIMESTEPS)[tf.newaxis, :, tf.newaxis]
mask = tf.math.logical_or(tmp < 5, tmp > 100)
gen_rnn_layer = GenerativeRNN(cell(N_UNITS), return_sequences=True, return_state=True,
timesteps=N_OUT_TIMESTEPS, tile_input=True)
in_ = tf.random.uniform((5, N_OUT_TIMESTEPS, N_UNITS), minval=-1.0, maxval=1.0)
x_, cell_state_ = gen_rnn_layer(in_, mask=mask)
print(x_.shape, cell_state_.shape)
print(K.any(x_), K.any(cell_state_)) # <- any non-zero values?
# Garbage code I don't want to throw out yet.
if False:
def call(self, inputs, mask=None, training=None, initial_state=None, constants=None):
assert(mask is None), "mask not supported."
# First part copied from super call()
# The input should be dense, padded with zeros. If a ragged input is fed
# into the layer, it is padded and the row lengths are used for masking.
inputs, row_lengths = K.convert_inputs_if_ragged(inputs)
is_ragged_input = (row_lengths is not None)
self._validate_args_if_ragged(is_ragged_input, mask)
# Get initial_state. Merge provided initial_state and preserved if self.stateful,
# otherwise use provided or zeros if provided is None.
inputs, initial_state, constants = self._process_inputs(
inputs, initial_state, constants)
self._maybe_reset_cell_dropout_mask(self.cell)
if isinstance(self.cell, tfkl.StackedRNNCells):
for cell in self.cell.cells:
self._maybe_reset_cell_dropout_mask(cell)
kwargs = {}
if generic_utils.has_arg(self.cell.call, 'training'):
kwargs['training'] = training
# TF RNN cells expect single tensor as state instead of list wrapped tensor.
is_tf_rnn_cell = getattr(self.cell, '_is_tf_rnn_cell', None) is not None
if constants:
if not generic_utils.has_arg(self.cell.call, 'constants'):
raise ValueError('RNN cell does not support constants')
def step(inputs, states):
constants = states[-self._num_constants:] # pylint: disable=invalid-unary-operand-type
states = states[:-self._num_constants] # pylint: disable=invalid-unary-operand-type
states = states[0] if len(states) == 1 and is_tf_rnn_cell else states
output, new_states = self.cell.call(
inputs, states, constants=constants, **kwargs)
if not nest.is_sequence(new_states):
new_states = [new_states]
return output, new_states
else:
def step(inputs, states):
states = states[0] if len(states) == 1 and is_tf_rnn_cell else states
output, new_states = self.cell.call(inputs, states, **kwargs)
if not nest.is_sequence(new_states):
new_states = [new_states]
return output, new_states
# Begin deviation from super call() #
#####################################
# We do not do K.rnn because it does not support feeding the output back as the input to the next step.
def _process_single_input_t(input_t):
input_t = tf.unstack(input_t, axis=-2) # unstack for time_step dim
if self.go_backwards:
input_t.reverse()
return input_t
if nest.is_sequence(inputs):
processed_input = nest.map_structure(_process_single_input_t, inputs)
else:
processed_input = (_process_single_input_t(inputs),)
cell_input = nest.pack_sequence_as(inputs, [_[0] for _ in processed_input])
cell_state = tuple(initial_state)
out_states = []
out_inputs = []
for step_ix in range(self.timesteps):
cell_input, new_states = step(cell_input, cell_state)
flat_new_states = nest.flatten(new_states)
cell_state = nest.pack_sequence_as(cell_state, flat_new_states)
out_states.append(cell_state)
out_inputs.append(cell_input)
out_inputs = tf.stack(out_inputs, axis=-2)
# if cell outputs a distribution, then we might do the following, but base class
# would have to change.
if False:
if hasattr(out_inputs[0], 'parameters') and 'distribution' in out_inputs[0].parameters:
dist0_parms = out_inputs[0].parameters['distribution'].parameters
coll_parms = {}
for p_name, p_val in dist0_parms.items():
if K.tensor_util.is_tensor(p_val):
coll_parms[p_name] = []
for dist in out_inputs:
for p_name in coll_parms.keys():
coll_parms[p_name].append(dist.parameters['distribution'].parameters[p_name])
for p_name in coll_parms.keys():
coll_parms[p_name] = tf.stack(coll_parms[p_name], axis=-2)
dist_class = out_inputs[0].parameters['distribution'].__class__
out_inputs = dist_class(**coll_parms)
# Warning! time dimension lost in batch with None
out_inputs = tfp.distributions.Independent(out_inputs, reinterpreted_batch_ndims=1)
out_states = tf.stack(out_states, axis=-2)
out_states = tf.unstack(out_states, axis=0)
if not hasattr(self.cell.state_size, '__len__'):
out_states = out_states[0]
if not self.return_sequences:
out_inputs = out_inputs[..., -1, :]
out_states = [_[..., -1, :] for _ in out_states] if isinstance(out_states, list) else out_states[..., -1, :]
if self.return_state:
return out_inputs, out_states
return out_inputs