web-dev-qa-db-fra.com

Changement de couleur et de marqueur de chaque point à l'aide de la représentation conjointe Seaborn

J'ai ce code légèrement modifié de ici :

import seaborn as sns
sns.set(style="darkgrid")

tips = sns.load_dataset("tips")
color = sns.color_palette()[5]
g = sns.jointplot("total_bill", "tip", data=tips, kind="reg", stat_func=None,
                  xlim=(0, 60), ylim=(0, 12), color='k', size=7)

g.set_axis_labels('total bill', 'tip', fontsize=16)

et je reçois un joli complot - Cependant, pour mon cas, je dois pouvoir changer la couleur ET le format de chaque point. 

J'ai essayé d'utiliser les mots clés marker, style et fmt, mais l'erreur TypeError: jointplot() got an unexpected keyword argument s'affiche. 

Quelle est la bonne façon de faire cela? J'aimerais éviter d'appeler sns.JointGrid et de tracer manuellement les données et les distributions marginales.

11
pbreach

La résolution de ce problème n’est presque pas différente de celle de matplotlib (tracé d’un nuage de points avec différents marqueurs et couleurs), sauf que je voulais conserver les distributions marginales:

import seaborn as sns
from itertools import product
sns.set(style="darkgrid")

tips = sns.load_dataset("tips")
color = sns.color_palette()[5]
g = sns.jointplot("total_bill", "tip", data=tips, kind="reg", stat_func=None,
                  xlim=(0, 60), ylim=(0, 12), color='k', size=7)

#Clear the axes containing the scatter plot
g.ax_joint.cla()

#Generate some colors and markers
colors = np.random.random((len(tips),3))
markers = ['x','o','v','^','<']*100

#Plot each individual point separately
for i,row in enumerate(tips.values):
    g.ax_joint.plot(row[0], row[1], color=colors[i], marker=markers[i])

g.set_axis_labels('total bill', 'tip', fontsize=16)

Ce qui me donne ceci:

enter image description here

La ligne de régression est maintenant partie, mais c'est tout ce dont j'avais besoin. 

16
pbreach

La réponse acceptée est trop compliquée. plt.sca() peut être utilisé pour le faire de manière plus simple:

import matplotlib.pyplot as plt
import seaborn as sns

tips = sns.load_dataset("tips")
g = sns.jointplot("total_bill", "tip", data=tips, kind="reg", stat_func=None,
                  xlim=(0, 60), ylim=(0, 12))


g.ax_joint.cla() # or g.ax_joint.collections[0].set_visible(False), as per mwaskom's comment

# set the current axis to be the joint plot's axis
plt.sca(g.ax_joint)

# plt.scatter takes a 'c' keyword for color
# you can also pass an array of floats and use the 'cmap' keyword to
# convert them into a colormap
plt.scatter(tips.total_bill, tips.tip, c=np.random.random((len(tips), 3)))
12
Max Shron

Vous pouvez également le préciser directement dans la liste des arguments, grâce au mot clé: joint_kws (testé avec seaborn 0.8.1). Si nécessaire, vous pouvez également modifier les propriétés du marginal avec marginal_kws

Donc, votre code devient: 

import seaborn as sns
colors = np.random.random((len(tips),3))
markers = (['x','o','v','^','<']*100)[:len(tips)]

sns.jointplot("total_bill", "tip", data=tips, kind="reg",
    joint_kws={"color":colors, "marker":markers})
3
Vincent Jeanselme
  1. Dans seaborn/categorical.py, recherchez def swarmplot
  2. Ajoutez le paramètre marker='o' avant **kwargs
  3. Dans kwargs.update, ajoutez marker=marker.

Ensuite, ajoutez par exemple marker='x' en tant que paramètre lors du traçage avec sns.swarmplot() comme vous le feriez avec Matplotlib plt.scatter().

Je viens de rencontrer le même besoin et avoir marker comme kwarg ne fonctionnait pas. Alors j'ai jeté un coup d'œil. Nous pouvons définir d'autres paramètres de manière similaire . https://github.com/ccneko/seaborn/blob/master/seaborn/categorical.py

Seul un petit changement est nécessaire ici, mais voici la page fourchue de GitHub pour une référence rapide;)

1
Claire

Une autre option consiste à utiliser JointGrid, car Jointplot est un wrapper qui simplifie son utilisation.

import matplotlib.pyplot as plt
import seaborn as sns

tips = sns.load_dataset("tips")

g = sns.JointGrid("total_bill", "tip", data=tips)
g = g.plot_joint(plt.scatter, c=np.random.random((len(tips), 3)))
g = g.plot_marginals(sns.distplot, kde=True, color="k")
0
Vlamir