¿Cómo funciona el truco de reparameterization para VAE y por qué es importante?

57

¿Cómo funciona el truco de reparameterization para autoencoders variacionales (VAE)? ¿Existe una explicación intuitiva y fácil sin simplificar las matemáticas subyacentes? ¿Y por qué necesitamos el 'truco'?

David Dao
fuente
55
Una parte de la respuesta es notar que todas las distribuciones normales son versiones escaladas y traducidas de Normal (1, 0). Para dibujar desde Normal (mu, sigma) puede dibujar desde Normal (1, 0), multiplicar por sigma (escala) y agregar mu (traducir).
monje el
@monk: debería haber sido Normal (0,1) en lugar de (1,0) a la derecha o de lo contrario, multiplicar y desplazar se volvería completamente heno.
Rika
@Breeze Ha! Si, por supuesto, gracias.
monje

Respuestas:

57

Después de leer las diapositivas del taller NIPS 2015 de Kingma , me di cuenta de que necesitamos el truco de reparametrización para propagar hacia atrás a través de un nodo aleatorio.

Intuitivamente, en su forma original, los VAE toman muestras de un nodo aleatorio que se aproxima mediante el modelo paramétrico del verdadero posterior. Backprop no puede fluir a través de un nodo aleatorio.q ( z ϕ , x )zq(zϕ,x)

La introducción de un nuevo parámetro nos permite volver a parametrizar de una manera que permita que el backprop fluya a través de los nodos deterministas.zϵz

forma original y reparameterized

David Dao
fuente
3
¿Por qué es determinista ahora a la derecha? z
bringingdownthegauss
2
No lo es, pero no es una "fuente de aleatoriedad": este rol ha sido asumido por . ϵ
quant_dev
Tenga en cuenta que este método se ha propuesto varias veces antes de 2014: blog.shakirm.com/2015/10/…
quant_dev
2
¡Tan simple, tan intuitivo! ¡Gran respuesta!
Serhiy
2
Lamentablemente no lo es. La forma original todavía puede ser retropropagable, sin embargo, con una mayor varianza. Los detalles se pueden encontrar en mi publicación .
JP Zhang
56

Supongamos que tenemos una distribución normal que está parametrizada por θ , específicamente q θ ( x ) = N ( θ , 1 ) . Queremos resolver el siguiente problema min θqθqθ(X)=norte(θ,1) Esto es por supuesto un problema bastante tonto y la óptima θ es obvia. Sin embargo, aquí solo queremos entender cómo el truco de reparameterización ayuda a calcular el gradiente de este objetivo E q [ x 2 ] .

minθmiq[X2]
θmiq[X2]

Una forma de calcular es la siguiente θ E q [ x 2 ] = θq θ ( x ) x 2 d x = x 2 θ q θ ( x ) q θ ( x )θmiq[X2]

θmiq[X2]=θqθ(X)X2reX=X2θqθ(X)qθ(X)qθ(X)reX=qθ(X)θIniciar sesiónqθ(X)X2reX=miq[X2θIniciar sesiónqθ(X)]

Para nuestro ejemplo donde , este método da θ E q [ x 2 ] = E q [ x 2 ( x - θ ) ]qθ(X)=norte(θ,1)

θmiq[X2]=miq[X2(X-θ)]

El truco de reparametrización es una forma de reescribir la expectativa para que la distribución con respecto a la cual tomamos el gradiente sea independiente del parámetro . Para lograr esto, necesitamos hacer que el elemento estocástico en q sea independiente de θ . Por lo tanto, escribimos x como x = θ + ϵ ,θqθX Entonces, podemos escribir E q [ x 2 ] = E p [ ( θ + ϵ ) 2 ] donde p es la distribución de ϵ , es decir, N ( 0 , 1 ) . Ahora podemos escribir la derivada de E q [ x 2 ] de la siguiente manera θ E q [ x 2 ] =

X=θ+ϵ,ϵnorte(0 0,1)
miq[X2]=mipags[(θ+ϵ)2]
pagsϵnorte(0 0,1)miq[X2]
θmiq[X2]=θmipags[(θ+ϵ)2]=mipags[2(θ+ϵ)]

Aquí hay un cuaderno de IPython que he escrito que analiza la varianza de estas dos formas de calcular gradientes. http://nbviewer.jupyter.org/github/gokererdogan/Notebooks/blob/master/Reparameterization%20Trick.ipynb

goker
fuente
44
¿Cuál es el theta "obvio" para la primera ecuación?
gwg
2
es 0. una forma de ver eso es notar que E [x ^ 2] = E [x] ^ 2 + Var (x), que es theta ^ 2 + 1 en este caso. Entonces theta = 0 minimiza este objetivo.
Goker
Entonces, ¿depende completamente del problema? Para decir min_ \ theta E_q [| x | ^ (1/4)], ¿podría ser completamente diferente?
Anne van Rossum
¿Qué depende del problema? La theta óptima? Si es así, sí, ciertamente depende del problema.
goker
θmiq[X2]=miq[X2(X-θ)qθ(X)]θmiq[X2]=miq[X2(X-θ)]
17

En la respuesta de Goker se da un ejemplo razonable de las matemáticas del "truco de reparameterización", pero alguna motivación podría ser útil. (No tengo permisos para comentar esa respuesta; por lo tanto, aquí hay una respuesta por separado).

solθ

solθ=θmiXqθ[...]

miXqθ[solθmist(X)]

solθmist(X)=...1qθ(X)θqθ(X)=...θIniciar sesión(qθ(X))

Xqθsolθmistsolθθ

solθmistsolθ

solθXXqθ(X)1qθ(X)XsolθqθsolθmistXqθθ, que puede estar lejos de ser óptimo (por ejemplo, un valor inicial elegido arbitrariamente). Es un poco como la historia de la persona borracha que busca sus llaves cerca de la farola (porque allí es donde puede ver / probar) en lugar de cerca de donde las dejó caer.

Xϵpagsθsolθpags

solθ=θmiϵpags[J(θ,ϵ)]=miϵpags[θJ(θ,ϵ)]
J(θ,ϵ)

θJ(θ,ϵ)pagsϵpagsθpags

θJ(θ,ϵ)solθsolθϵpagspagsϵJ

Espero que eso ayude.

Seth Bruder
fuente
"El factor de 1 / qθ (x) está aumentando su estimación para dar cuenta de esto, pero si nunca ve un valor de x, esa escala no ayudará". ¿Puedes explicarme mas?
czxttkl
qθXXsolθmist(X)1/ /qθ
10

Permítanme explicar primero, ¿por qué necesitamos el truco de reparametrización en VAE?

VAE tiene codificador y decodificador. Decodificador de muestras al azar de verdadero posterior Z ~ q (z∣ϕ, x) . Para implementar el codificador y el decodificador como una red neuronal, debe realizar una retropropagación mediante muestreo aleatorio y ese es el problema porque la retropropagación no puede fluir a través de un nodo aleatorio; Para superar este obstáculo, utilizamos el truco de reparameterización.

Ahora vamos a engañar. Dado que nuestro posterior está normalmente distribuido, podemos aproximarlo con otra distribución normal. Aproximamos Z con ε normalmente distribuido .

ingrese la descripción de la imagen aquí

Pero, ¿cómo es esto relevante?

Ahora, en lugar de decir que Z se muestrea a partir de q (z∣ϕ, x) , podemos decir que Z es una función que toma el parámetro (ε, (µ, L)) y estos µ, L proviene de la red neuronal superior (codificador) . Por lo tanto, mientras que la retropropagación todo lo que necesitamos es derivadas parciales wrt µ, L y ε es irrelevante para tomar derivados.

ingrese la descripción de la imagen aquí

Sherlock
fuente
El mejor video para entender este concepto. Recomendaría ver un video completo para una mejor comprensión, pero si desea comprender solo el truco de reparametrización, mire desde 8 minutos. youtube.com/channel/UCNIkB2IeJ-6AmZv7bQ1oBYg
Sherlock
9

Pensé que la explicación encontrada en el curso Stanford CS228 sobre modelos gráficos probabilísticos era muy buena. Se puede encontrar aquí: https://ermongroup.github.io/cs228-notes/extras/vae/

He resumido / copiado las partes importantes aquí por conveniencia / mi propia comprensión (aunque recomiendo encarecidamente que consulte el enlace original).

ϕmizq(zEl |X)[F(X,z)]

Si está familiarizado con los estimadores de la función de puntuación (creo que REINFORCE es solo un caso especial de esto), notará que ese es el problema que resuelven. Sin embargo, el estimador de la función de puntuación tiene una gran varianza, lo que genera dificultades para aprender modelos la mayor parte del tiempo.

qϕ(zEl |X)

ϵpags(ϵ)solϕ(ϵ,X)qϕ

Como ejemplo, usemos una q muy simple de la que tomamos muestras.

zqμ,σ=norte(μ,σ)
q
z=solμ,σ(ϵ)=μ+ϵσ
ϵnorte(0 0,1)

pags(ϵ)

ϕmizq(zEl |X)[F(X,z)]=miϵpags(ϵ)[ϕF(X,sol(ϵ,X))]

Esto tiene una varianza menor, por razones imo, no triviales. Consulte la parte D del apéndice aquí para obtener una explicación: https://arxiv.org/pdf/1401.4082.pdf

horace he
fuente
Hola, ¿sabes por qué en la implementación dividen el estándar por 2? (es decir, std = torch.exp (z_var / 2)) en la reparameterization?
Rika
4

Tenemos nuestro modelo probablístico. Y quiere recuperar los parámetros del modelo. Reducimos nuestra tarea a la optimización del límite inferior variacional (VLB). Para hacer esto, deberíamos poder hacer dos cosas:

  • calcular VLB
  • obtener gradiente de VLB

Los autores sugieren usar el Estimador de Monte Carlo para ambos. Y, de hecho, presentan este truco para obtener un estimador de gradiente Monte Carlo más preciso de VLB.

Es solo una mejora del método numérico.

Anton
fuente
2

El truco de reparameterization reduce dramáticamente la varianza del estimador MC para el gradiente. Entonces es una técnica de reducción de varianza :

ϕmiq(z(yo)X(yo);ϕ)[Iniciar sesiónpags(X(yo)z(yo),w)]

ϕmiq(z(yo)X(yo);ϕ)[Iniciar sesiónpags(X(yo)z(yo),w)]=miq(z(yo)X(yo);ϕ)[Iniciar sesiónpags(X(yo)z(yo),w)ϕIniciar sesiónqϕ(z)]
pags(X(yo)z(yo),w)Iniciar sesiónpags(X(yo)z(yo),w)es muy grande y el valor en sí mismo es negativo. Entonces tendríamos una gran varianza.

Con reparametrización z(yo)=sol(ϵ(yo),X(yo),ϕ)

ϕmiq(z(yo)X(yo);ϕ)[Iniciar sesiónpags(X(yo)z(yo),w)]=mipags(ϵ(yo))[ϕIniciar sesiónpags(X(yo)sol(ϵ(yo),X(yo),ϕ),w)]

pags(ϵ(yo))pags(ϵ(yo))ϕ

z(yo)z(yo)=sol(ϵ(yo),X(yo),ϕ)

Chris Elgoog
fuente