web-dev-qa-db-fra.com

Comment puis-je combiner ImageDataGenerator avec des ensembles de données TensorFlow dans TF2?

J'ai un jeu de données TF pour classer les chats et les chiens:

import tensorflow_datasets as tfds
SPLIT_WEIGHTS = (8, 1, 1)
splits = tfds.Split.TRAIN.subsplit(weighted=SPLIT_WEIGHTS)

(raw_train, raw_validation, raw_test), metadata = tfds.load(
    'cats_vs_dogs', split=list(splits),
    with_info=True, as_supervised=True)

Dans l'exemple, ils utilisent une augmentation d'image avec une fonction de carte. Je me demandais si cela pouvait également être fait avec la classe Nice ImageDataGenerator telle que décrite ici :

from tensorflow.keras.preprocessing.image import ImageDataGenerator
train_image_generator = ImageDataGenerator(rescale=1./255) # Generator for our training data
train_data_gen = train_image_generator.flow_from_directory(batch_size=batch_size,
                                                           directory=train_dir,
                                                           shuffle=True,
                                                           target_size=(IMG_HEIGHT, IMG_WIDTH),
                                                           class_mode='binary')

Le problème auquel je suis confronté est que je ne vois que 3 façons d'utiliser le ImageDataGenerator: pandas dataframe, numpy tableau et répertoire d'images. Existe-t-il un moyen d'utiliser également un ensemble de données Tensorflow et de combiner ces méthodes?

6
user2874583

Oui, mais c'est un peu délicat.
Keras ImageDataGenerator fonctionne sur les numpy.array Et non sur les tf.Tensor, Nous devons donc utiliser numpy_function de Tensorflow. Cela nous permettra d'effectuer des opérations sur le contenu tf.data.Dataset Comme s'il s'agissait de tableaux numpy.

Tout d'abord, déclarons la fonction que nous allons .map Sur notre ensemble de données (en supposant que votre ensemble de données se compose de paires d'images et d'étiquettes):

# We will take 1 original image and create 5 augmented images:
HOW_MANY_TO_AUGMENT = 5

def augment(image, label):

  # Create generator and fit it to an image
  img_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
  img_gen.fit(image)

  # We want to keep original image and label
  img_results = [(image/255.).astype(np.float32)] 
  label_results = [label]

  # Perform augmentation and keep the labels
  augmented_images = [next(img_gen.flow(image)) for _ in range(HOW_MANY_TO_AUGMENT)]
  labels = [label for _ in range(HOW_MANY_TO_AUGMENT)]

  # Append augmented data and labels to original data
  img_results.extend(augmented_images)
  label_results.extend(labels)

  return img_results, label_results

Maintenant, pour utiliser cette fonction dans tf.data.Dataset Nous devons déclarer un numpy_function:

def py_augment(image, label):
  func = tf.numpy_function(augment, [image, label], [tf.float32, tf.int32])
  return func

py_augment Peut être utilisé en toute sécurité comme:

augmented_dataset_ds = image_label_dataset.map(py_augment)

La partie image de l'ensemble de données est maintenant en forme (HOW_MANY_TO_AUGMENT, image_height, image_width, channels). Pour le convertir en simple (1, image_height, image_width, channels), Vous pouvez simplement utiliser unbatch:

unbatched_augmented_dataset_ds = augmented_dataset_ds.unbatch()

Donc, toute la section ressemble à ceci:

HOW_MANY_TO_AUGMENT = 5

def augment(image, label):

  # Create generator and fit it to an image
  img_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
  img_gen.fit(image)

  # We want to keep original image and label
  img_results = [(image/255.).astype(np.float32)] 
  label_results = [label]

  # Perform augmentation and keep the labels
  augmented_images = [next(img_gen.flow(image)) for _ in range(HOW_MANY_TO_AUGMENT)]
  labels = [label for _ in range(HOW_MANY_TO_AUGMENT)]

  # Append augmented data and labels to original data
  img_results.extend(augmented_images)
  label_results.extend(labels)

  return img_results, label_results

def py_augment(image, label):
  func = tf.numpy_function(augment, [image, label], [tf.float32, tf.int32])
  return func

unbatched_augmented_dataset_ds = augmented_dataset_ds.map(py_augment).unbatch()

# Iterate over the dataset for preview:
for image, label in unbatched_augmented_dataset_ds:
    ...
2
sebastian-sz