11package org .algo4j .linear ;
22
33import org .algo4j .util .MatrixCaster ;
4+ import org .jetbrains .annotations .NotNull ;
45import org .nd4j .linalg .api .ndarray .INDArray ;
56import org .nd4j .linalg .factory .Nd4j ;
67import org .nd4j .linalg .ops .transforms .Transforms ;
1617public class Matrix {
1718private INDArray data ;
1819
19- public Matrix (int [][] data ) {
20+ public Matrix (@ NotNull int [][] data ) {
2021this .data = Nd4j .create (MatrixCaster .cast (data ));
2122}
2223
23- public Matrix (float [][] data ) {
24+ public Matrix (@ NotNull float [][] data ) {
2425this .data = Nd4j .create (data );
2526}
2627
27- public Matrix (double [][] data ) {
28+ public Matrix (@ NotNull double [][] data ) {
2829this .data = Nd4j .create (data );
2930}
3031
31- Matrix (INDArray array ) {
32+ Matrix (@ NotNull INDArray array ) {
3233this .data = array ;
3334}
3435
36+ @ NotNull
3537public INDArray nativeData () {
3638return data ;
3739}
3840
41+ @ NotNull
3942public Matrix addEach (int n ) {
4043return new Matrix (data .add (n ));
4144}
4245
46+ @ NotNull
4347public Matrix multiplyEach (int n ) {
4448return new Matrix (data .mul (n ));
4549}
4650
51+ @ NotNull
4752public Matrix minusEach (int n ) {
4853return new Matrix (data .subi (n ));
4954}
5055
56+ @ NotNull
5157public Matrix divideEach (int n ) {
5258return new Matrix (data .divi (n ));
5359}
5460
55- public Matrix add (ColumnVector vector ) {
61+ @ NotNull
62+ public Matrix add (@ NotNull ColumnVector vector ) {
5663return new Matrix (data .add (vector .nativeData ()));
5764}
5865
59- public Matrix add (RowVector vector ) {
66+ @ NotNull
67+ public Matrix add (@ NotNull RowVector vector ) {
6068return new Matrix (data .add (vector .nativeData ()));
6169}
6270
63- public Matrix multiply (ColumnVector vector ) {
71+ @ NotNull
72+ public Matrix multiply (@ NotNull ColumnVector vector ) {
6473return new Matrix (data .mul (vector .nativeData ()));
6574}
6675
67- public Matrix multiply (RowVector vector ) {
76+ @ NotNull
77+ public Matrix multiply (@ NotNull RowVector vector ) {
6878return new Matrix (data .mul (vector .nativeData ()));
6979}
7080
71- public Matrix minus (ColumnVector vector ) {
81+ @ NotNull
82+ public Matrix minus (@ NotNull ColumnVector vector ) {
7283return new Matrix (data .sub (vector .nativeData ()));
7384}
7485
75- public Matrix minus (RowVector vector ) {
86+ @ NotNull
87+ public Matrix minus (@ NotNull RowVector vector ) {
7688return new Matrix (data .sub (vector .nativeData ()));
7789}
7890
79- public Matrix divide (ColumnVector vector ) {
91+ @ NotNull
92+ public Matrix divide (@ NotNull ColumnVector vector ) {
8093return new Matrix (data .div (vector .nativeData ()));
8194}
8295
83- public Matrix divide (RowVector vector ) {
96+ @ NotNull
97+ public Matrix divide (@ NotNull RowVector vector ) {
8498return new Matrix (data .div (vector .nativeData ()));
8599}
86100
101+ @ NotNull
87102public Matrix sigmoid () {
88103return new Matrix (Transforms .sigmoid (data ));
89104}
90105
106+ @ NotNull
91107public Matrix tanh () {
92108return new Matrix (Transforms .tanh (data ));
93109}
94110
111+ @ NotNull
95112public Matrix abs () {
96113return new Matrix (Transforms .abs (data ));
97114}
98115
116+ @ NotNull
99117public Matrix sqrt () {
100118return new Matrix (Transforms .sqrt (data ));
101119}
102120
121+ @ NotNull
103122public Matrix exp () {
104123return new Matrix (Transforms .exp (data ));
105124}
106125
126+ @ NotNull
107127public Matrix transpose () {
108128return new Matrix (data .transpose ());
109129}
110130
131+ @ NotNull
111132public Matrix reshape (int rowN , int columnN ) {
112133return new Matrix (data .reshape (rowN , columnN ));
113134}
@@ -116,14 +137,14 @@ public double get(int row, int column) {
116137return data .getDouble (row , column );
117138}
118139
140+ @ NotNull
119141public double [][] cast () {
120142return MatrixCaster .cast (this );
121143}
122144
123145@ Override
146+ @ NotNull
124147public String toString () {
125- return "Matrix{" +
126- "data=" + data .toString () +
127- '}' ;
148+ return "Matrix{data=" + data .toString () + "}" ;
128149}
129150}
0 commit comments