¿Cómo se entrena el generador en una GAN?

9

El documento sobre GAN dice que el discriminador usa el siguiente gradiente para entrenar:

θre1metroyo=1metro[Iniciar sesiónre(X(yo))+Iniciar sesión(1-re(sol(z(yo))))]

Los valores se muestrean, se pasan a través del generador para generar muestras de datos, y luego el discriminador se retroproyecta utilizando las muestras de datos generadas. Una vez que el generador genera los datos, ya no juega ningún papel en el entrenamiento del discriminador. En otras palabras, el generador se puede eliminar completamente de la métrica haciendo que genere muestras de datos y luego solo trabajando con las muestras.z

Sin embargo, estoy un poco más confundido acerca de cómo se entrena el generador. Utiliza el siguiente gradiente:

θsol1metroyo=1metro[Iniciar sesión(1-re(sol(z(yo))))]

En este caso, el discriminador es parte de la métrica. No se puede eliminar como en el caso anterior. Cosas como mínimos cuadrados o probabilidad de registro en modelos discriminativos regulares se pueden diferenciar fácilmente porque tienen una definición agradable y cercana. Sin embargo, estoy un poco confundido acerca de cómo se propaga cuando la métrica depende de otra red neuronal. ¿Esencialmente conecta las salidas del generador a las entradas del discriminador y luego trata todo como una red gigante donde los pesos en la parte del discriminador son constantes?

Fidias
fuente

Respuestas:

10

Es útil pensar en este proceso en pseudocódigo. Sea generator(z)una función que toma un vector de ruido muestreado uniformemente zy devuelve un vector del mismo tamaño que el vector de entrada X; Llamemos a esta longitud d. Sea discriminator(x)una función que toma un dvector dimensional y devuelve una probabilidad escalar que xpertenece a la distribución de datos verdadera. Para entrenamiento:

G_sample = generator(Z)
D_real = discriminator(X)
D_fake = discriminator(G_sample)

D_loss = maximize mean of (log(D_real) + log(1 - D_fake))
G_loss = maximize mean of log(D_fake)

# Only update D(X)'s parameters
D_solver = Optimizer().minimize(D_loss, theta_D)
# Only update G(X)'s parameters
G_solver = Optimizer().minimize(G_loss, theta_G)

# theta_D and theta_G are the weights and biases of D and G respectively
Repeat the above for a number of epochs

Entonces, sí, tiene razón en que esencialmente pensamos en el generador y el discriminador como una red gigante para alternar minibatches mientras usamos datos falsos. La función de pérdida del generador se encarga de los gradientes para esta mitad. Si piensa en este entrenamiento de red de forma aislada, entonces se entrena tal como lo haría normalmente con un MLP, siendo su entrada la salida de la última capa de la red del generador.

Puede seguir una explicación detallada con código en Tensorflow aquí (entre muchos lugares): http://wiseodd.github.io/techblog/2016/09/17/gan-tensorflow/

Debería ser fácil de seguir una vez que mira el código.

tejaskhot
fuente
1
¿Podría dar más detalles D_lossy G_loss? ¿Maximizando sobre qué espacio? IIUC, D_realy D_fakecada uno es un lote, ¿entonces estamos maximizando sobre el lote?
P i
@ Pi Sí, estamos maximizando sobre un lote.
tejaskhot
1

¿Esencialmente conecta las salidas del generador a las entradas del discriminador?> Y luego trata todo como una red gigante donde los pesos en la parte del discriminador son constantes?

En breve: Sí. (Busqué algunas de las fuentes de la GAN para verificar esto)

También hay mucho más en el entrenamiento de GAN como: deberíamos actualizar D y G cada vez o D en iteraciones impares y G en pares, y mucho más. También hay un muy buen artículo sobre este tema:

"Técnicas mejoradas para la formación de GAN"

Tim Salimans, Ian Goodfellow, Wojciech Zaremba, Vicki Cheung, Alec Radford, Xi Chen

https://arxiv.org/abs/1606.03498

Liberus
fuente
¿Podría por favor proporcionar enlaces a las fuentes que buscó? Sería útil para mí leerlos.
Vivek Subramanian
0

Recientemente he subido una colección de varios modelos de GAN en el repositorio de Github. Está basado en torch7 y es muy fácil de ejecutar. El código es lo suficientemente simple como para comprenderlo con resultados experimentales. Espero que esto ayude

https://github.com/nashory/gans-collection.torch

nashory
fuente