- Notifications
You must be signed in to change notification settings - Fork 36
Add gmean #259
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add gmean #259
Changes from 1 commit
7238ae5 407a8c2 c6ba9d5 3169fc0 2fa4ec0 1522469 7c7d16e 98cac69 89bf503 File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| | @@ -24,7 +24,8 @@ import mir.math.common: fmamath; | |
| import mir.math.sum; | ||
| import mir.math.numeric: ProdAlgo; | ||
| import mir.primitives; | ||
| import std.traits: isArray, isFloatingPoint, isMutable, isIterable, isIntegral; | ||
| import std.traits: isArray, isMutable, isIterable, isIntegral; | ||
| import mir.internal.utility: isFloatingPoint; | ||
| | ||
| // version = mir_test_topN; | ||
| | ||
| | @@ -444,116 +445,41 @@ unittest | |
| } | ||
| | ||
| private | ||
| I nthroot(I, J)(in I x, in J n) | ||
| if (isIntegral!(I) && isIntegral!J) | ||
| F nthroot(F)(in F x, in size_t n) | ||
| if (isFloatingPoint!F) | ||
| { | ||
| import mir.math.common: powi, pow; | ||
| import std.traits: Largest; | ||
| | ||
| assert(x > 0 && n > 0, "nthroot: Can only take nth root of positive numbers with n > 0"); | ||
| assert(n <= int.max, "nthroot: powi can only handle powers that fit in an int"); | ||
| import mir.math.common: sqrt, pow; | ||
| | ||
| if (x < 2) return n; | ||
| | ||
| static if (is(J == Largest!(J, int))) | ||
| int n1 = cast(int) n - 1; | ||
| else | ||
| J n1 = n - 1; | ||
| | ||
| I n2 = I(n); | ||
| I n3 = I(n - 1); | ||
| I c = I(1); | ||
| I d = (n3 + x) / n2; | ||
| assert(d <= pow(cast(double) I.max, cast(double) 1 / cast(double) n1), "nthroot: the value of d would result in overflow"); | ||
| I e = (n3 * d + x / powi(d, n1)) / n2; | ||
| while (c != d && c != e) { | ||
| c = d; | ||
| d = e; | ||
| assert(e <= pow(cast(double) I.max, cast(double) 1 / cast(double) n1), "nthroot: the value of e would result in overflow"); | ||
| e = (n3 * e + x / powi(e, n1)) / n2; | ||
| } | ||
| if (d < e) return d; | ||
| return e; | ||
| } | ||
| | ||
| version(mir_test_gmean) | ||
| @safe @nogc pure nothrow | ||
| unittest { | ||
| assert(nthroot(9, 2) == 3); | ||
| assert(nthroot(8, 3) == 2); | ||
| assert(nthroot(9, 3) == 2); | ||
| } | ||
| assert(x > 0, "nthroot: Can only take nth root of positive numbers"); | ||
| | ||
| version(mir_test_gmean) | ||
| @safe @nogc pure nothrow | ||
| unittest { | ||
| import mir.ndslice.topology: repeat; | ||
| | ||
| auto x = cast(ulong) uint.max * uint.max; | ||
| assert(nthroot(x, 2) == uint.max); | ||
| } | ||
| | ||
| private | ||
| F nthroot(I, F)(in I x, in F n) | ||
| if (isIntegral!I && isFloatingPoint!F) | ||
| { | ||
| import std.traits: Unqual; | ||
| | ||
| return nthroot(cast(Unqual!F) x, cast(Unqual!F) n); | ||
| if (n > 2) { | ||
| return pow(x, cast(F) 1 / cast(F) n); | ||
| } else if (n == 2) { | ||
| return sqrt(x); | ||
| } else if (n == 1) { | ||
| return x; | ||
| } else { | ||
| return cast(F) 1; | ||
| } | ||
| } | ||
| | ||
| version(mir_test_gmean) | ||
| version(mir_test) | ||
| @safe @nogc pure nothrow | ||
| unittest { | ||
| import mir.math.common: approxEqual; | ||
| assert(nthroot(9, 2.5).approxEqual(2.40822468)); | ||
| } | ||
| | ||
| private | ||
| F nthroot(F, I)(in F x, in I n) | ||
| if (isFloatingPoint!F && isIntegral!I) | ||
| { | ||
| import std.traits: Unqual; | ||
| | ||
| return nthroot(x, cast(Unqual!F) n); | ||
| } | ||
| | ||
| version(mir_test_gmean) | ||
| @safe @nogc pure nothrow | ||
| unittest { | ||
| import mir.math.common: approxEqual; | ||
| assert(nthroot(9.0, 0).approxEqual(1)); | ||
| assert(nthroot(9.0, 1).approxEqual(9)); | ||
| assert(nthroot(9.0, 2).approxEqual(3)); | ||
| assert(nthroot(9.5, 2).approxEqual(3.08220700)); | ||
| } | ||
| | ||
| private | ||
| F nthroot(F, G)(in F x, in G n) | ||
| if (isFloatingPoint!F && isFloatingPoint!G) | ||
| { | ||
| import mir.math.common: sqrt; | ||
| import mir.math.common: pow; | ||
| | ||
| assert(x > 0 && n > 0, "nthroot: Can only take nth root of positive numbers with n > 0"); | ||
| | ||
| if (n == 2) { | ||
| return sqrt(x); | ||
| } else { | ||
| return pow(x, cast(G) 1 / n); | ||
| } | ||
| } | ||
| | ||
| version(mir_test_gmean) | ||
| @safe @nogc pure nothrow | ||
| unittest { | ||
| import mir.math.common: approxEqual; | ||
| assert(nthroot(9.5, 2.0).approxEqual(3.08220700)); | ||
| assert(nthroot(9.5, 2.5).approxEqual(2.46087436)); | ||
| assert(nthroot(9.0, 3).approxEqual(2.08008382)); | ||
| } | ||
| | ||
| /++ | ||
| Output range for gmean. | ||
| +/ | ||
| struct GMeanAccumulator(T, ProdAlgo prodAlgo) | ||
| if (isMutable!T) | ||
| if (isMutable!T && isFloatingPoint!T) | ||
| { | ||
| import mir.math.numeric: ProdAccumulator; | ||
| | ||
| | @@ -564,8 +490,9 @@ struct GMeanAccumulator(T, ProdAlgo prodAlgo) | |
| | ||
| /// | ||
| F gmean(F = T)() @property | ||
| if (isFloatingPoint!F) | ||
| { | ||
| return nthroot(cast(F) prodAccumulator.prod, cast(F) count); | ||
| return nthroot(cast(F) prodAccumulator.prod, count); | ||
| ||
| } | ||
| | ||
| /// | ||
| | @@ -596,7 +523,7 @@ struct GMeanAccumulator(T, ProdAlgo prodAlgo) | |
| } | ||
| | ||
| /// | ||
| version(mir_test_gmean) | ||
| version(mir_test) | ||
| @safe pure nothrow | ||
| unittest | ||
| { | ||
| | @@ -610,7 +537,7 @@ unittest | |
| assert(x.gmean.approxEqual(2.60517108)); | ||
| } | ||
| | ||
| version(mir_test_gmean) | ||
| version(mir_test) | ||
| @safe pure nothrow | ||
| unittest | ||
| { | ||
| | @@ -629,15 +556,17 @@ package template gmeanType(T) | |
| import mir.math.numeric: prodType; | ||
| | ||
| alias U = prodType!T; | ||
| static assert(isFloatingPoint!U, "gmeanType: U must be a floating point type, not " ~ U.stringof); | ||
| | ||
| static if (__traits(compiles, { | ||
| auto temp = U.init * U.init; | ||
| auto a = nthroot(temp, cast(U) 2); | ||
| auto a = nthroot(temp, 2); | ||
| temp *= U.init; | ||
| a = nthroot(temp, cast(U) 3); | ||
| a = nthroot(temp, 3); | ||
| })) | ||
| alias gmeanType = typeof(nthroot(U.init * U.init, cast(U) 2)); | ||
| alias gmeanType = typeof(nthroot(U.init * U.init, 2)); | ||
| else | ||
| static assert(0, "Can't gmean elements of type " ~ U.stringof); | ||
| static assert(0, "gmeanType: Can't gmean elements of type " ~ U.stringof); | ||
| } | ||
| | ||
| /++ | ||
| | @@ -649,17 +578,19 @@ Returns: | |
| See_also: $(SUBREF prod, ProdAlgo) | ||
| +/ | ||
| template gmean(F, ProdAlgo prodAlgo = ProdAlgo.appropriate) | ||
| if (isFloatingPoint!F) | ||
| { | ||
| import mir.math.numeric: ResolveProdAlgoType; | ||
| | ||
| /++ | ||
| Params: | ||
| r = range, must be finite iterable | ||
| +/ | ||
| @fmamath F gmean(Range)(Range r) | ||
| @fmamath gmeanType!F gmean(Range)(Range r) | ||
| if (isIterable!Range) | ||
| { | ||
| GMeanAccumulator!(F, ResolveProdAlgoType!(prodAlgo, F)) gmean; | ||
| alias G = typeof(return); | ||
| GMeanAccumulator!(G, ResolveProdAlgoType!(prodAlgo, G)) gmean; | ||
| gmean.put(r.move); | ||
| return gmean.gmean; | ||
| } | ||
| | @@ -668,9 +599,10 @@ template gmean(F, ProdAlgo prodAlgo = ProdAlgo.appropriate) | |
| Params: | ||
| val = values | ||
| +/ | ||
| @fmamath F gmean(scope const F[] val...) | ||
| @fmamath gmeanType!F gmean(scope const F[] val...) | ||
| { | ||
| GMeanAccumulator!(F, ResolveProdAlgoType!(prodAlgo, F)) gmean; | ||
| alias G = typeof(return); | ||
| GMeanAccumulator!(G, ResolveProdAlgoType!(prodAlgo, G)) gmean; | ||
| gmean.put(val); | ||
| return gmean.gmean; | ||
| } | ||
| | @@ -688,8 +620,8 @@ template gmean(ProdAlgo prodAlgo = ProdAlgo.appropriate) | |
| @fmamath gmeanType!Range gmean(Range)(Range r) | ||
| if (isIterable!Range) | ||
| { | ||
| alias F = typeof(return); | ||
| return .gmean!(F, prodAlgo)(r.move); | ||
| alias G = typeof(return); | ||
| return .gmean!(G, prodAlgo)(r.move); | ||
| } | ||
| | ||
| /++ | ||
| | @@ -700,13 +632,14 @@ template gmean(ProdAlgo prodAlgo = ProdAlgo.appropriate) | |
| if (T.length > 0 && | ||
| !is(CommonType!T == void)) | ||
| { | ||
| alias F = typeof(return); | ||
| return .gmean!(F, prodAlgo)(val); | ||
| alias G = typeof(return); | ||
| return .gmean!(G, prodAlgo)(val); | ||
| } | ||
| } | ||
| | ||
| /// ditto | ||
| template gmean(F, string prodAlgo) | ||
| if (isFloatingPoint!F) | ||
| { | ||
| mixin("alias gmean = .gmean!(F, ProdAlgo." ~ prodAlgo ~ ");"); | ||
| } | ||
| | @@ -718,7 +651,7 @@ template gmean(string prodAlgo) | |
| } | ||
| | ||
| /// | ||
| version(mir_test_gmean) | ||
| version(mir_test) | ||
| @safe pure nothrow | ||
| unittest | ||
| { | ||
| | @@ -733,7 +666,7 @@ unittest | |
| } | ||
| | ||
| /// Geometric mean of vector | ||
| version(mir_test_gmean) | ||
| version(mir_test) | ||
| @safe pure nothrow | ||
| unittest | ||
| { | ||
| | @@ -747,7 +680,7 @@ unittest | |
| } | ||
| | ||
| /// Geometric mean of matrix | ||
| version(mir_test_gmean) | ||
| version(mir_test) | ||
| @safe pure | ||
| unittest | ||
| { | ||
| | @@ -763,7 +696,7 @@ unittest | |
| } | ||
| | ||
| /// Column gmean of matrix | ||
| version(mir_test_gmean) | ||
| version(mir_test) | ||
| @safe pure | ||
| unittest | ||
| { | ||
| | @@ -789,7 +722,7 @@ unittest | |
| } | ||
| | ||
| /// Can also set algorithm or output type | ||
| version(mir_test_gmean) | ||
| version(mir_test) | ||
| @safe pure nothrow | ||
| unittest | ||
| { | ||
| | @@ -803,25 +736,14 @@ unittest | |
| assert(x.gmean!(float, "separateExponentAccumulation").approxEqual(259281.45295212)); | ||
| | ||
| auto y = uint.max.repeat(2); | ||
| assert(y.gmean!ulong == uint.max); | ||
| assert(y.gmean!float.approxEqual(cast(float) uint.max)); | ||
| } | ||
| | ||
| version(mir_test_gmean) | ||
| @safe pure nothrow | ||
| unittest | ||
| { | ||
| import mir.ndslice.topology: repeat; | ||
| | ||
| auto y = uint.max.repeat(3); | ||
| assert(y.gmean!ulong == uint.max); | ||
| } | ||
| | ||
| /++ | ||
| For integral slices, pass output type as template parameter to ensure output | ||
| type is correct | ||
| +/ | ||
| version(mir_test_gmean) | ||
| version(mir_test) | ||
| @safe pure nothrow | ||
| unittest | ||
| { | ||
| | @@ -834,7 +756,7 @@ unittest | |
| } | ||
| | ||
| /// Mean works for user-defined types, provided the nth root can be taken for them | ||
| version(mir_test_gmean) | ||
| version(mir_test) | ||
| @safe pure nothrow | ||
| unittest | ||
| { | ||
| | @@ -851,7 +773,7 @@ unittest | |
| } | ||
| | ||
| /// Compute gmean tensors along specified dimention of tensors | ||
| version(mir_test_gmean) | ||
| version(mir_test) | ||
| @safe pure | ||
| unittest | ||
| { | ||
| | @@ -893,7 +815,7 @@ unittest | |
| } | ||
| | ||
| /// Arbitrary gmean | ||
| version(mir_test_gmean) | ||
| version(mir_test) | ||
| @safe pure nothrow @nogc | ||
| unittest | ||
| { | ||
| | @@ -902,7 +824,7 @@ unittest | |
| assert(gmean!float(1, 2, 3).approxEqual(1.81712059)); | ||
| } | ||
| | ||
| version(mir_test_gmean) | ||
| version(mir_test) | ||
| @safe pure nothrow | ||
| unittest | ||
| { | ||
| | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sqrtandpoware defined and return NaN, let's give them to return it.