Arpillera de la función logística

15

Tengo dificultad para derivar el hessiano de la función objetivo, l(θ) , en regresión logística donde l(θ) es:

l(θ)=i=1m[yilog(hθ(xi))+(1yi)log(1hθ(xi))]

hθ(x) es una función logística. El Hessian esXTDX . Traté de deducirlo calculando2l(θ)θiθj , pero entonces no era obvio para mí como llegar a la notación de matriz a partir de2l(θ)θiθj .

¿Alguien sabe alguna forma limpia y fácil de derivar XTDX ?

DSKim
fuente
3
¿Qué obtuviste por ? 2lθiθj
Glen_b -Reinstala a Mónica el
1
Aquí hay un buen conjunto de diapositivas que muestran el cálculo exacto que está buscando: sites.stat.psu.edu/~jiali/course/stat597e/notes2/logit.pdf
Encontré un video maravilloso que computa el Hesse paso a paso. Regresión logística (binario): cálculo del hessiano
Naomi

Respuestas:

19

Aquí deduzco todas las propiedades e identidades necesarias para que la solución sea autónoma, pero aparte de eso, esta derivación es limpia y fácil. Formalicemos nuestra notación y escribamos la función de pérdida un poco más compacta. Considere m muestras {xi,yi} tal que xiRd y yiR . Recuerde que en la regresión logística binaria típicamente tenemos la función de hipótesis hθ ser la función logística. Formalmente

hθ(xi)=σ(ωTxi)=σ(zi)=11+ezi,

donde ωRd y zi=ωTxi . La función de pérdida (que creo que a los OP les falta un signo negativo) se define como:

l(ω)=i=1m(yilogσ(zi)+(1yi)log(1σ(zi)))

Hay dos propiedades importantes de la función logística que obtengo aquí para referencia futura. Primero, tenga en cuenta que 1σ(z)=11/(1+ez)=ez/(1+ez)=1/(1+ez)=σ(z) .

También tenga en cuenta que

zσ(z)=z(1+ez)1=ez(1+ez)2=11+ezez1+ez=σ(z)(1σ(z))

En lugar de tomar derivados con respecto a componentes, aquí trabajaremos directamente con vectores (puede revisar derivados con vectores aquí ). La arpillera de la función de pérdida l(ω) viene dada por 2l(ω) , pero primero recuerde que zω=xTωω=xTyzωT=ωTxωT=x.

Let li(ω)=yilogσ(zi)(1yi)log(1σ(zi)). Using the properties we derived above and the chain rule

logσ(zi)ωT=1σ(zi)σ(zi)ωT=1σ(zi)σ(zi)ziziωT=(1σ(zi))xilog(1σ(zi))ωT=11σ(zi)(1σ(zi))ωT=σ(zi)xi

It's now trivial to show that

li(ω)=li(ω)ωT=yixi(1σ(zi))+(1yi)xiσ(zi)=xi(σ(zi)yi)

whew!

Our last step is to compute the Hessian

2li(ω)=li(ω)ωωT=xixiTσ(zi)(1σ(zi))

For m samples we have 2l(ω)=i=1mxixiTσ(zi)(1σ(zi)). This is equivalent to concatenating column vectors xiRd into a matrix X of size d×m such that i=1mxixiT=XXT. The scalar terms are combined in a diagonal matrix D such that Dii=σ(zi)(1σ(zi)). Finally, we conclude that

H(ω)=2l(ω)=XDXT

A faster approach can be derived by considering all samples at once from the beginning and instead work with matrix derivatives. As an extra note, with this formulation it's trivial to show that l(ω) is convex. Let δ be any vector such that δRd. Then

δTH(ω)δ=δT2l(ω)δ=δTXDXTδ=δTXD(δTX)T=δTDX20

since D>0 and δTX0. This implies H is positive-semidefinite and therefore l is convex (but not strongly convex).

Manuel Morales
fuente
2
In the last equation, shouldn't it be ||δD1/2X|| since XDX = XD1/2(XD1/2)?
appletree
1
Shouldn't it be XTDX?
Chintan Shah