¿Puedo extraer las reglas de decisión subyacentes (o 'rutas de decisión') de un árbol entrenado en un árbol de decisión como una lista de texto?
Algo como:
if A>0.4 then if B<0.2 then if C>0.8 then class='X'
Gracias por tu ayuda.
python
machine-learning
scikit-learn
decision-tree
random-forest
Dror Hilman
fuente
fuente
Respuestas:
Creo que esta respuesta es más correcta que las otras respuestas aquí:
Esto imprime una función Python válida. Aquí hay un ejemplo de salida para un árbol que está tratando de devolver su entrada, un número entre 0 y 10.
Aquí hay algunos obstáculos que veo en otras respuestas:
tree_.threshold == -2
para decidir si un nodo es una hoja no es una buena idea. ¿Qué pasa si es un nodo de decisión real con un umbral de -2? En cambio, debe mirartree.feature
otree.children_*
.features = [feature_names[i] for i in tree_.feature]
bloquea con mi versión de sklearn, porque algunos valores detree.tree_.feature
son -2 (específicamente para los nodos hoja).fuente
print "{}return {}".format(indent, tree_.value[node])
debe cambiarse aprint "{}return {}".format(indent, np.argmax(tree_.value[node][0]))
para que la función devuelva el índice de clase.RandomForestClassifier.estimators_
, pero no pude averiguar cómo combinar los resultados de los estimadores.print "bla"
=>print("bla")
Creé mi propia función para extraer las reglas de los árboles de decisión creados por sklearn:
Esta función comienza primero con los nodos (identificados por -1 en las matrices secundarias) y luego encuentra recursivamente a los padres. A esto lo llamo el "linaje" de un nodo. En el camino, tomo los valores que necesito para crear la lógica SAS if / then / else:
Los conjuntos de tuplas a continuación contienen todo lo que necesito para crear sentencias SAS if / then / else. No me gusta usar
do
bloques en SAS, por eso creo lógica que describe la ruta completa de un nodo. El número entero único después de las tuplas es la ID del nodo terminal en una ruta. Todas las tuplas anteriores se combinan para crear ese nodo.fuente
(0.5, 2.5]
. Los árboles están hechos con particiones recursivas. No hay nada que impida que una variable se seleccione varias veces.Modifiqué el código enviado por Zelazny7 para imprimir un pseudocódigo:
si llama
get_code(dt, df.columns)
al mismo ejemplo obtendrá:fuente
(threshold[node] != -2)
a( left[node] != -1)
(similar al método siguiente para obtener identificadores de nodos secundarios)Scikit Learn introdujo un nuevo método delicioso llamado
export_text
en la versión 0.21 (mayo de 2019) para extraer las reglas de un árbol. Documentación aquí . Ya no es necesario crear una función personalizada.Una vez que haya ajustado su modelo, solo necesita dos líneas de código. Primero, importa
export_text
:Segundo, crea un objeto que contendrá tus reglas. Para que las reglas se vean más legibles, use el
feature_names
argumento y pase una lista de los nombres de sus características. Por ejemplo, si se llama a su modelomodel
y sus características se nombran en un marco de datos llamadoX_train
, puede crear un objeto llamadotree_rules
:Luego simplemente imprima o guarde
tree_rules
. Su salida se verá así:fuente
Hay un nuevo
DecisionTreeClassifier
métododecision_path
, en el 0.18.0 versión . Los desarrolladores proporcionan un tutorial extenso (bien documentado) .La primera sección de código en el tutorial que imprime la estructura de árbol parece estar bien. Sin embargo, modifiqué el código en la segunda sección para interrogar una muestra. Mis cambios denotados con
# <--
Editar Los cambios marcados
# <--
en el código a continuación se han actualizado en el enlace de recorrido después de que se señalaron los errores en las solicitudes de extracción # 8653 y # 10951 . Es mucho más fácil seguirlo ahora.Cambie
sample_id
para ver las rutas de decisión para otras muestras. No he preguntado a los desarrolladores acerca de estos cambios, simplemente me pareció más intuitivo al trabajar con el ejemplo.fuente
Puedes ver un árbol de dígrafo. Entonces,
clf.tree_.feature
yclf.tree_.value
son una matriz de nodos que dividen la función y la matriz de valores de nodos respectivamente. Puede consultar más detalles de esta fuente de github .fuente
Solo porque todos fueron muy útiles, solo agregaré una modificación a Zelazny7 y las hermosas soluciones de Daniele. Este es para Python 2.7, con pestañas para hacerlo más legible:
fuente
Los códigos a continuación son mi enfoque bajo anaconda python 2.7 más un nombre de paquete "pydot-ng" para hacer un archivo PDF con reglas de decisión. Espero que sea útil.
un gráfico de árbol que se muestra aquí
fuente
He estado pasando por esto, pero necesitaba que las reglas se escribieran en este formato
Así que adapté la respuesta de @paulkernfeld (gracias) que puedes personalizar según tus necesidades
fuente
Aquí hay una manera de traducir todo el árbol en una sola expresión de Python (no necesariamente legible para humanos) utilizando la biblioteca SKompiler :
fuente
Esto se basa en la respuesta de @paulkernfeld. Si tiene un marco de datos X con sus características y un marco de datos de destino y con sus resones y desea hacerse una idea de qué valor de y terminó en qué nodo (y también hormiga para trazarlo en consecuencia), puede hacer lo siguiente:
no es la versión más elegante pero hace el trabajo ...
fuente
Este es el código que necesitas
He modificado el código que más me gustó para sangrar en un jupyter notebook python 3 correctamente
fuente
Aquí hay una función, imprimir reglas de un árbol de decisión de scikit-learn en python 3 y con compensaciones para bloques condicionales para hacer que la estructura sea más legible:
fuente
También puede hacerlo más informativo distinguiéndolo a qué clase pertenece o incluso mencionando su valor de salida.
fuente
Aquí está mi enfoque para extraer las reglas de decisión en una forma que se pueda usar directamente en sql, para que los datos se puedan agrupar por nodo. (Basado en los enfoques de los carteles anteriores).
El resultado serán
CASE
cláusulas posteriores que se pueden copiar en una instrucción sql, ej.SELECT COALESCE(*CASE WHEN <conditions> THEN > <NodeA>*, > *CASE WHEN <conditions> THEN <NodeB>*, > ....)NodeName,* > FROM <table or view>
fuente
Ahora puede usar export_text.
Un ejemplo completo de [sklearn] [1]
fuente
Se modificó el código de Zelazny7 para obtener SQL del árbol de decisión.
fuente
Aparentemente, hace mucho tiempo, alguien ya decidió intentar agregar la siguiente función a las funciones de exportación del árbol de scikit oficial (que básicamente solo admite export_graphviz)
Aquí está su compromiso completo:
https://github.com/scikit-learn/scikit-learn/blob/79bdc8f711d0af225ed6be9fdb708cea9f98a910/sklearn/tree/export.py
No estoy seguro de lo que sucedió con este comentario. Pero también podría intentar usar esa función.
Creo que esto garantiza una solicitud de documentación seria a las buenas personas de scikit-learn para documentar adecuadamente la
sklearn.tree.Tree
API, que es la estructura de árbol subyacente que seDecisionTreeClassifier
expone como su atributotree_
.fuente
Simplemente use la función de sklearn.tree como esta
Y luego busque en la carpeta de su proyecto el archivo tree.dot , copie TODO el contenido y péguelo aquí http://www.webgraphviz.com/ y genere su gráfico :)
fuente
Gracias por la maravillosa solución de @paulkerfeld. En la parte superior de su solución, para todos aquellos que quieren tener una versión serializada de árboles, sólo tiene que utilizar
tree.threshold
,tree.children_left
,tree.children_right
,tree.feature
ytree.value
. Dado que las hojas no tienen divisiones y, por lo tanto, no tienen nombres de características y elementos secundarios, su marcador de posición entree.feature
ytree.children_***
son_tree.TREE_UNDEFINED
y_tree.TREE_LEAF
. A cada división se le asigna un índice único pordepth first search
.Tenga en cuenta que
tree.value
es de forma[n, 1, 1]
fuente
Aquí hay una función que genera código Python a partir de un árbol de decisión al convertir la salida de
export_text
:Uso de la muestra:
Salida de muestra:
El ejemplo anterior se genera con
names = ['f'+str(j+1) for j in range(NUM_FEATURES)]
.Una característica útil es que puede generar un tamaño de archivo más pequeño con un espacio reducido. Solo establece
spacing=2
.fuente