Skip to content

Commit bba680d

Browse files
committed
add annotations
1 parent 06caeca commit bba680d

File tree

4 files changed

+48
-19
lines changed

4 files changed

+48
-19
lines changed

src/main/java/org/algo4j/linear/ColumnVector.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
package org.algo4j.linear;
22

3+
import org.jetbrains.annotations.NotNull;
4+
35
@SuppressWarnings({"WeakerAccess", "unused"})
46
public class ColumnVector extends Vector {
5-
public ColumnVector(double[] data) {
7+
public ColumnVector(@NotNull double[] data) {
68
super(true, data);
79
}
810

11+
@NotNull
912
public RowVector toRowVector() {
1013
int L = nativeData().rows();
1114
double[] data = new double[L];
Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package org.algo4j.linear;
22

33
import org.algo4j.util.MatrixCaster;
4+
import org.jetbrains.annotations.NotNull;
45
import org.nd4j.linalg.api.ndarray.INDArray;
56
import org.nd4j.linalg.factory.Nd4j;
67
import org.nd4j.linalg.ops.transforms.Transforms;
@@ -16,98 +17,118 @@
1617
public class Matrix {
1718
private INDArray data;
1819

19-
public Matrix(int[][] data) {
20+
public Matrix(@NotNull int[][] data) {
2021
this.data = Nd4j.create(MatrixCaster.cast(data));
2122
}
2223

23-
public Matrix(float[][] data) {
24+
public Matrix(@NotNull float[][] data) {
2425
this.data = Nd4j.create(data);
2526
}
2627

27-
public Matrix(double[][] data) {
28+
public Matrix(@NotNull double[][] data) {
2829
this.data = Nd4j.create(data);
2930
}
3031

31-
Matrix(INDArray array) {
32+
Matrix(@NotNull INDArray array) {
3233
this.data = array;
3334
}
3435

36+
@NotNull
3537
public INDArray nativeData() {
3638
return data;
3739
}
3840

41+
@NotNull
3942
public Matrix addEach(int n) {
4043
return new Matrix(data.add(n));
4144
}
4245

46+
@NotNull
4347
public Matrix multiplyEach(int n) {
4448
return new Matrix(data.mul(n));
4549
}
4650

51+
@NotNull
4752
public Matrix minusEach(int n) {
4853
return new Matrix(data.subi(n));
4954
}
5055

56+
@NotNull
5157
public Matrix divideEach(int n) {
5258
return new Matrix(data.divi(n));
5359
}
5460

55-
public Matrix add(ColumnVector vector) {
61+
@NotNull
62+
public Matrix add(@NotNull ColumnVector vector) {
5663
return new Matrix(data.add(vector.nativeData()));
5764
}
5865

59-
public Matrix add(RowVector vector) {
66+
@NotNull
67+
public Matrix add(@NotNull RowVector vector) {
6068
return new Matrix(data.add(vector.nativeData()));
6169
}
6270

63-
public Matrix multiply(ColumnVector vector) {
71+
@NotNull
72+
public Matrix multiply(@NotNull ColumnVector vector) {
6473
return new Matrix(data.mul(vector.nativeData()));
6574
}
6675

67-
public Matrix multiply(RowVector vector) {
76+
@NotNull
77+
public Matrix multiply(@NotNull RowVector vector) {
6878
return new Matrix(data.mul(vector.nativeData()));
6979
}
7080

71-
public Matrix minus(ColumnVector vector) {
81+
@NotNull
82+
public Matrix minus(@NotNull ColumnVector vector) {
7283
return new Matrix(data.sub(vector.nativeData()));
7384
}
7485

75-
public Matrix minus(RowVector vector) {
86+
@NotNull
87+
public Matrix minus(@NotNull RowVector vector) {
7688
return new Matrix(data.sub(vector.nativeData()));
7789
}
7890

79-
public Matrix divide(ColumnVector vector) {
91+
@NotNull
92+
public Matrix divide(@NotNull ColumnVector vector) {
8093
return new Matrix(data.div(vector.nativeData()));
8194
}
8295

83-
public Matrix divide(RowVector vector) {
96+
@NotNull
97+
public Matrix divide(@NotNull RowVector vector) {
8498
return new Matrix(data.div(vector.nativeData()));
8599
}
86100

101+
@NotNull
87102
public Matrix sigmoid() {
88103
return new Matrix(Transforms.sigmoid(data));
89104
}
90105

106+
@NotNull
91107
public Matrix tanh() {
92108
return new Matrix(Transforms.tanh(data));
93109
}
94110

111+
@NotNull
95112
public Matrix abs() {
96113
return new Matrix(Transforms.abs(data));
97114
}
98115

116+
@NotNull
99117
public Matrix sqrt() {
100118
return new Matrix(Transforms.sqrt(data));
101119
}
102120

121+
@NotNull
103122
public Matrix exp() {
104123
return new Matrix(Transforms.exp(data));
105124
}
106125

126+
@NotNull
107127
public Matrix transpose() {
108128
return new Matrix(data.transpose());
109129
}
110130

131+
@NotNull
111132
public Matrix reshape(int rowN, int columnN) {
112133
return new Matrix(data.reshape(rowN, columnN));
113134
}
@@ -116,14 +137,14 @@ public double get(int row, int column) {
116137
return data.getDouble(row, column);
117138
}
118139

140+
@NotNull
119141
public double[][] cast() {
120142
return MatrixCaster.cast(this);
121143
}
122144

123145
@Override
146+
@NotNull
124147
public String toString() {
125-
return "Matrix{" +
126-
"data=" + data.toString() +
127-
'}';
148+
return "Matrix{data=" + data.toString() + "}";
128149
}
129150
}

src/main/java/org/algo4j/linear/RowVector.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
package org.algo4j.linear;
22

33
import org.algo4j.util.MatrixCaster;
4+
import org.jetbrains.annotations.NotNull;
45

56
@SuppressWarnings({"WeakerAccess", "unused"})
67
public class RowVector extends Vector {
7-
public RowVector(double[] data) {
8+
public RowVector(@NotNull double[] data) {
89
super(false, data);
910
}
1011

12+
@NotNull
1113
public ColumnVector toColumnVector() {
1214
return new ColumnVector(MatrixCaster.cast(new Matrix(nativeData()))[0]);
1315
}

src/main/java/org/algo4j/util/MatrixCaster.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
package org.algo4j.util;
22

33
import org.algo4j.linear.Matrix;
4+
import org.jetbrains.annotations.NotNull;
45
import org.nd4j.linalg.api.ndarray.INDArray;
56

67
public class MatrixCaster {
78
private MatrixCaster() {
89
throw new Error("do not instantiation me");
910
}
1011

11-
public static float[][] cast(int[][] in) {
12+
@NotNull
13+
public static float[][] cast(@NotNull int[][] in) {
1214
int inL = in.length;
1315
float[][] out = new float[inL][in[0].length];
1416

@@ -21,7 +23,8 @@ public static float[][] cast(int[][] in) {
2123
return out;
2224
}
2325

24-
public static double[][] cast(Matrix inM) {
26+
@NotNull
27+
public static double[][] cast(@NotNull Matrix inM) {
2528
INDArray in = inM.nativeData();
2629
int l = in.length();
2730
double[][] out = new double[l][in.size(0)];

0 commit comments

Comments
 (0)