Skip to content

Custom Recurrent Layers

A bit of a mess. Also check out tfp_utils, lfads_utils, and lfads_complex_cell.

Test GenerativeRNN

N_IN = 8
N_UNITS = 24
N_OUT_TIMESTEPS = 115
cell = tfkl.GRUCell  # LSTMCell or GRUCell
#  Test regular RNN with zeros input
reg_rnn_layer = tfkl.RNN(cell(N_UNITS), return_state=True, return_sequences=True)
in_ = tf.zeros((1, 115, 16))
x_ = reg_rnn_layer(in_)
print(K.any(x_[0]))  # Just to remind myself that input zeros and state zeros will yield output zeros.
tf.Tensor(False, shape=(), dtype=bool)

# Test placeholder tensor with no timesteps
K.clear_session()
gen_rnn_layer = GenerativeRNN(cell(N_UNITS), return_sequences=True, return_state=True,
                              timesteps=N_OUT_TIMESTEPS)
in_ = tfkl.Input(shape=(N_IN,))
x_, cell_state_ = gen_rnn_layer(in_)
print("Test placeholder tensor")
model = tf.keras.Model(inputs=in_, outputs=x_)
model.summary()
Test placeholder tensor
Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 8)]          0                                            
__________________________________________________________________________________________________
tf_op_layer_strided_slice (Tens [(None, 1, 8)]       0           input_1[0][0]                    
__________________________________________________________________________________________________
tf_op_layer_strided_slice_1 (Te [(None, 8)]          0           tf_op_layer_strided_slice[0][0]  
__________________________________________________________________________________________________
tf_op_layer_ZerosLike (TensorFl [(None, 8)]          0           tf_op_layer_strided_slice_1[0][0]
__________________________________________________________________________________________________
tf_op_layer_strided_slice_2 (Te [(None, 1, 8)]       0           tf_op_layer_ZerosLike[0][0]      
__________________________________________________________________________________________________
tf_op_layer_AddV2 (TensorFlowOp [(None, 114, 8)]     0           tf_op_layer_strided_slice_2[0][0]
__________________________________________________________________________________________________
tf_op_layer_concat (TensorFlowO [(None, 115, 8)]     0           tf_op_layer_strided_slice[0][0]  
                                                                 tf_op_layer_AddV2[0][0]          
__________________________________________________________________________________________________
generative_rnn (GenerativeRNN)  [(None, 115, 24), (N 2448        tf_op_layer_concat[0][0]         
==================================================================================================
Total params: 2,448
Trainable params: 2,448
Non-trainable params: 0
__________________________________________________________________________________________________

# Test placeholder tensor with no timesteps as initial state
K.clear_session()
gen_rnn_layer = GenerativeRNN(cell(N_UNITS), return_sequences=True, return_state=True,
                              timesteps=N_OUT_TIMESTEPS)
in_ = tfkl.Input(shape=(N_UNITS,))
x_, cell_state_ = gen_rnn_layer(None, initial_state=in_)
print("Test placeholder tensor")
model = tf.keras.Model(inputs=in_, outputs=x_)
model.summary()
Test placeholder tensor
Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 24)]         0                                            
__________________________________________________________________________________________________
tf_op_layer_strided_slice (Tens [(None,)]            0           input_1[0][0]                    
__________________________________________________________________________________________________
tf_op_layer_strided_slice_1 (Te [(None, 1)]          0           tf_op_layer_strided_slice[0][0]  
__________________________________________________________________________________________________
tf_op_layer_ZerosLike (TensorFl [(None, 1)]          0           tf_op_layer_strided_slice_1[0][0]
__________________________________________________________________________________________________
tf_op_layer_strided_slice_2 (Te [(None, 1, 1)]       0           tf_op_layer_ZerosLike[0][0]      
__________________________________________________________________________________________________
tf_op_layer_strided_slice_3 (Te [(None, 1)]          0           tf_op_layer_strided_slice_2[0][0]
__________________________________________________________________________________________________
tf_op_layer_ZerosLike_1 (Tensor [(None, 1)]          0           tf_op_layer_strided_slice_3[0][0]
__________________________________________________________________________________________________
tf_op_layer_strided_slice_4 (Te [(None, 1, 1)]       0           tf_op_layer_ZerosLike_1[0][0]    
__________________________________________________________________________________________________
tf_op_layer_AddV2 (TensorFlowOp [(None, 114, 1)]     0           tf_op_layer_strided_slice_4[0][0]
__________________________________________________________________________________________________
tf_op_layer_concat (TensorFlowO [(None, 115, 1)]     0           tf_op_layer_strided_slice_2[0][0]
                                                                 tf_op_layer_AddV2[0][0]          
__________________________________________________________________________________________________
generative_rnn (GenerativeRNN)  [(None, 115, 24), (N 1944        tf_op_layer_concat[0][0]         
                                                                 input_1[0][0]                    
==================================================================================================
Total params: 1,944
Trainable params: 1,944
Non-trainable params: 0
__________________________________________________________________________________________________

# Test None input --> uses zeros
K.clear_session()
gen_rnn_layer = GenerativeRNN(cell(N_UNITS), return_sequences=True, return_state=True,
                              timesteps=N_OUT_TIMESTEPS)
print("Test None input")
x_, cell_state_ = gen_rnn_layer()
print(x_.shape, cell_state_.shape)
print(K.any(x_), K.any(cell_state_))  # <- any non-zero values?
Test None input
(1, 115, 24) (1, 24)
tf.Tensor(False, shape=(), dtype=bool) tf.Tensor(False, shape=(), dtype=bool)

# Test random input
K.clear_session()
gen_rnn_layer = GenerativeRNN(cell(N_UNITS), return_sequences=True, return_state=True,
                              timesteps=N_OUT_TIMESTEPS)
in_ = tf.random.uniform((1, 8, N_UNITS), minval=-1.0, maxval=1.0)
print("Test zeros input")
x_, cell_state_ = gen_rnn_layer(in_)
print(x_.shape, cell_state_.shape)
print(K.any(x_), K.any(cell_state_))  # <- any non-zero values?
Test zeros input
(1, 115, 24) (1, 24)
tf.Tensor(True, shape=(), dtype=bool) tf.Tensor(True, shape=(), dtype=bool)

# Test random states
K.clear_session()
gen_rnn_layer = GenerativeRNN(cell(N_UNITS), return_sequences=True, return_state=True,
                              timesteps=N_OUT_TIMESTEPS)
print(gen_rnn_layer.compute_output_shape())
init_states = [tf.random.uniform((1, N_UNITS), minval=-1.0, maxval=1.0) for _ in range(1)]
x_, cell_states_ = gen_rnn_layer(initial_state=init_states)
print(x_.shape, cell_state_.shape)
print(K.any(x_), K.any(cell_state_))  # <- any non-zero values?
[TensorShape([None, 115, 24]), TensorShape([None, 24])]
(1, 115, 24) (1, 24)
tf.Tensor(True, shape=(), dtype=bool) tf.Tensor(True, shape=(), dtype=bool)

# Test masking
K.clear_session()

tmp = tf.range(N_OUT_TIMESTEPS)[tf.newaxis, :, tf.newaxis]
mask = tf.math.logical_or(tmp < 5, tmp > 100)
gen_rnn_layer = GenerativeRNN(cell(N_UNITS), return_sequences=True, return_state=True,
                              timesteps=N_OUT_TIMESTEPS, tile_input=True)
in_ = tf.random.uniform((5, N_OUT_TIMESTEPS, N_UNITS), minval=-1.0, maxval=1.0)
x_, cell_state_ = gen_rnn_layer(in_, mask=mask)
print(x_.shape, cell_state_.shape)
print(K.any(x_), K.any(cell_state_))  # <- any non-zero values?
(5, 115, 24) (5, 24)
tf.Tensor(True, shape=(), dtype=bool) tf.Tensor(True, shape=(), dtype=bool)

# Garbage code I don't want to throw out yet.
if False:
    def call(self, inputs, mask=None, training=None, initial_state=None, constants=None):
        assert(mask is None), "mask not supported."
        # First part copied from super call()

        # The input should be dense, padded with zeros. If a ragged input is fed
        # into the layer, it is padded and the row lengths are used for masking.
        inputs, row_lengths = K.convert_inputs_if_ragged(inputs)
        is_ragged_input = (row_lengths is not None)
        self._validate_args_if_ragged(is_ragged_input, mask)

        # Get initial_state. Merge provided initial_state and preserved if self.stateful,
        # otherwise use provided or zeros if provided is None.
        inputs, initial_state, constants = self._process_inputs(
            inputs, initial_state, constants)

        self._maybe_reset_cell_dropout_mask(self.cell)
        if isinstance(self.cell, tfkl.StackedRNNCells):
            for cell in self.cell.cells:
                self._maybe_reset_cell_dropout_mask(cell)

        kwargs = {}
        if generic_utils.has_arg(self.cell.call, 'training'):
            kwargs['training'] = training

        # TF RNN cells expect single tensor as state instead of list wrapped tensor.
        is_tf_rnn_cell = getattr(self.cell, '_is_tf_rnn_cell', None) is not None
        if constants:
            if not generic_utils.has_arg(self.cell.call, 'constants'):
                raise ValueError('RNN cell does not support constants')

            def step(inputs, states):
                constants = states[-self._num_constants:]  # pylint: disable=invalid-unary-operand-type
                states = states[:-self._num_constants]  # pylint: disable=invalid-unary-operand-type

                states = states[0] if len(states) == 1 and is_tf_rnn_cell else states
                output, new_states = self.cell.call(
                    inputs, states, constants=constants, **kwargs)
                if not nest.is_sequence(new_states):
                    new_states = [new_states]
                return output, new_states
        else:

            def step(inputs, states):
                states = states[0] if len(states) == 1 and is_tf_rnn_cell else states
                output, new_states = self.cell.call(inputs, states, **kwargs)
                if not nest.is_sequence(new_states):
                    new_states = [new_states]
                return output, new_states

        # Begin deviation from super call() #
        #####################################
        # We do not do K.rnn because it does not support feeding the output back as the input to the next step.
        def _process_single_input_t(input_t):
            input_t = tf.unstack(input_t, axis=-2)  # unstack for time_step dim
            if self.go_backwards:
                input_t.reverse()
            return input_t

        if nest.is_sequence(inputs):
            processed_input = nest.map_structure(_process_single_input_t, inputs)
        else:
            processed_input = (_process_single_input_t(inputs),)
        cell_input = nest.pack_sequence_as(inputs, [_[0] for _ in processed_input])

        cell_state = tuple(initial_state)

        out_states = []
        out_inputs = []
        for step_ix in range(self.timesteps):
            cell_input, new_states = step(cell_input, cell_state)
            flat_new_states = nest.flatten(new_states)
            cell_state = nest.pack_sequence_as(cell_state, flat_new_states)
            out_states.append(cell_state)
            out_inputs.append(cell_input)

        out_inputs = tf.stack(out_inputs, axis=-2)
        # if cell outputs a distribution, then we might do the following, but base class
        # would have to change.
        if False:
            if hasattr(out_inputs[0], 'parameters') and 'distribution' in out_inputs[0].parameters:
                dist0_parms = out_inputs[0].parameters['distribution'].parameters
                coll_parms = {}
                for p_name, p_val in dist0_parms.items():
                    if K.tensor_util.is_tensor(p_val):
                        coll_parms[p_name] = []
                for dist in out_inputs:
                    for p_name in coll_parms.keys():
                        coll_parms[p_name].append(dist.parameters['distribution'].parameters[p_name])
                for p_name in coll_parms.keys():
                    coll_parms[p_name] = tf.stack(coll_parms[p_name], axis=-2)
                dist_class = out_inputs[0].parameters['distribution'].__class__
                out_inputs = dist_class(**coll_parms)
                # Warning! time dimension lost in batch with None
                out_inputs = tfp.distributions.Independent(out_inputs, reinterpreted_batch_ndims=1)

        out_states = tf.stack(out_states, axis=-2)
        out_states = tf.unstack(out_states, axis=0)
        if not hasattr(self.cell.state_size, '__len__'):
            out_states = out_states[0]

        if not self.return_sequences:
            out_inputs = out_inputs[..., -1, :]
            out_states = [_[..., -1, :] for _ in out_states] if isinstance(out_states, list) else out_states[..., -1, :]
        if self.return_state:
            return out_inputs, out_states
        return out_inputs