1 /** 2 * Very basic (and probably buggy) numeric 3 * distribution / probability operations. 4 * WIP, do not use. 5 * 6 * License: 7 * This Source Code Form is subject to the terms of 8 * the Mozilla Public License, v. 2.0. If a copy of 9 * the MPL was not distributed with this file, You 10 * can obtain one at http://mozilla.org/MPL/2.0/. 11 * 12 * Authors: 13 * Vladimir Panteleev <vladimir@thecybershadow.net> 14 */ 15 16 module ae.utils.math.distribution; 17 18 import std.algorithm.comparison; 19 20 import ae.utils.array; 21 import ae.utils.math; 22 23 /// A simplified representation of some probability distribution. 24 /// Supports uniform distributions and basic operations on them (sum / product). 25 struct Range(T) 26 { 27 T lo, hi, avg; 28 private bool uniform; 29 30 invariant 31 { 32 assert(lo <= avg); 33 assert(avg <= hi); 34 } 35 36 auto opBinary(string op, U)(U u) const 37 if (is(U : real)) 38 { 39 alias V = typeof(mixin("T.init " ~ op ~ " u")); 40 V a = mixin("lo " ~ op ~ " u"); 41 V b = mixin("hi " ~ op ~ " u"); 42 V avg = mixin("avg " ~ op ~ " u"); 43 return Range!V(min(a, b), max(a, b), avg, uniform); 44 } 45 46 auto opBinaryRight(string op, U)(U u) const 47 if (is(U : real)) 48 { 49 alias V = typeof(mixin("u " ~ op ~ " T.init")); 50 V a = mixin("u " ~ op ~ " lo"); 51 V b = mixin("u " ~ op ~ " hi"); 52 V avg = mixin("u " ~ op ~ " avg"); 53 return Range!V(min(a, b), max(a, b), avg, uniform); 54 } 55 56 auto opBinary(string op, R)(R r) const 57 if (is(R : Range!U, U)) 58 { 59 auto a = mixin("lo " ~ op ~ " r.lo"); 60 auto b = mixin("lo " ~ op ~ " r.hi"); 61 auto c = mixin("hi " ~ op ~ " r.lo"); 62 auto d = mixin("hi " ~ op ~ " r.hi"); 63 auto avg = mixin("avg " ~ op ~ " r.avg"); 64 return range(min(a, b, c, d), max(a, b, c, d), avg); 65 } 66 67 auto opCast(T)() const 68 if (is(T : Range!U, U)) 69 { 70 static if (is(T : Range!U, U)) 71 return range(cast(U)lo, cast(U)hi, cast(U)avg); 72 else 73 assert(false); 74 } 75 76 Range!U to(U)() const 77 { 78 return range(cast(U)lo, cast(U)hi, cast(U)avg); 79 } 80 81 Range!T fuzzyAssign(Range!T val, double prob = 0.5) 82 { 83 assert(prob >= 0 && prob <= 1); 84 if (prob == 0) 85 return this; 86 if (prob == 1) 87 return val; 88 89 auto r = this; 90 if (r.lo > val.lo) 91 r.lo = val.lo; 92 if (r.hi < val.hi) 93 r.hi = val.hi; 94 r.avg = itpl(r.avg, val.avg, prob, 0.0, 1.0); 95 r.uniform = false; 96 return r; 97 } 98 99 Range!T fuzzyAssign(T val, double prob = 0.5) 100 { 101 return fuzzyAssign(range(val), prob); 102 } 103 104 string toString() const 105 { 106 import std.format : format; 107 if (lo == hi) 108 return format("%s", lo); 109 else 110 if (avg == (lo + hi) / 2) 111 return format("%s..%s", lo, hi); 112 else 113 return format("%s..%s..%s", lo, avg, hi); 114 } 115 } 116 117 Range!T range(T)(T lo, T hi, T avg) { return Range!T(lo, hi, avg, false); } 118 Range!T range(T)(T lo, T hi) { return Range!T(lo, hi, (lo + hi) / 2, true); } 119 Range!T range(T)(T val) { return Range!T(val, val, val, true); } 120 121 unittest 122 { 123 assert(range(1, 2) + 1 == range(2, 3)); 124 assert(1 + range(1, 2) == range(2, 3)); 125 } 126 127 unittest 128 { 129 auto a = range(10, 20); 130 auto b = range(10, 20); 131 auto c = a * b; 132 assert(c.avg == 225); 133 } 134 135 unittest 136 { 137 auto a = range(10, 20); 138 a = a.fuzzyAssign(25); 139 assert(a == range(10, 25, 20)); 140 } 141 142 /// Indicates the probability of a certain event. 143 struct Probability 144 { 145 double p; 146 147 bool isImpossible() const @nogc { return p == 0; } 148 bool isPossible() const @nogc { return p > 0; } 149 bool isCertain() const @nogc { return p == 1; } 150 } 151 152 void cond(alias doIf)(Probability p) 153 { 154 if (p.p > 0) 155 doIf(p.p); 156 } 157 158 void cond(alias doIf, alias doElse)(Probability p) 159 { 160 if (p.p > 0) 161 doIf(p.p); 162 if (p.p < 1) 163 doElse(1 - p.p); 164 } 165 166 Probability not(Probability a) { return Probability(1 - a.p); } 167 Probability and(Probability a, Probability b) { return Probability(a.p * b.p); } 168 Probability or (Probability a, Probability b) { return not(and(not(a), not(b))); } 169 170 template cmp(string op) 171 if (op.isOneOf("<", "<=", ">", ">=")) 172 { 173 // Number-to-number 174 175 Probability cmp(A, B)(A a, B b) 176 if (!is(A : Range!AV, AV) && !is(B : Range!BV, BV)) 177 { 178 return Probability(mixin("a" ~ op ~ "b") ? 1 : 0); 179 } 180 181 // Number-to-range 182 183 Probability cmp(A, B)(A a, B b) 184 if ( is(A : Range!AV, AV) && !is(B : Range!BV, BV)) 185 { 186 double p; 187 188 if (a.hi < b) 189 p = 1; 190 else 191 if (a.lo <= b) 192 { 193 assert(a.uniform, "Can't compare a non-uniform distribution"); 194 195 auto lo = cast()a.lo; 196 auto hi = cast()a.hi; 197 198 static if (is(typeof(lo + b) : long)) 199 { 200 static if (op[0] == '<') 201 hi++; 202 else 203 lo--; 204 205 static if (op.length == 2) // >=, <= 206 { 207 static if (op[0] == '<') 208 b++; 209 else 210 b--; 211 } 212 } 213 p = itpl(0.0, 1.0, b, lo, hi); 214 } 215 else 216 p = 0; 217 218 static if (op[0] == '>') 219 p = 1 - p; 220 221 return Probability(p); 222 } 223 224 unittest // int unittest 225 { 226 auto a = range(1, 2); 227 foreach (b; 0..4) 228 { 229 auto p0 = cmp(a, b).p; 230 double p1 = 0; 231 foreach (x; a.lo .. a.hi+1) 232 if (mixin("x" ~ op ~ "b")) 233 p1 += 0.5; 234 debug 235 { 236 import std.conv : text; 237 assert(p0 == p1, text("a", op, b, " -> ", p0, " / ", p1)); 238 } 239 } 240 } 241 242 // Range-to-number 243 244 Probability cmp(A, B)(A a, B b) 245 if (!is(A : Range!AV, AV) && is(B : Range!BV, BV)) 246 { 247 static if (op[0] == '>') 248 return .cmp!("<" ~ op[1..$])(b, a); 249 else 250 return .cmp!(">" ~ op[1..$])(b, a); 251 } 252 253 unittest 254 { 255 auto b = range(1, 2); 256 foreach (a; 0..4) 257 { 258 auto p0 = cmp(a, b).p; 259 double p1 = 0; 260 foreach (x; b.lo .. b.hi+1) 261 if (mixin("a" ~ op ~ "x")) 262 p1 += 0.5; 263 debug 264 { 265 import std.conv : text; 266 assert(p0 == p1, text(a, op, "b", " -> ", p0, " / ", p1)); 267 } 268 } 269 } 270 271 // Range-to-range 272 273 Probability cmp(A, B)(A a, B b) 274 if (is(A : Range!AV, AV) && is(B : Range!BV, BV)) 275 { 276 assert(a.uniform && b.uniform, "Can't compare non-uniform distributions"); 277 278 static if (op[0] == '<') 279 { 280 auto x0 = a.lo; 281 auto x1 = a.hi; 282 auto y0 = b.lo; 283 auto y1 = b.hi; 284 } 285 else 286 { 287 auto x0 = b.lo; 288 auto x1 = b.hi; 289 auto y0 = a.lo; 290 auto y1 = a.hi; 291 } 292 293 static if (is(typeof(x0 + y0) : long)) 294 { 295 x1++, y1++; 296 297 static if (op.length == 2) // >=, <= 298 y0++, y1++; 299 } 300 301 double p; 302 303 // No intersection 304 if (x1 <= y0) // x0 ≤ x1 ≤ y0 ≤ y1 305 p = 1; 306 else 307 if (y1 <= x0) // y0 ≤ y1 ≤ x0 ≤ x1 308 p = 0; 309 else 310 if (x0 <= y0) 311 { 312 // y is subset of x 313 if (x0 <= y0 && y1 <= x1) // x0 ≤ y0 ≤ y1 ≤ x1 314 p = ((y0 - x0) + ((y1 - y0) / 2.)) / (x1 - x0); 315 316 // x is mostly less than y 317 else // x0 ≤ y0 ≤ x1 ≤ y1 318 p = 1 - (((x1 - y0) * (x1 - y0)) / 2.) / ((x1 - x0) * (y1 - y0)); 319 } 320 else 321 if (y0 <= x0) 322 { 323 // x is subset of y 324 if (y0 <= x0 && x1 <= y1) // y0 ≤ x0 ≤ x1 ≤ y1 325 p = ((y1 - x1) + ((x1 - x0) / 2.)) / (y1 - y0); 326 327 // y is mostly less than x 328 else // y0 ≤ x0 ≤ y1 ≤ x1 329 p = (((y1 - x0) * (y1 - x0)) / 2.) / ((y1 - y0) * (x1 - x0)); 330 } 331 else 332 assert(false); 333 334 return Probability(p); 335 } 336 } 337 338 unittest 339 { 340 assert(cmp!">"(0, 1).p == 0 ); 341 assert(cmp!">"(1, 0).p == 1 ); 342 343 assert(cmp!"<"(1, 0).p == 0 ); 344 assert(cmp!"<"(0, 1).p == 1 ); 345 } 346 347 unittest 348 { 349 auto a = range(1.0, 3.0); 350 assert(cmp!"<"(a, 0.0).p == 0 ); 351 assert(cmp!"<"(a, 2.0).p == 0.5); 352 assert(cmp!"<"(a, 5.0).p == 1 ); 353 354 assert(cmp!">"(a, 0.0).p == 1 ); 355 assert(cmp!">"(a, 2.0).p == 0.5); 356 assert(cmp!">"(a, 5.0).p == 0 ); 357 } 358 359 unittest // number-to-range, int 360 { 361 auto a = range(1, 2); 362 363 assert(cmp!"<"(a, 1).p == 0 ); 364 assert(cmp!"<"(a, 2).p == 0.5); 365 assert(cmp!"<"(a, 3).p == 1 ); 366 367 assert(cmp!">"(a, 0).p == 1 ); 368 assert(cmp!">"(a, 1).p == 0.5); 369 assert(cmp!">"(a, 2).p == 0 ); 370 371 // instantiate template unittest 372 alias le = cmp!"<="; 373 alias ge = cmp!">="; 374 } 375 376 unittest 377 { 378 assert(cmp!"<" (range(0.), range(1.)).p == 1); 379 assert(cmp!"<" (range(1.), range(0.)).p == 0); 380 381 // assert(cmp!"<" (range(0.), range(0.)).p == 0); 382 // assert(cmp!"<="(range(0.), range(0.)).p == 1); 383 384 assert(cmp!"<" (range(0., 1.), range(0., 1.)).p == 0.5); 385 assert(cmp!"<" (range(0., 1.), range(2., 3.)).p == 1.0); 386 assert(cmp!"<" (range(2., 3.), range(0., 1.)).p == 0.0); 387 assert(cmp!"<" (range(0., 1.), range(0., 2.)).p == 0.75); 388 assert(cmp!"<" (range(0., 2.), range(1., 3.)).p == 7./8); 389 }