Probabilidad marginal de la salida de Gibbs

13

Estoy reproduciendo desde cero los resultados en la Sección 4.2.1 de

Probabilidad marginal de la salida de Gibbs

Siddhartha Chib

Revista de la Asociación Americana de Estadística, vol. 90, núm. 432. (diciembre de 1995), págs. 1313-1321.

Es una mezcla de modelos normales con el número conocido de componentes. k1

f(xw,μ,σ2)=i=1nj=1kN(xiμj,σj2).()

La muestra de Gibbs para este modelo se implementa utilizando la técnica de aumento de datos de Tanner y Wong. Un conjunto de variables de asignación suponiendo que los valores 1, \ dots, k se introduce, y se especifica que \ Pr (Z_I = j \ mediados w) = w_j y f (x_i \ mediados z , \ mu, \ sigma ^ 2) = \ mathrm {N} (x_i \ mid \ mu_ {z_i}, \ sigma ^ 2_ {z_i}) . De ello se deduce que la integración sobre el z_i 's da la probabilidad original (*) .z=(z1,,zn)1,,kPr(zi=jw)=wjz i ( )f(xiz,μ,σ2)=N(xiμzi,σzi2)zi()

El conjunto de datos está formado por velocidades de 82 galaxias de la constelación Corona Borealis.

set.seed(1701)

x <- c(  9.172,  9.350,  9.483,  9.558,  9.775, 10.227, 10.406, 16.084, 16.170, 18.419, 18.552, 18.600, 18.927,
        19.052, 19.070, 19.330, 19.343, 19.349, 19.440, 19.473, 19.529, 19.541, 19.547, 19.663, 19.846, 19.856,
        19.863, 19.914, 19.918, 19.973, 19.989, 20.166, 20.175, 20.179, 20.196, 20.215, 20.221, 20.415, 20.629,
        20.795, 20.821, 20.846, 20.875, 20.986, 21.137, 21.492, 21.701, 21.814, 21.921, 21.960, 22.185, 22.209,
        22.242, 22.249, 22.314, 22.374, 22.495, 22.746, 22.747, 22.888, 22.914, 23.206, 23.241, 23.263, 23.484,
        23.538, 23.542, 23.666, 23.706, 23.711, 24.129, 24.285, 24.289, 24.366, 24.717, 24.990, 25.633, 26.960,
        26.995, 32.065, 32.789, 34.279 )

nn <- length(x)

Asumimos que , 's y ' s son independientes a priori con μ j σ 2 j ( w 1 , , w k ) D i r ( a 1 , , a k )wμjσj2

(w1,,wk)Dir(a1,,ak),μjN(μ0,σ02),σj2IG(ν02,δ02).
k <- 3

mu0 <- 20
va0 <- 100

nu0 <- 6
de0 <- 40

a <- rep(1, k)

Usando el teorema de Bayes, los condicionales completos son en el que con

wμ,σ2,z,xDir(a1+n1,,ak+nk)μjw,σ2,z,xN(njmjσ02+μ0σj2njσ02+σj2,σ02σj2njσ02+σj2)σj2w,μ,z,xIG(ν0+nj2,δ0+δj2)Pr(zi=jw,μ,σ2,x)wj×1σje(xiμj)2/2σj2
nj=|Lj|,mj={1njiLjxiifnj>00otherwise.,δj=iLj(xiμj)2,
Lj={i{1,,n}:zi=j} .

El objetivo es calcular una estimación de la probabilidad marginal del modelo. El método de Chib comienza con una primera ejecución de la muestra de Gibbs utilizando los condicionales completos.

burn_in <- 1000
run     <- 15000

cat("First Gibbs run (full):\n")

N <- burn_in + run

w  <- matrix(1, nrow = N, ncol = k)
mu <- matrix(0, nrow = N, ncol = k)
va <- matrix(1, nrow = N, ncol = k)
z  <- matrix(1, nrow = N, ncol = nn)

n <- integer(k)
m <- numeric(k)
de <- numeric(k)

rdirichlet <- function(a) { y <- rgamma(length(a), a, 1); y / sum(y) }

pb <- txtProgressBar(min = 2, max = N, style = 3)
z[1,] <- sample.int(k, size = nn, replace = TRUE)
for (t in 2:N) {
    n <- tabulate(z[t-1,], nbins = k)
    w[t,] <- rdirichlet(a + n)
    m <- sapply(1:k, function(j) sum(x[z[t-1,]==j]))
    m[n > 0] <- m[n > 0] / n[n > 0]
    mu[t,] <- rnorm(k, mean = (n*m*va0+mu0*va[t-1,])/(n*va0+va[t-1,]), sd = sqrt(va0*va[t-1,]/(n*va0+va[t-1,])))
    de <- sapply(1:k, function(j) sum((x[z[t-1,]==j] - mu[t,j])^2))
    va[t,] <- 1 / rgamma(k, shape = (nu0+n)/2, rate = (de0+de)/2)
    z[t,] <- sapply(1:nn, function(i) sample.int(k, size = 1, prob = exp(log(w[t,]) + dnorm(x[i], mean = mu[t,], sd = sqrt(va[t,]), log = TRUE))))
    setTxtProgressBar(pb, t)
}
close(pb)

De esta primera ejecución obtenemos un punto aproximado de máxima probabilidad. Dado que la probabilidad es realmente ilimitada, lo que probablemente da este procedimiento es un MAP local aproximado.(w,μ,σ2)

w  <- w[(burn_in+1):N,]
mu <- mu[(burn_in+1):N,]
va <- va[(burn_in+1):N,]
z  <- z[(burn_in+1):N,]
N  <- N - burn_in

log_L <- function(x, w, mu, va) sum(log(sapply(1:nn, function(i) sum(exp(log(w) + dnorm(x[i], mean = mu, sd = sqrt(va), log = TRUE))))))

ts <- which.max(sapply(1:N, function(t) log_L(x, w[t,], mu[t,], va[t,])))

ws <- w[ts,]
mus <- mu[ts,]
vas <- va[ts,]

La estimación de registro de Chib de la probabilidad marginal es

logf(x)^=logLx(w,μ,σ2)+logπ(w,μ,σ2)logπ(μx)logπ(σ2μ,x)logπ(wμ,σ2,x).

Ya tenemos los dos primeros términos.

log_prior <- function(w, mu, va) {
    lgamma(sum(a)) - sum(lgamma(a)) + sum((a-1)*log(w))
    + sum(dnorm(mu, mean = mu0, sd = sqrt(va0), log = TRUE))
    + sum((nu0/2)*log(de0/2) - lgamma(nu0/2) - (nu0/2+1)*log(va) - de0/(2*va))
}

chib <- log_L(x, ws, mus, vas) + log_prior(ws, mus, vas)

La estimación Rao-Blackwellized de es y se obtiene fácilmente desde la primera carrera de Gibbs.π(μx)

π(μx)=j=1kN(μj|njmjσ02+μ0σj2njσ02+σj2,σ02σj2njσ02+σj2)p(σ2,zx)dσ2dz,
pi.mu_va.z.x <- function(mu, va, z) {
    n <- tabulate(z, nbins = k)
    m <- sapply(1:k, function(j) sum(x[z==j]))
    m[n > 0] <- m[n > 0] / n[n > 0]
    exp(sum(dnorm(mu, mean = (n*m*va0+mu0*va)/(n*va0+va), sd = sqrt(va0*va/(n*va0+va)), log = TRUE)))
}

chib <- chib - log(mean(sapply(1:N, function(t) pi.mu_va.z.x(mus, va[t,], z[t,]))))

La estimación Rao-Blackwellized de es y se calcula a partir de una segunda ejecución reducida de Gibbs en la que los no se actualizan, sino que se hacen igual a en cada paso de iteración.π(σ2μ,x)

π(σ2μ,x)=j=1kIG(σj2|ν0+nj2,δ0+δj2)p(zμ,x)dz,
μjμj
cat("Second Gibbs run (reduced):\n")

N <- burn_in + run

w  <- matrix(1, nrow = N, ncol = k)
va <- matrix(1, nrow = N, ncol = k)
z  <- matrix(1, nrow = N, ncol = nn) 

pb <- txtProgressBar(min = 2, max = N, style = 3)
z[1,] <- sample.int(k, size = nn, replace = TRUE)
for (t in 2:N) {
    n <- tabulate(z[t-1,], nbins = k)
    w[t,] <- rdirichlet(a + n)
    de <- sapply(1:k, function(j) sum((x[z[t-1,]==j] - mus[j])^2))
    va[t,] <- 1 / rgamma(k, shape = (nu0+n)/2, rate = (de0+de)/2)
    z[t,] <- sapply(1:nn, function(i) sample.int(k, size = 1, prob = exp(log(w[t,]) + dnorm(x[i], mean = mus, sd = sqrt(va[t,]), log = TRUE))))
    setTxtProgressBar(pb, t)
}
close(pb)

w  <- w[(burn_in+1):N,]
va <- va[(burn_in+1):N,]
z  <- z[(burn_in+1):N,]
N  <- N - burn_in

pi.va_mu.z.x <- function(va, mu, z) {
    n <- tabulate(z, nbins = k)         
    de <- sapply(1:k, function(j) sum((x[z==j] - mu[j])^2))
    exp(sum(((nu0+n)/2)*log((de0+de)/2) - lgamma((nu0+n)/2) - ((nu0+n)/2+1)*log(va) - (de0+de)/(2*va)))
}

chib <- chib - log(mean(sapply(1:N, function(t) pi.va_mu.z.x(vas, mus, z[t,]))))

Del mismo modo, la estimación Rao-Blackwellized de es y se calcula a partir de una tercera ejecución reducida de Gibbs en la que los 's y los ' s no se actualizan, sino que se igualan a y respectivamente en cada paso de iteración.π(wμ,σ2,x)μ j σ 2 j μ j σ 2 j

π(wμ,σ2,x)=Dir(wa1+n1,,ak+nk)p(zμ,σ2,x)dz,
μjσj2μjσj2
cat("Third Gibbs run (reduced):\n")

N <- burn_in + run

w  <- matrix(1, nrow = N, ncol = k)
z  <- matrix(1, nrow = N, ncol = nn) 

pb <- txtProgressBar(min = 2, max = N, style = 3)
z[1,] <- sample.int(k, size = nn, replace = TRUE)
for (t in 2:N) {
    n <- tabulate(z[t-1,], nbins = k)
    w[t,] <- rdirichlet(a + n)
    z[t,] <- sapply(1:nn, function(i) sample.int(k, size = 1, prob = exp(log(w[t,]) + dnorm(x[i], mean = mus, sd = sqrt(vas), log = TRUE))))
    setTxtProgressBar(pb, t)
}
close(pb)

w  <- w[(burn_in+1):N,]
z  <- z[(burn_in+1):N,]
N  <- N - burn_in

pi.w_z.x <- function(w, z) {
    n <- tabulate(z, nbins = k)
    exp(lgamma(sum(a+n)) - sum(lgamma(a+n)) + sum((a+n-1)*log(w)))
}

chib <- chib - log(mean(sapply(1:N, function(t) pi.w_z.x(ws, z[t,]))))

Después de todo esto, obtenemos una estimación que es más grande que la informada por Chib: con el error de Monte Carlo .- 224.138 .086217.9199224.138.086

Para comprobar si de alguna manera estropeé los muestreadores Gibbs, reimplementé todo usando RJAGS. El siguiente código da los mismos resultados.

x <- c( 9.172,  9.350,  9.483,  9.558,  9.775, 10.227, 10.406, 16.084, 16.170, 18.419, 18.552, 18.600, 18.927, 19.052, 19.070, 19.330,
       19.343, 19.349, 19.440, 19.473, 19.529, 19.541, 19.547, 19.663, 19.846, 19.856, 19.863, 19.914, 19.918, 19.973, 19.989, 20.166,
       20.175, 20.179, 20.196, 20.215, 20.221, 20.415, 20.629, 20.795, 20.821, 20.846, 20.875, 20.986, 21.137, 21.492, 21.701, 21.814,
       21.921, 21.960, 22.185, 22.209, 22.242, 22.249, 22.314, 22.374, 22.495, 22.746, 22.747, 22.888, 22.914, 23.206, 23.241, 23.263,
       23.484, 23.538, 23.542, 23.666, 23.706, 23.711, 24.129, 24.285, 24.289, 24.366, 24.717, 24.990, 25.633, 26.960, 26.995, 32.065,
       32.789, 34.279 )

library(rjags)

nn <- length(x)

k <- 3

mu0 <- 20
va0 <- 100

nu0 <- 6
de0 <- 40

a <- rep(1, k)

burn_in <- 10^3

N <- 10^4

full <- "
    model {
        for (i in 1:n) {
            x[i] ~ dnorm(mu[z[i]], tau[z[i]])
            z[i] ~ dcat(w[])
        }
        for (i in 1:k) {
            mu[i] ~ dnorm(mu0, 1/va0)
            tau[i] ~ dgamma(nu0/2, de0/2)
            va[i] <- 1/tau[i]
        }
        w ~ ddirich(a)
    }
"
data <- list(x = x, n = nn, k = k, mu0 = mu0, va0 = va0, nu0 = nu0, de0 = de0, a = a)
model <- jags.model(textConnection(full), data = data, n.chains = 1, n.adapt = 100)
update(model, n.iter = burn_in)
samples <- jags.samples(model, c("mu", "va", "w", "z"), n.iter = N)

mu <- matrix(samples$mu, nrow = N, byrow = TRUE)
    va <- matrix(samples$va, nrow = N, byrow = TRUE)
w <- matrix(samples$w, nrow = N, byrow = TRUE)
    z <- matrix(samples$z, nrow = N, byrow = TRUE)

log_L <- function(x, w, mu, va) sum(log(sapply(1:nn, function(i) sum(exp(log(w) + dnorm(x[i], mean = mu, sd = sqrt(va), log = TRUE))))))

ts <- which.max(sapply(1:N, function(t) log_L(x, w[t,], mu[t,], va[t,])))

ws <- w[ts,]
mus <- mu[ts,]
vas <- va[ts,]

log_prior <- function(w, mu, va) {
    lgamma(sum(a)) - sum(lgamma(a)) + sum((a-1)*log(w))
    + sum(dnorm(mu, mean = mu0, sd = sqrt(va0), log = TRUE))
    + sum((nu0/2)*log(de0/2) - lgamma(nu0/2) - (nu0/2+1)*log(va) - de0/(2*va))
}

chib <- log_L(x, ws, mus, vas) + log_prior(ws, mus, vas)

cat("log-likelihood + log-prior =", chib, "\n")

pi.mu_va.z.x <- function(mu, va, z, x) {
    n <- sapply(1:k, function(j) sum(z==j))
    m <- sapply(1:k, function(j) sum(x[z==j]))
    m[n > 0] <- m[n > 0] / n[n > 0]
    exp(sum(dnorm(mu, mean = (n*m*va0+mu0*va)/(n*va0+va), sd = sqrt(va0*va/(n*va0+va)), log = TRUE)))
}

chib <- chib - log(mean(sapply(1:N, function(t) pi.mu_va.z.x(mus, va[t,], z[t,], x))))

cat("log-likelihood + log-prior - log-pi.mu_ =", chib, "\n")

fixed.mu <- "
    model {
        for (i in 1:n) {
            x[i] ~ dnorm(mus[z[i]], tau[z[i]])
            z[i] ~ dcat(w[])
        }
        for (i in 1:k) {
            tau[i] ~ dgamma(nu0/2, de0/2)
            va[i] <- 1/tau[i]
        }
        w ~ ddirich(a)
    }
"
data <- list(x = x, n = nn, k = k, nu0 = nu0, de0 = de0, a = a, mus = mus)
model <- jags.model(textConnection(fixed.mu), data = data, n.chains = 1, n.adapt = 100)
update(model, n.iter = burn_in)
samples <- jags.samples(model, c("va", "w", "z"), n.iter = N)

va <- matrix(samples$va, nrow = N, byrow = TRUE)
    w <- matrix(samples$w, nrow = N, byrow = TRUE)
z <- matrix(samples$z, nrow = N, byrow = TRUE)

pi.va_mu.z.x <- function(va, mu, z, x) {
    n <- sapply(1:k, function(j) sum(z==j))
    de <- sapply(1:k, function(j) sum((x[z==j] - mu[j])^2))
    exp(sum(((nu0+n)/2)*log((de0+de)/2) - lgamma((nu0+n)/2) - ((nu0+n)/2+1)*log(va) - (de0+de)/(2*va)))
}

chib <- chib - log(mean(sapply(1:N, function(t) pi.va_mu.z.x(vas, mus, z[t,], x))))

cat("log-likelihood + log-prior - log-pi.mu_ - log-pi.va_ =", chib, "\n")

fixed.mu.and.va <- "
    model {
        for (i in 1:n) {
            x[i] ~ dnorm(mus[z[i]], 1/vas[z[i]])
            z[i] ~ dcat(w[])
        }
        w ~ ddirich(a)
    }
"
data <- list(x = x, n = nn, a = a, mus = mus, vas = vas)
model <- jags.model(textConnection(fixed.mu.and.va), data = data, n.chains = 1, n.adapt = 100)
update(model, n.iter = burn_in)
samples <- jags.samples(model, c("w", "z"), n.iter = N)

w <- matrix(samples$w, nrow = N, byrow = TRUE)
    z <- matrix(samples$z, nrow = N, byrow = TRUE)

pi.w_z.x <- function(w, z, x) {
    n <- sapply(1:k, function(j) sum(z==j))
    exp(lgamma(sum(a)+nn) - sum(lgamma(a+n)) + sum((a+n-1)*log(w)))
}

chib <- chib - log(mean(sapply(1:N, function(t) pi.w_z.x(ws, z[t,], x))))

cat("log-likelihood + log-prior - log-pi.mu_ - log-pi.va_ - log-pi.w_ =", chib, "\n")

Mi pregunta es si en la descripción anterior hay algún malentendido del método de Chib o algún error en su implementación.

zen
fuente
1
Ejecutando la simulación 100 veces, los resultados están en el rango . [218.7655;216.8824]
Zen

Respuestas:

6

Hay un pequeño error de programación en el anterior

log_prior <- function(w, mu, va) {
    lgamma(sum(a)) - sum(lgamma(a)) + sum((a-1)*log(w))
    + sum(dnorm(mu, mean = mu0, sd = sqrt(va0), log = TRUE))
    + sum((nu0/2)*log(de0/2) - lgamma(nu0/2) - (nu0/2+1)*log(va) - de0/(2*va))
}

como debería ser en su lugar

log_prior <- function(w, mu, va) {
    lgamma(sum(a)) - sum(lgamma(a)) + sum((a-1)*log(w)) +
      sum(dnorm(mu, mean = mu0, sd = sqrt(va0), log = TRUE)) +
      sum((nu0/2)*log(de0/2) - lgamma(nu0/2) - (nu0/2+1)*log(va) - de0/(2*va))
}

Volver a ejecutar el código de esta manera conduce a

> chib
[1] -228.194

¡cuál no es el valor producido en Chib (1995) para ese caso! Sin embargo, en el reanálisis del problema de Neal (1999), menciona que

Según un árbitro anónimo de JASA, la cifra de -224.138 para el registro de la probabilidad marginal para el modelo de tres componentes con variaciones desiguales que se dio en el documento de Chib es un "error tipográfico" con la cifra correcta siendo -228.608.

Entonces esto resuelve el problema de discrepancia.

Xi'an
fuente
2
Prof. Christian Robert y Kate Lee: ¿sabes lo genial que eres?
Zen
2
Por cierto, este es definitivamente un ejemplo de "sintaxis malvada". No voy a olvidar este.
Zen