web-dev-qa-db-fra.com

Comment vérifier si deux tenseurs ou matrices de torche sont égaux?

J'ai besoin d'une commande Torch qui vérifie si deux tenseurs ont le même contenu et renvoie VRAI s'ils ont le même contenu.

Par exemple:

local tens_a = torch.Tensor({9,8,7,6});
local tens_b = torch.Tensor({9,8,7,6});

if (tens_a EQUIVALENCE_COMMAND tens_b) then ... end

Que devrais-je utiliser dans ce script au lieu de EQUIVALENCE_COMMAND?

J'ai simplement essayé avec == mais cela ne fonctionne pas. 

15
DavideChicco.it

https://github.com/torch/torch7/blob/master/doc/maths.md#torcheqa-b

torch.eq(a, b)

Implémente == l'opérateur comparant chaque élément de a avec b (si b est un nombre) ou chaque élément de a avec l'élément correspondant de b.

--METTRE À JOUR

de @deltheil

torch.all(torch.eq(tens_a, tens_b))

ou même plus simple

torch.all(tens_a:eq(tens_b))
16
YuTse

Essayez ceci si vous voulez ignorer les petites différences de précision communes aux flottants.

torch.all(torch.lt(torch.abs(torch.add(tens_a, -tens_b)), 1e-12))
4
tworec

Cette solution ci-dessous a fonctionné pour moi: 

torch.equal(tensorA, tensorB)
0
Erik