Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions source/mir/math/stat.d
Original file line number Diff line number Diff line change
Expand Up @@ -1549,6 +1549,17 @@ unittest
assert(x4.median.approxEqual(1));
}

// Check issue #328 fixed
version(mir_test)
@safe pure nothrow
unittest {
import mir.ndslice.topology: iota;

auto x = iota(18);
auto y = median(x);
assert(y == 8.5);
}

private pure @trusted nothrow @nogc
F smallMedianImpl(F, Iterator)(Slice!Iterator slice)
{
Expand Down
87 changes: 63 additions & 24 deletions source/mir/ndslice/sorting.d
Original file line number Diff line number Diff line change
Expand Up @@ -1099,16 +1099,27 @@ unittest {
assert(x[nth] == 2);
}

// Check issue #328 fixed
version(mir_test)
@safe pure nothrow
unittest {
import mir.ndslice.slice: sliced;

auto slice = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17].sliced;
partitionAt(slice, 8);
partitionAt(slice, 9);
}

version(unittest) {
template checkTopNAll(alias less = "a < b")
template checkPartitionAtAll(alias less = "a < b")
{
import mir.functional: naryFun;
import mir.ndslice.slice: SliceKind, Slice;

static if (__traits(isSame, naryFun!less, less))
{
@safe pure nothrow
static bool checkTopNAll
static bool checkPartitionAtAll
(Iterator, SliceKind kind)(
Slice!(Iterator, 1, kind) x)
{
Expand All @@ -1129,7 +1140,7 @@ version(unittest) {
return result;
}
} else {
alias checkTopNAll = .checkTopNAll!(naryFun!less);
alias checkPartitionAtAll = .checkPartitionAtAll!(naryFun!less);
}
}
}
Expand All @@ -1139,25 +1150,25 @@ version(mir_test)
unittest {
import mir.ndslice.slice: sliced;

assert(checkTopNAll([2, 2].sliced));
assert(checkPartitionAtAll([2, 2].sliced));

assert(checkTopNAll([3, 1, 5, 2, 0].sliced));
assert(checkTopNAll([3, 1, 5, 0, 2].sliced));
assert(checkTopNAll([0, 0, 4, 3, 3].sliced));
assert(checkTopNAll([5, 1, 5, 1, 5].sliced));
assert(checkTopNAll([2, 2, 0, 0, 0].sliced));

assert(checkTopNAll([ 2, 12, 10, 8, 1, 20, 19, 1, 2, 7].sliced));
assert(checkTopNAll([ 4, 18, 16, 0, 15, 6, 2, 17, 10, 16].sliced));
assert(checkTopNAll([ 7, 5, 9, 4, 4, 2, 12, 20, 15, 15].sliced));

assert(checkTopNAll([17, 87, 58, 50, 34, 98, 25, 77, 88, 79].sliced));

assert(checkTopNAll([ 6, 7, 10, 25, 5, 10, 9, 0, 2, 15, 7, 9, 11, 8, 13, 18, 17, 13, 25, 22].sliced));
assert(checkTopNAll([21, 3, 11, 22, 24, 12, 14, 12, 15, 15, 1, 3, 12, 15, 25, 19, 9, 16, 16, 19].sliced));
assert(checkTopNAll([22, 6, 18, 0, 1, 8, 13, 13, 16, 19, 23, 17, 4, 6, 12, 24, 15, 20, 11, 17].sliced));
assert(checkTopNAll([19, 23, 14, 5, 12, 3, 13, 7, 25, 25, 24, 9, 21, 25, 12, 22, 15, 22, 7, 11].sliced));
assert(checkTopNAll([ 0, 2, 7, 16, 2, 20, 1, 11, 17, 5, 22, 17, 25, 13, 14, 5, 22, 21, 24, 14].sliced));
assert(checkPartitionAtAll([3, 1, 5, 2, 0].sliced));
assert(checkPartitionAtAll([3, 1, 5, 0, 2].sliced));
assert(checkPartitionAtAll([0, 0, 4, 3, 3].sliced));
assert(checkPartitionAtAll([5, 1, 5, 1, 5].sliced));
assert(checkPartitionAtAll([2, 2, 0, 0, 0].sliced));

assert(checkPartitionAtAll([ 2, 12, 10, 8, 1, 20, 19, 1, 2, 7].sliced));
assert(checkPartitionAtAll([ 4, 18, 16, 0, 15, 6, 2, 17, 10, 16].sliced));
assert(checkPartitionAtAll([ 7, 5, 9, 4, 4, 2, 12, 20, 15, 15].sliced));

assert(checkPartitionAtAll([17, 87, 58, 50, 34, 98, 25, 77, 88, 79].sliced));

assert(checkPartitionAtAll([ 6, 7, 10, 25, 5, 10, 9, 0, 2, 15, 7, 9, 11, 8, 13, 18, 17, 13, 25, 22].sliced));
assert(checkPartitionAtAll([21, 3, 11, 22, 24, 12, 14, 12, 15, 15, 1, 3, 12, 15, 25, 19, 9, 16, 16, 19].sliced));
assert(checkPartitionAtAll([22, 6, 18, 0, 1, 8, 13, 13, 16, 19, 23, 17, 4, 6, 12, 24, 15, 20, 11, 17].sliced));
assert(checkPartitionAtAll([19, 23, 14, 5, 12, 3, 13, 7, 25, 25, 24, 9, 21, 25, 12, 22, 15, 22, 7, 11].sliced));
assert(checkPartitionAtAll([ 0, 2, 7, 16, 2, 20, 1, 11, 17, 5, 22, 17, 25, 13, 14, 5, 22, 21, 24, 14].sliced));
}

private @trusted pure nothrow @nogc
Expand Down Expand Up @@ -1372,6 +1383,35 @@ unittest {
assert(x[nth] == 10);
}

// Check all partitionAt
version(mir_test)
@trusted pure nothrow
unittest {
import mir.ndslice.slice: sliced;
import mir.ndslice.allocation: slice;

static immutable raw = [ 6, 7, 10, 25, 5, 10, 9, 0, 2, 15, 7, 9, 11, 8, 13, 18, 17, 13, 25, 22];

static void fill(T)(T x) {
for (size_t i = 0; i < x.length; i++) {
x[i] = raw[i];
}
}
auto x = slice!int(raw.length);
fill(x);
auto x_sort = x.dup;
x_sort = x_sort.sort;
size_t i = 0;
while (i < raw.length) {
auto frontI = x._iterator;
auto lastI = frontI + x.length - 1;
partitionAtImpl!((a, b) => (a < b))(frontI, lastI, i, true);
assert(x[i] == x_sort[i]);
fill(x);
i++;
}
}

private @trusted pure nothrow @nogc
Iterator partitionAtPartition(alias less, Iterator)(
ref Iterator frontI,
Expand All @@ -1381,7 +1421,7 @@ Iterator partitionAtPartition(alias less, Iterator)(
{
size_t len = lastI - frontI + 1;

assert(len >= 9 && n < len, "partitionAtImpl: length must be longer than 9 and n must be less than r.length");
assert(len >= 9 && n < len, "partitionAtPartition: length must be longer than 9 and n must be less than r.length");

size_t ninth = len / 9;
size_t pivot = ninth / 2;
Expand All @@ -1397,7 +1437,7 @@ Iterator partitionAtPartition(alias less, Iterator)(
// We have either one straggler on the left, one on the right, or none.
assert(loI - frontI <= lastI - hiI + 1 || lastI - hiI <= loI - frontI + 1, "partitionAtPartition: straggler check failed for loI, len, hiI");
assert(loI - frontI >= ninth * 4, "partitionAtPartition: loI - frontI >= ninth * 4");
assert(lastI - hiI >= ninth * 4, "partitionAtPartition: lastI - hiI >= ninth * 4");
assert((lastI + 1) - hiI >= ninth * 4, "partitionAtPartition: (lastI + 1) - hiI >= ninth * 4");

// Partition in groups of 3, and the mid tertile again in groups of 3
if (!useSampling) {
Expand Down Expand Up @@ -1428,7 +1468,6 @@ version(mir_test)
@trusted pure nothrow
unittest {
import mir.ndslice.slice: sliced;

auto x = [ 6, 7, 10, 25, 5, 10, 9, 0, 2, 15, 7, 9, 11, 8, 13, 18, 17, 13, 25, 22].sliced;
auto x_sort = x.dup;
x_sort = x_sort.sort;
Expand Down