Skip to content

Commit 1555803

Browse files
committed
fix taking address in argmin and argmax
fix #189
1 parent a3dc954 commit 1555803

File tree

1 file changed

+18
-4
lines changed
  • multik-openblas/src/nativeMain/kotlin/org.jetbrains.kotlinx.multik/openblas/math

1 file changed

+18
-4
lines changed

multik-openblas/src/nativeMain/kotlin/org.jetbrains.kotlinx.multik/openblas/math/JniMath.kt

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,15 @@ import org.jetbrains.kotlinx.multik.cinterop.*
55

66
@OptIn(ExperimentalForeignApi::class)
77
internal actual object JniMath {
8-
actual fun argMin(arr: Any, offset: Int, size: Int, shape: IntArray, strides: IntArray?, dtype: Int): Int =
9-
argmin(StableRef.create(arr).asCPointer(), offset, size, shape.size, shape.toCValues(), strides?.toCValues(), dtype)
10-
actual fun argMax(arr: Any, offset: Int, size: Int, shape: IntArray, strides: IntArray?, dtype: Int): Int =
11-
argmax(StableRef.create(arr).asCPointer(), offset, size, shape.size, shape.toCValues(), strides?.toCValues(), dtype)
8+
actual fun argMin(arr: Any, offset: Int, size: Int, shape: IntArray, strides: IntArray?, dtype: Int): Int = when(arr) {
9+
is DoubleArray -> arr.usePinned { argmin(it.addressOf(0), offset, size, shape.size, shape.toCValues(), strides?.toCValues(), dtype) }
10+
is FloatArray -> arr.usePinned { argmin(it.addressOf(0), offset, size, shape.size, shape.toCValues(), strides?.toCValues(), dtype) }
11+
is IntArray -> arr.usePinned { argmin(it.addressOf(0), offset, size, shape.size, shape.toCValues(), strides?.toCValues(), dtype) }
12+
is LongArray -> arr.usePinned { argmin(it.addressOf(0), offset, size, shape.size, shape.toCValues(), strides?.toCValues(), dtype) }
13+
is ByteArray -> arr.usePinned { argmin(it.addressOf(0), offset, size, shape.size, shape.toCValues(), strides?.toCValues(), dtype) }
14+
is ShortArray -> arr.usePinned { argmin(it.addressOf(0), offset, size, shape.size, shape.toCValues(), strides?.toCValues(), dtype) }
15+
else -> throw Exception("Only primitive arrays are supported for Kotlin/Native `argMin`")
16+
}
1217

1318
actual fun exp(arr: FloatArray, size: Int): Boolean {
1419
for (i in 0 until size) {
@@ -21,6 +26,15 @@ internal actual object JniMath {
2126
arr[i] = kotlin.math.exp(arr[i])
2227
}
2328
return true
29+
actual fun argMax(arr: Any, offset: Int, size: Int, shape: IntArray, strides: IntArray?, dtype: Int): Int = when(arr) {
30+
is DoubleArray -> arr.usePinned { argmax(it.addressOf(0), offset, size, shape.size, shape.toCValues(), strides?.toCValues(), dtype) }
31+
is FloatArray -> arr.usePinned { argmax(it.addressOf(0), offset, size, shape.size, shape.toCValues(), strides?.toCValues(), dtype) }
32+
is IntArray -> arr.usePinned { argmax(it.addressOf(0), offset, size, shape.size, shape.toCValues(), strides?.toCValues(), dtype) }
33+
is LongArray -> arr.usePinned { argmax(it.addressOf(0), offset, size, shape.size, shape.toCValues(), strides?.toCValues(), dtype) }
34+
is ByteArray -> arr.usePinned { argmax(it.addressOf(0), offset, size, shape.size, shape.toCValues(), strides?.toCValues(), dtype) }
35+
is ShortArray -> arr.usePinned { argmax(it.addressOf(0), offset, size, shape.size, shape.toCValues(), strides?.toCValues(), dtype) }
36+
else -> throw Exception("Only primitive arrays are supported for Kotlin/Native `argMin`")
37+
}
2438
}
2539
actual fun expC(arr: FloatArray, size: Int): Boolean {
2640
for (i in 0 until size step 2) {

0 commit comments

Comments
 (0)