Parámetro "estratificar" del método "train_test_split" (scikit Learn)

94

Estoy intentando usar el train_test_splitpaquete scikit Learn, pero tengo problemas con el parámetro stratify. A continuación está el código:

from sklearn import cross_validation, datasets 

X = iris.data[:,:2]
y = iris.target

cross_validation.train_test_split(X,y,stratify=y)

Sin embargo, sigo teniendo el siguiente problema:

raise TypeError("Invalid parameters passed: %s" % str(options))
TypeError: Invalid parameters passed: {'stratify': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])}

¿Alguien tiene una idea de lo que está pasando? A continuación se muestra la documentación de la función.

[...]

estratificar : similar a una matriz o Ninguno (el valor predeterminado es Ninguno)

Si no es Ninguno, los datos se dividen de forma estratificada, utilizando esto como matriz de etiquetas.

Nuevo en la versión 0.17: estratificar la división

[...]

Daneel Olivaw
fuente
No, todo resuelto.
Daneel Olivaw

Respuestas:

58

Scikit-Learn solo le dice que no reconoce el argumento "estratificar", no que lo esté usando incorrectamente. Esto se debe a que el parámetro se agregó en la versión 0.17 como se indica en la documentación que citó.

Así que solo necesitas actualizar Scikit-Learn.

Borja
fuente
Recibo el mismo error, aunque tengo la versión 0.21.2 de scikit-learn. scikit-learn 0.21.2 py37h2a6a0b8_0 conda-forge
Kareem Jeiroudi
326

Este stratifyparámetro realiza una división para que la proporción de valores en la muestra producida sea la misma que la proporción de valores proporcionados al parámetro stratify.

Por ejemplo, si la variable yes una variable categórica binaria con valores 0y 1hay un 25% de ceros y un 75% de unos, stratify=yse asegurará de que su división aleatoria tenga un 25% de 0y un 75% de 1.

Fazzolini
fuente
117
Esto realmente no responde a la pregunta, pero es muy útil para entender cómo funciona. Gracias una tonelada.
Reed Jessen
6
Todavía me cuesta entender por qué es necesaria esta estratificación: si hay un equilibrio de clases en los datos, ¿no se conservaría en promedio cuando se hace una división aleatoria de los datos?
Holger Brandl
14
@HolgerBrandl se conservará en promedio; con estratificar, seguro que se conservará.
Yonatan
7
@HolgerBrandl con conjuntos de datos muy pequeños o muy desequilibrados, es muy posible que la división aleatoria pueda eliminar por completo una clase de una de las divisiones.
cddt
1
@HolgerBrandl ¡Buena pregunta! Tal vez podríamos agregar eso primero, tienes que dividirlo en entrenamiento y prueba usando stratify. Luego, en segundo lugar, para corregir el desequilibrio, eventualmente tendrá que realizar un muestreo excesivo o insuficiente en el conjunto de entrenamiento. Muchos clasificadores de Sklearn tienen un parámetro llamado class-weight que puede establecer en balanceado. Finalmente, también podría tomar una métrica más apropiada que la precisión para un conjunto de datos desequilibrado. Pruebe, F1 o área debajo de ROC.
Claude COULOMBE
62

Para mi yo futuro que viene aquí a través de Google:

train_test_splitestá ahora en model_selection, por lo tanto:

from sklearn.model_selection import train_test_split

# given:
# features: xs
# ground truth: ys

x_train, x_test, y_train, y_test = train_test_split(xs, ys,
                                                    test_size=0.33,
                                                    random_state=0,
                                                    stratify=ys)

es la forma de utilizarlo. Establecer el random_statees deseable para la reproducibilidad.

Martín Thoma
fuente
Esta debería ser la respuesta :) Gracias
SwimBikeRun
15

En este contexto, la estratificación significa que el método train_test_split devuelve subconjuntos de entrenamiento y prueba que tienen las mismas proporciones de etiquetas de clase que el conjunto de datos de entrada.

X. Wang
fuente
3

Intente ejecutar este código, "simplemente funciona":

from sklearn import cross_validation, datasets 

iris = datasets.load_iris()

X = iris.data[:,:2]
y = iris.target

x_train, x_test, y_train, y_test = cross_validation.train_test_split(X,y,train_size=.8, stratify=y)

y_test

array([0, 0, 0, 0, 2, 2, 1, 0, 1, 2, 2, 0, 0, 1, 0, 1, 1, 2, 1, 2, 0, 2, 2,
       1, 2, 1, 1, 0, 2, 1])
Sergey Bushmanov
fuente
@ user5767535 Como puede ver, está funcionando en mi máquina Ubuntu, con sklearnla versión '0.17', distribución Anaconda para Python 3,5. Solo puedo sugerir que verifique una vez más si ingresa el código correctamente y actualiza su software.
Sergey Bushmanov
2
@ user5767535 Por cierto, "Nuevo en la versión 0.17: estratificar división" me hace casi seguro que tienes que actualizar tu sklearn...
Sergey Bushmanov