web-dev-qa-db-fra.com

PyTorch: réglage manuel des paramètres de poids avec tableau numpy pour GRU / LSTM

J'essaie de remplir GRU/LSTM avec des paramètres définis manuellement dans pytorch.

J'ai des tableaux numpy pour les paramètres avec des formes définies dans leur documentation ( https://pytorch.org/docs/stable/nn.html#torch.nn.GR ).

Cela semble fonctionner mais je ne sais pas si les valeurs retournées sont correctes.

Est-ce une bonne façon de remplir GRU/LSTM avec des paramètres numpy?

gru = nn.GRU(input_size, hidden_size, num_layers,
              bias=True, batch_first=False, dropout=dropout, bidirectional=bidirectional)

def set_nn_wih(layer, parameter_name, w, l0=True):
    param = getattr(layer, parameter_name)
    if l0:
        for i in range(3*hidden_size):
            param.data[i] = w[i*input_size:(i+1)*input_size]
    else:
        for i in range(3*hidden_size):
            param.data[i] = w[i*num_directions*hidden_size:(i+1)*num_directions*hidden_size]

def set_nn_whh(layer, parameter_name, w):
    param = getattr(layer, parameter_name)
    for i in range(3*hidden_size):
        param.data[i] = w[i*hidden_size:(i+1)*hidden_size]

l0=True

for i in range(num_directions):
    for j in range(num_layers):
        if j == 0:
            wih = w0[i, :, :3*input_size]
            whh = w0[i, :, 3*input_size:]  # check
            l0=True
        else:
            wih = w[j-1, i, :, :num_directions*3*hidden_size]
            whh = w[j-1, i, :, num_directions*3*hidden_size:]
            l0=False

        if i == 0:
            set_nn_wih(
                gru, "weight_ih_l{}".format(j), torch.from_numpy(wih.flatten()),l0)
            set_nn_whh(
                gru, "weight_hh_l{}".format(j), torch.from_numpy(whh.flatten()))
        else:
            set_nn_wih(
                gru, "weight_ih_l{}_reverse".format(j), torch.from_numpy(wih.flatten()),l0)
            set_nn_whh(
                gru, "weight_hh_l{}_reverse".format(j), torch.from_numpy(whh.flatten()))

y, hn = gru(x_t, h_t)

les tableaux numpy sont définis comme suit:

rng = np.random.RandomState(313)
w0 = rng.randn(num_directions, hidden_size, 3*(input_size +
               hidden_size)).astype(np.float32)
w = rng.randn(max(1, num_layers-1), num_directions, hidden_size,
              3*(num_directions*hidden_size + hidden_size)).astype(np.float32)
10
ytrewq

C'est une bonne question, et vous donnez déjà une réponse décente. Cependant, il réinvente la roue - il existe une routine interne Pytorch très élégante qui vous permettra de faire la même chose sans trop d'effort - et qui s'applique à n'importe quel réseau.

Le concept de base ici est le state_dict De PyTorch. Le dictionnaire d'état contient effectivement le parameters organisé par l'arborescence donnée par la relation entre le nn.Modules Et leurs sous-modules, etc.

La réponse courte

Si vous voulez seulement que le code charge une valeur dans un tenseur en utilisant le state_dict, Essayez cette ligne (où le dict contient un state_dict Valide):

`model.load_state_dict(dict, strict=False)`

strict=False est crucial si vous voulez charger seulement certaines valeurs de paramètres .

La réponse longue - y compris une introduction au PyTorch state_dict

Voici un exemple de la façon dont un dict d'état recherche un GRU (j'ai choisi input_size = hidden_size = 2 Afin que je puisse imprimer le dict d'état entier):

rnn = torch.nn.GRU(2, 2, 1)
rnn.state_dict()
# Out[10]: 
#     OrderedDict([('weight_ih_l0', tensor([[-0.0023, -0.0460],
#                         [ 0.3373,  0.0070],
#                         [ 0.0745, -0.5345],
#                         [ 0.5347, -0.2373],
#                         [-0.2217, -0.2824],
#                         [-0.2983,  0.4771]])),
#                 ('weight_hh_l0', tensor([[-0.2837, -0.0571],
#                         [-0.1820,  0.6963],
#                         [ 0.4978, -0.6342],
#                         [ 0.0366,  0.2156],
#                         [ 0.5009,  0.4382],
#                         [-0.7012, -0.5157]])),
#                 ('bias_ih_l0',
#                 tensor([-0.2158, -0.6643, -0.3505, -0.0959, -0.5332, -0.6209])),
#                 ('bias_hh_l0',
#                 tensor([-0.1845,  0.4075, -0.1721, -0.4893, -0.2427,  0.3973]))])

Donc le state_dict Tous les paramètres du réseau. Si nous avons "imbriqué" nn.Modules, Nous obtenons l'arbre représenté par les noms des paramètres:

class MLP(torch.nn.Module):      
    def __init__(self):
        torch.nn.Module.__init__(self)
        self.lin_a = torch.nn.Linear(2, 2)
        self.lin_b = torch.nn.Linear(2, 2)


mlp = MLP()
mlp.state_dict()
#    Out[23]: 
#        OrderedDict([('lin_a.weight', tensor([[-0.2914,  0.0791],
#                            [-0.1167,  0.6591]])),
#                    ('lin_a.bias', tensor([-0.2745, -0.1614])),
#                    ('lin_b.weight', tensor([[-0.4634, -0.2649],
#                            [ 0.4552,  0.3812]])),
#                    ('lin_b.bias', tensor([ 0.0273, -0.1283]))])


class NestedMLP(torch.nn.Module):
    def __init__(self):
        torch.nn.Module.__init__(self)
        self.mlp_a = MLP()
        self.mlp_b = MLP()


n_mlp = NestedMLP()
n_mlp.state_dict()
#   Out[26]: 
#        OrderedDict([('mlp_a.lin_a.weight', tensor([[ 0.2543,  0.3412],
#                            [-0.1984, -0.3235]])),
#                    ('mlp_a.lin_a.bias', tensor([ 0.2480, -0.0631])),
#                    ('mlp_a.lin_b.weight', tensor([[-0.4575, -0.6072],
#                            [-0.0100,  0.5887]])),
#                    ('mlp_a.lin_b.bias', tensor([-0.3116,  0.5603])),
#                    ('mlp_b.lin_a.weight', tensor([[ 0.3722,  0.6940],
#                            [-0.5120,  0.5414]])),
#                    ('mlp_b.lin_a.bias', tensor([0.3604, 0.0316])),
#                    ('mlp_b.lin_b.weight', tensor([[-0.5571,  0.0830],
#                            [ 0.5230, -0.1020]])),
#                    ('mlp_b.lin_b.bias', tensor([ 0.2156, -0.2930]))])

Alors - que faire si vous ne voulez pas extraire le dict d'état, mais le changer - et donc les paramètres du réseau? Utiliser nn.Module.load_state_dict(state_dict, strict=True) ( lien vers les documents ) Cette méthode vous permet de charger un state_dict entier avec des valeurs arbitraires dans un modèle instancié du même type tant que les touches (ie les noms des paramètres) sont correctes et les valeurs (ie les paramètres) sont torch.tensors de la bonne forme. Si le strict kwarg est défini sur True (par défaut), le dict que vous chargez doit correspondre exactement au dict de l'état d'origine, à l'exception des valeurs des paramètres. Autrement dit, il doit y avoir une nouvelle valeur pour chaque paramètre.

Pour l'exemple GRU ci-dessus, nous avons besoin d'un tenseur de la bonne taille (et du bon appareil, btw) pour chacun de 'weight_ih_l0', 'weight_hh_l0', 'bias_ih_l0', 'bias_hh_l0'. Comme nous voulons parfois seulement charger certaines valeurs (comme je pense que vous voulez le faire), nous pouvons définir le strict kwarg sur False - et nous ne pouvons alors charger que des états partiels, comme par exemple celui qui ne contient que des valeurs de paramètre pour 'weight_ih_l0'.

Comme conseil pratique, je créerais simplement le modèle dans lequel vous souhaitez charger des valeurs, puis j'imprimerais le dict d'état (ou au moins une liste des clés et les tailles de tenseur respectives)

print([k, v.shape for k, v in model.state_dict().items()])

Cela vous indique le nom exact du paramètre que vous souhaitez modifier. Vous créez ensuite simplement un dict d'état avec le nom de paramètre et le tenseur respectifs, et le chargez:

from dollections import OrderedDict
new_state_dict = OrderedDict({'tensor_name_retrieved_from_original_dict': new_tensor_value})
model.load_state_dict(new_state_dict, strict=False)
9
cleros