Cómo ponderar la pérdida de KLD frente a la pérdida de reconstrucción en el codificador automático variacional

26

En casi todos los ejemplos de código que he visto de un VAE, las funciones de pérdida se definen de la siguiente manera (este es el código de tensorflow, pero he visto algo similar para theano, torch, etc.) También es para un convnet, pero eso tampoco es demasiado relevante , solo afecta a los ejes donde se toman las sumas):

# latent space loss. KL divergence between latent space distribution and unit gaussian, for each batch.
# first half of eq 10. in https://arxiv.org/abs/1312.6114
kl_loss = -0.5 * tf.reduce_sum(1 + log_sigma_sq - tf.square(mu) - tf.exp(log_sigma_sq), axis=1)

# reconstruction error, using pixel-wise L2 loss, for each batch
rec_loss = tf.reduce_sum(tf.squared_difference(y, x), axis=[1,2,3])

# or binary cross entropy (assuming 0...1 values)
y = tf.clip_by_value(y, 1e-8, 1-1e-8) # prevent nan on log(0)
rec_loss = -tf.reduce_sum(x * tf.log(y) + (1-x) * tf.log(1-y), axis=[1,2,3])

# sum the two and average over batches
loss = tf.reduce_mean(kl_loss + rec_loss)

Sin embargo, el rango numérico de kl_loss y rec_loss depende mucho de las atenuaciones de espacio latente y el tamaño de la característica de entrada (por ejemplo, resolución de píxeles) respectivamente. ¿Sería sensato reemplazar los reduce_sum's por reduce_mean para obtener por z-dim KLD y por píxel (o característica) LSE o BCE? Más importante aún, ¿cómo ponderamos la pérdida latente con la pérdida de reconstrucción cuando sumamos la pérdida final? ¿Es solo prueba y error? ¿o hay alguna teoría (o al menos regla general) para ello? No pude encontrar ninguna información sobre esto en ninguna parte (incluido el documento original).


El problema que tengo es que si el equilibrio entre las dimensiones de mi característica de entrada (x) y las dimensiones del espacio latente (z) no es 'óptimo', mis reconstrucciones son muy buenas, pero el espacio latente aprendido no está estructurado (si las dimensiones x es muy alto y el error de reconstrucción domina sobre KLD), o viceversa (las reconstrucciones no son buenas, pero el espacio latente aprendido está bien estructurado si KLD domina).

Me encuentro teniendo que normalizar la pérdida de reconstrucción (dividiendo por el tamaño de la característica de entrada) y KLD (dividiendo por las dimensiones z) y luego ponderando manualmente el término KLD con un factor de peso arbitrario (La normalización es para que pueda usar el mismo o peso similar independiente de las dimensiones de x o z ). Empíricamente, he encontrado que alrededor de 0.1 proporciona un buen equilibrio entre la reconstrucción y el espacio latente estructurado, lo que me parece un "punto ideal". Estoy buscando trabajo previo en esta área.


A pedido, notación matemática de arriba (enfocándose en la pérdida de L2 por error de reconstrucción)

Llunatminortet(yo)=-12j=1J(1+Iniciar sesión(σj(yo))2-(μj(yo))2-(σj(yo))2)

Lrmidoonorte(yo)=-k=1K(yk(yo)-Xk(yo))2

L(metro)=1METROyo=1METRO(Llunatminortet(yo)+Lrmidoonorte(yo))

Jzμσ2KMETRO(yo)yoL(metro)metro

memorándum
fuente

Respuestas:

17

Para cualquiera que se encuentre con esta publicación y también busque una respuesta, este hilo de Twitter ha agregado mucha información muy útil.

A saber:

beta-VAE: Aprendizaje de conceptos visuales básicos con un marco variacional restringido

βnorteormetro

y lectura relacionada (donde se discuten temas similares)

memorándum
fuente
7

Me gustaría agregar un artículo más relacionado con este tema (no puedo comentar debido a mi baja reputación en este momento).

En la subsección 3.1 del documento, los autores especificaron que no pudieron entrenar una implementación directa de VAE que ponderara igualmente la probabilidad y la divergencia KL. En su caso, la pérdida de KL se redujo indeseablemente a cero, aunque se esperaba que tuviera un valor pequeño. Para superar esto, propusieron utilizar el "recocido de costos KL", que aumentó lentamente el factor de peso del término de divergencia KL (curva azul) de 0 a 1.

Figura 2. El peso del término de divergencia KL del límite inferior variacional de acuerdo con un programa de recocido sigmoide típico trazado junto con el valor (no ponderado) del término de divergencia KL para nuestro VAE en el Penn TreeBank.

Esta solución alternativa también se aplica en Ladder VAE.

Papel:

Bowman, SR, Vilnis, L., Vinyals, O., Dai, AM, Jozefowicz, R. y Bengio, S., 2015. Generando oraciones desde un espacio continuo . preimpresión de arXiv arXiv: 1511.06349.

Cuong
fuente