¿Cómo se propagan los gradientes en una red neuronal recurrente desenrollada?

8

Estoy tratando de entender cómo se pueden usar los rnn para predecir secuencias trabajando con un ejemplo simple. Aquí está mi red simple, que consta de una entrada, una neurona oculta y una salida:

ingrese la descripción de la imagen aquí

La neurona oculta es la función sigmoidea, y se considera que la salida es una salida lineal simple. Entonces, creo que la red funciona de la siguiente manera: si la unidad oculta comienza en estado s, y estamos procesando un punto de datos que es una secuencia de longitud3, (x1,x2,x3), entonces:

En el momento 1, el valor predicho,p1, es

p1=u×σ(ws+vx1)

En el momento 2, tenemos

p2=u×σ(w×σ(ws+vx1)+vx2)

En el momento 3, tenemos

p3=u×σ(w×σ(w×σ(ws+vx1)+vx2)+vx3)

¿Hasta aquí todo bien?

El rnn "desenrollado" se ve así:

ingrese la descripción de la imagen aquí

Si usamos un término de suma de error cuadrado para la función objetivo, ¿cómo se define? En toda la secuencia? En cuyo caso tendríamos algo comoE=(p1x1)2+(p2x2)2+(p3x3)2?

¿Se actualizan los pesos solo una vez que se examinó la secuencia completa (en este caso, la secuencia de 3 puntos)?

En cuanto al gradiente con respecto a los pesos, necesitamos calcular dE/dw,dE/dv,dE/du, Intentaré hacerlo simplemente examinando las 3 ecuaciones para piarriba, si todo lo demás parece correcto. Además de hacerlo de esa manera, esto no me parece una propagación inversa, porque los mismos parámetros aparecen en diferentes capas de la red. ¿Cómo nos ajustamos para eso?

Si alguien puede ayudarme a guiarme a través de este ejemplo de juguete, estaría muy agradecido.

Fequish
fuente
Creo que hay algo mal con la función de error, probablemente obtengas p1 como término del segundo elemento y debes compararlo probablemente con x2, en caso perfecto deben ser iguales. En su función de error, simplemente compara la entrada y la salida de la red.
itdxer
Pensé que ese podría ser el caso. Pero entonces, ¿cómo se define el error para el último elemento predicho,p3?
Fequish

Respuestas:

1

Creo que necesitas valores objetivo. Entonces para la secuencia(x1,x2,x3), necesitarías objetivos correspondientes (t1,t2,t3). Como parece que desea predecir el próximo término de la secuencia de entrada original, necesitaría:

t1=x2, t2=x3, t3=x4

Necesitarías definir x4, así que si tuviera una secuencia de entrada de longitud N para entrenar al RNN, solo podrás usar el primero N1 términos como valores de entrada y el último N1 términos como valores objetivo.

Si usamos un término de suma de error cuadrado para la función objetivo, ¿cómo se define?

Hasta donde sé, tienes razón: el error es la suma de toda la secuencia. Esto es porque los pesosu, v y w son los mismos en el RNN desplegado.

Entonces,

E=tEt=t(ttpt)2

¿Se actualizan los pesos solo una vez que se examinó la secuencia completa (en este caso, la secuencia de 3 puntos)?

Sí, si utilizo la propagación inversa a través del tiempo, creo que sí.

En cuanto a los diferenciales, no querrá expandir toda la expresión para Ey diferenciarlo cuando se trata de RNN más grandes. Entonces, alguna notación puede hacerlo más ordenado:

  • Dejar zt denotar la entrada a la neurona oculta en el momento t (es decir z1=ws+vx1)
  • Dejar yt denotar la salida de la neurona oculta en el momento t (es decir y1=σ(ws+vx1))
  • Dejar y0=s
  • Dejar δt=Ezt

Entonces, los derivados son:

Eu=ytEv=tδtxtEw=tδtyt1

Dónde t[1, T] para una secuencia de longitud Ty:

δt=σ(zt)(u+δt+1w)

Esta relación recurrente proviene de darse cuenta de que el tth la actividad oculta no solo afecta el error en tth salida, Et, pero también afecta el resto del error más abajo en el RNN, EEt:

Ezt=Etytytzt+(EEt)zt+1zt+1ytytztEzt=ytzt(Etyt+(EEt)zt+1zt+1yt)Ezt=σ(zt)(u+(EEt)zt+1w)δt=Ezt=σ(zt)(u+δt+1w)

Además de hacerlo de esta manera, esto no me parece una propagación inversa, porque los mismos parámetros aparecen en diferentes capas de la red. ¿Cómo nos ajustamos para eso?

Este método se llama propagación inversa a través del tiempo (BPTT), y es similar a la propagación inversa en el sentido de que utiliza la aplicación repetida de la regla de la cadena.

Un ejemplo trabajado más detallado pero complicado para un RNN se puede encontrar en el Capítulo 3.2 de 'Etiquetado de secuencias supervisadas con redes neuronales recurrentes' por Alex Graves - ¡lectura realmente interesante!

dok
fuente
0

El error que describió anteriormente (después de la modificación que escribí en el comentario debajo de la pregunta) puede usarlo solo como un error de predicción total, pero no puede usarlo en el proceso de aprendizaje. En cada iteración, coloca un valor de entrada en la red y obtiene una salida. Cuando obtenga resultados, debe verificar el resultado de su red y propagar el error a todos los pesos. Después de la actualización, colocará el siguiente valor en secuencia y hará una predicción para este valor, de lo que también propagará el error, etc.

itdxer
fuente