web-dev-qa-db-fra.com

RuntimeError: Objet attendu du type torch.DoubleTensor mais le type trouvé torch.FloatTensor pour l'argument n ° 2 'poids'

Mon tenseur d’entrée est du type torch.DoubleTensor. Mais j'ai le RuntimeError ci-dessous:

RuntimeError: Expected object of type torch.DoubleTensor but found type torch.FloatTensor for argument #2 'weight'

Je n’ai pas spécifié explicitement le type de poids (c’est-à-dire que je n’ai pas lancé mon poids seul. Le poids est créé par pytorch). Qu'est-ce qui va influencer le type de poids dans le processus de transfert?

Merci beaucoup!! 

16
Eric Kani

Le type par défaut pour weights et biases est torch.FloatTensor. Ainsi, vous devrez convertir votre modèle en torch.DoubleTensor ou vos entrées en torch.FloatTensor. Vous pouvez faire vos entrées

X = X.float()

ou lancez votre modèle complet en DoubleTensor comme

model = model.double()

Vous pouvez également définir le type par défaut pour tous les tenseurs à l'aide de

pytorch.set_default_tensor_type('torch.DoubleTensor')

Il est préférable de convertir vos entrées en float plutôt que de convertir votre modèle en double, car les calculs mathématiques sur le type de données double sont considérablement plus lents sur GPU.

20
layog

Je recevais aussi exactement la même erreur. La cause fondamentale s'est avérée être cette déclaration dans mon code de chargement de données:

t = t.astype(np.float)

Ici, np.float est traduit en float 64 bits qui correspond à DoubleTensor. Donc, changer ceci pour,

t = t.astype(np.float32)

résolu le problème.

0
Shital Shah