web-dev-qa-db-fra.com

Comment puis-je filtrer tf.data.Dataset par des valeurs spécifiques?

Je crée un jeu de données en lisant les TFRecords, je mappe les valeurs et je veux filtrer le jeu de données pour des valeurs spécifiques, mais comme le résultat est un dict avec des tenseurs, je ne peux pas obtenir la valeur réelle d'un tenseur ni le vérifier. avec tf.cond()/tf.equal. Comment puis je faire ça?

def mapping_func(serialized_example):
    feature = { 'label': tf.FixedLenFeature([1], tf.string) }
    features = tf.parse_single_example(serialized_example, features=feature)
    return features

def filter_func(features):
    # this doesn't work
    #result = features['label'] == 'some_label_value'
    # neither this
    result = tf.reshape(tf.equal(features['label'], 'some_label_value'), [])
    return result

def main():
    file_names = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
    dataset = tf.contrib.data.TFRecordDataset(file_names)
    dataset = dataset.map(mapping_func)
    dataset = dataset.shuffle(buffer_size=10000)
    dataset = dataset.filter(filter_func)
    dataset = dataset.repeat()
    iterator = dataset.make_one_shot_iterator()
    sample = iterator.get_next()
7
tsveti_iko

Je réponds à ma propre question. J'ai trouvé le problème!

Ce que je devais faire, c’était tf.unstack() l’étiquette comme ceci:

label = tf.unstack(features['label'])
label = label[0]

avant de le donner à tf.equal():

result = tf.reshape(tf.equal(label, 'some_label_value'), [])

Je suppose que le problème est que l’étiquette est définie comme un tableau avec un élément de type chaîne tf.FixedLenFeature([1], tf.string), donc pour obtenir le premier et unique élément, je devais le décompresser (ce qui crée une liste), puis obtenir l’élément avec l’indice 0 , Corrige moi si je me trompe.

2
tsveti_iko

Je pense que vous n'avez pas besoin de faire en premier lieu pour étiqueter un tableau à 1 dimension.

avec:

feature = {'label': tf.FixedLenFeature((), tf.string)}

vous n'aurez pas besoin de désempiler l'étiquette dans votre filter_func

0