Estoy tratando de usar la pérdida al cuadrado para hacer una clasificación binaria en un conjunto de datos de juguete.
Estoy usando mtcars
un conjunto de datos, uso millas por galón y peso para predecir el tipo de transmisión. El gráfico a continuación muestra los dos tipos de datos del tipo de transmisión en diferentes colores y el límite de decisión generado por la función de pérdida diferente. La pérdida al cuadrado es
donde es la etiqueta de verdad básica (0 o 1) y es la probabilidad predicha . En otras palabras, estoy reemplazando la pérdida logística con pérdida cuadrada en la configuración de clasificación, otras partes son iguales.
Para un ejemplo de juguete con mtcars
datos, en muchos casos, obtuve un modelo "similar" a la regresión logística (ver figura siguiente, con semilla aleatoria 0).
Pero en algunas cosas (si lo hacemos set.seed(1)
), la pérdida al cuadrado parece no funcionar bien.
¿Que está sucediendo aquí? La optimización no converge? ¿La pérdida logística es más fácil de optimizar en comparación con la pérdida al cuadrado? Cualquier ayuda sería apreciada.
Código
d=mtcars[,c("am","mpg","wt")]
plot(d$mpg,d$wt,col=factor(d$am))
lg_fit=glm(am~.,d, family = binomial())
abline(-lg_fit$coefficients[1]/lg_fit$coefficients[3],
-lg_fit$coefficients[2]/lg_fit$coefficients[3])
grid()
# sq loss
lossSqOnBinary<-function(x,y,w){
p=plogis(x %*% w)
return(sum((y-p)^2))
}
# ----------------------------------------------------------------
# note, this random seed is important for squared loss work
# ----------------------------------------------------------------
set.seed(0)
x0=runif(3)
x=as.matrix(cbind(1,d[,2:3]))
y=d$am
opt=optim(x0, lossSqOnBinary, method="BFGS", x=x,y=y)
abline(-opt$par[1]/opt$par[3],
-opt$par[2]/opt$par[3], lty=2)
legend(25,5,c("logisitc loss","squared loss"), lty=c(1,2))
fuente
optim
te dice que no ha terminado, eso es todo: está convergiendo. Puede aprender mucho volviendo a ejecutar su código con el argumento adicionalcontrol=list(maxit=10000)
, trazando su ajuste y comparando sus coeficientes con los originales.Respuestas:
Parece que ha solucionado el problema en su ejemplo particular, pero creo que todavía vale la pena estudiar más detenidamente la diferencia entre los mínimos cuadrados y la regresión logística de máxima probabilidad.
Consigamos algo de notación. LetLS(yi,y^i)=12(yi−y^i)2 yLL(yi,y^i)=yilogy^i+(1−yi)log(1−y^i) . Si estamos haciendo máxima verosimilitud (o mínimo registro de probabilidad negativo como yo estoy haciendo aquí), tenemos
β L:=argminb∈ Rβ^L:=argminb∈Rp−∑i=1nyilogg−1(xTib)+(1−yi)log(1−g−1(xTib))
cong como nuestra función de enlace.
Alternativamente tenemos β S : = argmin b ∈ R p 1β^S:=argminb∈Rp12∑i=1n(yi−g−1(xTib))2
como la solución de mínimos cuadrados. Por lo tanto β SminimizaLSy de manera similar paraLLβ^S LS LL .
DejefS y fL ser las funciones objetivo correspondientes a minimizar LS y LL respectivamente como se hace para β S y β L . Por último, dejar que h = g - 1 por loβ^S β^L h=g−1 y^i=h(xTib) . Tenga en cuenta que si estamos usando el enlace canónico tenemos
h(z)=11+e−z⟹h′(z)=h(z)(1−h(z)).
Para la regresión logística regular tenemos∂fL∂bj=−∑i=1nh′(xTib)xij(yih(xTib)−1−yi1−h(xTib)).
Usandoh′=h⋅(1−h) podemos simplificar esto a
∂fL∂bj=−∑i=1nxij(yi(1−y^i)−(1−yi)y^i)=−∑i=1nxij(yi−y^i)
entonces
∇fL(b)=−XT(Y−Y^).
A continuación, hagamos segundas derivadas. El hessiano
Comparemos esto con mínimos cuadrados.
Esto significa que tenemos∇fS(b)=−XTA(Y−Y^).
Este es un punto vital: el gradiente es casi el mismo, excepto para todos i y i ( 1 - Y i ) ∈ ( 0 , 1 ) , así que básicamente estamos aplanamiento de la pendiente en relación con ∇ f L . Esto hará que la convergencia sea más lenta.y^i(1−y^i)∈(0,1) ∇fL
Para el Hessian podemos escribir primero∂fS∂bj=−∑i=1nxij(yi−y^i)y^i(1−y^i)=−∑i=1nxij(yiy^i−(1+yi)y^2i+y^3i).
LetB=diag(yi−2(1+yi)y^i+3y^2i) . We now have
HS=−XTABX.
Unfortunately for us, the weights inB are not guaranteed to be non-negative: if yi=0 then yi−2(1+yi)y^i+3y^2i=y^i(3y^i−2) which is positive iff y^i>23 . Similarly, if yi=1 then yi−2(1+yi)y^i+3y^2i=1−4y^i+3y^2i which is positive when y^i<13 (it's also positive for y^i>1 but that's not possible). This means that HS is not necessarily PSD, so not only are we squashing our gradients which will make learning harder, but we've also messed up the convexity of our problem.
All in all, it's no surprise that least squares logistic regression struggles sometimes, and in your example you've got enough fitted values close to0 or 1 so that y^i(1−y^i) can be pretty small and thus the gradient is quite flattened.
Connecting this to neural networks, even though this is but a humble logistic regression I think with squared loss you're experiencing something like what Goodfellow, Bengio, and Courville are referring to in their Deep Learning book when they write the following:
and, in 6.2.2,
(both excerpts are from chapter 6).
fuente
I would thank to thank @whuber and @Chaconne for help. Especially @Chaconne, this derivation is what I wished to have for years.
The problem IS in the optimization part. If we set the random seed to 1, the default BFGS will not work. But if we change the algorithm and change the max iteration number it will work again.
As @Chaconne mentioned, the problem is squared loss for classification is non-convex and harder to optimize. To add on @Chaconne's math, I would like to present some visualizations on to logistic loss and squared loss.
We will change the demo data from3 coefficients including the intercept. We will use another toy data set generated from 2 parameters, which is better for visualization.
mtcars
, since the original toy example hasmlbench
, in this data set, we setHere is the demo
The data is shown in the left figure: we have two classes in two colors. x,y are two features for the data. In addition, we use red line to represent the linear classifier from logistic loss, and the blue line represent the linear classifier from squared loss.
The middle figure and right figure shows the contour for logistic loss (red) and squared loss (blue). x, y are two parameters we are fitting. The dot is the optimal point found by BFGS.
From the contour we can easily see how why optimizing squared loss is harder: as Chaconne mentioned, it is non-convex.
Here is one more view from persp3d.
Code
fuente