web-dev-qa-db-fra.com

pytorch skip connection dans un modèle séquentiel

J'essaie d'envelopper ma tête autour de sauter les connexions dans un modèle séquentiel. Avec l'API fonctionnelle, je ferais quelque chose d'aussi simple que (exemple rapide, peut-être pas 100% correct sur le plan syntaxique, mais je devrais avoir l'idée):

x1 = self.conv1(inp)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)

x = self.deconv4(x)
x = self.deconv3(x)
x = self.deconv2(x)
x = torch.cat((x, x1), 1))
x = self.deconv1(x)

J'utilise maintenant un modèle séquentiel et j'essaie de faire quelque chose de similaire, de créer une connexion de saut qui amène les activations de la première couche conv jusqu'à la dernière convTranspose. J'ai jeté un œil à l'architecture U-net implémentée ici et c'est un peu déroutant, cela fait quelque chose comme ceci:

upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                    kernel_size=4, stride=2,
                                    padding=1, bias=use_bias)
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]

if use_dropout:
    model = down + [submodule] + up + [nn.Dropout(0.5)]
else:
    model = down + [submodule] + up

N'est-ce pas simplement ajouter des couches au modèle séquentiel bien, séquentiellement? Il y a le down conv qui est suivi par submodule (qui ajoute récursivement des couches internes) puis concaténé à up qui est la couche de conversion ascendante. Il me manque probablement quelque chose d'important sur le fonctionnement de l'API Sequential, mais comment le code extrait de U-NET implémente-t-il réellement le saut?

9
powder

Vos observations sont correctes, mais vous avez peut-être manqué la définition de UnetSkipConnectionBlock.forward() (UnetSkipConnectionBlock étant le Module définissant le bloc U-Net que vous avez partagé), ce qui peut clarifier cette implémentation :

(à partir de pytorch-CycleGAN-and-pix2pix/models/networks.py#L259 )

# Defines the submodule with skip connection.
# X -------------------identity---------------------- X
#   |-- downsampling -- |submodule| -- upsampling --|
class UnetSkipConnectionBlock(nn.Module):

    # ...

    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:
            return torch.cat([x, self.model(x)], 1)

La dernière ligne est la clé (appliquée à tous les blocs internes). La couche de saut se fait simplement en concaténant l'entrée x et la sortie du bloc (récursif) self.model(x), avec self.model La liste des opérations que vous avez mentionnées - donc pas si différemment à partir du code Functional que vous avez écrit.

4
benjaminplanche