Skip to content

Commit 06caeca

Browse files
Ray-EldathRay-Eldath
authored andcommitted
Untracked ignored files. Add Matrix, Vector, RowVector and ColumnVector (from ND4J). Optimized class access control policy. Add some advices (See TODO). WILL add SGD algorithm.
1 parent 7097d7d commit 06caeca

File tree

12 files changed

+246
-1
lines changed

12 files changed

+246
-1
lines changed

build.gradle

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ apply plugin: 'kotlin'
66
// apply plugin: 'maven'
77
// apply plugin: 'cpp'
88

9+
String os = "windows" // TODO for different build, change here
910
sourceCompatibility = 1.8
1011
targetCompatibility = 1.8
1112

@@ -19,7 +20,7 @@ buildscript {
1920

2021
repositories {
2122
// mavenCentral()
22-
maven{ url 'http://maven.aliyun.com/nexus/content/groups/public/' }
23+
maven { url 'http://maven.aliyun.com/nexus/content/groups/public/' }
2324
}
2425

2526
dependencies {
@@ -37,9 +38,32 @@ sourceSets {
3738
test.java.srcDirs += 'demos'
3839
}
3940

41+
switch (os) {
42+
case 'windows':
43+
os = 'windows-x86_64'
44+
break
45+
case 'linux':
46+
os = 'linux-x86_64'
47+
break
48+
case 'linux-ppc64':
49+
os = 'linux-ppc64'
50+
break
51+
case 'linux-ppc64le':
52+
os = 'linux-ppc64le'
53+
break
54+
case 'macosx':
55+
os = 'macosx-x86_64'
56+
break
57+
default:
58+
throw new Exception('Unknown OS defined for -Plibnd4jOS parameter. ND4J will be unable to find platform-specific binaries and thus unable to run.')
59+
}
60+
4061
dependencies {
4162
compile group: 'org.jetbrains', name: 'annotations', version: '15.0'
4263
compile "org.jetbrains.kotlin:kotlin-stdlib:$kotlin_version"
64+
// https://mvnrepository.com/artifact/org.nd4j/nd4j-native
65+
compile group: 'org.nd4j', name: 'nd4j-native', version: '0.8.0'
66+
compile 'org.nd4j:nd4j-native:0.8.0:' + os
4367

4468
testCompile 'junit:junit:4.12'
4569
testCompile "org.jetbrains.kotlin:kotlin-test-junit:$kotlin_version"
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package org.algo4j.linear;
2+
3+
@SuppressWarnings({"WeakerAccess", "unused"})
4+
public class ColumnVector extends Vector {
5+
public ColumnVector(double[] data) {
6+
super(true, data);
7+
}
8+
9+
public RowVector toRowVector() {
10+
int L = nativeData().rows();
11+
double[] data = new double[L];
12+
for (int i = 0; i < L; i++)
13+
data[i] = nativeData().getDouble(0, i);
14+
return new RowVector(data);
15+
}
16+
}
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
package org.algo4j.linear;
2+
3+
import org.algo4j.util.MatrixCaster;
4+
import org.nd4j.linalg.api.ndarray.INDArray;
5+
import org.nd4j.linalg.factory.Nd4j;
6+
import org.nd4j.linalg.ops.transforms.Transforms;
7+
8+
/**
9+
* Matrix class.
10+
*
11+
* @author Ray Eldath
12+
* @since 1.0.6
13+
*/
14+
15+
@SuppressWarnings("unused")
16+
public class Matrix {
17+
private INDArray data;
18+
19+
public Matrix(int[][] data) {
20+
this.data = Nd4j.create(MatrixCaster.cast(data));
21+
}
22+
23+
public Matrix(float[][] data) {
24+
this.data = Nd4j.create(data);
25+
}
26+
27+
public Matrix(double[][] data) {
28+
this.data = Nd4j.create(data);
29+
}
30+
31+
Matrix(INDArray array) {
32+
this.data = array;
33+
}
34+
35+
public INDArray nativeData() {
36+
return data;
37+
}
38+
39+
public Matrix addEach(int n) {
40+
return new Matrix(data.add(n));
41+
}
42+
43+
public Matrix multiplyEach(int n) {
44+
return new Matrix(data.mul(n));
45+
}
46+
47+
public Matrix minusEach(int n) {
48+
return new Matrix(data.subi(n));
49+
}
50+
51+
public Matrix divideEach(int n) {
52+
return new Matrix(data.divi(n));
53+
}
54+
55+
public Matrix add(ColumnVector vector) {
56+
return new Matrix(data.add(vector.nativeData()));
57+
}
58+
59+
public Matrix add(RowVector vector) {
60+
return new Matrix(data.add(vector.nativeData()));
61+
}
62+
63+
public Matrix multiply(ColumnVector vector) {
64+
return new Matrix(data.mul(vector.nativeData()));
65+
}
66+
67+
public Matrix multiply(RowVector vector) {
68+
return new Matrix(data.mul(vector.nativeData()));
69+
}
70+
71+
public Matrix minus(ColumnVector vector) {
72+
return new Matrix(data.sub(vector.nativeData()));
73+
}
74+
75+
public Matrix minus(RowVector vector) {
76+
return new Matrix(data.sub(vector.nativeData()));
77+
}
78+
79+
public Matrix divide(ColumnVector vector) {
80+
return new Matrix(data.div(vector.nativeData()));
81+
}
82+
83+
public Matrix divide(RowVector vector) {
84+
return new Matrix(data.div(vector.nativeData()));
85+
}
86+
87+
public Matrix sigmoid() {
88+
return new Matrix(Transforms.sigmoid(data));
89+
}
90+
91+
public Matrix tanh() {
92+
return new Matrix(Transforms.tanh(data));
93+
}
94+
95+
public Matrix abs() {
96+
return new Matrix(Transforms.abs(data));
97+
}
98+
99+
public Matrix sqrt() {
100+
return new Matrix(Transforms.sqrt(data));
101+
}
102+
103+
public Matrix exp() {
104+
return new Matrix(Transforms.exp(data));
105+
}
106+
107+
public Matrix transpose() {
108+
return new Matrix(data.transpose());
109+
}
110+
111+
public Matrix reshape(int rowN, int columnN) {
112+
return new Matrix(data.reshape(rowN, columnN));
113+
}
114+
115+
public double get(int row, int column) {
116+
return data.getDouble(row, column);
117+
}
118+
119+
public double[][] cast() {
120+
return MatrixCaster.cast(this);
121+
}
122+
123+
@Override
124+
public String toString() {
125+
return "Matrix{" +
126+
"data=" + data.toString() +
127+
'}';
128+
}
129+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
package org.algo4j.linear;
2+
3+
import org.algo4j.util.MatrixCaster;
4+
5+
@SuppressWarnings({"WeakerAccess", "unused"})
6+
public class RowVector extends Vector {
7+
public RowVector(double[] data) {
8+
super(false, data);
9+
}
10+
11+
public ColumnVector toColumnVector() {
12+
return new ColumnVector(MatrixCaster.cast(new Matrix(nativeData()))[0]);
13+
}
14+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package org.algo4j.linear;
2+
3+
import org.nd4j.linalg.api.ndarray.INDArray;
4+
import org.nd4j.linalg.factory.Nd4j;
5+
6+
class Vector {
7+
private INDArray array;
8+
9+
Vector(boolean isColumnVector, double[] data) {
10+
int n = data.length;
11+
array = isColumnVector ?
12+
Nd4j.create(data, new int[]{n, 1}) :
13+
Nd4j.create(data, new int[]{n});
14+
}
15+
16+
INDArray nativeData() {
17+
return array;
18+
}
19+
}

src/main/java/org/algo4j/math/Constants.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@
66
* @author ice1000
77
*/
88
@SuppressWarnings({"unused", "WeakerAccess"})
9+
10+
//TODO 我认为将此类放入package-info中并限制访问权限更好。但如此Test可能无法运行。
911
public final class Constants {
1012
private Constants() {
13+
throw new Error("do not instantiation me");
1114
}
1215

1316
/** m * s^-1 */

src/main/java/org/algo4j/math/MathUtils.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
public final class MathUtils {
1515

1616
private MathUtils() {
17+
throw new Error("do not instantiation me");
1718
}
1819

1920
public static final double E = java.lang.Math.E;

src/main/java/org/algo4j/math/Trigonometric.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
@SuppressWarnings("WeakerAccess")
1111
public final class Trigonometric {
1212
private Trigonometric() {
13+
throw new Error("do not instantiation me");
1314
}
1415

1516
/**
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package org.algo4j.util;
2+
3+
import org.algo4j.linear.Matrix;
4+
import org.nd4j.linalg.api.ndarray.INDArray;
5+
6+
public class MatrixCaster {
7+
private MatrixCaster() {
8+
throw new Error("do not instantiation me");
9+
}
10+
11+
public static float[][] cast(int[][] in) {
12+
int inL = in.length;
13+
float[][] out = new float[inL][in[0].length];
14+
15+
for (int i = 0; i < inL; i++) {
16+
int[] n = in[i];
17+
for (int j = 0; j < n.length; j++)
18+
out[i][j] = n[j];
19+
}
20+
21+
return out;
22+
}
23+
24+
public static double[][] cast(Matrix inM) {
25+
INDArray in = inM.nativeData();
26+
int l = in.length();
27+
double[][] out = new double[l][in.size(0)];
28+
for (int i = 0; i < l; i++) {
29+
int n = in.size(i);
30+
for (int j = 0; j < n; j++)
31+
out[i][j] = in.getDouble(i, j);
32+
}
33+
return out;
34+
}
35+
}

src/main/java/org/algo4j/util/SeqUtils.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
@SuppressWarnings({"WeakerAccess", "unused"})
1414
public final class SeqUtils {
1515
private SeqUtils() {
16+
throw new Error("do not instantiation me");
1617
}
1718

1819
/**

0 commit comments

Comments
 (0)