web-dev-qa-db-fra.com

Matplotlib: éviter les points de données qui se chevauchent dans un graphique "scatter / dot / beeswarm"

Lorsque je dessine un tracé de points à l'aide de matplotlib, je voudrais compenser les points de données qui se chevauchent pour les garder tous visibles. Par exemple, si j'ai

CategoryA: 0,0,3,0,5  
CategoryB: 5,10,5,5,10  

Je veux que chacun des points de données CategoryA "0" soit placé côte à côte, plutôt que directement les uns sur les autres, tout en restant distinct de CategoryB.

Dans R (ggplot2) Il y a un "jitter" option qui fait cela. Existe-t-il une option similaire dans matplotlib, ou existe-t-il une autre approche qui conduirait à un résultat similaire?

Modifier: pour clarifier, le "beeswarm" plot in R est essentiellement ce que j'ai en tête, et pybeeswarm est un début précoce mais utile dans une version matplotlib/Python.

Edit: pour ajouter que le Seaborn Swarmplot , introduit dans la version 0.7, est un excellent la mise en œuvre de ce que je voulais.

36
iayork

En étendant la réponse par @ user2467675, voici comment je l'ai fait:

def Rand_jitter(arr):
    stdev = .01*(max(arr)-min(arr))
    return arr + np.random.randn(len(arr)) * stdev

def jitter(x, y, s=20, c='b', marker='o', cmap=None, norm=None, vmin=None, vmax=None, alpha=None, linewidths=None, verts=None, hold=None, **kwargs):
    return scatter(Rand_jitter(x), Rand_jitter(y), s=s, c=c, marker=marker, cmap=cmap, norm=norm, vmin=vmin, vmax=vmax, alpha=alpha, linewidths=linewidths, verts=verts, hold=hold, **kwargs)

La variable stdev s'assure que la gigue est suffisante pour être vue à différentes échelles, mais elle suppose que les limites des axes sont 0 et la valeur max.

Vous pouvez alors appeler jitter au lieu de scatter.

32
yoavram

J'ai utilisé numpy.random pour "diffuser/réchauffer les abeilles" les données le long de l'axe X mais autour d'un point fixe pour chaque catégorie, puis je fais essentiellement pyplot.scatter () pour chaque catégorie:

import matplotlib.pyplot as plt
import numpy as np

#random data for category A, B, with B "taller"
yA, yB = np.random.randn(100), 5.0+np.random.randn(1000)

xA, xB = np.random.normal(1, 0.1, len(yA)), 
         np.random.normal(3, 0.1, len(yB))

plt.scatter(xA, yA)
plt.scatter(xB, yB)
plt.show()

X-scattered data

9
sun.huaiyu

Une façon d'aborder le problème est de penser à chaque "ligne" de votre graphique de dispersion/point/chaleur d'abeille comme un bac dans un histogramme:

data = np.random.randn(100)

width = 0.8     # the maximum width of each 'row' in the scatter plot
xpos = 0        # the centre position of the scatter plot in x

counts, edges = np.histogram(data, bins=20)

centres = (edges[:-1] + edges[1:]) / 2.
yvals = centres.repeat(counts)

max_offset = width / counts.max()
offsets = np.hstack((np.arange(cc) - 0.5 * (cc - 1)) for cc in counts)
xvals = xpos + (offsets * max_offset)

fig, ax = plt.subplots(1, 1)
ax.scatter(xvals, yvals, s=30, c='b')

Cela implique évidemment de regrouper les données, vous risquez donc de perdre une certaine précision. Si vous avez des données discrètes, vous pouvez remplacer:

counts, edges = np.histogram(data, bins=20)
centres = (edges[:-1] + edges[1:]) / 2.

avec:

centres, counts = np.unique(data, return_counts=True)

Une approche alternative qui préserve les coordonnées y exactes, même pour des données continues, consiste à utiliser un estimation de la densité du noya pour mettre à l'échelle l'amplitude de la gigue aléatoire sur l'axe des x:

from scipy.stats import gaussian_kde

kde = gaussian_kde(data)
density = kde(data)     # estimate the local density at each datapoint

# generate some random jitter between 0 and 1
jitter = np.random.Rand(*data.shape) - 0.5 

# scale the jitter by the KDE estimate and add it to the centre x-coordinate
xvals = 1 + (density * jitter * width * 2)

ax.scatter(xvals, data, s=30, c='g')
for sp in ['top', 'bottom', 'right']:
    ax.spines[sp].set_visible(False)
ax.tick_params(top=False, bottom=False, right=False)

ax.set_xticks([0, 1])
ax.set_xticklabels(['Histogram', 'KDE'], fontsize='x-large')
fig.tight_layout()

Cette seconde méthode est vaguement basée sur le fonctionnement de violon plots . Il ne peut toujours pas garantir qu'aucun des points ne se chevauchent, mais je trouve qu'en pratique, il a tendance à donner des résultats assez agréables tant qu'il y a un nombre décent de points (> 20), et la distribution peut être raisonnablement bien approximée par une somme de gaussiens.

enter image description here

8
ali_m

Seaborn fournit des tracés de points catégoriels de type histogramme via sns.swarmplot() et des tracés de points catégoriels instables via sns.stripplot() :

import seaborn as sns

sns.set(style='ticks', context='talk')
iris = sns.load_dataset('iris')

sns.swarmplot('species', 'sepal_length', data=iris)
sns.despine()

enter image description here

sns.stripplot('species', 'sepal_length', data=iris, jitter=0.2)
sns.despine()

enter image description here

8
joelostblom

Ne connaissant pas d'alternative directe mpl ici vous avez une proposition très rudimentaire:

from matplotlib import pyplot as plt
from itertools import groupby

CA = [0,4,0,3,0,5]  
CB = [0,0,4,4,2,2,2,2,3,0,5]  

x = []
y = []
for indx, klass in enumerate([CA, CB]):
    klass = groupby(sorted(klass))
    for item, objt in klass:
        objt = list(objt)
        points = len(objt)
        pos = 1 + indx + (1 - points) / 50.
        for item in objt:
            x.append(pos)
            y.append(item)
            pos += 0.04

plt.plot(x, y, 'o')
plt.xlim((0,3))

plt.show()

enter image description here

7
joaquin

Le swarmplot de Seaborn semble être le plus approprié pour ce que vous avez en tête, mais vous pouvez également trembler avec le regplot de Seaborn:

import seaborn as sns
iris = sns.load_dataset('iris')

sns.regplot(x='sepal_length',
            y='sepal_width',
            data=iris,
            fit_reg=False,  # do not fit a regression line
            x_jitter=0.1,  # could also dynamically set this with range of data
            y_jitter=0.1,
            scatter_kws={'alpha': 0.5})  # set transparency to 50%
3
wordsforthewise