Trazar la matriz de correlación usando pandas

212

Tengo un conjunto de datos con una gran cantidad de características, por lo que analizar la matriz de correlación se ha vuelto muy difícil. Quiero trazar una matriz de correlación que obtenemos usando la dataframe.corr()función de la biblioteca de pandas. ¿Hay alguna función incorporada proporcionada por la biblioteca de pandas para trazar esta matriz?

Gaurav Singh
fuente
Puede encontrar respuestas relacionadas aquí Cómo hacer un mapa de calor de pandas DataFrame
joelostblom

Respuestas:

293

Puedes usar pyplot.matshow() desde matplotlib:

import matplotlib.pyplot as plt

plt.matshow(dataframe.corr())
plt.show()

Editar:

En los comentarios había una solicitud de cómo cambiar las etiquetas de marca del eje. Aquí hay una versión de lujo que se dibuja en un tamaño de figura más grande, tiene etiquetas de eje para que coincidan con el marco de datos y una leyenda de la barra de colores para interpretar la escala de colores.

Incluyo cómo ajustar el tamaño y la rotación de las etiquetas, y estoy usando una relación de figura que hace que la barra de colores y la figura principal salgan a la misma altura.

f = plt.figure(figsize=(19, 15))
plt.matshow(df.corr(), fignum=f.number)
plt.xticks(range(df.shape[1]), df.columns, fontsize=14, rotation=45)
plt.yticks(range(df.shape[1]), df.columns, fontsize=14)
cb = plt.colorbar()
cb.ax.tick_params(labelsize=14)
plt.title('Correlation Matrix', fontsize=16);

ejemplo de diagrama de correlación

jrjc
fuente
1
Debo estar perdiendo algo:AttributeError: 'module' object has no attribute 'matshow'
Tom Russell
1
@TomRussell ¿Lo hiciste import matplotlib.pyplot as plt?
joelostblom 05 de
1
¡Me gustaría pensar que lo hice! :-)
Tom Russell
77
¿Sabes cómo mostrar los nombres reales de las columnas en el gráfico?
WebQube
2
@Cecilia Había resuelto este asunto cambiando el parámetro de rotación a 90
ikbel benabdessamad el
182

Si su objetivo principal es visualizar la matriz de correlación, en lugar de crear un gráfico per se, las pandas opciones de estilo convenientes son una solución integrada viable:

import pandas as pd
import numpy as np

rs = np.random.RandomState(0)
df = pd.DataFrame(rs.rand(10, 10))
corr = df.corr()
corr.style.background_gradient(cmap='coolwarm')
# 'RdBu_r' & 'BrBG' are other good diverging colormaps

ingrese la descripción de la imagen aquí

Tenga en cuenta que esto debe estar en un back-end que admita la representación de HTML, como el JupyterLab Notebook. (El texto claro automático sobre fondos oscuros es de un RP existente y no de la última versión lanzada, pandas0.23).


Estilo

Puede limitar fácilmente la precisión de los dígitos:

corr.style.background_gradient(cmap='coolwarm').set_precision(2)

ingrese la descripción de la imagen aquí

O elimine los dígitos por completo si prefiere la matriz sin anotaciones:

corr.style.background_gradient(cmap='coolwarm').set_properties(**{'font-size': '0pt'})

ingrese la descripción de la imagen aquí

La documentación de estilo también incluye instrucciones de estilos más avanzados, como cómo cambiar la visualización de la celda sobre la que se mueve el puntero del mouse. Para guardar el resultado, puede devolver el HTML agregando el render()método y luego escribirlo en un archivo (o simplemente tomar una captura de pantalla para fines menos formales).


Comparación de tiempo

En mis pruebas, style.background_gradient()fue 4 veces más rápido plt.matshow()y 120 veces más rápido que sns.heatmap()con una matriz de 10x10. Desafortunadamente, no escala tan bien como plt.matshow(): los dos toman aproximadamente el mismo tiempo para una matriz de 100x100, y plt.matshow()es 10 veces más rápido para una matriz de 1000x1000.


Ahorro

Hay algunas formas posibles de guardar el marco de datos estilizado:

  • Devuelva el HTML agregando el render()método y luego escriba el resultado en un archivo.
  • Guarde como un .xslxarchivo con formato condicional agregando el to_excel()método.
  • Combinar con imgkit para guardar un mapa de bits
  • Tome una captura de pantalla (para fines menos formales).

Actualización para pandas> = 0.24

Al configurar axis=None, ahora es posible calcular los colores en función de toda la matriz en lugar de por columna o por fila:

corr.style.background_gradient(cmap='coolwarm', axis=None)

ingrese la descripción de la imagen aquí

joelostblom
fuente
2
Si hubiera una forma de exportar es como una imagen, ¡eso hubiera sido genial!
Kristada673
1
¡Gracias! Definitivamente necesita una paleta divergenteimport seaborn as sns corr = df.corr() cm = sns.light_palette("green", as_cmap=True) cm = sns.diverging_palette(220, 20, sep=20, as_cmap=True) corr.style.background_gradient(cmap=cm).set_precision(2)
pararse Un
1
@stallingOne Un buen punto, no debería haber incluido valores negativos en el ejemplo, podría cambiar eso más tarde. Solo como referencia para las personas que leen esto, no necesita crear un cmap divergente personalizado con seaborn (aunque el que se encuentra en el comentario anterior parece bastante elegante), también puede usar los cmaps divergentes integrados de matplotlib, por ejemplo corr.style.background_gradient(cmap='coolwarm'). Actualmente no hay forma de centrar el cmap en un valor específico, lo que puede ser una buena idea con cmaps divergentes.
joelostblom
1
@rovyko ¿Estás en pandas> = 0.24.0?
joelostblom
2
Estas tramas son visualmente geniales, pero la pregunta de @ Kristada673 es bastante relevante, ¿cómo las exportaría?
Erfan
89

Pruebe esta función, que también muestra nombres de variables para la matriz de correlación:

def plot_corr(df,size=10):
    '''Function plots a graphical correlation matrix for each pair of columns in the dataframe.

    Input:
        df: pandas DataFrame
        size: vertical and horizontal size of the plot'''

    corr = df.corr()
    fig, ax = plt.subplots(figsize=(size, size))
    ax.matshow(corr)
    plt.xticks(range(len(corr.columns)), corr.columns);
    plt.yticks(range(len(corr.columns)), corr.columns);
Apogentus
fuente
66
plt.xticks(range(len(corr.columns)), corr.columns, rotation='vertical')si desea orientación vertical de los nombres de columna en el eje x
nishant
Otra cosa gráfica, pero agregar un plt.tight_layout()también podría ser útil para nombres largos de columna.
user3017048
86

Versión de mapa de calor de Seaborn:

import seaborn as sns
corr = dataframe.corr()
sns.heatmap(corr, 
            xticklabels=corr.columns.values,
            yticklabels=corr.columns.values)
rafaelvalle
fuente
9
El mapa de calor de Seaborn es elegante pero funciona mal en matrices grandes. El método matshow de matplotlib es mucho más rápido.
anilbey
3
Seaborn puede inferir automáticamente las etiquetas de los nombres de columna.
Tulio Casagrande
80

Puede observar la relación entre las características, ya sea dibujando un mapa de calor de los mares marinos o una matriz de dispersión de los pandas.

Matriz de dispersión:

pd.scatter_matrix(dataframe, alpha = 0.3, figsize = (14,8), diagonal = 'kde');

Si también desea visualizar el sesgo de cada característica, use las parcelas nacidas en el mar.

sns.pairplot(dataframe)

Sns Heatmap:

import seaborn as sns

f, ax = pl.subplots(figsize=(10, 8))
corr = dataframe.corr()
sns.heatmap(corr, mask=np.zeros_like(corr, dtype=np.bool), cmap=sns.diverging_palette(220, 10, as_cmap=True),
            square=True, ax=ax)

El resultado será un mapa de correlación de las características. es decir, ver el siguiente ejemplo.

ingrese la descripción de la imagen aquí

La correlación entre comestibles y detergentes es alta. Similar:

Pdoductos con alta correlación:
  1. Comestibles y detergentes.
Productos con correlación media:
  1. Leche y abarrotes
  2. Leche y Detergentes_Papel
Productos con baja correlación:
  1. Leche y delicatessen
  2. Congelados y Frescos.
  3. Congelados y Deli.

Desde parcelas: puede observar el mismo conjunto de relaciones desde parcelas o matriz de dispersión. Pero de estos podemos decir que si los datos se distribuyen normalmente o no.

ingrese la descripción de la imagen aquí

Nota: Lo anterior es el mismo gráfico tomado de los datos, que se utiliza para dibujar el mapa de calor.

phanindravarma
fuente
3
Creo que debería ser .plt no .pl (si esto se refiere a matplotlib)
ghukill
2
@ghukill No necesariamente. Podría haberlo referido comofrom matplotlib import pyplot as pl
Jeru Luke el
cómo establecer el límite de la correlación entre -1 y +1 siempre, en el gráfico de correlación
debaonline4u
7

Puede usar el método imshow () de matplotlib

import pandas as pd
import matplotlib.pyplot as plt
plt.style.use('ggplot')

plt.imshow(X.corr(), cmap=plt.cm.Reds, interpolation='nearest')
plt.colorbar()
tick_marks = [i for i in range(len(X.columns))]
plt.xticks(tick_marks, X.columns, rotation='vertical')
plt.yticks(tick_marks, X.columns)
plt.show()
Khandelwal-manik
fuente
5

Si su marco de datos es dfsimplemente puede usar:

import matplotlib.pyplot as plt
import seaborn as sns

plt.figure(figsize=(15, 10))
sns.heatmap(df.corr(), annot=True)
Harvey
fuente
3

los gráficos de statmodels también ofrecen una buena vista de la matriz de correlación

import statsmodels.api as sm
import matplotlib.pyplot as plt

corr = dataframe.corr()
sm.graphics.plot_corr(corr, xnames=list(corr.columns))
plt.show()
Shahriar Miraj
fuente
3

Para completar, la solución más simple que conozco con seaborn a fines de 2019, si uno está usando Jupyter :

import seaborn as sns
sns.heatmap(dataframe.corr())
Marcin
fuente
1

Junto con otros métodos, también es bueno tener un diagrama de pares que proporcionará un diagrama de dispersión para todos los casos.

import pandas as pd
import numpy as np
import seaborn as sns
rs = np.random.RandomState(0)
df = pd.DataFrame(rs.rand(10, 10))
sns.pairplot(df)
Nishant Tyagi
fuente
0

Matriz de correlación de formularios, en mi caso zdf es el marco de datos que necesito para realizar la matriz de correlaciones.

corrMatrix =zdf.corr()
corrMatrix.to_csv('sm_zscaled_correlation_matrix.csv');
html = corrMatrix.style.background_gradient(cmap='RdBu').set_precision(2).render()

# Writing the output to a html file.
with open('test.html', 'w') as f:
   print('<!DOCTYPE html><html lang="en"><head><meta charset="UTF-8"><meta name="viewport" content="width=device-widthinitial-scale=1.0"><title>Document</title></head><style>table{word-break: break-all;}</style><body>' + html+'</body></html>', file=f)

Entonces podemos tomar una captura de pantalla. o convertir html a un archivo de imagen.

smsivaprakaash
fuente