¿Cuáles son la media y la varianza de una normal multivariada censurada por 0?

9

Deje que esté en . ¿Cuáles son la media y la matriz de covarianza de (con el máximo calculado por elemento)?ZN(μ,Σ)RdZ+=max(0,Z)

Esto surge, por ejemplo, porque, si usamos la función de activación ReLU dentro de una red profunda, y asumimos a través del CLT que las entradas a una capa dada son aproximadamente normales, entonces esta es la distribución de las salidas.

(Estoy seguro de que muchas personas han calculado esto antes, pero no pude encontrar el resultado enumerado en ninguna parte de una manera razonablemente legible).

Dougal
fuente
Simplificaría su respuesta, tal vez en gran medida, observar que puede obtenerla combinando los resultados de dos preguntas separadas: (1) cuáles son los momentos de una distribución Normal truncada y (2) cuáles son los momentos de una mezcla ? Este último es sencillo y todo lo que necesita hacer es citar resultados para el primero.
whuber
@whuber Hmm. Aunque no lo dije explícitamente, eso es esencialmente lo que hago en mi respuesta, excepto que no encontré resultados para una distribución bivariada truncada con una media general y una varianza, por lo que tuve que escalar y cambiar de todos modos. ¿Hay alguna forma de derivar, por ejemplo, la covarianza sin hacer la cantidad de álgebra que tenía que hacer? Ciertamente no estoy afirmando que nada en esta respuesta sea novedoso, solo que el álgebra era tedioso y propenso a errores, y tal vez alguien más encuentre la solución útil.
Dougal
Correcto: estoy seguro de que su álgebra es equivalente a lo que describí, por lo que parece que compartimos una apreciación por simplificar el álgebra. Una manera fácil de reducir el álgebra es estandarizar los elementos diagonales de a la unidad, porque todo lo que hace es establecer una unidad de medida para cada variable. En ese punto, puede conectar directamente los resultados de Rosenbaum a las expresiones (simples, obvias) para momentos de mezclas. Ya sea que valga la pena la simplificación algebraica puede ser una cuestión de gustos: sin simplificación, conduce a un programa de computadora simple y modular. Σ
whuber
1
Supongo que uno podría escribir un programa que calcule momentos directamente con los resultados de Rosenbaum y mezcle adecuadamente, y luego los cambie y escale nuevamente al espacio original. Eso probablemente habría sido más rápido que la forma en que lo hice.
Dougal

Respuestas:

7

Primero podemos reducir esto para depender solo de ciertos momentos de distribuciones normales univariadas / bivariadas: tenga en cuenta, por supuesto, que

E[Z+]=[E[(Zi)+]]iCov(Z+)=[Cov((Zi)+,(Zj)+)]ij,
y debido a que estamos haciendo transformaciones coordinadas de ciertas dimensiones de una distribución normal, solo Es necesario preocuparse por la media y la varianza de una normal censurada 1d y la covarianza de dos normales censuradas 1d.

Usaremos algunos resultados de

S Rosenbaum (1961). Momentos de una distribución normal bivariada truncada . JRSS B, vol. 23 págs. 405-408. ( jstor )

Rosenbaum considera y considera el truncamiento al evento .

[X~Y~]N([00],[1ρρ1]),
V={X~aX,Y~aY}

Específicamente, usaremos los siguientes tres resultados, his (1), (3) y (5). Primero, defina lo siguiente:

qx=ϕ(ax)qy=ϕ(ay)Qx=Φ(ax)Qy=Φ(ay)Rxy=Φ(ρaxay1ρ2)Ryx=Φ(ρayax1ρ2)rxy=1ρ22πϕ(h22ρhk+k21ρ2)

Ahora, Rosenbaum muestra que:

(1)Pr(V)E[X~V]=qxRxy+ρqyRyx(3)Pr(V)E[X~2V]=Pr(V)+axqxRxy+ρ2ayqyRyx+ρrxy(5)Pr(V)E[X~Y~V]=ρPr(V)+ρaxqxRxy+ρayqyRyx+rxy.

Será útil considerar también el caso especial de (1) y (3) con , es decir, un truncamiento 1d: ay=

(*)Pr(V)E[X~V]=qx(**)Pr(V)E[X~2V]=Pr(V)=Qx.

Ahora queremos considerar

[XY]=[μxμy]+[σx00σy][X~Y~]N([μXμY],[σx2ρσxσyρσxσyσy2])=N(μ,Σ).

Usaremos que son los valores de y cuando , .

ax=μxσxay=μyσy,
X~Y~X=0Y=0

Ahora, usando (*), obtenemos y usando tanto (*) como (**) produce para que

E[X+]=Pr(X+>0)E[XX>0]+Pr(X+=0)0=Pr(X>0)(μx+σxE[X~X~ax])=Qxμx+qxσx,
E[X+2]=Pr(X+>0)E[X2X>0]+Pr(X+=0)0=Pr(X~ax)E[(μx+σxX~)2X~ax]=Pr(X~ax)E[μx2+μxσxX~+σx2X~2X~ax]=Qxμx2+qxμxσx+Qxσx2
Var[X+]=E[X+2]E[X+]2=Qxμx2+qxμxσx+Qxσx2Qx2μx2qx2σx22qxQxμxσx=Qx(1Qx)μx2+(12Qx)qxμxσx+(Qxqx2)σx2.

Para encontrar , necesitaremos Cov(X+,Y+)

E[X+Y+]=Pr(V)E[XYV]+Pr(¬V)0=Pr(V)E[(μx+σxX~)(μy+σyY~)V]=μxμyPr(V)+μyσxPr(V)E[X~V]+μxσyPr(V)E[Y~V]+σxσyPr(V)E[X~Y~V]=μxμyPr(V)+μyσx(qxRxy+ρqyRyx)+μxσy(ρqxRxy+qyRyx)+σxσy(ρPr(V)ρμxqxRxy/σxρμyqyRyx/σy+rxy)=(μxμy+σxσyρ)Pr(V)+(μyσx+μxσyρρμxσy)qxRxy+(μyσxρ+μxσyρμyσx)qyRyx+σxσyrxy=(μxμy+Σxy)Pr(V)+μyσxqxRxy+μxσyqyRyx+σxσyrxy,
y luego restando obtenemos E[X+]E[Y+]
Cov(X+,Y+)=(μxμy+Σxy)Pr(V)+μyσxqxRxy+μxσyqyRyx+σxσyrxy(Qxμx+qxσx)(Qyμy+qyσy).

Aquí hay un código de Python para calcular los momentos:

import numpy as np
from scipy import stats

def relu_mvn_mean_cov(mu, Sigma):
    mu = np.asarray(mu, dtype=float)
    Sigma = np.asarray(Sigma, dtype=float)
    d, = mu.shape
    assert Sigma.shape == (d, d)

    x = (slice(None), np.newaxis)
    y = (np.newaxis, slice(None))

    sigma2s = np.diagonal(Sigma)
    sigmas = np.sqrt(sigma2s)
    rhos = Sigma / sigmas[x] / sigmas[y]

    prob = np.empty((d, d))  # prob[i, j] = Pr(X_i > 0, X_j > 0)
    zero = np.zeros(d)
    for i in range(d):
        prob[i, i] = np.nan
        for j in range(i + 1, d):
            # Pr(X > 0) = Pr(-X < 0); X ~ N(mu, S) => -X ~ N(-mu, S)
            s = [i, j]
            prob[i, j] = prob[j, i] = stats.multivariate_normal.cdf(
                zero[s], mean=-mu[s], cov=Sigma[np.ix_(s, s)])

    mu_sigs = mu / sigmas

    Q = stats.norm.cdf(mu_sigs)
    q = stats.norm.pdf(mu_sigs)
    mean = Q * mu + q * sigmas

    # rho_cs is sqrt(1 - rhos**2); but don't calculate diagonal, because
    # it'll just be zero and we're dividing by it (but not using result)
    # use inf instead of nan; stats.norm.cdf doesn't like nan inputs
    rho_cs = 1 - rhos**2
    np.fill_diagonal(rho_cs, np.inf)
    np.sqrt(rho_cs, out=rho_cs)

    R = stats.norm.cdf((mu_sigs[y] - rhos * mu_sigs[x]) / rho_cs)

    mu_sigs_sq = mu_sigs ** 2
    r_num = mu_sigs_sq[x] + mu_sigs_sq[y] - 2 * rhos * mu_sigs[x] * mu_sigs[y]
    np.fill_diagonal(r_num, 1)  # don't want slightly negative numerator here
    r = rho_cs / np.sqrt(2 * np.pi) * stats.norm.pdf(np.sqrt(r_num) / rho_cs)

    bit = mu[y] * sigmas[x] * q[x] * R
    cov = (
        (mu[x] * mu[y] + Sigma) * prob
        + bit + bit.T
        + sigmas[x] * sigmas[y] * r
        - mean[x] * mean[y])

    cov[range(d), range(d)] = (
        Q * (1 - Q) * mu**2 + (1 - 2 * Q) * q * mu * sigmas
        + (Q - q**2) * sigma2s)

    return mean, cov

y una prueba de Monte Carlo de que funciona:

np.random.seed(12)
d = 4
mu = np.random.randn(d)
L = np.random.randn(d, d)
Sigma = L.T.dot(L)
dist = stats.multivariate_normal(mu, Sigma)

mn, cov = relu_mvn_mean_cov(mu, Sigma)

samps = dist.rvs(10**7)
mn_est = samps.mean(axis=0)
cov_est = np.cov(samps, rowvar=False)
print(np.max(np.abs(mn - mn_est)), np.max(np.abs(cov - cov_est)))

lo que da 0.000572145310512 0.00298692620286, lo que indica que la expectativa y la covarianza alegadas coinciden con las estimaciones de Monte Carlo (basadas en muestras).10,000,000

Dougal
fuente
¿Puedes resumir cuáles son esos valores finales? ¿Son estimaciones de los parámetros mu y L que generó? Tal vez imprimir esos valores objetivo?
AdamO
No, los valores de retorno son y ; Lo que imprimí fue la distancia entre los estimadores de Monte Carlo de esas cantidades y el valor calculado. Tal vez podría invertir estas expresiones para obtener un estimador de coincidencia de momentos para y ; Rosenbaum en realidad lo hace en su sección 3 en el caso truncado, pero eso no es lo que quería aquí. \E(Z+)\Cov(Z+)LμΣ
Dougal