web-dev-qa-db-fra.com

Sélection pondérée aléatoire en Java

Je veux choisir un article au hasard dans un ensemble, mais les chances de choisir n'importe quel article doivent être proportionnelles au poids associé.

Exemple d'entrées:

item                weight
----                ------
sword of misery         10
shield of happy          5
potion of dying          6
triple-edged sword       1

Donc, si j’ai 4 items possibles, la chance d’obtenir un item sans poids serait de 1 sur 4.

Dans ce cas, un utilisateur devrait être 10 fois plus susceptible de recevoir l'épée de la misère que l'épée à triple tranchant.

Comment faire une sélection aléatoire pondérée en Java?

54
yosi

Je voudrais utiliser un NavigableMap

public class RandomCollection<E> {
    private final NavigableMap<Double, E> map = new TreeMap<Double, E>();
    private final Random random;
    private double total = 0;

    public RandomCollection() {
        this(new Random());
    }

    public RandomCollection(Random random) {
        this.random = random;
    }

    public RandomCollection<E> add(double weight, E result) {
        if (weight <= 0) return this;
        total += weight;
        map.put(total, result);
        return this;
    }

    public E next() {
        double value = random.nextDouble() * total;
        return map.higherEntry(value).getValue();
    }
}

Disons que j'ai une liste d'animaux chien, chat, cheval avec des probabilités respectives de 40%, 35% et 25%

RandomCollection<String> rc = new RandomCollection<>()
                              .add(40, "dog").add(35, "cat").add(25, "horse");

for (int i = 0; i < 10; i++) {
    System.out.println(rc.next());
} 
92
Peter Lawrey

Vous ne trouverez pas de cadre pour ce type de problème, car la fonctionnalité demandée n'est rien de plus qu'une simple fonction. Faites quelque chose comme ça:

interface Item {
    double getWeight();
}

class RandomItemChooser {
    public Item chooseOnWeight(List<Item> items) {
        double completeWeight = 0.0;
        for (Item item : items)
            completeWeight += item.getWeight();
        double r = Math.random() * completeWeight;
        double countWeight = 0.0;
        for (Item item : items) {
            countWeight += item.getWeight();
            if (countWeight >= r)
                return item;
        }
        throw new RuntimeException("Should never be shown.");
    }
}
23
Arne Deutsch

Il existe maintenant une classe pour cela dans Apache Commons: EnumeratedDistribution

Item selectedItem = new EnumeratedDistribution(itemWeights).sample();

itemWeights est un List<Pair<Item,Double>>, comme (en supposant l'interface Item dans la réponse d'Arne):

List<Pair<Item,Double>> itemWeights = Collections.newArrayList();
for (Item i : itemSet) {
    itemWeights.add(new Pair(i, i.getWeight()));
}

ou en Java 8:

itemSet.stream().map(i -> new Pair(i, i.getWeight())).collect(toList());

Remarque: Pair doit ici être org.Apache.commons.math3.util.Pair et non org.Apache.commons.lang3.Tuple.Pair.

15
kdkeck

Utiliser une méthode d'alias

Si vous voulez rouler beaucoup de fois (comme dans un jeu), vous devriez utiliser une méthode d'alias.

Le code ci-dessous est une longue implémentation d'une telle méthode de pseudonyme. Mais c'est à cause de la partie d'initialisation. La récupération des éléments est très rapide (voir les méthodes next et applyAsInt qu’elles ne bouclent pas).

Usage

Set<Item> items = ... ;
ToDoubleFunction<Item> weighter = ... ;

Random random = new Random();

RandomSelector<T> selector = RandomSelector.weighted(items, weighter);
Item drop = selector.next(random);

La mise en oeuvre

Cette implémentation:

  • utilise Java 8;
  • est conçu pour être aussi rapide que possible (enfin, au moins, j'ai essayé de le faire en utilisant un micro-benchmarking);
  • est totalement thread-safe (conservez une Random dans chaque thread pour une performance maximale, utilisez ThreadLocalRandom?);
  • récupère des éléments dans O(1), contrairement à ce que vous trouvez principalement sur Internet ou sur StackOverflow, où des implémentations naïves fonctionnent dans O(n) ou O (log (n));
  • garde les items indépendants de leur poids, ainsi un poids peut être assigné à différents éléments dans différents contextes.

Quoi qu'il en soit, voici le code. (Notez que je tiens à jour une version de cette classe .)

import static Java.util.Objects.requireNonNull;

import Java.util.*;
import Java.util.function.*;

public final class RandomSelector<T> {

  public static <T> RandomSelector<T> weighted(Set<T> elements, ToDoubleFunction<? super T> weighter)
      throws IllegalArgumentException {
    requireNonNull(elements, "elements must not be null");
    requireNonNull(weighter, "weighter must not be null");
    if (elements.isEmpty()) { throw new IllegalArgumentException("elements must not be empty"); }

    // Array is faster than anything. Use that.
    int size = elements.size();
    T[] elementArray = elements.toArray((T[]) new Object[size]);

    double totalWeight = 0d;
    double[] discreteProbabilities = new double[size];

    // Retrieve the probabilities
    for (int i = 0; i < size; i++) {
      double weight = weighter.applyAsDouble(elementArray[i]);
      if (weight < 0.0d) { throw new IllegalArgumentException("weighter may not return a negative number"); }
      discreteProbabilities[i] = weight;
      totalWeight += weight;
    }
    if (totalWeight == 0.0d) { throw new IllegalArgumentException("the total weight of elements must be greater than 0"); }

    // Normalize the probabilities
    for (int i = 0; i < size; i++) {
      discreteProbabilities[i] /= totalWeight;
    }
    return new RandomSelector<>(elementArray, new RandomWeightedSelection(discreteProbabilities));
  }

  private final T[] elements;
  private final ToIntFunction<Random> selection;

  private RandomSelector(T[] elements, ToIntFunction<Random> selection) {
    this.elements = elements;
    this.selection = selection;
  }

  public T next(Random random) {
    return elements[selection.applyAsInt(random)];
  }

  private static class RandomWeightedSelection implements ToIntFunction<Random> {
    // Alias method implementation O(1)
    // using Vose's algorithm to initialize O(n)

    private final double[] probabilities;
    private final int[] alias;

    RandomWeightedSelection(double[] probabilities) {
      int size = probabilities.length;

      double average = 1.0d / size;
      int[] small = new int[size];
      int smallSize = 0;
      int[] large = new int[size];
      int largeSize = 0;

      // Describe a column as either small (below average) or large (above average).
      for (int i = 0; i < size; i++) {
        if (probabilities[i] < average) {
          small[smallSize++] = i;
        } else {
          large[largeSize++] = i;
        }
      }

      // For each column, saturate a small probability to average with a large probability.
      while (largeSize != 0 && smallSize != 0) {
        int less = small[--smallSize];
        int more = large[--largeSize];
        probabilities[less] = probabilities[less] * size;
        alias[less] = more;
        probabilities[more] += probabilities[less] - average;
        if (probabilities[more] < average) {
          small[smallSize++] = more;
        } else {
          large[largeSize++] = more;
        }
      }

      // Flush unused columns.
      while (smallSize != 0) {
        probabilities[small[--smallSize]] = 1.0d;
      }
      while (largeSize != 0) {
        probabilities[large[--largeSize]] = 1.0d;
      }
    }

    @Override public int applyAsInt(Random random) {
      // Call random once to decide which column will be used.
      int column = random.nextInt(probabilities.length);

      // Call random a second time to decide which will be used: the column or the alias.
      if (random.nextDouble() < probabilities[column]) {
        return column;
      } else {
        return alias[column];
      }
    }
  }
}
4
Olivier Grégoire

Si vous devez supprimer des éléments après avoir choisi, vous pouvez utiliser une autre solution. Ajoutez tous les éléments dans une 'LinkedList', chaque élément doit être ajouté autant de fois qu'il est nécessaire, puis utilisez Collections.shuffle() qui, selon JavaDoc

Permute de manière aléatoire la liste spécifiée en utilisant une source aléatoire par défaut. Toutes les permutations se produisent avec une probabilité approximativement égale.

Enfin, récupérez et supprimez des éléments en utilisant pop() ou removeFirst()

Map<String, Integer> map = new HashMap<String, Integer>() {{
    put("Five", 5);
    put("Four", 4);
    put("Three", 3);
    put("Two", 2);
    put("One", 1);
}};

LinkedList<String> list = new LinkedList<>();

for (Map.Entry<String, Integer> entry : map.entrySet()) {
    for (int i = 0; i < entry.getValue(); i++) {
        list.add(entry.getKey());
    }
}

Collections.shuffle(list);

int size = list.size();
for (int i = 0; i < size; i++) {
    System.out.println(list.pop());
}
1
Yuri Heiko
public class RandomCollection<E> {
  private final NavigableMap<Double, E> map = new TreeMap<Double, E>();
  private double total = 0;

  public void add(double weight, E result) {
    if (weight <= 0 || map.containsValue(result))
      return;
    total += weight;
    map.put(total, result);
  }

  public E next() {
    double value = ThreadLocalRandom.current().nextDouble() * total;
    return map.ceilingEntry(value).getValue();
  }
}
1
ronen

139

Il existe un algorithme simple pour choisir un article au hasard, les articles ayant des poids individuels:

  1. calculer la somme de tous les poids

  2. choisissez un nombre aléatoire égal ou supérieur à 0 et inférieur à la somme des poids

  3. parcourez les objets un par un en soustrayant leur poids de votre nombre aléatoire jusqu'à ce que vous obteniez l'article dont le nombre est inférieur au poids de cet objet

0
Quinton Gordon