web-dev-qa-db-fra.com

Comment obtenir un élément par Index dans Spark RDD (Java)

Je connais la méthode rdd.first () qui me donne le premier élément d'un RDD.

Il y a aussi la méthode rdd.take (num) qui me donne les premiers éléments "num".

Mais n'y a-t-il pas une possibilité d'obtenir un élément par index?

Merci.

27
progNewbie

Cela devrait être possible en indexant d'abord le RDD. La transformation zipWithIndex fournit une indexation stable, numérotant chaque élément dans son ordre d'origine.

Étant donné: rdd = (a,b,c)

val withIndex = rdd.zipWithIndex // ((a,0),(b,1),(c,2))

Pour rechercher un élément par index, ce formulaire n'est pas utile. Nous devons d'abord utiliser l'index comme clé:

val indexKey = withIndex.map{case (k,v) => (v,k)}  //((0,a),(1,b),(2,c))

Maintenant, il est possible d'utiliser l'action lookup dans PairRDD pour trouver un élément par clé:

val b = indexKey.lookup(1) // Array(b)

Si vous prévoyez d'utiliser lookup souvent sur le même RDD, je vous recommande de mettre en cache le indexKey RDD pour améliorer les performances.

Comment faire cela en utilisant API Java est un exercice laissé au lecteur.

57
maasg

Je suis resté coincé là-dessus pendant un certain temps également, donc pour développer la réponse de Maasg mais en répondant pour rechercher une plage de valeurs par index pour Java (vous devrez définir les 4 variables à haut):

DataFrame df;
SQLContext sqlContext;
Long start;
Long end;

JavaPairRDD<Row, Long> indexedRDD = df.toJavaRDD().zipWithIndex();
JavaRDD filteredRDD = indexedRDD.filter((Tuple2<Row,Long> v1) -> v1._2 >= start && v1._2 < end);
DataFrame filteredDataFrame = sqlContext.createDataFrame(filteredRDD, df.schema());

N'oubliez pas que lorsque vous exécutez ce code, votre cluster devra avoir Java 8 (car une expression lambda est utilisée).

De plus, zipWithIndex est probablement cher!

2
Luke W

J'ai essayé cette classe pour récupérer un élément par index. Premièrement, lorsque vous construisez new IndexedFetcher(rdd, itemClass), il compte le nombre d'éléments dans chaque partition du RDD. Ensuite, lorsque vous appelez indexedFetcher.get(n), il exécute un travail uniquement sur la partition qui contient cet index.

Notez que je devais compiler cela en utilisant Java 1.7 au lieu de 1.8; à partir de Spark 1.1.0, le org.objectweb.asm fourni dans com.esotericsoftware .reflectasm ne peut pas encore lire Java 1,8 classes pour l'instant (lève IllegalStateException lorsque vous essayez d'exécuter Job = Java 1,8)).

import Java.io.Serializable;

import org.Apache.spark.SparkContext;
import org.Apache.spark.TaskContext;
import org.Apache.spark.rdd.RDD;

import scala.reflect.ClassTag;

public static class IndexedFetcher<E> implements Serializable {
    private static final long serialVersionUID = 1L;
    public final RDD<E> rdd;
    public Integer[] elementsPerPartitions;
    private Class<?> clazz;
    public IndexedFetcher(RDD<E> rdd, Class<?> clazz){
        this.rdd = rdd;
        this.clazz = clazz;
        SparkContext context = this.rdd.context();
        ClassTag<Integer> intClassTag = scala.reflect.ClassTag$.MODULE$.<Integer>apply(Integer.class);
        elementsPerPartitions = (Integer[]) context.<E, Integer>runJob(rdd, IndexedFetcher.<E>countFunction(), intClassTag);
    }
    public static class IteratorCountFunction<E> extends scala.runtime.AbstractFunction2<TaskContext, scala.collection.Iterator<E>, Integer> implements Serializable {
        private static final long serialVersionUID = 1L;
        @Override public Integer apply(TaskContext taskContext, scala.collection.Iterator<E> iterator) {
            int count = 0;
            while (iterator.hasNext()) {
                count++;
                iterator.next();
            }
            return count;
        }
    }
    static <E> scala.Function2<TaskContext, scala.collection.Iterator<E>, Integer> countFunction() {
        scala.Function2<TaskContext, scala.collection.Iterator<E>, Integer> function = new IteratorCountFunction<E>();
        return function;
    }
    public E get(long index) {
        long remaining = index;
        long totalCount = 0;
        for (int partition = 0; partition < elementsPerPartitions.length; partition++) {
            if (remaining < elementsPerPartitions[partition]) {
                return getWithinPartition(partition, remaining);
            }
            remaining -= elementsPerPartitions[partition];
            totalCount += elementsPerPartitions[partition];
        }
        throw new IllegalArgumentException(String.format("Get %d within RDD that has only %d elements", index, totalCount));
    }
    public static class FetchWithinPartitionFunction<E> extends scala.runtime.AbstractFunction2<TaskContext, scala.collection.Iterator<E>, E> implements Serializable {
        private static final long serialVersionUID = 1L;
        private final long indexWithinPartition;
        public FetchWithinPartitionFunction(long indexWithinPartition) {
            this.indexWithinPartition = indexWithinPartition;
        }
        @Override public E apply(TaskContext taskContext, scala.collection.Iterator<E> iterator) {
            int count = 0;
            while (iterator.hasNext()) {
                E element = iterator.next();
                if (count == indexWithinPartition)
                    return element;
                count++;
            }
            throw new IllegalArgumentException(String.format("Fetch %d within partition that has only %d elements", indexWithinPartition, count));
        }
    }
    public E getWithinPartition(int partition, long indexWithinPartition) {
        System.out.format("getWithinPartition(%d, %d)%n", partition, indexWithinPartition);
        SparkContext context = rdd.context();
        scala.Function2<TaskContext, scala.collection.Iterator<E>, E> function = new FetchWithinPartitionFunction<E>(indexWithinPartition);
        scala.collection.Seq<Object> partitions = new scala.collection.mutable.WrappedArray.ofInt(new int[] {partition});
        ClassTag<E> classTag = scala.reflect.ClassTag$.MODULE$.<E>apply(this.clazz);
        E[] result = (E[]) context.<E, E>runJob(rdd, function, partitions, true, classTag);
        return result[0];
    }
}
2
yonran