web-dev-qa-db-fra.com

itertools non défini lorsqu'il est utilisé à l'intérieur du module

J'enregistre mes fonctions personnalisées dans un module séparé que je peux appeler quand j'en ai besoin. Une de mes nouvelles fonctions utilise itertools, mais je reçois toujours une erreur de nom.

NameError: name 'itertools' is not defined

C'est vraiment bizarre. Je peux très bien importer itertools dans la console, mais lorsque j'appelle ma fonction, j'obtiens une erreur de nom. Habituellement, je peux utiliser des fonctions d'autres bibliothèques (pandas, sklearn, etc.) à l'intérieur d'une fonction personnalisée très bien tant que j'importe la bibliothèque en premier.

MAIS si j'importe itertools dans la console, copiez et collez ma fonction dans la console, puis appelez la fonction, cela fonctionne très bien.

Cela me rend fou, mais je pense que je ne comprends peut-être tout simplement pas les règles des modules ou quelque chose.

voici la fonction que j'utilise dans le module. il est simplement copié et collé à partir d'un des exemples sklearn:

import itertools    
def plot_confusion_matrix(cm, classes,
                              normalize=False,
                              title='Confusion matrix',
                              cmap=plt.cm.Blues):
        import itertools
        plt.imshow(cm, interpolation='nearest', cmap=cmap)
        plt.title(title)
        plt.colorbar()
        tick_marks = np.arange(len(classes))
        plt.xticks(tick_marks, classes, rotation=45)
        plt.yticks(tick_marks, classes)

        if normalize:
            cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
            print("Normalized confusion matrix")
        else:
            print('Confusion matrix, without normalization')

        print(cm)

        thresh = cm.max() / 2.
        for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
            plt.text(j, i, cm[i, j],
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black")

        plt.tight_layout()
        plt.ylabel('True label')
        plt.xlabel('Predicted label')

J'ai essayé de l'importer à l'intérieur de la fonction, à l'intérieur du module et à l'intérieur du fichier où je l'appelle - le tout sans succès. Si je l'importe dans la console, ça va. Même après son importation dans la console, si je l'exécute dans le fichier sur lequel je travaille à nouveau, cela donne la même erreur.

8
Adam

Ça fonctionne maintenant.

LEÇON IMPORTANTE: Si vous modifiez un module, vous devez fermer et rouvrir spyder/ipython/peu importe. La simple réinitialisation du noyau n'est pas suffisante. Je suis stupide, je sais, mais peut-être que cette réponse fera gagner du temps à quelqu'un.

7
Adam

Vous pouvez d'abord utiliser à partir du produit d'importation itertools, puis changer simplement itertools.product en produit. Cela devrait fonctionner.

0
Qin Peng Michelle

Vous changez juste
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):

À:

for i in range (cm.shape[0]): for j in range (cm.shape[1]):

0
Vong Ho