1 /** 2 * Very basic (and probably buggy) numeric 3 * distribution / probability operations. 4 * WIP, do not use. 5 * 6 * License: 7 * This Source Code Form is subject to the terms of 8 * the Mozilla Public License, v. 2.0. If a copy of 9 * the MPL was not distributed with this file, You 10 * can obtain one at http://mozilla.org/MPL/2.0/. 11 * 12 * Authors: 13 * Vladimir Panteleev <ae@cy.md> 14 */ 15 16 module ae.utils.math.distribution; 17 18 import std.algorithm.comparison; 19 import std.algorithm.iteration; 20 21 import ae.utils.array; 22 import ae.utils.math; 23 24 /// A simplified representation of some probability distribution. 25 /// Supports uniform distributions and basic operations on them (sum / product). 26 struct Range(T) 27 { 28 /// Low, high, and average points. 29 T lo, hi, avg; 30 private bool uniform; 31 32 invariant 33 { 34 assert(lo <= avg); 35 assert(avg <= hi); 36 } 37 38 auto opBinary(string op, U)(U u) const 39 if (is(U : real)) 40 { 41 alias V = typeof(mixin("T.init " ~ op ~ " u")); 42 V a = mixin("lo " ~ op ~ " u"); 43 V b = mixin("hi " ~ op ~ " u"); 44 V avg = mixin("avg " ~ op ~ " u"); 45 return Range!V(min(a, b), max(a, b), avg, uniform); 46 } /// 47 48 auto opBinaryRight(string op, U)(U u) const 49 if (is(U : real)) 50 { 51 alias V = typeof(mixin("u " ~ op ~ " T.init")); 52 V a = mixin("u " ~ op ~ " lo"); 53 V b = mixin("u " ~ op ~ " hi"); 54 V avg = mixin("u " ~ op ~ " avg"); 55 return Range!V(min(a, b), max(a, b), avg, uniform); 56 } /// 57 58 auto opBinary(string op, R)(R r) const 59 if (is(R : Range!U, U)) 60 { 61 auto a = mixin("lo " ~ op ~ " r.lo"); 62 auto b = mixin("lo " ~ op ~ " r.hi"); 63 auto c = mixin("hi " ~ op ~ " r.lo"); 64 auto d = mixin("hi " ~ op ~ " r.hi"); 65 auto avg = mixin("avg " ~ op ~ " r.avg"); 66 return range(min(a, b, c, d), max(a, b, c, d), avg); 67 } /// 68 69 auto opCast(T)() const 70 if (is(T : Range!U, U)) 71 { 72 static if (is(T : Range!U, U)) 73 return range(cast(U)lo, cast(U)hi, cast(U)avg); 74 else 75 assert(false); 76 } /// 77 78 Range!U to(U)() const 79 { 80 return range(cast(U)lo, cast(U)hi, cast(U)avg); 81 } /// 82 83 /// Apply a `prob` chance that `this` equals `val`. 84 Range!T fuzzyAssign(Range!T val, double prob = 0.5) 85 { 86 assert(prob >= 0 && prob <= 1); 87 if (prob == 0) 88 return this; 89 if (prob == 1) 90 return val; 91 92 auto r = this; 93 if (r.lo > val.lo) 94 r.lo = val.lo; 95 if (r.hi < val.hi) 96 r.hi = val.hi; 97 r.avg = itpl(r.avg, val.avg, prob, 0.0, 1.0); 98 r.uniform = false; 99 return r; 100 } 101 102 /// ditto 103 Range!T fuzzyAssign(T val, double prob = 0.5) 104 { 105 return fuzzyAssign(range(val), prob); 106 } 107 108 string toString() const 109 { 110 import std.format : format; 111 if (lo == hi) 112 return format("%s", lo); 113 else 114 if (avg == (lo + hi) / 2) 115 return format("%s..%s", lo, hi); 116 else 117 return format("%s..%s..%s", lo, avg, hi); 118 } /// 119 } 120 121 Range!T range(T)(T lo, T hi, T avg) { return Range!T(lo, hi, avg, false); } /// ditto 122 Range!T range(T)(T lo, T hi) { return Range!T(lo, hi, (lo + hi) / 2, true); } /// ditto 123 Range!T range(T)(T val) { return Range!T(val, val, val, true); } /// ditto 124 125 /// 126 unittest 127 { 128 assert(range(1, 2) + 1 == range(2, 3)); 129 assert(1 + range(1, 2) == range(2, 3)); 130 } 131 132 /// 133 unittest 134 { 135 auto a = range(10, 20); 136 auto b = range(10, 20); 137 auto c = a * b; 138 assert(c.avg == 225); 139 } 140 141 unittest 142 { 143 auto a = range(10, 20); 144 a = a.fuzzyAssign(25); 145 assert(a == range(10, 25, 20)); 146 } 147 148 // **************************************************************************** 149 150 /// Indicates the probability of a certain event. 151 struct Probability 152 { 153 double p; /// [0,1] 154 155 bool isImpossible() const @nogc { return p == 0; } /// 156 bool isPossible() const @nogc { return p > 0; } /// 157 bool isCertain() const @nogc { return p == 1; } /// 158 } 159 160 /// Apply `doIf` if `p` is possible. 161 /// `doIf` receives the probability of the event (non-zero). 162 void cond(alias doIf)(Probability p) 163 { 164 if (p.p > 0) 165 doIf(p.p); 166 } 167 168 /// Apply `doIf` if `p` is possible, 169 /// and/or `doElse` if `!p` is possible, 170 /// `doIf` and `doElse` receive the probability of their respective event (non-zero). 171 void cond(alias doIf, alias doElse)(Probability p) 172 { 173 if (p.p > 0) 174 doIf(p.p); 175 if (p.p < 1) 176 doElse(1 - p.p); 177 } 178 179 /// Return the probability of event `a` not occurring. 180 Probability not(Probability a) { return Probability(1 - a.p); } 181 /// Return the probability of both unrelated events `a` and `b` occurring. 182 Probability and(Probability a, Probability b) { return Probability(a.p * b.p); } 183 /// Return the probability of at least one of the unrelated events `a` and `b` occurring. 184 Probability or (Probability a, Probability b) { return not(and(not(a), not(b))); } 185 186 /// Return the probability that `a` `op` `b`, where `op` is `<` / `<=` / `>` / `>=`, 187 /// and `a` and `b` are numbers or ranges representing a uniform distribution. 188 template cmp(string op) 189 if (op.isOneOf("<", "<=", ">", ">=")) 190 { 191 // Number-to-number 192 193 Probability cmp(A, B)(A a, B b) 194 if (!is(A : Range!AV, AV) && !is(B : Range!BV, BV)) 195 { 196 return Probability(mixin("a" ~ op ~ "b") ? 1 : 0); 197 } 198 199 // Number-to-range 200 201 Probability cmp(A, B)(A a, B b) 202 if ( is(A : Range!AV, AV) && !is(B : Range!BV, BV)) 203 { 204 double p; 205 206 if (a.hi < b) 207 p = 1; 208 else 209 if (a.lo <= b) 210 { 211 assert(a.uniform, "Can't compare a non-uniform distribution"); 212 213 auto lo = cast()a.lo; 214 auto hi = cast()a.hi; 215 216 static if (is(typeof(lo + b) : long)) 217 { 218 static if (op[0] == '<') 219 hi++; 220 else 221 lo--; 222 223 static if (op.length == 2) // >=, <= 224 { 225 static if (op[0] == '<') 226 b++; 227 else 228 b--; 229 } 230 } 231 p = itpl(0.0, 1.0, b, lo, hi); 232 } 233 else 234 p = 0; 235 236 static if (op[0] == '>') 237 p = 1 - p; 238 239 return Probability(p); 240 } 241 242 unittest // int unittest 243 { 244 auto a = range(1, 2); 245 foreach (b; 0..4) 246 { 247 auto p0 = cmp(a, b).p; 248 double p1 = 0; 249 foreach (x; a.lo .. a.hi+1) 250 if (mixin("x" ~ op ~ "b")) 251 p1 += 0.5; 252 debug 253 { 254 import std.conv : text; 255 assert(p0 == p1, text("a", op, b, " -> ", p0, " / ", p1)); 256 } 257 } 258 } 259 260 // Range-to-number 261 262 Probability cmp(A, B)(A a, B b) 263 if (!is(A : Range!AV, AV) && is(B : Range!BV, BV)) 264 { 265 static if (op[0] == '>') 266 return .cmp!("<" ~ op[1..$])(b, a); 267 else 268 return .cmp!(">" ~ op[1..$])(b, a); 269 } 270 271 unittest 272 { 273 auto b = range(1, 2); 274 foreach (a; 0..4) 275 { 276 auto p0 = cmp(a, b).p; 277 double p1 = 0; 278 foreach (x; b.lo .. b.hi+1) 279 if (mixin("a" ~ op ~ "x")) 280 p1 += 0.5; 281 debug 282 { 283 import std.conv : text; 284 assert(p0 == p1, text(a, op, "b", " -> ", p0, " / ", p1)); 285 } 286 } 287 } 288 289 // Range-to-range 290 291 Probability cmp(A, B)(A a, B b) 292 if (is(A : Range!AV, AV) && is(B : Range!BV, BV)) 293 { 294 assert(a.uniform && b.uniform, "Can't compare non-uniform distributions"); 295 296 static if (op[0] == '<') 297 { 298 auto x0 = a.lo; 299 auto x1 = a.hi; 300 auto y0 = b.lo; 301 auto y1 = b.hi; 302 } 303 else 304 { 305 auto x0 = b.lo; 306 auto x1 = b.hi; 307 auto y0 = a.lo; 308 auto y1 = a.hi; 309 } 310 311 static if (is(typeof(x0 + y0) : long)) 312 { 313 x1++, y1++; 314 315 static if (op.length == 2) // >=, <= 316 y0++, y1++; 317 } 318 319 double p; 320 321 // No intersection 322 if (x1 <= y0) // x0 ≤ x1 ≤ y0 ≤ y1 323 p = 1; 324 else 325 if (y1 <= x0) // y0 ≤ y1 ≤ x0 ≤ x1 326 p = 0; 327 else 328 if (x0 <= y0) 329 { 330 // y is subset of x 331 if (x0 <= y0 && y1 <= x1) // x0 ≤ y0 ≤ y1 ≤ x1 332 p = ((y0 - x0) + ((y1 - y0) / 2.)) / (x1 - x0); 333 334 // x is mostly less than y 335 else // x0 ≤ y0 ≤ x1 ≤ y1 336 p = 1 - (((x1 - y0) * (x1 - y0)) / 2.) / ((x1 - x0) * (y1 - y0)); 337 } 338 else 339 if (y0 <= x0) 340 { 341 // x is subset of y 342 if (y0 <= x0 && x1 <= y1) // y0 ≤ x0 ≤ x1 ≤ y1 343 p = ((y1 - x1) + ((x1 - x0) / 2.)) / (y1 - y0); 344 345 // y is mostly less than x 346 else // y0 ≤ x0 ≤ y1 ≤ x1 347 p = (((y1 - x0) * (y1 - x0)) / 2.) / ((y1 - y0) * (x1 - x0)); 348 } 349 else 350 assert(false); 351 352 return Probability(p); 353 } 354 } 355 356 unittest 357 { 358 assert(cmp!">"(0, 1).p == 0 ); 359 assert(cmp!">"(1, 0).p == 1 ); 360 361 assert(cmp!"<"(1, 0).p == 0 ); 362 assert(cmp!"<"(0, 1).p == 1 ); 363 } 364 365 unittest 366 { 367 auto a = range(1.0, 3.0); 368 assert(cmp!"<"(a, 0.0).p == 0 ); 369 assert(cmp!"<"(a, 2.0).p == 0.5); 370 assert(cmp!"<"(a, 5.0).p == 1 ); 371 372 assert(cmp!">"(a, 0.0).p == 1 ); 373 assert(cmp!">"(a, 2.0).p == 0.5); 374 assert(cmp!">"(a, 5.0).p == 0 ); 375 } 376 377 unittest // number-to-range, int 378 { 379 auto a = range(1, 2); 380 381 assert(cmp!"<"(a, 1).p == 0 ); 382 assert(cmp!"<"(a, 2).p == 0.5); 383 assert(cmp!"<"(a, 3).p == 1 ); 384 385 assert(cmp!">"(a, 0).p == 1 ); 386 assert(cmp!">"(a, 1).p == 0.5); 387 assert(cmp!">"(a, 2).p == 0 ); 388 389 // instantiate template unittest 390 alias le = cmp!"<="; 391 alias ge = cmp!">="; 392 } 393 394 unittest 395 { 396 assert(cmp!"<" (range(0.), range(1.)).p == 1); 397 assert(cmp!"<" (range(1.), range(0.)).p == 0); 398 399 // assert(cmp!"<" (range(0.), range(0.)).p == 0); 400 // assert(cmp!"<="(range(0.), range(0.)).p == 1); 401 402 assert(cmp!"<" (range(0., 1.), range(0., 1.)).p == 0.5); 403 assert(cmp!"<" (range(0., 1.), range(2., 3.)).p == 1.0); 404 assert(cmp!"<" (range(2., 3.), range(0., 1.)).p == 0.0); 405 assert(cmp!"<" (range(0., 1.), range(0., 2.)).p == 0.75); 406 assert(cmp!"<" (range(0., 2.), range(1., 3.)).p == 7./8); 407 } 408 409 // **************************************************************************** 410 411 version (none) 412 { 413 /// A quantized representation of a the probability distribution of 414 /// some continuous function returning a value between 0 and 1. 415 struct QuantizedDistribution(size_t numSegments, P = float, V = double) 416 { 417 enum V minValue = 0.0; 418 enum V maxValue = 1.0; 419 420 /// Represents the relative probability that the function will 421 /// return a value in the represented interval. 422 P[numSegments] buckets = 1.0; 423 424 /// The length of one segment (of the function's return value) represented by one bucket. 425 private enum V bucketSize = (maxValue - minValue) / numSegments; 426 427 private static size_t toBucketIndex(V value) 428 { 429 assert(value >= minValue && value <= maxValue); 430 auto bucketIndex = cast(size_t)((value - minValue) / (maxValue - minValue) * numSegments); 431 assert(bucketIndex <= numSegments); 432 if (bucketIndex == numSegments) 433 bucketIndex = numSegments - 1; // 1.0 goes into the last bucket, together with 0.999... 434 return bucketIndex; 435 } 436 437 private V bucketLowValue(size_t bucketIndex) 438 { 439 return minValue + ((maxValue - minValue) * bucketIndex / numSegments); 440 } 441 private V bucketHighValue(size_t bucketIndex) 442 { 443 return bucketLowValue(bucketIndex + 1); 444 } 445 446 /// Normalizes `buckets` so that they add up to 1. 447 typeof(this) normalize() 448 { 449 typeof(this) result = this; 450 P sum = result.buckets[].sum; 451 result.buckets[] /= sum; 452 return result; 453 } 454 455 /// Call `fun` a `numSamples` number of times, and return a distribution representing the result. 456 static typeof(this) sample(V delegate() fun, size_t numSamples) 457 { 458 typeof(this) result; 459 result.buckets[] = 0; 460 foreach (_; 0 .. numSamples) 461 result.buckets[toBucketIndex(fun())]++; 462 return result; 463 } 464 465 Probability gt(V value) 466 { 467 auto bucketIndex = toBucketIndex(value); 468 auto lowValue = bucketLowValue (bucketIndex); 469 auto highValue = bucketHighValue(bucketIndex); 470 auto total = buckets[].sum; 471 return Probability(( 472 buckets[0 .. bucketIndex].sum + 473 itpl(0, buckets[bucketIndex], value, lowValue, highValue) 474 ) / total); 475 } 476 } 477 478 unittest 479 { 480 import std.random : Random, uniform, uniform01; 481 import std.math.operations : isClose; 482 483 auto rng = Random(0); 484 { 485 auto d = QuantizedDistribution!256.sample(() => uniform01!double(rng), 10_000); 486 assert(d.gt(0.25).p.between(0.24, 0.26)); 487 } 488 { 489 auto d = QuantizedDistribution!256.sample(() => uniform(0, 4) == 0, 10_000); 490 assert(d.gt(0.5).p.between(0.74, 0.76)); 491 } 492 } 493 }