Skip to content

Commit fcb3ad9

Browse files
committed
add tests
1 parent 3330c8c commit fcb3ad9

File tree

2 files changed

+51
-8
lines changed
  • multik-kotlin/src/commonTest/kotlin/org/jetbrains/kotlinx/multik_kotlin/linAlg
  • multik-openblas/src/commonTest/kotlin/org/jetbrains/kotlinx/multik/openblas/linalg

2 files changed

+51
-8
lines changed

multik-kotlin/src/commonTest/kotlin/org/jetbrains/kotlinx/multik_kotlin/linAlg/KELinAlgTest.kt

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,49 @@
55
package org.jetbrains.kotlinx.multik_kotlin.linAlg
66

77

8-
import org.jetbrains.kotlinx.multik.api.*
8+
import kotlin.math.abs
9+
import kotlin.math.max
10+
import kotlin.math.min
11+
import kotlin.random.Random
12+
import kotlin.test.Test
13+
import kotlin.test.assertContentEquals
14+
import kotlin.test.assertEquals
15+
import kotlin.test.assertFailsWith
16+
import kotlin.test.assertTrue
17+
import org.jetbrains.kotlinx.multik.api.d1array
18+
import org.jetbrains.kotlinx.multik.api.d2array
919
import org.jetbrains.kotlinx.multik.api.linalg.Norm
1020
import org.jetbrains.kotlinx.multik.api.linalg.dot
1121
import org.jetbrains.kotlinx.multik.api.linalg.norm
12-
import org.jetbrains.kotlinx.multik.kotlin.linalg.*
22+
import org.jetbrains.kotlinx.multik.api.mk
23+
import org.jetbrains.kotlinx.multik.api.ndarray
24+
import org.jetbrains.kotlinx.multik.api.zeros
25+
import org.jetbrains.kotlinx.multik.kotlin.linalg.KELinAlg
26+
import org.jetbrains.kotlinx.multik.kotlin.linalg.KELinAlgEx
1327
import org.jetbrains.kotlinx.multik.kotlin.linalg.KELinAlgEx.solve
1428
import org.jetbrains.kotlinx.multik.kotlin.linalg.KELinAlgEx.solveC
29+
import org.jetbrains.kotlinx.multik.kotlin.linalg.conjTranspose
30+
import org.jetbrains.kotlinx.multik.kotlin.linalg.dotMatrixComplex
31+
import org.jetbrains.kotlinx.multik.kotlin.linalg.gramShmidtComplexDouble
32+
import org.jetbrains.kotlinx.multik.kotlin.linalg.qrComplexDouble
33+
import org.jetbrains.kotlinx.multik.kotlin.linalg.schurDecomposition
34+
import org.jetbrains.kotlinx.multik.kotlin.linalg.upperHessenbergDouble
1535
import org.jetbrains.kotlinx.multik.ndarray.complex.Complex
1636
import org.jetbrains.kotlinx.multik.ndarray.complex.ComplexDouble
1737
import org.jetbrains.kotlinx.multik.ndarray.complex.ComplexFloat
1838
import org.jetbrains.kotlinx.multik.ndarray.complex.toComplexDouble
19-
import org.jetbrains.kotlinx.multik.ndarray.data.*
39+
import org.jetbrains.kotlinx.multik.ndarray.data.D1Array
40+
import org.jetbrains.kotlinx.multik.ndarray.data.D2
41+
import org.jetbrains.kotlinx.multik.ndarray.data.D2Array
42+
import org.jetbrains.kotlinx.multik.ndarray.data.DataType
43+
import org.jetbrains.kotlinx.multik.ndarray.data.Dim2
44+
import org.jetbrains.kotlinx.multik.ndarray.data.MultiArray
45+
import org.jetbrains.kotlinx.multik.ndarray.data.NDArray
46+
import org.jetbrains.kotlinx.multik.ndarray.data.get
47+
import org.jetbrains.kotlinx.multik.ndarray.data.set
2048
import org.jetbrains.kotlinx.multik.ndarray.operations.map
2149
import org.jetbrains.kotlinx.multik.ndarray.operations.minus
2250
import org.jetbrains.kotlinx.multik.ndarray.operations.plus
23-
import kotlin.math.abs
24-
import kotlin.math.max
25-
import kotlin.math.min
26-
import kotlin.random.Random
27-
import kotlin.test.*
2851

2952
class KELinAlgTest {
3053

@@ -492,6 +515,16 @@ class KELinAlgTest {
492515
}
493516

494517
}
518+
519+
@Test
520+
fun compute_norm_for_vector() {
521+
val vector = mk.ndarray(mk[1.1, 0.0, 3.2, 2.3, 5.0])
522+
523+
assertEquals(6.460650122085238, mk.linalg.norm(vector, Norm.Fro))
524+
assertEquals(11.600000000000001, mk.linalg.norm(vector, Norm.Inf))
525+
assertEquals(5.0, mk.linalg.norm(vector, Norm.N1))
526+
assertEquals(5.0, mk.linalg.norm(vector, Norm.Max))
527+
}
495528
}
496529

497530

multik-openblas/src/commonTest/kotlin/org/jetbrains/kotlinx/multik/openblas/linalg/NativeLinAlgTest.kt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,4 +324,14 @@ class NativeLinAlgTest {
324324
assertFloatingNumber(7.0, NativeLinAlg.norm(b, Norm.N1))
325325
assertFloatingNumber(4.0, NativeLinAlg.norm(b, Norm.Max))
326326
}
327+
328+
@Test
329+
fun `compute norm for vector`() {
330+
val vector = mk.ndarray(mk[1.1, 0.0, 3.2, 2.3, 5.0])
331+
332+
assertEquals(6.460650122085238, mk.linalg.norm(vector, Norm.Fro))
333+
assertEquals(11.600000000000001, mk.linalg.norm(vector, Norm.Inf))
334+
assertEquals(5.0, mk.linalg.norm(vector, Norm.N1))
335+
assertEquals(5.0, mk.linalg.norm(vector, Norm.Max))
336+
}
327337
}

0 commit comments

Comments
 (0)