Estoy experimentando con el algoritmo de la máquina de aumento de gradiente a través del caret
paquete en R.
Usando un pequeño conjunto de datos de admisión a la universidad, ejecuté el siguiente código:
library(caret)
### Load admissions dataset. ###
mydata <- read.csv("http://www.ats.ucla.edu/stat/data/binary.csv")
### Create yes/no levels for admission. ###
mydata$admit_factor[mydata$admit==0] <- "no"
mydata$admit_factor[mydata$admit==1] <- "yes"
### Gradient boosting machine algorithm. ###
set.seed(123)
fitControl <- trainControl(method = 'cv', number = 5, summaryFunction=defaultSummary)
grid <- expand.grid(n.trees = seq(5000,1000000,5000), interaction.depth = 2, shrinkage = .001, n.minobsinnode = 20)
fit.gbm <- train(as.factor(admit_factor) ~ . - admit, data=mydata, method = 'gbm', trControl=fitControl, tuneGrid=grid, metric='Accuracy')
plot(fit.gbm)
y encontré para mi sorpresa que la precisión de validación cruzada del modelo disminuyó en lugar de aumentar a medida que aumentó el número de iteraciones de refuerzo, alcanzando una precisión mínima de aproximadamente .59 a ~ 450,000 iteraciones.
¿Implementé incorrectamente el algoritmo GBM?
EDITAR: Siguiendo la sugerencia de Underminer, volví a ejecutar el caret
código anterior pero me concentré en ejecutar de 100 a 5,000 iteraciones de refuerzo:
set.seed(123)
fitControl <- trainControl(method = 'cv', number = 5, summaryFunction=defaultSummary)
grid <- expand.grid(n.trees = seq(100,5000,100), interaction.depth = 2, shrinkage = .001, n.minobsinnode = 20)
fit.gbm <- train(as.factor(admit_factor) ~ . - admit, data=mydata, method = 'gbm', trControl=fitControl, tuneGrid=grid, metric='Accuracy')
plot(fit.gbm)
El gráfico resultante muestra que la precisión realmente alcanza su punto máximo en casi .705 a ~ 1,800 iteraciones:
Lo curioso es que la precisión no se estabilizó en ~ .70, sino que disminuyó después de 5.000 iteraciones.
fuente
Códigos para reproducir un resultado similar, sin búsqueda de cuadrícula,
fuente
El paquete gbm tiene una función para estimar el número óptimo de iteraciones (= # de árboles, o # de funciones básicas),
No necesitas el tren de Caret para eso.
fuente