¿Cómo logra un modelo de regresión logística simple una precisión de clasificación del 92% en MNIST?

73

A pesar de que todas las imágenes en el conjunto de datos MNIST están centradas, con una escala similar y boca arriba sin rotaciones, tienen una variación significativa en la escritura a mano que me desconcierta cómo un modelo lineal logra una precisión de clasificación tan alta.

Hasta donde puedo visualizar, dada la importante variación en la escritura a mano, los dígitos deben ser linealmente inseparables en un espacio dimensional de 784, es decir, debe haber un límite no lineal poco complejo (aunque no muy complejo) que separa los diferentes dígitos. , similar al ejemplo bien citado de XOR donde las clases positivas y negativas no pueden separarse por ningún clasificador lineal. Me parece desconcertante cómo la regresión logística de clases múltiples produce una precisión tan alta con características completamente lineales (sin características polinómicas).

Como ejemplo, dado cualquier píxel en la imagen, diferentes variaciones escritas a mano de los dígitos 2 y 3 pueden hacer que ese píxel se ilumine o no. Por lo tanto, con un conjunto de pesos aprendidos, cada píxel puede hacer que un dígito parezca un 2 y un 3 . Solo con una combinación de valores de píxeles debería ser posible decir si un dígito es un 2 o un 3 . Esto es cierto para la mayoría de los pares de dígitos. Entonces, ¿cómo es que la regresión logística, que ciegamente basa su decisión de manera independiente en todos los valores de píxeles (sin considerar ninguna dependencia entre píxeles), es capaz de lograr tan altas precisiones.

Sé que estoy equivocado en alguna parte o simplemente estoy sobreestimando la variación en las imágenes. Sin embargo, sería genial si alguien pudiera ayudarme con una intuición sobre cómo los dígitos son 'casi' linealmente separables.

Nitish Agarwal
fuente
Eche un vistazo al libro de texto Aprendizaje estadístico con escasez: el lazo y las generalizaciones 3.3.1 Ejemplo: dígitos escritos a
Adrian
He tenido curiosidad: ¿qué tan bien hace algo como un modelo lineal penalizado (es decir, glmnet) en el problema? Si mal no recuerdo, lo que está informando es la precisión fuera de la muestra sin potencializar.
Cliff AB

Respuestas:

91

tl; dr Aunque este es un conjunto de datos de clasificación de imágenes, sigue siendo una tarea muy fácil , para la cual se puede encontrar fácilmente un mapeo directo desde las entradas hasta las predicciones.


Responder:

Esta es una pregunta muy interesante y, gracias a la simplicidad de la regresión logística, puede encontrar la respuesta.

78478428×28

Tenga en cuenta, de nuevo, que estos son los pesos .

Ahora eche un vistazo a la imagen de arriba y concéntrese en los dos primeros dígitos (es decir, cero y uno). Los pesos azules significan que la intensidad de este píxel contribuye mucho para esa clase y los valores rojos significan que contribuye negativamente.

0

1

2378

A través de esto, puede ver que la regresión logística tiene una muy buena posibilidad de obtener muchas imágenes correctas y es por eso que tiene una puntuación tan alta.


El código para reproducir la figura anterior está un poco anticuado, pero aquí tienes:

import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data

# Load MNIST:
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

# Create model
x = tf.placeholder(tf.float32, shape=(None, 784))
y = tf.placeholder(tf.float32, shape=(None, 10))

W = tf.Variable(tf.zeros((784,10)))
b = tf.Variable(tf.zeros((10)))
z = tf.matmul(x, W) + b

y_hat = tf.nn.softmax(z)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y * tf.log(y_hat), reduction_indices=[1]))
optimizer = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) # 

correct_pred = tf.equal(tf.argmax(y_hat, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

# Train model
batch_size = 64
with tf.Session() as sess:

    loss_tr, acc_tr, loss_ts, acc_ts = [], [], [], []

    sess.run(tf.global_variables_initializer()) 

    for step in range(1, 1001):

        x_batch, y_batch = mnist.train.next_batch(batch_size) 
        sess.run(optimizer, feed_dict={x: x_batch, y: y_batch})

        l_tr, a_tr = sess.run([cross_entropy, accuracy], feed_dict={x: x_batch, y: y_batch})
        l_ts, a_ts = sess.run([cross_entropy, accuracy], feed_dict={x: mnist.test.images, y: mnist.test.labels})
        loss_tr.append(l_tr)
        acc_tr.append(a_tr)
        loss_ts.append(l_ts)
        acc_ts.append(a_ts)

    weights = sess.run(W)      
    print('Test Accuracy =', sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})) 

# Plotting:
for i in range(10):
    plt.subplot(2, 5, i+1)
    weight = weights[:,i].reshape([28,28])
    plt.title(i)
    plt.imshow(weight, cmap='RdBu')  # as noted by @Eric Duminil, cmap='gray' makes the numbers stand out more
    frame1 = plt.gca()
    frame1.axes.get_xaxis().set_visible(False)
    frame1.axes.get_yaxis().set_visible(False)
Djib2011
fuente
13
2378
13
Por supuesto, ayuda que las muestras MNIST estén centradas, escaladas y normalizadas por contraste antes de que el clasificador las vea. No tiene que responder preguntas como "¿y si el borde del cero realmente atraviesa el centro de la caja?" porque el preprocesador ya ha recorrido un largo camino para que todos los ceros se vean iguales.
Hobbs
1
@EricDuminil Agregué un elogio en el script con su sugerencia. Muchas gracias por el aporte! : D
Djib2011
1
@NitishAgarwal, si crees que esta respuesta es la respuesta a tu pregunta, considera marcarla como tal.
sintax
16
Para alguien que esté interesado en este tipo de procesamiento pero no esté particularmente familiarizado con este tipo de respuesta, esta respuesta ofrece un fantástico ejemplo intuitivo de la mecánica.
Chrylis -en huelga-