web-dev-qa-db-fra.com

Comment mapper une fonction avec un paramètre supplémentaire en utilisant la nouvelle API Dataset dans TF1.3?

Je joue avec l'API Dataset dans Tensorflow v1. . C'est bien. Il est possible de mapper un ensemble de données avec une fonction comme décrit ici . Je souhaite savoir comment passer une fonction qui a un argument supplémentaire, par exemple arg1:

def _parse_function(example_proto, arg1):
  features = {"image": tf.FixedLenFeature((), tf.string, default_value=""),
              "label": tf.FixedLenFeature((), tf.int32, default_value=0)}
  parsed_features = tf.parse_single_example(example_proto, features)
  return parsed_features["image"], parsed_features["label"]

Bien sûr,

dataset = dataset.map(_parse_function)

ne fonctionnera pas car il n'y a aucun moyen de passer arg1.

16
AmirHJ

Voici un exemple utilisant une expression lambda pour encapsuler la fonction à laquelle nous voulons passer un argument:

import tensorflow as tf
def fun(x, arg):
    return x * arg

my_arg = tf.constant(2, dtype=tf.int64)
ds = tf.data.Dataset.range(5)
ds = ds.map(lambda x: fun(x, my_arg))

Dans ce qui précède, la signature de la fonction fournie à map doit correspondre au contenu de notre ensemble de données. Nous devons donc écrire notre expression lambda pour correspondre à cela. Ici, c'est simple, car il n'y a qu'un seul élément contenu dans l'ensemble de données, le x qui contient des éléments compris entre 0 et 4.

Si nécessaire, vous pouvez transmettre un nombre arbitraire d'arguments externes provenant de l'extérieur de l'ensemble de données: ds = ds.map(lambda x: my_other_fun(x, arg1, arg2, arg3), etc.

Pour vérifier que ce qui précède fonctionne, nous pouvons observer que le mappage multiplie en effet chaque élément de l'ensemble de données par deux:

iterator = ds.make_initializable_iterator()
next_x = iterator.get_next()
with tf.Session() as sess:
    sess.run(iterator.initializer)

    while True:
      try:
        print(sess.run(next_x))
      except tf.errors.OutOfRangeError:
        break

Le résultat:

0
2
4
6
8
24
mikkola

Une autre solution consiste à utiliser un wrapper de classe. Dans le code suivant, j'ai passé le paramètre shape à la fonction d'analyse.

class MyDataSets:

    def __init__(self, shape):
        self.shape = shape

    def parse_sample(self.sample):
        features = { ... }
        f = tf.parse_example([example], features=features)

        image_raw = tf.decode_raw(f['image_raw'], tf.uint8)
        image = image.reshape(image_raw, self.shape)

        label = tf.cast(f['label'], tf.int32)

        return image, label

    def init(self):
        ds = tf.data.TFRecordDataSets(...)
        ds = ds.map(self.parse_sample)
        ...
        return ds.make_initializable_iterator()
1
redice li

Vous pouvez également utiliser une fonction Partial à la place pour encapsuler votre paramètre:

def _parse_function(arg1, example_proto):
  features = {"image": tf.FixedLenFeature((), tf.string, default_value=""),
              "label": tf.FixedLenFeature((), tf.int32, default_value=0)}
  parsed_features = tf.parse_single_example(example_proto, features)
  return parsed_features["image"], parsed_features["label"]

L'ordre des paramètres de votre fonction est modifié afin de s'adapter à la partialité, vous pouvez ensuite envelopper votre fonction avec une valeur de paramètre comme suit:

from functools import partial

arg1 = ...
dataset = dataset.map(partial(_parse_function, arg1))
1
Faylixe