001package squidpony.squidmath;
002
003import squidpony.StringKit;
004
005import java.io.Serializable;
006
007/**
008 * A different approach to the same task {@link ProbabilityTable} solves, though this only looks up an appropriate index
009 * instead of also storing items it can choose; allows positive doubles for weights but does not allow nested tables for
010 * simplicity. This doesn't store an RNG (or RandomnessSource) in this class, and instead expects a long to be given for
011 * each random draw from the table (these long parameters can be random, sequential, or in some other way different
012 * every time). Uses <a href="http://www.keithschwarz.com/darts-dice-coins/">Vose's Alias Method</a>, and is based
013 * fairly-closely on the code given by Keith Schwarz at that link. Because Vose's Alias Method is remarkably fast (O(1)
014 * generation time in use, and O(n) time to construct a WeightedTable instance), this may be useful to consider if you
015 * don't need all the features of ProbabilityTable or if you want deeper control over the random aspects of it.
016 * <br>
017 * Internally, this uses DiverRNG's algorithm as found in {@link DiverRNG#determineBounded(long, int)} and
018 * {@link DiverRNG#determine(long)} to generate two ints, one used for probability and treated as a 31-bit integer
019 * and the other used to determine the chosen column, which is bounded to an arbitrary positive int. It does this with
020 * just one randomized 64-bit value, allowing the state given to {@link #random(long)} to be just one long.
021 * <br>
022 * Created by Tommy Ettinger on 1/5/2018.
023 */
024public class WeightedTable implements Serializable {
025    private static final long serialVersionUID = 101L;
026//    protected final int[] alias;
027//    protected final int[] probability;
028    protected final int[] mixed;
029    public final int size;
030
031    /**
032     * Constructs a useless WeightedTable that always returns the index 0.
033     */
034    public WeightedTable()
035    {
036        this(1);
037    }
038
039    /**
040     * Constructs a WeightedTable with the given array of weights for each index. The array can also be a varargs for
041     * convenience. The weights can be any positive non-zero doubles, but should usually not be so large or small that
042     * precision loss is risked. Each weight will be used to determine the likelihood of that weight's index being
043     * returned by {@link #random(long)}.
044     * @param probabilities an array or varargs of positive doubles representing the weights for their own indices
045     */
046    public WeightedTable(double... probabilities) {
047        /* Begin by doing basic structural checks on the inputs. */
048        if (probabilities == null)
049            throw new NullPointerException("Array 'probabilities' given to WeightedTable cannot be null");
050        if ((size = probabilities.length) == 0)
051            throw new IllegalArgumentException("Array 'probabilities' given to WeightedTable must be nonempty.");
052
053        mixed = new int[size<<1];
054
055        double sum = 0.0;
056
057        /* Make a copy of the probabilities array, since we will be making
058         * changes to it.
059         */
060        double[] probs = new double[size];
061        for (int i = 0; i < size; ++i) {
062            if(probabilities[i] <= 0) continue;
063            sum += (probs[i] = probabilities[i]);
064        }
065        if(sum <= 0)
066            throw new IllegalArgumentException("At least one probability must be positive");
067        final double average = sum / size, invAverage = 1.0 / average;
068
069        /* Create two stacks to act as worklists as we populate the tables. */
070        IntVLA small = new IntVLA(size);
071        IntVLA large = new IntVLA(size);
072
073        /* Populate the stacks with the input probabilities. */
074        for (int i = 0; i < size; ++i) {
075            /* If the probability is below the average probability, then we add
076             * it to the small list; otherwise we add it to the large list.
077             */
078            if (probs[i] >= average)
079                large.add(i);
080            else
081                small.add(i);
082        }
083
084        /* As a note: in the mathematical specification of the algorithm, we
085         * will always exhaust the small list before the big list.  However,
086         * due to floating point inaccuracies, this is not necessarily true.
087         * Consequently, this inner loop (which tries to pair small and large
088         * elements) will have to check that both lists aren't empty.
089         */
090        while (!small.isEmpty() && !large.isEmpty()) {
091            /* Get the index of the small and the large probabilities. */
092            int less = small.pop(), less2 = less << 1;
093            int more = large.pop();
094
095            /* These probabilities have not yet been scaled up to be such that
096             * sum/n is given weight 1.0.  We do this here instead.
097             */
098            mixed[less2] = (int)(0x7FFFFFFF * (probs[less] * invAverage));
099            mixed[less2|1] = more;
100
101            probs[more] += probs[less] - average;
102
103            if (probs[more] >= average)
104                large.add(more);
105            else
106                small.add(more);
107        }
108
109        while (!small.isEmpty())
110            mixed[small.pop()<<1] = 0x7FFFFFFF;
111        while (!large.isEmpty())
112            mixed[large.pop()<<1] = 0x7FFFFFFF;
113    }
114
115    private WeightedTable(int[] mixed, boolean ignored)
116    {
117        size = mixed.length >> 1;
118        this.mixed = mixed;
119    }
120    /**
121     * Gets an index of one of the weights in this WeightedTable, with the choice determined deterministically by the
122     * given long, but higher weights will be returned by more possible inputs than lower weights. The state parameter
123     * can be from a random source, but this will randomize it again anyway, so it is also fine to just give sequential
124     * longs. The important thing is that each state input this is given will produce the same result for this
125     * WeightedTable every time, so you should give different state values when you want random-seeming results. You may
126     * want to call this like {@code weightedTable.random(++state)}, where state is a long, to ensure the inputs change.
127     * This will always return an int between 0 (inclusive) and {@link #size} (exclusive).
128     * @param state a long that should be different every time; consider calling with {@code ++state}
129     * @return a random-seeming index from 0 to {@link #size} - 1, determined by weights and the given state
130     */
131    public int random(long state)
132    {
133        // This is DiverRNG's algorithm to generate a random long given sequential states
134        state = (state = ((state = (((state * 0x632BE59BD9B4E019L) ^ 0x9E3779B97F4A7C15L) * 0xC6BC279692B5CC83L)) ^ 
135                state >>> 27) * 0xAEF17502108EF2D9L) ^ state >>> 25;
136        // get a random int (using half the bits of our previously-calculated state) that is less than size
137        int column = (int)((size * (state & 0xFFFFFFFFL)) >> 32);
138        // use the other half of the bits of state to get a 31-bit int, compare to probability and choose either the
139        // current column or the alias for that column based on that probability
140        return ((state >>> 33) <= mixed[column << 1]) ? column : mixed[column << 1 | 1];
141    }
142    
143    public String serializeToString()
144    {
145        return StringKit.join(",", mixed);
146    }
147    public static WeightedTable deserializeFromString(String data)
148    {
149        if(data == null || data.isEmpty())
150            return null;
151        int pos = -1;//data.indexOf(':');
152        //int size = StringKit.intFromDec(data, 0, pos);
153        int count = StringKit.count(data, ',') + 1;
154        int[] mixed = new int[count];
155        for (int i = 0; i < count; i++) {
156            mixed[i] = StringKit.intFromDec(data, pos+1, pos = data.indexOf(',', pos+1));
157        }
158        return new WeightedTable(mixed, true);
159    }
160
161}