web-dev-qa-db-fra.com

AttributeError: l'objet 'collections.OrderedDict' n'a pas d'attribut 'eval'

J'ai un fichier modèle qui ressemble à ceci

OrderedDict([('inp.conv1.conv.weight', 
          (0 ,0 ,0 ,.,.) = 
           -1.5073e-01  6.4760e-02  1.9156e-01
            1.2175e-01  3.5886e-02  1.3992e-01
           -1.5903e-01  8.2055e-02  1.7820e-01

          (0 ,0 ,1 ,.,.) = 
            1.0604e-01 -1.3653e-01  1.4803e-01
            6.0276e-02 -1.4674e-02  2.3059e-06
           -6.2192e-02 -5.1061e-03 -7.4145e-03

          (0 ,0 ,2 ,.,.) = 
           -5.5632e-02  3.5326e-02  6.5108e-02
            1.1411e-01 -4.4160e-02  8.2610e-02
            8.9979e-02 -3.5454e-02  4.2549e-02

          (1 ,0 ,0 ,.,.) = 
            4.8523e-02 -4.3961e-02  5.3614e-02
           -1.2644e-01  1.2777e-01  8.9547e-02
            3.8392e-02  2.7016e-02 -1.4552e-01

          (1 ,0 ,1 ,.,.) = 
            9.5537e-02  2.8748e-02  3.9772e-02
           -6.2410e-02  1.1264e-01  7.8663e-02
           -2.6374e-02  1.4401e-01 -1.7109e-01

          (1 ,0 ,2 ,.,.) = 
            5.1791e-02 -1.6388e-01 -1.7605e-01
            3.5028e-02  7.7164e-02 -1.4499e-01
           -2.9189e-02  2.7064e-03 -2.3228e-02

          (2 ,0 ,0 ,.,.) = 
           -7.4446e-03 -9.7202e-02 -1.4704e-01
           -1.0019e-02  8.1780e-02 -5.3530e-02
           -1.8412e-01  1.5988e-01 -1.3450e-01

          (2 ,0 ,1 ,.,.) = 
           -1.1075e-01 -5.2478e-02  6.0658e-02
            1.6739e-01 -2.9360e-02  1.2621e-01
            2.0686e-02  1.1468e-01  1.2282e-01

Je veux faire l'inférence sur ce modèle, mais quand je fais model.eval () j'obtiens,

AttributeError: 'collections.OrderedDict' object has no attribute 'eval Je ne sais pas trop comment procéder, toute suggestion sur la façon de résoudre ce problème sera très utile, merci d'avance

6
Ryan

Ce n'est pas un fichier modèle, mais plutôt un fichier d'état. Dans un fichier de modèle, le modèle complet est stocké, tandis que dans un fichier d'état, seuls les paramètres sont stockés.
Ainsi, vos OrderedDict ne sont que des valeurs pour votre modèle. Vous devrez créer le modèle, puis charger ces valeurs dans votre modèle. Ainsi, le processus sera quelque chose sous forme de

import torch
import torch.nn as nn

class TempModel(nn.Module):
    def __init__(self):
        self.conv1 = nn.Conv2d(3, 5, (3, 3))
    def forward(self, inp):
        return self.conv1(inp)

model = TempModel()
model.load_state_dict(torch.load(file_path))
model.eval()

Vous devrez définir correctement votre modèle. Celui donné dans l'exemple ci-dessus est juste un mannequin. Si vous construisez votre modèle vous-même, vous devrez peut-être mettre à jour les clés du fichier dict enregistré comme mentionné ici . La meilleure solution consiste à définir votre modèle exactement de la même manière que lorsque le state_dict a été enregistré puis exécuté directement model.load_state_dict marchera.

11
layog