web-dev-qa-db-fra.com

Scatter Plots in Pandas / Pyplot: Comment tracer par catégorie

J'essaie de créer un diagramme de dispersion simple dans pyplot en utilisant un objet DataFrame Pandas), mais je veux un moyen efficace de tracer deux variables, mais les symboles sont dictés par une troisième colonne (clé). essayé de différentes manières en utilisant df.groupby, mais sans succès. Un exemple de script df est présenté ci-dessous. Il permet de colorer les marqueurs en fonction de 'key1', mais je voudrais voir une légende avec les catégories 'key1'. Suis-je fermé? Merci.

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), index = pd.date_range('2010-01-01', freq = 'M', periods = 10), columns = ('one', 'two', 'three'))
df['key1'] = (4,4,4,6,6,6,8,8,8,8)
fig1 = plt.figure(1)
ax1 = fig1.add_subplot(111)
ax1.scatter(df['one'], df['two'], marker = 'o', c = df['key1'], alpha = 0.8)
plt.show()
68
user2989613

Vous pouvez utiliser scatter pour cela, mais cela nécessite d’avoir des valeurs numériques pour votre key1, et vous n'aurez pas de légende, comme vous l'avez remarqué.

Il vaut mieux utiliser plot pour des catégories discrètes comme celle-ci. Par exemple:

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
np.random.seed(1974)

# Generate Data
num = 20
x, y = np.random.random((2, num))
labels = np.random.choice(['a', 'b', 'c'], num)
df = pd.DataFrame(dict(x=x, y=y, label=labels))

groups = df.groupby('label')

# Plot
fig, ax = plt.subplots()
ax.margins(0.05) # Optional, just adds 5% padding to the autoscaling
for name, group in groups:
    ax.plot(group.x, group.y, marker='o', linestyle='', ms=12, label=name)
ax.legend()

plt.show()

enter image description here

Si vous souhaitez que les choses ressemblent au style par défaut pandas, mettez simplement à jour rcParams avec la feuille de style pandas) et utilisez son générateur de couleur. (I Je modifie aussi légèrement la légende):

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
np.random.seed(1974)

# Generate Data
num = 20
x, y = np.random.random((2, num))
labels = np.random.choice(['a', 'b', 'c'], num)
df = pd.DataFrame(dict(x=x, y=y, label=labels))

groups = df.groupby('label')

# Plot
plt.rcParams.update(pd.tools.plotting.mpl_stylesheet)
colors = pd.tools.plotting._get_standard_colors(len(groups), color_type='random')

fig, ax = plt.subplots()
ax.set_color_cycle(colors)
ax.margins(0.05)
for name, group in groups:
    ax.plot(group.x, group.y, marker='o', linestyle='', ms=12, label=name)
ax.legend(numpoints=1, loc='upper left')

plt.show()

enter image description here

89
Joe Kington

C’est simple à faire avec Seaborn (pip install seaborn) En tant que doublure

sns.pairplot(x_vars=["one"], y_vars=["two"], data=df, hue="key1", size=5) :

import seaborn as sns
import pandas as pd
import numpy as np
np.random.seed(1974)

df = pd.DataFrame(
    np.random.normal(10, 1, 30).reshape(10, 3),
    index=pd.date_range('2010-01-01', freq='M', periods=10),
    columns=('one', 'two', 'three'))
df['key1'] = (4, 4, 4, 6, 6, 6, 8, 8, 8, 8)

sns.pairplot(x_vars=["one"], y_vars=["two"], data=df, hue="key1", size=5)

enter image description here

Voici le dataframe pour référence:

enter image description here

Comme vous disposez de trois colonnes variables dans vos données, vous pouvez tracer toutes les dimensions par paires avec:

sns.pairplot(vars=["one","two","three"], data=df, hue="key1", size=5)

enter image description here

https://rasbt.github.io/mlxtend/user_guide/plotting/category_scatter/ est une autre option.

43
Bob Baxley

Avec plt.scatter, Je ne peux penser qu’à un seul: utiliser un artiste proxy:

df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), index = pd.date_range('2010-01-01', freq = 'M', periods = 10), columns = ('one', 'two', 'three'))
df['key1'] = (4,4,4,6,6,6,8,8,8,8)
fig1 = plt.figure(1)
ax1 = fig1.add_subplot(111)
x=ax1.scatter(df['one'], df['two'], marker = 'o', c = df['key1'], alpha = 0.8)

ccm=x.get_cmap()
circles=[Line2D(range(1), range(1), color='w', marker='o', markersize=10, markerfacecolor=item) for item in ccm((array([4,6,8])-4.0)/4)]
leg = plt.legend(circles, ['4','6','8'], loc = "center left", bbox_to_anchor = (1, 0.5), numpoints = 1)

Et le résultat est:

enter image description here

19
CT Zhu

Vous pouvez utiliser df.plot.scatter et passer un tableau à l'argument c = définissant la couleur de chaque point:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), index = pd.date_range('2010-01-01', freq = 'M', periods = 10), columns = ('one', 'two', 'three'))
df['key1'] = (4,4,4,6,6,6,8,8,8,8)
colors = np.where(df["key1"]==4,'r','-')
colors[df["key1"]==6] = 'g'
colors[df["key1"]==8] = 'b'
print(colors)
df.plot.scatter(x="one",y="two",c=colors)
plt.show()

enter image description here

6
Arjaan Buijk

Vous pouvez également essayer Altair ou ggpot qui sont centrés sur les visualisations déclaratives.

import numpy as np
import pandas as pd
np.random.seed(1974)

# Generate Data
num = 20
x, y = np.random.random((2, num))
labels = np.random.choice(['a', 'b', 'c'], num)
df = pd.DataFrame(dict(x=x, y=y, label=labels))

Code Altair

from altair import Chart
c = Chart(df)
c.mark_circle().encode(x='x', y='y', color='label')

enter image description here

code ggplot

from ggplot import *
ggplot(aes(x='x', y='y', color='label'), data=df) +\
geom_point(size=50) +\
theme_bw()

enter image description here

3
Nipun Batra

C'est plutôt hacky, mais vous pouvez utiliser one1 comme un Float64Index tout faire en une fois:

df.set_index('one').sort_index().groupby('key1')['two'].plot(style='--o', legend=True)

enter image description here

Notez qu'à partir de 0.20.3, il est nécessaire de trier l'index , et la légende est n peu confus .

2
fuglede

À partir de matplotlib 3.1, vous pouvez utiliser .legend_elements() . Un exemple est présenté dans Création automatique de légende . L'avantage est qu'un seul appel scatter peut être utilisé.

Dans ce cas:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), 
                  index = pd.date_range('2010-01-01', freq = 'M', periods = 10), 
                  columns = ('one', 'two', 'three'))
df['key1'] = (4,4,4,6,6,6,8,8,8,8)


fig, ax = plt.subplots()
sc = ax.scatter(df['one'], df['two'], marker = 'o', c = df['key1'], alpha = 0.8)
ax.legend(*sc.legend_elements())
plt.show()

enter image description here

Au cas où les clés ne seraient pas directement attribuées sous forme de chiffres, cela ressemblerait à

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), 
                  index = pd.date_range('2010-01-01', freq = 'M', periods = 10), 
                  columns = ('one', 'two', 'three'))
df['key1'] = list("AAABBBCCCC")

labels, index = np.unique(df["key1"], return_inverse=True)

fig, ax = plt.subplots()
sc = ax.scatter(df['one'], df['two'], marker = 'o', c = index, alpha = 0.8)
ax.legend(sc.legend_elements()[0], labels)
plt.show()

enter image description here