web-dev-qa-db-fra.com

PyTorch: Comment utiliser DataLoaders pour des jeux de données personnalisés

Comment utiliser le torch.utils.data.Dataset et torch.utils.data.DataLoader sur vos propres données (pas seulement le torchvision.datasets)?

Est-il possible d'utiliser le DataLoaders incorporé qu'ils utilisent sur TorchVisionDatasets pour être utilisé dans n'importe quel jeu de données?

45
Sarthak

Oui c'est possible Créez simplement les objets vous-même, par exemple.

import torch.utils.data as data_utils

train = data_utils.TensorDataset(features, targets)
train_loader = data_utils.DataLoader(train, batch_size=50, shuffle=True)

features et targets sont des tenseurs. features doit être 2-D, c'est-à-dire une matrice où chaque ligne représente un échantillon d'apprentissage, et targets peut être 1-D ou 2-D, selon que vous essayez ou non de prédire une scalaire ou un vecteur.

J'espère que ça t'as aidé!


EDIT : réponse à la question de @ sarthak

Fondamentalement oui. Si vous créez un objet de type TensorData, le constructeur examine alors si les premières dimensions du tenseur de la fonction (qui s'appelle en réalité data_tensor) et le tenseur cible (appelé target_tensor) ont la même longueur:

assert data_tensor.size(0) == target_tensor.size(0)

Toutefois, si vous souhaitez ensuite intégrer ces données dans un réseau de neurones, vous devez être prudent. Bien que les couches de convolution fonctionnent sur des données comme la vôtre, tous les autres types de couches s’attendent à ce que les données soient fournies sous forme de matrice. Donc, si vous rencontrez un problème comme celui-ci, une solution simple consisterait à convertir votre jeu de données 4D (donné sous forme de tenseur, par exemple FloatTensor) en une matrice à l'aide de la méthode view. Pour votre jeu de données 5000xnxnx3, ceci ressemblerait à ceci:

2d_dataset = 4d_dataset.view(5000, -1)

(La valeur -1 indique à PyTorch de déterminer automatiquement la longueur de la deuxième dimension.)

46
pho7

Vous pouvez facilement le faire en étendant le data.Dataset classe. Selon le API , tout ce que vous avez à faire est d'implémenter deux fonctions: __getitem__ et __len__.

Vous pouvez ensuite envelopper l'ensemble de données avec le chargeur de données, comme indiqué dans l'API et dans la réponse de @ pho7.

Je pense que la classe ImageFolder est une référence. Voir le code ici .

9
user3693922