@@ -5,10 +5,15 @@ import org.jetbrains.kotlinx.multik.cinterop.*
55
66@OptIn(ExperimentalForeignApi ::class )
77internal actual object JniMath {
8- actual fun argMin (arr : Any , offset : Int , size : Int , shape : IntArray , strides : IntArray? , dtype : Int ): Int =
9- argmin(StableRef .create(arr).asCPointer(), offset, size, shape.size, shape.toCValues(), strides?.toCValues(), dtype)
10- actual fun argMax (arr : Any , offset : Int , size : Int , shape : IntArray , strides : IntArray? , dtype : Int ): Int =
11- argmax(StableRef .create(arr).asCPointer(), offset, size, shape.size, shape.toCValues(), strides?.toCValues(), dtype)
8+ actual fun argMin (arr : Any , offset : Int , size : Int , shape : IntArray , strides : IntArray? , dtype : Int ): Int = when (arr) {
9+ is DoubleArray -> arr.usePinned { argmin(it.addressOf(0 ), offset, size, shape.size, shape.toCValues(), strides?.toCValues(), dtype) }
10+ is FloatArray -> arr.usePinned { argmin(it.addressOf(0 ), offset, size, shape.size, shape.toCValues(), strides?.toCValues(), dtype) }
11+ is IntArray -> arr.usePinned { argmin(it.addressOf(0 ), offset, size, shape.size, shape.toCValues(), strides?.toCValues(), dtype) }
12+ is LongArray -> arr.usePinned { argmin(it.addressOf(0 ), offset, size, shape.size, shape.toCValues(), strides?.toCValues(), dtype) }
13+ is ByteArray -> arr.usePinned { argmin(it.addressOf(0 ), offset, size, shape.size, shape.toCValues(), strides?.toCValues(), dtype) }
14+ is ShortArray -> arr.usePinned { argmin(it.addressOf(0 ), offset, size, shape.size, shape.toCValues(), strides?.toCValues(), dtype) }
15+ else -> throw Exception (" Only primitive arrays are supported for Kotlin/Native `argMin`" )
16+ }
1217
1318 actual fun exp (arr : FloatArray , size : Int ): Boolean {
1419 for (i in 0 until size) {
@@ -21,6 +26,15 @@ internal actual object JniMath {
2126 arr[i] = kotlin.math.exp(arr[i])
2227 }
2328 return true
29+ actual fun argMax (arr : Any , offset : Int , size : Int , shape : IntArray , strides : IntArray? , dtype : Int ): Int = when (arr) {
30+ is DoubleArray -> arr.usePinned { argmax(it.addressOf(0 ), offset, size, shape.size, shape.toCValues(), strides?.toCValues(), dtype) }
31+ is FloatArray -> arr.usePinned { argmax(it.addressOf(0 ), offset, size, shape.size, shape.toCValues(), strides?.toCValues(), dtype) }
32+ is IntArray -> arr.usePinned { argmax(it.addressOf(0 ), offset, size, shape.size, shape.toCValues(), strides?.toCValues(), dtype) }
33+ is LongArray -> arr.usePinned { argmax(it.addressOf(0 ), offset, size, shape.size, shape.toCValues(), strides?.toCValues(), dtype) }
34+ is ByteArray -> arr.usePinned { argmax(it.addressOf(0 ), offset, size, shape.size, shape.toCValues(), strides?.toCValues(), dtype) }
35+ is ShortArray -> arr.usePinned { argmax(it.addressOf(0 ), offset, size, shape.size, shape.toCValues(), strides?.toCValues(), dtype) }
36+ else -> throw Exception (" Only primitive arrays are supported for Kotlin/Native `argMin`" )
37+ }
2438 }
2539 actual fun expC (arr : FloatArray , size : Int ): Boolean {
2640 for (i in 0 until size step 2 ) {
0 commit comments