Ejemplo paso a paso de diferenciación automática en modo inverso

27

No estoy seguro si esta pregunta pertenece aquí, pero está estrechamente relacionada con los métodos de gradiente en la optimización, que parece estar en el tema aquí. De todos modos, siéntase libre de migrar si cree que alguna otra comunidad tiene una mejor experiencia en el tema.

En resumen, estoy buscando un ejemplo paso a paso de diferenciación automática en modo inverso . No hay mucha literatura sobre el tema y las implementaciones existentes (como la de TensorFlow ) son difíciles de entender sin conocer la teoría detrás de esto. Por lo tanto estaría muy agradecido si alguien podría mostrar en detalle lo que pase en , la forma en que procesamos y lo que sacamos de la gráfica computacional.

Un par de preguntas con las que tengo más dificultades:

  • semillas : ¿por qué las necesitamos?
  • reglas de diferenciación inversa : sé cómo hacer una diferenciación hacia adelante, pero ¿cómo vamos hacia atrás? Por ejemplo, en el ejemplo de esta sección , ¿cómo sabemos que ?w2¯=w3¯w1
  • ¿trabajamos solo con símbolos o pasamos valores reales ? Por ejemplo, en el mismo ejemplo , ¿son y símbolos o valores?¯ w iwiwi¯
amigo
fuente
"Aprendizaje automático práctico con Scikit-Learn y TensorFlow" El Apéndice D ofrece una muy buena explicación en mi opinión. Lo recomiendo.
Agustín Barrachina

Respuestas:

37

Digamos que tenemos la expresión y queremos encontrar derivados y . El modo inverso AD divide esta tarea en 2 partes, a saber, pases hacia adelante y hacia atrás.z=x1x2+sin(x1)dzdx1dzdx2

Pase adelantado

Primero, descomponemos nuestra expresión compleja en un conjunto de primitivas, es decir, expresiones que consisten en, a lo sumo, llamada de función única. Tenga en cuenta que también cambio el nombre de las variables de entrada y salida por coherencia, aunque no es necesario:

w1=x1
w2=x2
w3=w1w2
w4=sin(w1)
w5=w3+w4
z=w5

La ventaja de esta representación es que las reglas de diferenciación para cada expresión separada ya son conocidas. Por ejemplo, sabemos que la derivada de es , y entonces . Usaremos este hecho en el paso inverso a continuación.sincosdw4dw1=cos(w1)

Esencialmente, el pase directo consiste en evaluar cada una de estas expresiones y guardar los resultados. Digamos que nuestras entradas son: y . Entonces tenemos:x1=2x2=3

w1=x1=2
w2=x2=3
w3=w1w2=6
w4=sin(w1) =0.9
w5=w3+w4=6.9
z=w5=6.9

Pase reverso

Aquí es donde comienza la magia, y comienza con la regla de la cadena . En su forma básica, la regla de la cadena establece que si tiene una variable que depende de que, a su vez, depende de , entonces:t(u(v))uv

dtdv=dtdududv

o, si depende de través de varias rutas / variables , por ejemplo:tvui

u1=f(v)
u2=g(v)
t=h(u1,u2)

entonces (ver prueba aquí ):

dtdv=idtduiduidv

En términos de gráfico de expresión, si tenemos un nodo final y nodos de entrada , y la ruta de a pasa por los nodos intermedios (es decir, donde ), podemos encontrar derivadas comozwizwiwpz=g(wp)wp=f(wi)dzdwi

dzdwi=pparents(i)dzdwpdwpdwi

En otras palabras, para calcular la derivada de la variable de salida wrt cualquier variable intermedia o de entrada , solo necesitamos conocer las derivadas de sus padres y la fórmula para calcular la derivada de la expresión primitiva .zwiwp=f(wi)

El pase inverso comienza al final (es decir, ) y se propaga hacia atrás a todas las dependencias. Aquí tenemos (expresión para "semilla"):dzdz

dzdz=1

Eso puede leerse como "el cambio en da como resultado exactamente el mismo cambio en ", lo cual es bastante obvio.zz

Entonces sabemos que y así:z=w5

dzdw5=1

w5 depende linealmente de y , entonces y . Usando la regla de la cadena encontramos:w3w4dw5dw3=1dw5dw4=1

dzdw3=dzdw5dw5dw3=1×1=1
dzdw4=dzdw5dw5dw4=1×1=1

De la definición y las reglas de derivadas parciales, encontramos que . Así:w3=w1w2dw3dw2=w1

dzdw2=dzdw3dw3dw2=1×w1=w1

Lo cual, como ya sabemos por pase adelantado, es:

dzdw2=w1=2

Finalmente, contribuye a través de y . Una vez más, a partir de las reglas de derivadas parciales, sabemos que y . Así:w1zw3w4dw3dw1=w2dw4dw1=cos(w1)

dzdw1=dzdw3dw3dw1+dzdw4dw4dw1=w2+cos(w1)

Y de nuevo, dadas las entradas conocidas, podemos calcularlo:

dzdw1=w2+cos(w1)=3+cos(2) =2.58

Como y son solo alias para y , obtenemos nuestra respuesta:w1w2x1x2

dzdx1=2.58
dzdx2=2

¡Y eso es!


Esta descripción se refiere solo a entradas escalares, es decir, números, pero de hecho también se puede aplicar a matrices multidimensionales como vectores y matrices. Dos cosas que uno debe tener en cuenta al diferenciar expresiones con tales objetos:

  1. Los derivados pueden tener una dimensionalidad mucho más alta que las entradas o salidas, por ejemplo, la derivada del vector wrt es una matriz y la derivada de la matriz wrt es una matriz de 4 dimensiones (a veces denominada tensor). En muchos casos, estos derivados son muy escasos.
  2. Cada componente en la matriz de salida es una función independiente de 1 o más componentes de la (s) matriz (s) de entrada. Por ejemplo, si y e son vectores, nunca depende de , sino solo del subconjunto de . En particular, esto significa que encontrar la derivada reduce a rastrear cómo depende de .y=f(x)xyyiyjxkdyidxjyixj

El poder de la diferenciación automática es que puede manejar estructuras complicadas de lenguajes de programación como condiciones y bucles. Sin embargo, si todo lo que necesita son expresiones algebraicas y tiene un marco lo suficientemente bueno como para trabajar con representaciones simbólicas, es posible construir expresiones completamente simbólicas. De hecho, en este ejemplo podríamos producir la expresión y calcular esta derivada para cualquier entrada que queramos.dzdw1=w2+cos(w1)=x2+cos(x1)

amigo
fuente
1
Muy útil pregunta / respuesta. Gracias. Solo una pequeña crítica: parece que te mueves en una estructura de árbol sin explicar (es cuando comienzas a hablar de padres, etc.)
MadHatter
1
Además, no hará daño aclarar por qué necesitamos semillas.
MadHatter
@MadHatter gracias por el comentario. Traté de reformular un par de párrafos (estos que se refieren a los padres) para enfatizar una estructura gráfica. También agregué "semilla" al texto, aunque este nombre en sí mismo puede ser engañoso en mi opinión: en AD semilla siempre es una expresión fija - , no es algo que pueda elegir o generar. dzdz=1
amigo
¡Gracias! Me di cuenta cuando tienes que establecer más de una "semilla", generalmente uno elige 1 y 0. Me gustaría saber por qué. Quiero decir, uno toma el "cociente" de un wrt diferencial, por lo que "1" está al menos intuitivamente justificado. Pero ¿qué pasa con 0? ¿Y si uno tiene que recoger más de 2 semillas?
MadHatter
1
Según tengo entendido, más de una semilla se usa solo en AD de modo directo. En este caso, establece la semilla en 1 para una variable de entrada que desea diferenciar con respecto y establece la semilla en 0 para todas las demás variables de entrada para que no contribuyan al valor de salida. En el modo inverso, establece la semilla en una variable de salida , y normalmente solo tiene una variable de salida. Supongo que puede construir una tubería AD en modo inverso con varias variables de salida y establecerlas todas, pero una en 0, para obtener el mismo efecto que en el modo directo, pero nunca he investigado esta opción.
amigo