web-dev-qa-db-fra.com

Comment faire une norme de lot entièrement connectée dans PyTorch?

torch.nn a des classes BatchNorm1d, BatchNorm2d, BatchNorm3d, mais il n'a pas de classe BatchNorm entièrement connectée? Quelle est la façon standard de faire la norme de lot normale dans PyTorch?

9
patapouf_ai

D'accord. Je l'ai compris. BatchNorm1d peut également gérer les tenseurs de rang 2, il est donc possible d'utiliser BatchNorm1d pour le boîtier normal entièrement connecté.

Ainsi, par exemple:

import torch.nn as nn


class Policy(nn.Module):
def __init__(self, num_inputs, action_space, hidden_size1=256, hidden_size2=128):
    super(Policy2, self).__init__()
    self.action_space = action_space
    num_outputs = action_space

    self.linear1 = nn.Linear(num_inputs, hidden_size1)
    self.linear2 = nn.Linear(hidden_size1, hidden_size2)
    self.linear3 = nn.Linear(hidden_size2, num_outputs)
    self.bn1 = nn.BatchNorm1d(hidden_size1)
    self.bn2 = nn.BatchNorm1d(hidden_size2)

def forward(self, inputs):
    x = inputs
    x = self.bn1(F.relu(self.linear1(x)))
    x = self.bn2(F.relu(self.linear2(x)))
    out = self.linear3(x)


    return out
19
patapouf_ai