web-dev-qa-db-fra.com

Existe-t-il une fonction pour créer des matrices de nuage de points dans matplotlib?

Exemple de matrice de nuage de points

enter image description here

Existe-t-il une telle fonction dans matplotlib.pyplot?

49
hatmatrix

De manière générale, matplotlib ne contient généralement pas de fonctions de traçage qui fonctionnent sur plusieurs objets axes (sous-tracé, dans ce cas). On s'attend à ce que vous écriviez une fonction simple pour enchaîner les choses comme vous le souhaitez.

Je ne sais pas trop à quoi ressemblent vos données, mais il est assez simple de simplement créer une fonction pour le faire à partir de zéro. Si vous allez toujours travailler avec des tableaux structurés ou rec, alors vous pouvez simplifier cela d'une touche. (C'est-à-dire qu'il y a toujours un nom associé à chaque série de données, vous pouvez donc omettre de spécifier des noms.)

Par exemple:

import itertools
import numpy as np
import matplotlib.pyplot as plt

def main():
    np.random.seed(1977)
    numvars, numdata = 4, 10
    data = 10 * np.random.random((numvars, numdata))
    fig = scatterplot_matrix(data, ['mpg', 'disp', 'drat', 'wt'],
            linestyle='none', marker='o', color='black', mfc='none')
    fig.suptitle('Simple Scatterplot Matrix')
    plt.show()

def scatterplot_matrix(data, names, **kwargs):
    """Plots a scatterplot matrix of subplots.  Each row of "data" is plotted
    against other rows, resulting in a nrows by nrows grid of subplots with the
    diagonal subplots labeled with "names".  Additional keyword arguments are
    passed on to matplotlib's "plot" command. Returns the matplotlib figure
    object containg the subplot grid."""
    numvars, numdata = data.shape
    fig, axes = plt.subplots(nrows=numvars, ncols=numvars, figsize=(8,8))
    fig.subplots_adjust(hspace=0.05, wspace=0.05)

    for ax in axes.flat:
        # Hide all ticks and labels
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)

        # Set up ticks only on one side for the "Edge" subplots...
        if ax.is_first_col():
            ax.yaxis.set_ticks_position('left')
        if ax.is_last_col():
            ax.yaxis.set_ticks_position('right')
        if ax.is_first_row():
            ax.xaxis.set_ticks_position('top')
        if ax.is_last_row():
            ax.xaxis.set_ticks_position('bottom')

    # Plot the data.
    for i, j in Zip(*np.triu_indices_from(axes, k=1)):
        for x, y in [(i,j), (j,i)]:
            axes[x,y].plot(data[x], data[y], **kwargs)

    # Label the diagonal subplots...
    for i, label in enumerate(names):
        axes[i,i].annotate(label, (0.5, 0.5), xycoords='axes fraction',
                ha='center', va='center')

    # Turn on the proper x or y axes ticks.
    for i, j in Zip(range(numvars), itertools.cycle((-1, 0))):
        axes[j,i].xaxis.set_visible(True)
        axes[i,j].yaxis.set_visible(True)

    return fig

main()

enter image description here

21
Joe Kington

Pour ceux qui ne veulent pas définir leurs propres fonctions, il existe une grande bibliothèque d'analyse de données en Python, appelée Pandas , où l'on peut trouver la méthode scatter_matrix () :

from pandas.plotting import scatter_matrix
df = pd.DataFrame(np.random.randn(1000, 4), columns = ['a', 'b', 'c', 'd'])
scatter_matrix(df, alpha = 0.2, figsize = (6, 6), diagonal = 'kde')

enter image description here

97
Roman Pekar

Vous pouvez également utiliser fonction pairplot de Seaborn :

import seaborn as sns
sns.set()
df = sns.load_dataset("iris")
sns.pairplot(df, hue="species")
12
sushmit

Merci de partager votre code! Vous avez compris toutes les choses difficiles pour nous. Pendant que je travaillais avec, j'ai remarqué quelques petites choses qui ne semblaient pas tout à fait correctes.

  1. [FIX # 1] Les tics des axes ne s'alignaient pas comme je m'y attendais (c'est-à-dire, dans votre exemple ci-dessus, vous devriez pouvoir tracer une ligne verticale et horizontale à travers n'importe quel point de tous les tracés et les lignes devraient traverser la ligne correspondante point dans les autres parcelles, mais comme il se trouve maintenant, cela ne se produit pas.

  2. [FIX # 2] Si vous avez un nombre impair de variables avec lesquelles vous tracez, les axes en bas à droite ne tirent pas les xtics ou ytics corrects. Il le laisse juste comme les ticks par défaut 0..1.

  3. Pas un correctif, mais je l'ai rendu facultatif pour entrer explicitement names, afin qu'il place un xi par défaut pour la variable i dans les positions diagonales.

Vous trouverez ci-dessous une version mise à jour de votre code qui répond à ces deux points, sinon préservant la beauté de votre code.

import itertools
import numpy as np
import matplotlib.pyplot as plt

def scatterplot_matrix(data, names=[], **kwargs):
    """
    Plots a scatterplot matrix of subplots.  Each row of "data" is plotted
    against other rows, resulting in a nrows by nrows grid of subplots with the
    diagonal subplots labeled with "names".  Additional keyword arguments are
    passed on to matplotlib's "plot" command. Returns the matplotlib figure
    object containg the subplot grid.
    """
    numvars, numdata = data.shape
    fig, axes = plt.subplots(nrows=numvars, ncols=numvars, figsize=(8,8))
    fig.subplots_adjust(hspace=0.0, wspace=0.0)

    for ax in axes.flat:
        # Hide all ticks and labels
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)

        # Set up ticks only on one side for the "Edge" subplots...
        if ax.is_first_col():
            ax.yaxis.set_ticks_position('left')
        if ax.is_last_col():
            ax.yaxis.set_ticks_position('right')
        if ax.is_first_row():
            ax.xaxis.set_ticks_position('top')
        if ax.is_last_row():
            ax.xaxis.set_ticks_position('bottom')

    # Plot the data.
    for i, j in Zip(*np.triu_indices_from(axes, k=1)):
        for x, y in [(i,j), (j,i)]:
            # FIX #1: this needed to be changed from ...(data[x], data[y],...)
            axes[x,y].plot(data[y], data[x], **kwargs)

    # Label the diagonal subplots...
    if not names:
        names = ['x'+str(i) for i in range(numvars)]

    for i, label in enumerate(names):
        axes[i,i].annotate(label, (0.5, 0.5), xycoords='axes fraction',
                ha='center', va='center')

    # Turn on the proper x or y axes ticks.
    for i, j in Zip(range(numvars), itertools.cycle((-1, 0))):
        axes[j,i].xaxis.set_visible(True)
        axes[i,j].yaxis.set_visible(True)

    # FIX #2: if numvars is odd, the bottom right corner plot doesn't have the
    # correct axes limits, so we pull them from other axes
    if numvars%2:
        xlimits = axes[0,-1].get_xlim()
        ylimits = axes[-1,0].get_ylim()
        axes[-1,-1].set_xlim(xlimits)
        axes[-1,-1].set_ylim(ylimits)

    return fig

if __name__=='__main__':
    np.random.seed(1977)
    numvars, numdata = 4, 10
    data = 10 * np.random.random((numvars, numdata))
    fig = scatterplot_matrix(data, ['mpg', 'disp', 'drat', 'wt'],
            linestyle='none', marker='o', color='black', mfc='none')
    fig.suptitle('Simple Scatterplot Matrix')
    plt.show()

Merci encore d'avoir partagé cela avec nous. Je l'ai utilisé plusieurs fois! Oh, et j'ai réorganisé la partie main() du code afin qu'il puisse être un exemple de code formel ou ne pas être appelé s'il est importé dans un autre morceau de code.

10
tisimst

En lisant la question, je m'attendais à voir une réponse comprenant rpy . Je pense que c'est une bonne option en profitant de deux belles langues. Voici donc:

import rpy
import numpy as np

def main():
    np.random.seed(1977)
    numvars, numdata = 4, 10
    data = 10 * np.random.random((numvars, numdata))
    mpg = data[0,:]
    disp = data[1,:]
    drat = data[2,:]
    wt = data[3,:]
    rpy.set_default_mode(rpy.NO_CONVERSION)

    R_data = rpy.r.data_frame(mpg=mpg,disp=disp,drat=drat,wt=wt)

    # Figure saved as eps
    rpy.r.postscript('pairsPlot.eps')
    rpy.r.pairs(R_data,
       main="Simple Scatterplot Matrix Via RPy")
    rpy.r.dev_off()

    # Figure saved as png
    rpy.r.png('pairsPlot.png')
    rpy.r.pairs(R_data,
       main="Simple Scatterplot Matrix Via RPy")
    rpy.r.dev_off()

    rpy.set_default_mode(rpy.BASIC_CONVERSION)


if __== '__main__': main()

Je ne peux pas poster une image pour montrer le résultat :( désolé!

4
omun