¿Qué son exactamente los mecanismos de atención?

23

Los mecanismos de atención se han utilizado en varios documentos de Deep Learning en los últimos años. Ilya Sutskever, jefe de investigación de Open AI, los ha elogiado con entusiasmo: https://towardsdatascience.com/the-fall-of-rnn-lstm-2d1594c74ce0

Eugenio Culurciello de la Universidad de Purdue ha afirmado que las RNN y LSTM deben abandonarse en favor de las redes neuronales basadas exclusivamente en la atención:

https://towardsdatascience.com/the-fall-of-rnn-lstm-2d1594c74ce0

Esto parece una exageración, pero es innegable que los modelos puramente basados ​​en la atención han funcionado bastante bien en las tareas de modelado de secuencias: todos sabemos sobre el papel de Google, la atención es todo lo que necesita.

Sin embargo, ¿ qué son exactamente los modelos basados ​​en la atención? Todavía tengo que encontrar una explicación clara de tales modelos. Supongamos que quiero pronosticar los nuevos valores de una serie de tiempo multivariante, dados sus valores históricos. Está bastante claro cómo hacer eso con un RNN que tiene células LSTM. ¿Cómo haría lo mismo con un modelo basado en la atención?

DeltaIV
fuente

Respuestas:

20

La atención es un método para agregar un conjunto de vectores en un solo vector, a menudo a través de un vector de búsqueda . Por lo general, son las entradas al modelo o los estados ocultos de pasos de tiempo anteriores, o los estados ocultos un nivel hacia abajo (en el caso de LSTM apilados).viuvi

El resultado a menudo se llama el vector de contexto , ya que contiene el contexto relevante para el paso de tiempo actual.c

Este vector de contexto adicional se alimenta al RNN / LSTM (puede simplemente concatenarse con la entrada original). Por lo tanto, el contexto puede usarse para ayudar con la predicción.c

La forma más sencilla de hacer esto es calcular el vector de probabilidad y donde es la concatenación de todos los anteriores . Un vector de búsqueda común es el estado oculto actual .p=softmax(VTu)c=ipiviVviuht

Hay muchas variaciones en esto, y puedes hacer las cosas tan complicadas como quieras. Por ejemplo, en lugar de usar como logits, uno puede elegir , donde es una red neuronal arbitraria.viTuf(vi,u)f

Un mecanismo de atención común para los modelos de secuencia a secuencia utiliza , donde son los estados ocultos del codificador y es el oculto actual estado del decodificador y ambos s son parámetros.p=softmax(qTtanh(W1vi+W2ht))vhtqW

Algunos documentos que muestran diferentes variaciones en la idea de atención:

Las redes de punteros prestan atención a las entradas de referencia para resolver problemas de optimización combinatoria.

Las redes de entidades recurrentes mantienen estados de memoria separados para diferentes entidades (personas / objetos) mientras leen texto, y actualizan el estado de memoria correcto con atención.

Los modelos de transformadores también hacen un uso extensivo de la atención. Su formulación de atención es un poco más general y también involucra vectores clave : los pesos de atención se calculan realmente entre las teclas y la búsqueda, y el contexto se construye con .kipvi


Aquí hay una implementación rápida de una forma de atención, aunque no puedo garantizar la corrección más allá del hecho de que pasó algunas pruebas simples.

RNN básico:

def rnn(inputs_split):
    bias = tf.get_variable('bias', shape = [hidden_dim, 1])
    weight_hidden = tf.tile(tf.get_variable('hidden', shape = [1, hidden_dim, hidden_dim]), [batch, 1, 1])
    weight_input = tf.tile(tf.get_variable('input', shape = [1, hidden_dim, in_dim]), [batch, 1, 1])

    hidden_states = [tf.zeros((batch, hidden_dim, 1), tf.float32)]
    for i, input in enumerate(inputs_split):
        input = tf.reshape(input, (batch, in_dim, 1))
        last_state = hidden_states[-1]
        hidden = tf.nn.tanh( tf.matmul(weight_input, input) + tf.matmul(weight_hidden, last_state) + bias )
        hidden_states.append(hidden)
    return hidden_states[-1]

Con atención, agregamos solo unas pocas líneas antes de que se calcule el nuevo estado oculto:

        if len(hidden_states) > 1:
            logits = tf.transpose(tf.reduce_mean(last_state * hidden_states[:-1], axis = [2, 3]))
            probs = tf.nn.softmax(logits)
            probs = tf.reshape(probs, (batch, -1, 1, 1))
            context = tf.add_n([v * prob for (v, prob) in zip(hidden_states[:-1], tf.unstack(probs, axis = 1))])
        else:
            context = tf.zeros_like(last_state)

        last_state = tf.concat([last_state, context], axis = 1)

        hidden = tf.nn.tanh( tf.matmul(weight_input, input) + tf.matmul(weight_hidden, last_state) + bias )

el código completo

shimao
fuente
p=softmax(VTu)ic=ipivipiVTvVTv
1
zi=viTup=softmax(z)pi=eizjejz
ppi
1
sí, eso es lo que quise decir
shimao
@shimao Creé una sala de chat , avíseme si estaría interesado en hablar (no sobre esta pregunta)
DeltaIV