web-dev-qa-db-fra.com

À quoi servent les transformés dans PyTorch?

Je suis nouveau chez Pytorch et je ne suis pas très expert en CNN . J'ai réussi un classificateur avec le tutoriel fourni par Tutorial Pytorch , mais je ne comprends pas vraiment ce que je fais lors du chargement des données.

Ils effectuent quelques augmentations et normalisations de données pour la formation, mais lorsque j'essaie de modifier les paramètres, le code ne fonctionne pas.

# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

Est-ce que j'étends mon jeu de données d'entraînement? Je ne vois pas l'augmentation des données.

Pourquoi, si je modifie la valeur de transforms.RandomResizedCrop (224), le chargement des données cesse de fonctionner? 

Dois-je également transformer le jeu de données de test?

Je suis un peu confus avec cette transformation de données qu'ils font.

5
carioka88

transforms.Compose regroupe uniquement toutes les transformations qui lui sont fournies. Ainsi, toutes les transformations dans le transforms.Compose sont appliquées à l’entrée une par une. 

Train transforme

  1. transforms.RandomResizedCrop(224): Ceci extraira un patch de taille (224, 224) de votre image d'entrée de manière aléatoire. Ainsi, il pourrait choisir ce chemin depuis le haut, le bas ou juste entre les deux. Vous effectuez donc une augmentation des données dans cette partie. De plus, si vous modifiez cette valeur, les couches entièrement connectées de votre modèle ne fonctionneront pas bien. Il est donc déconseillé de la modifier.
  2. transforms.RandomHorizontalFlip(): Une fois que nous avons notre image de taille (224, 224), nous pouvons choisir de la retourner. C'est une autre partie de l'augmentation des données.
  3. transforms.ToTensor(): Ceci convertit simplement votre image d'entrée en tenseur PyTorch.
  4. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]): Ceci est simplement une mise à l'échelle des données d'entrée et ces valeurs (moyenne et standard) doivent avoir été précalculées pour votre jeu de données. Changer ces valeurs est également déconseillé.

La validation transforme

  1. transforms.Resize(256): D'abord, votre image d'entrée est redimensionnée pour être de taille (256, 256)
  2. transforms.CentreCrop(224): recadre la partie centrale de l'image de la forme (224, 224) 

Le reste est la même chose que le train

P.S .: Vous pouvez en savoir plus sur ces transformations dans la documentation officielle

12
layog

Pour les ambiguïtés concernant l’augmentation des données, je vous renvoie à cette réponse:

Augmentation de données dans PyTorch

En résumé, supposons que vous ne disposiez que d’une transformation par retournement horizontal aléatoire. Lorsque vous parcourez un jeu de données d’images, certaines sont retournées en tant qu’original et d’autres en tant que retournées (les images originales des images retournées ne sont pas renvoyées). En d'autres termes, le nombre d'images renvoyées dans une itération est identique à la taille d'origine du jeu de données et n'est pas augmenté.

1
Ashkan372