web-dev-qa-db-fra.com

Classification multi-étiquettes en pytorch

J'ai un problème de classification multi-étiquettes. J'ai 11 classes, environ 4k exemples. Chaque exemple peut avoir de 1 à 4-5 étiquettes. En ce moment, j'entraîne un classificateur séparément pour chaque classe avec log_loss. Comme vous pouvez vous y attendre, cela prend un certain temps pour former 11 classificateurs, et je voudrais essayer une autre approche et former seulement 1 classifieur. L'idée est que la dernière couche de ce classificateur aurait 11 nœuds et produirait un nombre réel par classes qui serait converti en proba par un sigmoïde. La perte que je veux optimiser est la moyenne de log_loss sur toutes les classes.

Malheureusement, je suis une sorte de noob avec pytorch, et même en lisant le code source des pertes, je ne peux pas savoir si l'une des pertes déjà existantes fait exactement ce que je veux, ou si je dois créer une nouvelle perte et si c'est le cas, je ne sais pas vraiment comment faire.

Pour être très précis, je veux donner pour chaque élément du lot un vecteur de taille 11 (qui contient un nombre réel pour chaque étiquette (plus l'infini est proche, plus cette classe est présumée être 1), et 1 vecteur de taille 11 (qui contient un 1 à chaque vraie étiquette), et être capable de calculer la perte de log moyenne sur les 11 étiquettes, et d'optimiser mon classificateur en fonction de cette perte.

Toute aide serait grandement appréciée :)

8
Statistic Dean

Tu recherches torch.nn.BCELoss . Voici un exemple de code:

import torch

batch_size = 2
num_classes = 11

loss_fn = torch.nn.BCELoss()

outputs_before_sigmoid = torch.randn(batch_size, num_classes)
sigmoid_outputs = torch.sigmoid(outputs_before_sigmoid)
target_classes = torch.randint(0, 2, (batch_size, num_classes))  # randints in [0, 2).

loss = loss_fn(sigmoid_outputs, target_classes)

# alternatively, use BCE with logits, on outputs before sigmoid.
loss_fn_2 = torch.nn.BCEWithLogitsLoss()
loss2 = loss_fn_2(outputs_before_sigmoid, target_classes)
assert loss == loss2
7