¿Cómo usar RBM para la clasificación?

12

En este momento estoy jugando con Restricted Boltzmann Machines y, dado que estoy allí, me gustaría tratar de clasificar los dígitos escritos a mano.

El modelo que creé ahora es un modelo generativo bastante elegante, pero no sé cómo ir más allá.

En este artículo, el autor dice que después de crear un buen modelo generativo, uno " luego entrena un clasificador discriminativo (es decir, clasificador lineal, máquina de vectores de soporte) en la parte superior del RBM usando las muestras etiquetadas " y además declara " ya que propaga el vectores de datos a las unidades ocultas del modelo RBM para obtener vectores de unidades ocultas, o una representación de nivel superior de los datos ". El problema es que no estoy seguro de hacerlo bien.

¿Eso significa que todo lo que tengo que hacer es propagar la entrada a las unidades ocultas y allí tengo mi función RBM para la clasificación?

¿Alguien puede explicarme este proceso?

nombre para mostrar
fuente
La máquina de Boltzmann restringida es uno de los primeros componentes utilizados para el aprendizaje profundo. De hecho, el primer trabajo importante en DNN realizado por Hinton es que la red de creencias profundas se basó en RBM. Busque este documento (red de creencias profundas, 2007, para Hinton) para obtener más información. En su sitio web puede encontrar recursos muy importantes, así como un experimento de demostración cs.toronto.edu/~hinton/digits.html
Bashar Haddad
@hbaderts Empecé a jugar con RBM. La respuesta aceptada fue fácil de leer. Quería pedir aclaraciones, las capas ocultas de RBM son aleatorias después del muestreo de la distribución binaria. Para la clasificación, ¿se utilizan las probabilidades de unidades ocultas o las unidades ocultas muestreadas de una distribución binaria (1 y 0) pasadas al clasificador?
M3tho5

Respuestas:

15

Revisión de máquinas de Boltzmann restringidas

Una máquina de Boltzmann restringida (RBM) es un modelo generativo , que aprende una distribución de probabilidad sobre la entrada. Eso significa que, después de ser entrenado, el RBM puede generar nuevas muestras a partir de la distribución de probabilidad aprendida. La distribución de probabilidad sobre las unidades visibles viene dada por p ( vh ) = V i = 0 p ( v ih ) , donde p ( v ih ) = σ ( a i + Hv

p(vh)=i=0Vp(vih),
yσes la función sigmoidea,aies el sesgo del nodo visiblei, ywjies el peso dehjavi. De estas dos ecuaciones, se deduce quep(vh)solo depende de los estados ocultosh. Eso significa que la información sobre cómose generauna muestra visiblev, debe almacenarse en las unidades ocultas, los pesos y los sesgos.
p(vih)=σ(ai+j=0Hwjihj)
σaiiwjihjvip(vh)hv

Usando RBMs para la clasificación

h

Este vector oculto es solo una versión transformada de los datos de entrada; esto no puede clasificar nada por sí mismo. Para hacer una clasificación, entrenaría cualquier clasificador (clasificador lineal, SVM, una red neuronal de avance, o cualquier otra cosa) con el vector oculto en lugar de los datos de entrenamiento "en bruto" como entradas.

Si está construyendo una red de creencias profundas (DBN), que se usó para entrenar previamente redes neuronales de alimentación profunda de manera no supervisada, tomaría este vector oculto y lo usaría como entrada para un nuevo RBM, que usted apila en lo alto de ello. De esa manera, puede entrenar la red capa por capa hasta alcanzar el tamaño deseado, sin necesidad de ningún dato etiquetado. Finalmente, agregaría, por ejemplo, una capa softmax a la parte superior, y entrenaría a toda la red con retropropagación en su tarea de clasificación.

hbaderts
fuente
Gracias por la edición @ Seanny123, esto hace que sea mucho más fácil de leer.
hbaderts
5

@hbaderts describió todo el flujo de trabajo a la perfección. Sin embargo, puede que no tenga sentido en caso de que sea completamente nuevo en esta idea. Por lo tanto, voy a explicarlo de manera simple (por lo tanto, omitiré los detalles):

Piense en las redes profundas como una función para transformar sus datos. Ejemplos de transformaciones incluyen la normalización, el registro de datos, etc. Las redes profundas que está entrenando tienen múltiples capas. Cada una de estas capas se entrena utilizando algún tipo de algoritmo de aprendizaje. Para la primera capa, pasa los datos originales como entrada e intenta obtener una función que le devuelva esos "mismos datos originales" que la salida. Sin embargo, no obtienes el resultado perfecto. Por lo tanto, está obteniendo una versión transformada de su entrada como salida de la primera capa.

Ahora, para la segunda capa, toma esos "datos transformados" y los pasa como entrada y repite todo el proceso de aprendizaje. Sigue haciendo eso para todas las capas en su red profunda.

En la última capa, lo que obtienes es una "versión transformada" de tus datos de entrada originales. Esto puede pensarse en una abstracción de nivel superior de sus datos de entrada originales. Tenga en cuenta que todavía no ha utilizado las etiquetas / resultados en su red profunda. Por lo tanto, todo hasta este punto es aprendizaje no supervisado. Esto se llama pre-entrenamiento en capas.

Ahora, desea entrenar un modelo de clasificador / regresión y este es un problema de aprendizaje supervisado. La forma de lograr ese objetivo es tomar la "versión transformada final" de su entrada original de la última capa en su red profunda y usarla como entrada para cualquier clasificador (por ejemplo, clasificador knn / clasificador softmax / regresión logística, etc.). Esto se llama apilamiento.

Cuando entrena a este clasificador / alumno de último paso, propaga todo su aprendizaje en la red completa. Esto garantiza que pueda aprender de las etiquetas / salidas y modificar los parámetros aprendidos en capas en consecuencia.

Entonces, una vez que haya entrenado su modelo generativo, tome el resultado de su modelo generativo y úselo como entrada para un clasificador / alumno. Deje que el error fluya a través de toda la red a medida que continúa el aprendizaje para que pueda modificar el parámetro de capa sabio aprendido en los pasos anteriores.

Sal
fuente