¿Cuál es la salida de un tf.nn.dynamic_rnn ()?

8

No estoy seguro de lo que entiendo de la documentación oficial, que dice:

Devuelve: Un par (salidas, estado) donde:

outputs: El tensor de salida RNN.

Si time_major == False(por defecto), esta será la forma de un tensor: [batch_size, max_time, cell.output_size].

Si time_major == True, esta será una forma Tensor: [max_time, batch_size, cell.output_size].

Tenga en cuenta que si cell.output_sizees una tupla (posiblemente anidada) de enteros u objetos TensorShape, las salidas serán una tupla que tenga la misma estructura que cell.output_size, que contenga Tensores con formas correspondientes a los datos de forma cell.output_size.

state: El estado final. Si cell.state_size es un int, se formará [batch_size, cell.state_size]. Si es un TensorShape, se formará [batch_size] + cell.state_size. Si se trata de una tupla (posiblemente anidada) de entradas o TensorShape, será una tupla con las formas correspondientes. Si las celdas son LSTMCells, el estado será una tupla que contiene una LSTMStateTuple para cada celda.

¿Es output[-1] siempre (en los tres tipos de celdas, es decir, RNN, GRU, LSTM) igual al estado (segundo elemento de la tupla de retorno)? Supongo que la literatura en todas partes es demasiado liberal en el uso del término estado oculto. ¿Es el estado oculto en las tres celdas el resultado que sale?

MiloMinderbinder
fuente

Respuestas:

10

Sí, la salida de la celda es igual al estado oculto. En el caso de LSTM, es la parte a corto plazo de la tupla (segundo elemento de LSTMStateTuple), como se puede ver en esta imagen:

LSTM

Pero para tf.nn.dynamic_rnn, el estado devuelto puede ser diferente cuando la secuencia es más corta ( sequence_lengthargumento). Echale un vistazo a éste ejemplo:

n_steps = 2
n_inputs = 3
n_neurons = 5

X = tf.placeholder(dtype=tf.float32, shape=[None, n_steps, n_inputs])
seq_length = tf.placeholder(tf.int32, [None])

basic_cell = tf.nn.rnn_cell.BasicRNNCell(num_units=n_neurons)
outputs, states = tf.nn.dynamic_rnn(basic_cell, X, sequence_length=seq_length, dtype=tf.float32)

X_batch = np.array([
  # t = 0      t = 1
  [[0, 1, 2], [9, 8, 7]], # instance 0
  [[3, 4, 5], [0, 0, 0]], # instance 1
  [[6, 7, 8], [6, 5, 4]], # instance 2
  [[9, 0, 1], [3, 2, 1]], # instance 3
])
seq_length_batch = np.array([2, 1, 2, 2])

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  outputs_val, states_val = sess.run([outputs, states], 
                                     feed_dict={X: X_batch, seq_length: seq_length_batch})

  print(outputs_val)
  print()
  print(states_val)

Aquí el lote de entrada contiene 4 secuencias y una de ellas es corta y rellena con ceros. Al correr deberías algo como esto:

[[[ 0.2315362  -0.37939444 -0.625332   -0.80235624  0.2288385 ]
  [ 0.9999524   0.99987394  0.33580178 -0.9981791   0.99975705]]

 [[ 0.97374666  0.8373545  -0.7455188  -0.98751736  0.9658986 ]
  [ 0.          0.          0.          0.          0.        ]]

 [[ 0.9994331   0.9929737  -0.8311569  -0.99928087  0.9990415 ]
  [ 0.9984355   0.9936006   0.3662448  -0.87244385  0.993848  ]]

 [[ 0.9962312   0.99659646  0.98880637  0.99548346  0.9997809 ]
  [ 0.9915743   0.9936939   0.4348318   0.8798458   0.95265496]]]

[[ 0.9999524   0.99987394  0.33580178 -0.9981791   0.99975705]
 [ 0.97374666  0.8373545  -0.7455188  -0.98751736  0.9658986 ]
 [ 0.9984355   0.9936006   0.3662448  -0.87244385  0.993848  ]
 [ 0.9915743   0.9936939   0.4348318   0.8798458   0.95265496]]

... que de hecho lo muestra state == output[1]para secuencias completas y state == output[0]para secuencias cortas. También output[1]es un vector cero para esta secuencia. Lo mismo vale para las celdas LSTM y GRU.

Entonces, statees un tensor conveniente que contiene el último estado RNN real , ignorando los ceros. El outputtensor contiene las salidas de todas las celdas, por lo que no ignora los ceros. Esa es la razón para devolver a los dos.

Máxima
fuente
2

Posible copia de /programming/36817596/get-last-output-of-dynamic-rnn-in-tensorflow/49705930#49705930

De todos modos, sigamos con la respuesta.

Este recorte de código podría ayudar a comprender lo que realmente está devolviendo la dynamic_rnncapa

=> Tupla de (salidas, final_output_state) .

Entonces, para una entrada con una longitud máxima de secuencia de pasos de tiempo T, las salidas tienen la forma [Batch_size, T, num_inputs](dado time_major= Falso; valor predeterminado) y contienen el estado de salida en cada paso de tiempo h1, h2.....hT.

Y final_output_state tiene la forma [Batch_size,num_inputs]y tiene el estado final de la celda cTy el estado hTde salida de cada secuencia de lotes.

Pero dado que dynamic_rnnse está utilizando, supongo que las longitudes de su secuencia varían para cada lote.

    import tensorflow as tf
    import numpy as np
    from tensorflow.contrib import rnn
    tf.reset_default_graph()

    # Create input data
    X = np.random.randn(2, 10, 8)

    # The second example is of length 6 
    X[1,6:] = 0
    X_lengths = [10, 6]

    cell = tf.nn.rnn_cell.LSTMCell(num_units=64, state_is_tuple=True)

    outputs, states  = tf.nn.dynamic_rnn(cell=cell,
                                         dtype=tf.float64,
                                         sequence_length=X_lengths,
                                         inputs=X)

    result = tf.contrib.learn.run_n({"outputs": outputs, "states":states},
                                    n=1,
                                    feed_dict=None)
    assert result[0]["outputs"].shape == (2, 10, 64)
    print result[0]["outputs"].shape
    print result[0]["states"].h.shape
    # the final outputs state and states returned must be equal for each      
    # sequence
    assert(result[0]["outputs"][0][-1]==result[0]["states"].h[0]).all()
    assert(result[0]["outputs"][-1][5]==result[0]["states"].h[-1]).all()
    assert(result[0]["outputs"][-1][-1]==result[0]["states"].h[-1]).all()

La afirmación final fallará ya que el estado final para la segunda secuencia es en el sexto paso, es decir. el índice 5 y el resto de las salidas de [6: 9] son ​​todos ceros en el segundo paso de tiempo

Bhaskar Arun
fuente