001/*
002The MIT License(MIT)
003Copyright(c) mxgmn 2016, modified by Tommy Ettinger 2018
004Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
005The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
006The software is provided "as is", without warranty of any kind, express or implied, including but not limited to the warranties of merchantability, fitness for a particular purpose and noninfringement. In no event shall the authors or copyright holders be liable for any claim, damages or other liability, whether in an action of contract, tort or otherwise, arising from, out of or in connection with the software or the use or other dealings in the software.
007*/
008
009package squidpony.squidgrid;
010
011import squidpony.squidmath.CrossHash;
012import squidpony.squidmath.GWTRNG;
013import squidpony.squidmath.IRNG;
014import squidpony.squidmath.IntIntOrderedMap;
015import squidpony.squidmath.IntVLA;
016import squidpony.squidmath.OrderedMap;
017
018/**
019 * A port of WaveFunctionCollapse by ExUtumno/mxgmn; takes a single sample of a grid to imitate and produces one or more
020 * grids of requested sizes that have a similar layout of cells to the sample. Samples are given as {@code int[][]}
021 * where an int is usually an index into an array, list, {@link squidpony.squidmath.Arrangement}, or some similar
022 * indexed collection of items (such as char values or colors) that would be used instead of an int directly. The
023 * original WaveFunctionCollapse code, <a href="https://github.com/mxgmn/WaveFunctionCollapse">here</a>, used colors in
024 * bitmap images, but this uses 2D int arrays that can stand as substitutes for colors or chars.
025 * <br>
026 * Created by Tommy Ettinger on 3/28/2018. Port of https://github.com/mxgmn/WaveFunctionCollapse
027 */
028public class MimicWFC {
029    private boolean[][] wave;
030
031    private int[][][] propagator;
032    private int[][][] compatible;
033    private int[] observed;
034
035    private int[] stack;
036    private int stacksize;
037
038    public IRNG random;
039    private int FMX, FMY, totalOptions;
040    private boolean periodic;
041
042    private double[] baseWeights;
043    private double[] weightLogWeights;
044
045    private int[] sumsOfOnes;
046    private double sumOfWeights, sumOfWeightLogWeights, startingEntropy;
047    private double[] sumsOfWeights, sumsOfWeightLogWeights, entropies;
048
049
050    private int order;
051    private int[][] patterns;
052    private IntIntOrderedMap choices;
053    private int ground;
054
055    public MimicWFC(int[][] itemGrid, int order, int width, int height, boolean periodicInput, boolean periodicOutput, int symmetry, int ground)
056    {
057        FMX = width;
058        FMY = height;
059
060        this.order = order;
061        periodic = periodicOutput;
062
063        int SMX = itemGrid.length, SMY = itemGrid[0].length;
064        //colors = new List<Color>();
065        choices = new IntIntOrderedMap(SMX * SMY);
066        int[][] sample = new int[SMX][SMY];
067        for (int y = 0; y < SMY; y++) {
068            for (int x = 0; x < SMX; x++)
069            {
070                int color = itemGrid[x][y];
071                int i = choices.getOrDefault(color, choices.size());
072                if(i == choices.size())
073                    choices.put(color, i);
074                sample[x][y] = i;
075            }
076        }
077
078        int C = choices.size();
079
080
081//        Dictionary<long, int> weights = new Dictionary<long, int>();
082//        List<long> ordering = new List<long>();
083        OrderedMap<int[], Integer> weights = new OrderedMap<>(CrossHash.intHasher);
084
085        for (int y = 0; y < (periodicInput ? SMY : SMY - order + 1); y++) {
086            for (int x = 0; x < (periodicInput ? SMX : SMX - order + 1); x++) {
087                int[][] ps = new int[8][];
088
089                ps[0] = patternFromSample(x, y, sample, SMX, SMY);
090                ps[1] = reflect(ps[0]);
091                ps[2] = rotate(ps[0]);
092                ps[3] = reflect(ps[2]);
093                ps[4] = rotate(ps[2]);
094                ps[5] = reflect(ps[4]);
095                ps[6] = rotate(ps[4]);
096                ps[7] = reflect(ps[6]);
097
098                for (int k = 0; k < symmetry; k++) {
099                    int[] ind = ps[k];
100                    Integer wt = weights.get(ind);
101                    if (wt != null) weights.put(ind, wt + 1);
102                    else {
103                        weights.put(ind, 1);
104                    }
105                }
106            }
107        }
108
109        totalOptions = weights.size();
110        this.ground = (ground + totalOptions) % totalOptions;
111        patterns = new int[totalOptions][];
112        baseWeights = new double[totalOptions];
113
114        for (int w = 0; w < totalOptions; w++) {
115            patterns[w] = weights.keyAt(w);
116            baseWeights[w] = weights.getAt(w);
117        }
118        
119
120        propagator = new int[4][][];
121        IntVLA list = new IntVLA(totalOptions);
122        for (int d = 0; d < 4; d++)
123        {
124            propagator[d] = new int[totalOptions][];
125            for (int t = 0; t < totalOptions; t++)
126            {
127                list.clear();
128                for (int t2 = 0; t2 < totalOptions; t2++) if (agrees(patterns[t], patterns[t2], DX[d], DY[d])) list.add(t2);
129                propagator[d][t] = list.toArray();
130            }
131        }
132    }
133
134    private long index(byte[] p, long C)
135    {
136        long result = 0, power = 1;
137        for (int i = 0; i < p.length; i++)
138        {
139            result += p[p.length - 1 - i] * power;
140            power *= C;
141        }
142        return result;
143    }
144
145    private byte[] patternFromIndex(long ind, long power, long C)
146    {
147        long residue = ind;
148        byte[] result = new byte[order * order];
149
150        for (int i = 0; i < result.length; i++)
151        {
152            power /= C;
153            int count = 0;
154
155            while (residue >= power)
156            {
157                residue -= power;
158                count++;
159            }
160
161            result[i] = (byte)count;
162        }
163
164        return result;
165    }
166
167    private int[] patternFromSample(int x, int y, int[][] sample, int SMX, int SMY) {
168        int[] result = new int[order * order];
169        for (int dy = 0; dy < order; dy++) {
170            for (int dx = 0; dx < order; dx++) {
171                result[dx + dy * order] = sample[(x + dx) % SMX][(y + dy) % SMY];
172            }
173        }
174        return result;
175    }
176    private int[] rotate(int[] p)
177    {
178        int[] result = new int[order * order];
179        for (int y = 0; y < order; y++) {
180            for (int x = 0; x < order; x++){
181                result[x + y * order] = p[order - 1 - y + x * order];
182            }
183        }
184        return result;
185    }
186    private int[] reflect(int[] p)
187    {
188        int[] result = new int[order * order];
189        for (int y = 0; y < order; y++) {
190            for (int x = 0; x < order; x++){
191                result[x + y * order] = p[order - 1 - x + y * order];
192            }
193        }
194        return result;
195    }
196    private boolean agrees(int[] p1, int[] p2, int dx, int dy)
197    {
198        int xmin = Math.max(dx, 0), xmax = dx < 0 ? dx + order : order,
199                ymin = Math.max(dy, 0), ymax = dy < 0 ? dy + order : order;
200        for (int y = ymin; y < ymax; y++) {
201            for (int x = xmin; x < xmax; x++) {
202                if (p1[x + order * y] != p2[x - dx + order * (y - dy)])
203                    return false;
204            }
205        }
206        return true;
207    }
208
209    private void init()
210    {
211        wave = new boolean[FMX * FMY][];
212        compatible = new int[wave.length][][];
213        for (int i = 0; i < wave.length; i++)
214        {
215            wave[i] = new boolean[totalOptions];
216            compatible[i] = new int[totalOptions][];
217            for (int t = 0; t < totalOptions; t++) compatible[i][t] = new int[4];
218        }
219
220        weightLogWeights = new double[totalOptions];
221        sumOfWeights = 0;
222        sumOfWeightLogWeights = 0;
223
224        for (int t = 0; t < totalOptions; t++)
225        {
226            weightLogWeights[t] = baseWeights[t] * Math.log(baseWeights[t]);
227            sumOfWeights += baseWeights[t];
228            sumOfWeightLogWeights += weightLogWeights[t];
229        }
230
231        startingEntropy = Math.log(sumOfWeights) - sumOfWeightLogWeights / sumOfWeights;
232
233        sumsOfOnes = new int[FMX * FMY];
234        sumsOfWeights = new double[FMX * FMY];
235        sumsOfWeightLogWeights = new double[FMX * FMY];
236        entropies = new double[FMX * FMY];
237
238        stack = new int[wave.length * totalOptions << 1];
239        stacksize = 0;
240    }
241
242    private Boolean observe()
243    {
244        double min = 1E+3;
245        int argmin = -1;
246
247        for (int i = 0; i < wave.length; i++)
248        {
249            if (onBoundary(i % FMX, i / FMX)) continue;
250
251            int amount = sumsOfOnes[i];
252            if (amount == 0) return false;
253
254            double entropy = entropies[i];
255            if (amount > 1 && entropy <= min)
256            {
257                double noise = 1E-6 * random.nextDouble();
258                if (entropy + noise < min)
259                {
260                    min = entropy + noise;
261                    argmin = i;
262                }
263            }
264        }
265
266        if (argmin == -1)
267        {
268            observed = new int[FMX * FMY];
269            for (int i = 0; i < wave.length; i++) {
270                for (int t = 0; t < totalOptions; t++) {
271                    if (wave[i][t]) { 
272                        observed[i] = t;
273                        break;
274                    }
275                }
276            }
277            return true;
278        }
279
280        double[] distribution = new double[totalOptions];
281        double sum = 0.0, x = 0.0;
282        for (int t = 0; t < totalOptions; t++)
283        {
284            sum += (distribution[t] = wave[argmin][t] ? baseWeights[t] : 0);
285        }
286        int r = 0;
287        sum = random.nextDouble(sum);
288        for (; r < totalOptions; r++) {
289            if((x += distribution[r]) > sum)
290                break;
291        }
292
293        boolean[] w = wave[argmin];
294        for (int t = 0; t < totalOptions; t++){
295            if (w[t] != (t == r))
296                ban(argmin, t);
297        }
298
299        return null;
300    }
301
302    private void propagate()
303    {
304        while (stacksize > 0)
305        {
306            int i1 = stack[stacksize - 2], e2 = stack[stacksize - 1];
307            stacksize -= 2;
308            int x1 = i1 % FMX, y1 = i1 / FMX;
309
310            for (int d = 0; d < 4; d++)
311            {
312                int dx = DX[d], dy = DY[d];
313                int x2 = x1 + dx, y2 = y1 + dy;
314                if (onBoundary(x2, y2)) continue;
315
316                if (x2 < 0) x2 += FMX;
317                else if (x2 >= FMX) x2 -= FMX;
318                if (y2 < 0) y2 += FMY;
319                else if (y2 >= FMY) y2 -= FMY;
320
321                int i2 = x2 + y2 * FMX;
322                int[] p = propagator[d][e2];
323                int[][] compat = compatible[i2];
324
325                for (int l = 0; l < p.length; l++)
326                {
327                    int t2 = p[l];
328                    int[] comp = compat[t2];
329
330                    comp[d]--;
331                    if (comp[d] == 0) ban(i2, t2);
332                }
333            }
334        }
335    }
336
337    public boolean run(long seed, int limit)
338    {
339        if (wave == null) init();
340
341        clear();
342        random = new GWTRNG(seed);
343
344        for (int l = 0; l < limit || limit == 0; l++)
345        {
346            Boolean result = observe();
347            if (result != null) return result;
348            propagate();
349        }
350
351        return true;
352    }
353
354    public boolean run(IRNG rng, int limit)
355    {
356        if (wave == null) init();
357
358        clear();
359        random = rng;
360
361        for (int l = 0; l < limit || limit == 0; l++)
362        {
363            Boolean result = observe();
364            if (result != null) return result;
365            propagate();
366        }
367
368        return true;
369    }
370
371    private void ban(int i, int t)
372    {
373        wave[i][t] = false;
374
375        int[] comp = compatible[i][t];
376        for (int d = 0; d < 4; d++) comp[d] = 0;
377        stack[stacksize++] = i;
378        stack[stacksize++] = t;
379
380        double sum = sumsOfWeights[i];
381        entropies[i] += sumsOfWeightLogWeights[i] / sum - Math.log(sum);
382
383        sumsOfOnes[i] -= 1;
384        sumsOfWeights[i] -= baseWeights[t];
385        sumsOfWeightLogWeights[i] -= weightLogWeights[t];
386
387        sum = sumsOfWeights[i];
388        entropies[i] -= sumsOfWeightLogWeights[i] / sum - Math.log(sum);
389    }
390
391
392    private boolean onBoundary(int x, int y) {
393        return !periodic && (x + order > FMX || y + order > FMY || x < 0 || y < 0);
394    }
395
396    public int[][] result()
397    {
398        int[][] result = new int[FMX][FMY];
399
400        if (observed != null)
401        {
402            for (int y = 0; y < FMY; y++)
403            {
404                int dy = y < FMY - order + 1 ? 0 : order - 1;
405                for (int x = 0; x < FMX; x++)
406                {
407                    int dx = x < FMX - order + 1 ? 0 : order - 1;
408                    result[x][y] = choices.keyAt(patterns[observed[x - dx + (y - dy) * FMX]][dx + dy * order]);
409                }
410            }
411        }
412        return result;
413    }
414
415    private void clear()
416    {
417        for (int i = 0; i < wave.length; i++)
418        {
419            for (int t = 0; t < totalOptions; t++)
420            {
421                wave[i][t] = true;
422                for (int d = 0; d < 4; d++) compatible[i][t][d] = propagator[OPPOSITE[d]][t].length;
423            }
424
425            sumsOfOnes[i] = baseWeights.length;
426            sumsOfWeights[i] = sumOfWeights;
427            sumsOfWeightLogWeights[i] = sumOfWeightLogWeights;
428            entropies[i] = startingEntropy;
429        }
430
431
432        if (ground != 0)
433        {
434            for (int x = 0; x < FMX; x++)
435            {
436                for (int t = 0; t < totalOptions; t++) if (t != ground) ban(x + (FMY - 1) * FMX, t);
437                for (int y = 0; y < FMY - 1; y++) ban(x + y * FMX, ground);
438            }
439
440            propagate();
441        }
442    }
443    private static final int[] DX = { -1, 0, 1, 0 };
444    private static final int[] DY = { 0, 1, 0, -1 };
445    private static final int[] OPPOSITE = { 2, 3, 0, 1 };
446
447}