Estoy tratando de hacer un diagrama de dispersión simple en pyplot usando un objeto Pandas DataFrame, pero quiero una forma eficiente de trazar dos variables pero que los símbolos estén dictados por una tercera columna (clave). He intentado varias formas usando df.groupby, pero no con éxito. A continuación se muestra un ejemplo de secuencia de comandos df. Esto colorea los marcadores de acuerdo con 'key1', pero me gustaría ver una leyenda con las categorías de 'key1'. Estoy cerca? Gracias.
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), index = pd.date_range('2010-01-01', freq = 'M', periods = 10), columns = ('one', 'two', 'three'))
df['key1'] = (4,4,4,6,6,6,8,8,8,8)
fig1 = plt.figure(1)
ax1 = fig1.add_subplot(111)
ax1.scatter(df['one'], df['two'], marker = 'o', c = df['key1'], alpha = 0.8)
plt.show()
fuente
ax.legend(numpoints=1)
para mostrar solo un marcador. Hay dos, como con aLine2D
, a menudo hay una línea que conecta los dos marcadores.plt.hold(True)
después delax.plot()
comando. ¿Alguna idea de por qué?set_color_cycle()
quedó obsoleto en matplotlib 1.5. La hayset_prop_cycle()
, ahora.Esto es simple de hacer con Seaborn (
pip install seaborn
) como un delineadorsns.scatterplot(x_vars="one", y_vars="two", data=df, hue="key1")
:import seaborn as sns import pandas as pd import numpy as np np.random.seed(1974) df = pd.DataFrame( np.random.normal(10, 1, 30).reshape(10, 3), index=pd.date_range('2010-01-01', freq='M', periods=10), columns=('one', 'two', 'three')) df['key1'] = (4, 4, 4, 6, 6, 6, 8, 8, 8, 8) sns.scatterplot(x="one", y="two", data=df, hue="key1")
Aquí está el marco de datos para referencia:
Dado que tiene tres columnas variables en sus datos, es posible que desee trazar todas las dimensiones por pares con:
sns.pairplot(vars=["one","two","three"], data=df, hue="key1")
https://rasbt.github.io/mlxtend/user_guide/plotting/category_scatter/ es otra opción.
fuente
Con
plt.scatter
, solo puedo pensar en uno: usar un artista proxy:df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), index = pd.date_range('2010-01-01', freq = 'M', periods = 10), columns = ('one', 'two', 'three')) df['key1'] = (4,4,4,6,6,6,8,8,8,8) fig1 = plt.figure(1) ax1 = fig1.add_subplot(111) x=ax1.scatter(df['one'], df['two'], marker = 'o', c = df['key1'], alpha = 0.8) ccm=x.get_cmap() circles=[Line2D(range(1), range(1), color='w', marker='o', markersize=10, markerfacecolor=item) for item in ccm((array([4,6,8])-4.0)/4)] leg = plt.legend(circles, ['4','6','8'], loc = "center left", bbox_to_anchor = (1, 0.5), numpoints = 1)
Y el resultado es:
fuente
Puede usar df.plot.scatter y pasar una matriz al argumento c = que define el color de cada punto:
import numpy as np import pandas as pd import matplotlib.pyplot as plt df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), index = pd.date_range('2010-01-01', freq = 'M', periods = 10), columns = ('one', 'two', 'three')) df['key1'] = (4,4,4,6,6,6,8,8,8,8) colors = np.where(df["key1"]==4,'r','-') colors[df["key1"]==6] = 'g' colors[df["key1"]==8] = 'b' print(colors) df.plot.scatter(x="one",y="two",c=colors) plt.show()
fuente
También puede probar Altair o ggpot, que se centran en visualizaciones declarativas.
import numpy as np import pandas as pd np.random.seed(1974) # Generate Data num = 20 x, y = np.random.random((2, num)) labels = np.random.choice(['a', 'b', 'c'], num) df = pd.DataFrame(dict(x=x, y=y, label=labels))
Código de Altair
from altair import Chart c = Chart(df) c.mark_circle().encode(x='x', y='y', color='label')
código ggplot
from ggplot import * ggplot(aes(x='x', y='y', color='label'), data=df) +\ geom_point(size=50) +\ theme_bw()
fuente
Desde matplotlib 3.1 en adelante puede usar
.legend_elements()
. Se muestra un ejemplo en Creación automática de leyendas . La ventaja es que se puede utilizar una única llamada dispersa.En este caso:
import numpy as np import pandas as pd import matplotlib.pyplot as plt df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), index = pd.date_range('2010-01-01', freq = 'M', periods = 10), columns = ('one', 'two', 'three')) df['key1'] = (4,4,4,6,6,6,8,8,8,8) fig, ax = plt.subplots() sc = ax.scatter(df['one'], df['two'], marker = 'o', c = df['key1'], alpha = 0.8) ax.legend(*sc.legend_elements()) plt.show()
En caso de que las claves no se dieran directamente como números, se vería como
import numpy as np import pandas as pd import matplotlib.pyplot as plt df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), index = pd.date_range('2010-01-01', freq = 'M', periods = 10), columns = ('one', 'two', 'three')) df['key1'] = list("AAABBBCCCC") labels, index = np.unique(df["key1"], return_inverse=True) fig, ax = plt.subplots() sc = ax.scatter(df['one'], df['two'], marker = 'o', c = index, alpha = 0.8) ax.legend(sc.legend_elements()[0], labels) plt.show()
fuente
fig, ax = plt.subplots(1, 1, figsize = (4,4)) scat = ax.scatter(rand_jitter(important_dataframe["workout_type_int"], jitter = 0.04), important_dataframe["distance"], c = color_list, marker = 'o', alpha = 0.9) print(scat.legends_elements()) #ax.legend(*scat.legend_elements())
legends_elements
ylegend_elements
.Es bastante hacky, pero se puede utilizar
one1
como unFloat64Index
hacer todo de una vez:df.set_index('one').sort_index().groupby('key1')['two'].plot(style='--o', legend=True)
Tenga en cuenta que a partir de 0.20.3, es necesario ordenar el índice y la leyenda es un poco confusa .
fuente
seaborn tiene una función de envoltura
scatterplot
que lo hace de manera más eficiente.sns.scatterplot(data = df, x = 'one', y = 'two', data = 'key1'])
fuente