web-dev-qa-db-fra.com

Comment obtenir le nom de fichier d'un échantillon à partir d'un DataLoader?

J'ai besoin d'écrire un fichier avec le résultat du test de données d'un réseau neuronal convolutionnel que j'ai formé. Les données comprennent la collecte de données vocales. Le format de fichier doit être "nom de fichier, prédiction", mais j'ai du mal à extraire le nom de fichier. Je charge les données comme ceci:

import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader

TEST_DATA_PATH = ...

trans = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

test_dataset = torchvision.datasets.MNIST(
    root=TEST_DATA_PATH,
    train=False,
    transform=trans,
    download=True
)

test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)

et j'essaie d'écrire dans le fichier comme suit:

f = open("test_y", "w")
with torch.no_grad():
    for i, (images, labels) in enumerate(test_loader, 0):
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        file = os.listdir(TEST_DATA_PATH + "/all")[i]
        format = file + ", " + str(predicted.item()) + '\n'
        f.write(format)
f.close()

Le problème avec os.listdir(TESTH_DATA_PATH + "/all")[i] est qu'il n'est pas synchronisé avec l'ordre des fichiers chargés de test_loader. Que puis-je faire?

5
Almog Levi

Dans le cas général DataLoader est là pour vous fournir les lots du ou des jeux de données qu'il contient.

Comme @Barriel l'a mentionné en cas de problèmes de classification simple/multi-étiquette, le DataLoader n'a pas de nom de fichier image, juste les tenseurs représentant les images et les classes/étiquettes.

Cependant, le constructeur DataLoader lors du chargement d'objets peut prendre de petites choses (avec le jeu de données, vous pouvez emballer les cibles/étiquettes et les noms de fichier si vous le souhaitez), même un cadre de données

De cette façon, le DataLoader peut en quelque sorte saisir ce dont vous avez besoin.

1
prosti

Eh bien, cela dépend de la façon dont votre Dataset est implémenté. Par exemple, dans le cas torchvision.datasets.MNIST(...), vous ne pouvez pas récupérer le nom de fichier simplement parce qu'il n'y a rien de tel que le nom de fichier d'un seul échantillon (les échantillons MNIST sont chargés d'une manière différente ) .

Comme vous n'avez pas montré votre implémentation de Dataset, je vais vous dire comment cela pourrait être fait avec la torchvision.datasets.ImageFolder(...) (ou n'importe quelle torchvision.datasets.DatasetFolder(...) ):

f = open("test_y", "w")
with torch.no_grad():
    for i, (images, labels) in enumerate(test_loader, 0):
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        sample_fname, _ = test_loader.dataset.samples[i]
        f.write("{}, {}\n".format(sample_fname, predicted.item()))
f.close()

Vous pouvez voir que le chemin du fichier est récupéré pendant __getitem__(self, index) , en particulier ici .

Si vous avez implémenté votre propre Dataset (et que vous souhaitez peut-être prendre en charge shuffle et batch_size > 1), Je retournerais le sample_fname Sur la fonction __getitem__(...) appeler et faire quelque chose comme ceci:

for i, (images, labels, sample_fname) in enumerate(test_loader, 0):
    # [...]

De cette façon, vous n'auriez pas à vous soucier de shuffle. Et si le batch_size Est supérieur à 1, vous devrez changer le contenu de la boucle pour quelque chose de plus générique, par exemple:

f = open("test_y", "w")
for i, (images, labels, samples_fname) in enumerate(test_loader, 0):
    outputs = model(images)
    pred = torch.max(outputs, 1)[1]
    f.write("\n".join([
        ", ".join(x)
        for x in Zip(map(str, pred.cpu().tolist()), samples_fname)
    ]) + "\n")
f.close()
1
Berriel