1 /**
2  * Very basic (and probably buggy) numeric
3  * distribution / probability operations.
4  * WIP, do not use.
5  *
6  * License:
7  *   This Source Code Form is subject to the terms of
8  *   the Mozilla Public License, v. 2.0. If a copy of
9  *   the MPL was not distributed with this file, You
10  *   can obtain one at http://mozilla.org/MPL/2.0/.
11  *
12  * Authors:
13  *   Vladimir Panteleev <ae@cy.md>
14  */
15 
16 module ae.utils.math.distribution;
17 
18 import std.algorithm.comparison;
19 
20 import ae.utils.array;
21 import ae.utils.math;
22 
23 /// A simplified representation of some probability distribution.
24 /// Supports uniform distributions and basic operations on them (sum / product).
25 struct Range(T)
26 {
27 	/// Low, high, and average points.
28 	T lo, hi, avg;
29 	private bool uniform;
30 
31 	invariant
32 	{
33 		assert(lo <= avg);
34 		assert(avg <= hi);
35 	}
36 
37 	auto opBinary(string op, U)(U u) const
38 	if (is(U : real))
39 	{
40 		alias V = typeof(mixin("T.init " ~ op ~ " u"));
41 		V a   = mixin("lo "  ~ op ~ " u");
42 		V b   = mixin("hi "  ~ op ~ " u");
43 		V avg = mixin("avg " ~ op ~ " u");
44 		return Range!V(min(a, b), max(a, b), avg, uniform);
45 	} ///
46 
47 	auto opBinaryRight(string op, U)(U u) const
48 	if (is(U : real))
49 	{
50 		alias V = typeof(mixin("u " ~ op ~ " T.init"));
51 		V a   = mixin("u " ~ op ~ " lo");
52 		V b   = mixin("u " ~ op ~ " hi");
53 		V avg = mixin("u " ~ op ~ " avg");
54 		return Range!V(min(a, b), max(a, b), avg, uniform);
55 	} ///
56 
57 	auto opBinary(string op, R)(R r) const
58 	if (is(R : Range!U, U))
59 	{
60 		auto a = mixin("lo " ~ op ~ " r.lo");
61 		auto b = mixin("lo " ~ op ~ " r.hi");
62 		auto c = mixin("hi " ~ op ~ " r.lo");
63 		auto d = mixin("hi " ~ op ~ " r.hi");
64 		auto avg = mixin("avg " ~ op ~ " r.avg");
65 		return range(min(a, b, c, d), max(a, b, c, d), avg);
66 	} ///
67 
68 	auto opCast(T)() const
69 	if (is(T : Range!U, U))
70 	{
71 		static if (is(T : Range!U, U))
72 			return range(cast(U)lo, cast(U)hi, cast(U)avg);
73 		else
74 			assert(false);
75 	} ///
76 
77 	Range!U to(U)() const
78 	{
79 		return range(cast(U)lo, cast(U)hi, cast(U)avg);
80 	} ///
81 
82 	/// Apply a `prob` chance that `this` equals `val`.
83 	Range!T fuzzyAssign(Range!T val, double prob = 0.5)
84 	{
85 		assert(prob >= 0 && prob <= 1);
86 		if (prob == 0)
87 			return this;
88 		if (prob == 1)
89 			return val;
90 
91 		auto r = this;
92 		if (r.lo > val.lo)
93 			r.lo = val.lo;
94 		if (r.hi < val.hi)
95 			r.hi = val.hi;
96 		r.avg = itpl(r.avg, val.avg, prob, 0.0, 1.0);
97 		r.uniform = false;
98 		return r;
99 	}
100 
101 	/// ditto
102 	Range!T fuzzyAssign(T val, double prob = 0.5)
103 	{
104 		return fuzzyAssign(range(val), prob);
105 	}
106 
107 	string toString() const
108 	{
109 		import std.format : format;
110 		if (lo == hi)
111 			return format("%s", lo);
112 		else
113 		if (avg == (lo + hi) / 2)
114 			return format("%s..%s", lo, hi);
115 		else
116 			return format("%s..%s..%s", lo, avg, hi);
117 	} ///
118 }
119 
120 Range!T range(T)(T lo, T hi, T avg) { return Range!T(lo, hi, avg, false); } /// ditto
121 Range!T range(T)(T lo, T hi) { return Range!T(lo, hi, (lo + hi) / 2, true); } /// ditto
122 Range!T range(T)(T val) { return Range!T(val, val, val, true); } /// ditto
123 
124 ///
125 unittest
126 {
127 	assert(range(1, 2) + 1 == range(2, 3));
128 	assert(1 + range(1, 2) == range(2, 3));
129 }
130 
131 ///
132 unittest
133 {
134 	auto a = range(10, 20);
135 	auto b = range(10, 20);
136 	auto c = a * b;
137 	assert(c.avg == 225);
138 }
139 
140 unittest
141 {
142 	auto a = range(10, 20);
143 	a = a.fuzzyAssign(25);
144 	assert(a == range(10, 25, 20));
145 }
146 
147 // ****************************************************************************
148 
149 /// Indicates the probability of a certain event.
150 struct Probability
151 {
152 	double p; /// [0,1]
153 
154 	bool isImpossible() const @nogc { return p == 0; } ///
155 	bool isPossible() const @nogc { return p > 0; } ///
156 	bool isCertain() const @nogc { return p == 1; } ///
157 }
158 
159 /// Apply `doIf` if `p` is possible.
160 /// `doIf` receives the probability of the event (non-zero).
161 void cond(alias doIf)(Probability p)
162 {
163 	if (p.p > 0)
164 		doIf(p.p);
165 }
166 
167 /// Apply `doIf` if `p` is possible,
168 /// and/or `doElse` if `!p` is possible,
169 /// `doIf` and `doElse` receive the probability of their respective event (non-zero).
170 void cond(alias doIf, alias doElse)(Probability p)
171 {
172 	if (p.p > 0)
173 		doIf(p.p);
174 	if (p.p < 1)
175 		doElse(1 - p.p);
176 }
177 
178 /// Return the probability of event `a` not occurring.
179 Probability not(Probability a) { return Probability(1 - a.p); }
180 /// Return the probability of both unrelated events `a` and `b` occurring.
181 Probability and(Probability a, Probability b) { return Probability(a.p * b.p); }
182 /// Return the probability of at least one of the unrelated events `a` and `b` occurring.
183 Probability or (Probability a, Probability b) { return not(and(not(a), not(b))); }
184 
185 /// Return the probability that `a` `op` `b`, where `op` is `<` / `<=` / `>` / `>=`,
186 /// and `a` and `b` are numbers or ranges representing a uniform distribution.
187 template cmp(string op)
188 if (op.isOneOf("<", "<=", ">", ">="))
189 {
190 	// Number-to-number
191 
192 	Probability cmp(A, B)(A a, B b)
193 	if (!is(A : Range!AV, AV) && !is(B : Range!BV, BV))
194 	{
195 		return Probability(mixin("a" ~ op ~ "b") ? 1 : 0);
196 	}
197 
198 	// Number-to-range
199 
200 	Probability cmp(A, B)(A a, B b)
201 	if ( is(A : Range!AV, AV) && !is(B : Range!BV, BV))
202 	{
203 		double p;
204 
205 		if (a.hi < b)
206 			p = 1;
207 		else
208 		if (a.lo <= b)
209 		{
210 			assert(a.uniform, "Can't compare a non-uniform distribution");
211 
212 			auto lo = cast()a.lo;
213 			auto hi = cast()a.hi;
214 
215 			static if (is(typeof(lo + b) : long))
216 			{
217 				static if (op[0] == '<')
218 					hi++;
219 				else
220 					lo--;
221 
222 				static if (op.length == 2) // >=, <=
223 				{
224 					static if (op[0] == '<')
225 						b++;
226 					else
227 						b--;
228 				}
229 			}
230 			p = itpl(0.0, 1.0, b, lo, hi);
231 		}
232 		else
233 			p = 0;
234 
235 		static if (op[0] == '>')
236 			p = 1 - p;
237 
238 		return Probability(p);
239 	}
240 
241 	unittest // int unittest
242 	{
243 		auto a = range(1, 2);
244 		foreach (b; 0..4)
245 		{
246 			auto p0 = cmp(a, b).p;
247 			double p1 = 0;
248 			foreach (x; a.lo .. a.hi+1)
249 				if (mixin("x" ~ op ~ "b"))
250 					p1 += 0.5;
251 			debug
252 			{
253 				import std.conv : text;
254 				assert(p0 == p1, text("a", op, b, " -> ", p0, " / ", p1));
255 			}
256 		}
257 	}
258 
259 	// Range-to-number
260 
261 	Probability cmp(A, B)(A a, B b)
262 	if (!is(A : Range!AV, AV) &&  is(B : Range!BV, BV))
263 	{
264 		static if (op[0] == '>')
265 			return .cmp!("<" ~ op[1..$])(b, a);
266 		else
267 			return .cmp!(">" ~ op[1..$])(b, a);
268 	}
269 
270 	unittest
271 	{
272 		auto b = range(1, 2);
273 		foreach (a; 0..4)
274 		{
275 			auto p0 = cmp(a, b).p;
276 			double p1 = 0;
277 			foreach (x; b.lo .. b.hi+1)
278 				if (mixin("a" ~ op ~ "x"))
279 					p1 += 0.5;
280 			debug
281 			{
282 				import std.conv : text;
283 				assert(p0 == p1, text(a, op, "b", " -> ", p0, " / ", p1));
284 			}
285 		}
286 	}
287 
288 	// Range-to-range
289 
290 	Probability cmp(A, B)(A a, B b)
291 	if (is(A : Range!AV, AV) &&  is(B : Range!BV, BV))
292 	{
293 		assert(a.uniform && b.uniform, "Can't compare non-uniform distributions");
294 
295 		static if (op[0] == '<')
296 		{
297 			auto x0 = a.lo;
298 			auto x1 = a.hi;
299 			auto y0 = b.lo;
300 			auto y1 = b.hi;
301 		}
302 		else
303 		{
304 			auto x0 = b.lo;
305 			auto x1 = b.hi;
306 			auto y0 = a.lo;
307 			auto y1 = a.hi;
308 		}
309 
310 		static if (is(typeof(x0 + y0) : long))
311 		{
312 			x1++, y1++;
313 				
314 			static if (op.length == 2) // >=, <=
315 				y0++, y1++;
316 		}
317 
318 		double p;
319 
320 		// No intersection
321 		if (x1 <= y0) // x0 ≤ x1 ≤ y0 ≤ y1
322 			p = 1;
323 		else
324 		if (y1 <= x0) // y0 ≤ y1 ≤ x0 ≤ x1
325 			p = 0;
326 		else
327 		if (x0 <= y0)
328 		{
329 			// y is subset of x
330 			if (x0 <= y0 && y1 <= x1) // x0 ≤ y0 ≤ y1 ≤ x1
331 				p = ((y0 - x0) + ((y1 - y0) / 2.)) / (x1 - x0);
332 			
333 			// x is mostly less than y
334 			else // x0 ≤ y0 ≤ x1 ≤ y1
335 				p = 1 - (((x1 - y0) * (x1 - y0)) / 2.) / ((x1 - x0) * (y1 - y0));
336 		}
337 		else
338 		if (y0 <= x0)
339 		{
340 			// x is subset of y
341 			if (y0 <= x0 && x1 <= y1) // y0 ≤ x0 ≤ x1 ≤ y1
342 				p = ((y1 - x1) + ((x1 - x0) / 2.)) / (y1 - y0);
343 
344 			// y is mostly less than x
345 			else // y0 ≤ x0 ≤ y1 ≤ x1
346 				p =     (((y1 - x0) * (y1 - x0)) / 2.) / ((y1 - y0) * (x1 - x0));
347 		}
348 		else
349 			assert(false);
350 
351 		return Probability(p);
352 	}
353 }
354 
355 unittest
356 {
357 	assert(cmp!">"(0, 1).p == 0  );
358 	assert(cmp!">"(1, 0).p == 1  );
359 
360 	assert(cmp!"<"(1, 0).p == 0  );
361 	assert(cmp!"<"(0, 1).p == 1  );
362 }
363 
364 unittest
365 {
366 	auto a = range(1.0, 3.0);
367 	assert(cmp!"<"(a, 0.0).p == 0  );
368 	assert(cmp!"<"(a, 2.0).p == 0.5);
369 	assert(cmp!"<"(a, 5.0).p == 1  );
370 
371 	assert(cmp!">"(a, 0.0).p == 1  );
372 	assert(cmp!">"(a, 2.0).p == 0.5);
373 	assert(cmp!">"(a, 5.0).p == 0  );
374 }
375 
376 unittest // number-to-range, int
377 {
378 	auto a = range(1, 2);
379 
380 	assert(cmp!"<"(a, 1).p == 0  );
381 	assert(cmp!"<"(a, 2).p == 0.5);
382 	assert(cmp!"<"(a, 3).p == 1  );
383 
384 	assert(cmp!">"(a, 0).p == 1  );
385 	assert(cmp!">"(a, 1).p == 0.5);
386 	assert(cmp!">"(a, 2).p == 0  );
387 
388 	// instantiate template unittest
389 	alias le = cmp!"<=";
390 	alias ge = cmp!">=";
391 }
392 
393 unittest
394 {
395 	assert(cmp!"<" (range(0.), range(1.)).p == 1);
396 	assert(cmp!"<" (range(1.), range(0.)).p == 0);
397 
398 	// assert(cmp!"<" (range(0.), range(0.)).p == 0);
399 	// assert(cmp!"<="(range(0.), range(0.)).p == 1);
400 
401 	assert(cmp!"<" (range(0., 1.), range(0., 1.)).p == 0.5);
402 	assert(cmp!"<" (range(0., 1.), range(2., 3.)).p == 1.0);
403 	assert(cmp!"<" (range(2., 3.), range(0., 1.)).p == 0.0);
404 	assert(cmp!"<" (range(0., 1.), range(0., 2.)).p == 0.75);
405 	assert(cmp!"<" (range(0., 2.), range(1., 3.)).p == 7./8);
406 }