1 /**
2  * Very basic (and probably buggy) numeric
3  * distribution / probability operations.
4  * WIP, do not use.
5  *
6  * License:
7  *   This Source Code Form is subject to the terms of
8  *   the Mozilla Public License, v. 2.0. If a copy of
9  *   the MPL was not distributed with this file, You
10  *   can obtain one at http://mozilla.org/MPL/2.0/.
11  *
12  * Authors:
13  *   Vladimir Panteleev <ae@cy.md>
14  */
15 
16 module ae.utils.math.distribution;
17 
18 import std.algorithm.comparison;
19 import std.algorithm.iteration;
20 
21 import ae.utils.array;
22 import ae.utils.math;
23 
24 /// A simplified representation of some probability distribution.
25 /// Supports uniform distributions and basic operations on them (sum / product).
26 struct Range(T)
27 {
28 	/// Low, high, and average points.
29 	T lo, hi, avg;
30 	private bool uniform;
31 
32 	invariant
33 	{
34 		assert(lo <= avg);
35 		assert(avg <= hi);
36 	}
37 
38 	auto opBinary(string op, U)(U u) const
39 	if (is(U : real))
40 	{
41 		alias V = typeof(mixin("T.init " ~ op ~ " u"));
42 		V a   = mixin("lo "  ~ op ~ " u");
43 		V b   = mixin("hi "  ~ op ~ " u");
44 		V avg = mixin("avg " ~ op ~ " u");
45 		return Range!V(min(a, b), max(a, b), avg, uniform);
46 	} ///
47 
48 	auto opBinaryRight(string op, U)(U u) const
49 	if (is(U : real))
50 	{
51 		alias V = typeof(mixin("u " ~ op ~ " T.init"));
52 		V a   = mixin("u " ~ op ~ " lo");
53 		V b   = mixin("u " ~ op ~ " hi");
54 		V avg = mixin("u " ~ op ~ " avg");
55 		return Range!V(min(a, b), max(a, b), avg, uniform);
56 	} ///
57 
58 	auto opBinary(string op, R)(R r) const
59 	if (is(R : Range!U, U))
60 	{
61 		auto a = mixin("lo " ~ op ~ " r.lo");
62 		auto b = mixin("lo " ~ op ~ " r.hi");
63 		auto c = mixin("hi " ~ op ~ " r.lo");
64 		auto d = mixin("hi " ~ op ~ " r.hi");
65 		auto avg = mixin("avg " ~ op ~ " r.avg");
66 		return range(min(a, b, c, d), max(a, b, c, d), avg);
67 	} ///
68 
69 	auto opCast(T)() const
70 	if (is(T : Range!U, U))
71 	{
72 		static if (is(T : Range!U, U))
73 			return range(cast(U)lo, cast(U)hi, cast(U)avg);
74 		else
75 			assert(false);
76 	} ///
77 
78 	Range!U to(U)() const
79 	{
80 		return range(cast(U)lo, cast(U)hi, cast(U)avg);
81 	} ///
82 
83 	/// Apply a `prob` chance that `this` equals `val`.
84 	Range!T fuzzyAssign(Range!T val, double prob = 0.5)
85 	{
86 		assert(prob >= 0 && prob <= 1);
87 		if (prob == 0)
88 			return this;
89 		if (prob == 1)
90 			return val;
91 
92 		auto r = this;
93 		if (r.lo > val.lo)
94 			r.lo = val.lo;
95 		if (r.hi < val.hi)
96 			r.hi = val.hi;
97 		r.avg = itpl(r.avg, val.avg, prob, 0.0, 1.0);
98 		r.uniform = false;
99 		return r;
100 	}
101 
102 	/// ditto
103 	Range!T fuzzyAssign(T val, double prob = 0.5)
104 	{
105 		return fuzzyAssign(range(val), prob);
106 	}
107 
108 	string toString() const
109 	{
110 		import std.format : format;
111 		if (lo == hi)
112 			return format("%s", lo);
113 		else
114 		if (avg == (lo + hi) / 2)
115 			return format("%s..%s", lo, hi);
116 		else
117 			return format("%s..%s..%s", lo, avg, hi);
118 	} ///
119 }
120 
121 Range!T range(T)(T lo, T hi, T avg) { return Range!T(lo, hi, avg, false); } /// ditto
122 Range!T range(T)(T lo, T hi) { return Range!T(lo, hi, (lo + hi) / 2, true); } /// ditto
123 Range!T range(T)(T val) { return Range!T(val, val, val, true); } /// ditto
124 
125 ///
126 unittest
127 {
128 	assert(range(1, 2) + 1 == range(2, 3));
129 	assert(1 + range(1, 2) == range(2, 3));
130 }
131 
132 ///
133 unittest
134 {
135 	auto a = range(10, 20);
136 	auto b = range(10, 20);
137 	auto c = a * b;
138 	assert(c.avg == 225);
139 }
140 
141 unittest
142 {
143 	auto a = range(10, 20);
144 	a = a.fuzzyAssign(25);
145 	assert(a == range(10, 25, 20));
146 }
147 
148 // ****************************************************************************
149 
150 /// Indicates the probability of a certain event.
151 struct Probability
152 {
153 	double p; /// [0,1]
154 
155 	bool isImpossible() const @nogc { return p == 0; } ///
156 	bool isPossible() const @nogc { return p > 0; } ///
157 	bool isCertain() const @nogc { return p == 1; } ///
158 }
159 
160 /// Apply `doIf` if `p` is possible.
161 /// `doIf` receives the probability of the event (non-zero).
162 void cond(alias doIf)(Probability p)
163 {
164 	if (p.p > 0)
165 		doIf(p.p);
166 }
167 
168 /// Apply `doIf` if `p` is possible,
169 /// and/or `doElse` if `!p` is possible,
170 /// `doIf` and `doElse` receive the probability of their respective event (non-zero).
171 void cond(alias doIf, alias doElse)(Probability p)
172 {
173 	if (p.p > 0)
174 		doIf(p.p);
175 	if (p.p < 1)
176 		doElse(1 - p.p);
177 }
178 
179 /// Return the probability of event `a` not occurring.
180 Probability not(Probability a) { return Probability(1 - a.p); }
181 /// Return the probability of both unrelated events `a` and `b` occurring.
182 Probability and(Probability a, Probability b) { return Probability(a.p * b.p); }
183 /// Return the probability of at least one of the unrelated events `a` and `b` occurring.
184 Probability or (Probability a, Probability b) { return not(and(not(a), not(b))); }
185 
186 /// Return the probability that `a` `op` `b`, where `op` is `<` / `<=` / `>` / `>=`,
187 /// and `a` and `b` are numbers or ranges representing a uniform distribution.
188 template cmp(string op)
189 if (op.isOneOf("<", "<=", ">", ">="))
190 {
191 	// Number-to-number
192 
193 	Probability cmp(A, B)(A a, B b)
194 	if (!is(A : Range!AV, AV) && !is(B : Range!BV, BV))
195 	{
196 		return Probability(mixin("a" ~ op ~ "b") ? 1 : 0);
197 	}
198 
199 	// Number-to-range
200 
201 	Probability cmp(A, B)(A a, B b)
202 	if ( is(A : Range!AV, AV) && !is(B : Range!BV, BV))
203 	{
204 		double p;
205 
206 		if (a.hi < b)
207 			p = 1;
208 		else
209 		if (a.lo <= b)
210 		{
211 			assert(a.uniform, "Can't compare a non-uniform distribution");
212 
213 			auto lo = cast()a.lo;
214 			auto hi = cast()a.hi;
215 
216 			static if (is(typeof(lo + b) : long))
217 			{
218 				static if (op[0] == '<')
219 					hi++;
220 				else
221 					lo--;
222 
223 				static if (op.length == 2) // >=, <=
224 				{
225 					static if (op[0] == '<')
226 						b++;
227 					else
228 						b--;
229 				}
230 			}
231 			p = itpl(0.0, 1.0, b, lo, hi);
232 		}
233 		else
234 			p = 0;
235 
236 		static if (op[0] == '>')
237 			p = 1 - p;
238 
239 		return Probability(p);
240 	}
241 
242 	unittest // int unittest
243 	{
244 		auto a = range(1, 2);
245 		foreach (b; 0..4)
246 		{
247 			auto p0 = cmp(a, b).p;
248 			double p1 = 0;
249 			foreach (x; a.lo .. a.hi+1)
250 				if (mixin("x" ~ op ~ "b"))
251 					p1 += 0.5;
252 			debug
253 			{
254 				import std.conv : text;
255 				assert(p0 == p1, text("a", op, b, " -> ", p0, " / ", p1));
256 			}
257 		}
258 	}
259 
260 	// Range-to-number
261 
262 	Probability cmp(A, B)(A a, B b)
263 	if (!is(A : Range!AV, AV) &&  is(B : Range!BV, BV))
264 	{
265 		static if (op[0] == '>')
266 			return .cmp!("<" ~ op[1..$])(b, a);
267 		else
268 			return .cmp!(">" ~ op[1..$])(b, a);
269 	}
270 
271 	unittest
272 	{
273 		auto b = range(1, 2);
274 		foreach (a; 0..4)
275 		{
276 			auto p0 = cmp(a, b).p;
277 			double p1 = 0;
278 			foreach (x; b.lo .. b.hi+1)
279 				if (mixin("a" ~ op ~ "x"))
280 					p1 += 0.5;
281 			debug
282 			{
283 				import std.conv : text;
284 				assert(p0 == p1, text(a, op, "b", " -> ", p0, " / ", p1));
285 			}
286 		}
287 	}
288 
289 	// Range-to-range
290 
291 	Probability cmp(A, B)(A a, B b)
292 	if (is(A : Range!AV, AV) &&  is(B : Range!BV, BV))
293 	{
294 		assert(a.uniform && b.uniform, "Can't compare non-uniform distributions");
295 
296 		static if (op[0] == '<')
297 		{
298 			auto x0 = a.lo;
299 			auto x1 = a.hi;
300 			auto y0 = b.lo;
301 			auto y1 = b.hi;
302 		}
303 		else
304 		{
305 			auto x0 = b.lo;
306 			auto x1 = b.hi;
307 			auto y0 = a.lo;
308 			auto y1 = a.hi;
309 		}
310 
311 		static if (is(typeof(x0 + y0) : long))
312 		{
313 			x1++, y1++;
314 				
315 			static if (op.length == 2) // >=, <=
316 				y0++, y1++;
317 		}
318 
319 		double p;
320 
321 		// No intersection
322 		if (x1 <= y0) // x0 ≤ x1 ≤ y0 ≤ y1
323 			p = 1;
324 		else
325 		if (y1 <= x0) // y0 ≤ y1 ≤ x0 ≤ x1
326 			p = 0;
327 		else
328 		if (x0 <= y0)
329 		{
330 			// y is subset of x
331 			if (x0 <= y0 && y1 <= x1) // x0 ≤ y0 ≤ y1 ≤ x1
332 				p = ((y0 - x0) + ((y1 - y0) / 2.)) / (x1 - x0);
333 			
334 			// x is mostly less than y
335 			else // x0 ≤ y0 ≤ x1 ≤ y1
336 				p = 1 - (((x1 - y0) * (x1 - y0)) / 2.) / ((x1 - x0) * (y1 - y0));
337 		}
338 		else
339 		if (y0 <= x0)
340 		{
341 			// x is subset of y
342 			if (y0 <= x0 && x1 <= y1) // y0 ≤ x0 ≤ x1 ≤ y1
343 				p = ((y1 - x1) + ((x1 - x0) / 2.)) / (y1 - y0);
344 
345 			// y is mostly less than x
346 			else // y0 ≤ x0 ≤ y1 ≤ x1
347 				p =     (((y1 - x0) * (y1 - x0)) / 2.) / ((y1 - y0) * (x1 - x0));
348 		}
349 		else
350 			assert(false);
351 
352 		return Probability(p);
353 	}
354 }
355 
356 unittest
357 {
358 	assert(cmp!">"(0, 1).p == 0  );
359 	assert(cmp!">"(1, 0).p == 1  );
360 
361 	assert(cmp!"<"(1, 0).p == 0  );
362 	assert(cmp!"<"(0, 1).p == 1  );
363 }
364 
365 unittest
366 {
367 	auto a = range(1.0, 3.0);
368 	assert(cmp!"<"(a, 0.0).p == 0  );
369 	assert(cmp!"<"(a, 2.0).p == 0.5);
370 	assert(cmp!"<"(a, 5.0).p == 1  );
371 
372 	assert(cmp!">"(a, 0.0).p == 1  );
373 	assert(cmp!">"(a, 2.0).p == 0.5);
374 	assert(cmp!">"(a, 5.0).p == 0  );
375 }
376 
377 unittest // number-to-range, int
378 {
379 	auto a = range(1, 2);
380 
381 	assert(cmp!"<"(a, 1).p == 0  );
382 	assert(cmp!"<"(a, 2).p == 0.5);
383 	assert(cmp!"<"(a, 3).p == 1  );
384 
385 	assert(cmp!">"(a, 0).p == 1  );
386 	assert(cmp!">"(a, 1).p == 0.5);
387 	assert(cmp!">"(a, 2).p == 0  );
388 
389 	// instantiate template unittest
390 	alias le = cmp!"<=";
391 	alias ge = cmp!">=";
392 }
393 
394 unittest
395 {
396 	assert(cmp!"<" (range(0.), range(1.)).p == 1);
397 	assert(cmp!"<" (range(1.), range(0.)).p == 0);
398 
399 	// assert(cmp!"<" (range(0.), range(0.)).p == 0);
400 	// assert(cmp!"<="(range(0.), range(0.)).p == 1);
401 
402 	assert(cmp!"<" (range(0., 1.), range(0., 1.)).p == 0.5);
403 	assert(cmp!"<" (range(0., 1.), range(2., 3.)).p == 1.0);
404 	assert(cmp!"<" (range(2., 3.), range(0., 1.)).p == 0.0);
405 	assert(cmp!"<" (range(0., 1.), range(0., 2.)).p == 0.75);
406 	assert(cmp!"<" (range(0., 2.), range(1., 3.)).p == 7./8);
407 }
408 
409 // ****************************************************************************
410 
411 version (none)
412 {
413 	/// A quantized representation of a the probability distribution of
414 	/// some continuous function returning a value between 0 and 1.
415 	struct QuantizedDistribution(size_t numSegments, P = float, V = double)
416 	{
417 		enum V minValue = 0.0;
418 		enum V maxValue = 1.0;
419 
420 		/// Represents the relative probability that the function will
421 		/// return a value in the represented interval.
422 		P[numSegments] buckets = 1.0;
423 
424 		/// The length of one segment (of the function's return value) represented by one bucket.
425 		private enum V bucketSize = (maxValue - minValue) / numSegments;
426 
427 		private static size_t toBucketIndex(V value)
428 		{
429 			assert(value >= minValue && value <= maxValue);
430 			auto bucketIndex = cast(size_t)((value - minValue) / (maxValue - minValue) * numSegments);
431 			assert(bucketIndex <= numSegments);
432 			if (bucketIndex == numSegments)
433 				bucketIndex = numSegments - 1; // 1.0 goes into the last bucket, together with 0.999...
434 			return bucketIndex;
435 		}
436 
437 		private V bucketLowValue(size_t bucketIndex)
438 		{
439 			return minValue + ((maxValue - minValue) * bucketIndex / numSegments);
440 		}
441 		private V bucketHighValue(size_t bucketIndex)
442 		{
443 			return bucketLowValue(bucketIndex + 1);
444 		}
445 
446 		/// Normalizes `buckets` so that they add up to 1.
447 		typeof(this) normalize()
448 		{
449 			typeof(this) result = this;
450 			P sum = result.buckets[].sum;
451 			result.buckets[] /= sum;
452 			return result;
453 		}
454 
455 		/// Call `fun` a `numSamples` number of times, and return a distribution representing the result.
456 		static typeof(this) sample(V delegate() fun, size_t numSamples)
457 		{
458 			typeof(this) result;
459 			result.buckets[] = 0;
460 			foreach (_; 0 .. numSamples)
461 				result.buckets[toBucketIndex(fun())]++;
462 			return result;
463 		}
464 
465 		Probability gt(V value)
466 		{
467 			auto bucketIndex = toBucketIndex(value);
468 			auto lowValue  = bucketLowValue (bucketIndex);
469 			auto highValue = bucketHighValue(bucketIndex);
470 			auto total = buckets[].sum;
471 			return Probability((
472 				buckets[0 .. bucketIndex].sum +
473 				itpl(0, buckets[bucketIndex], value, lowValue, highValue)
474 			) / total);
475 		}
476 	}
477 
478 	unittest
479 	{
480 		import std.random : Random, uniform, uniform01;
481 		import std.math.operations : isClose;
482 
483 		auto rng = Random(0);
484 		{
485 			auto d = QuantizedDistribution!256.sample(() => uniform01!double(rng), 10_000);
486 			assert(d.gt(0.25).p.between(0.24, 0.26));
487 		}
488 		{
489 			auto d = QuantizedDistribution!256.sample(() => uniform(0, 4) == 0, 10_000);
490 			assert(d.gt(0.5).p.between(0.74, 0.76));
491 		}
492 	}
493 }