Skip to content

Commit 1a6deda

Browse files
committed
minor changed
1 parent 608558c commit 1a6deda

File tree

2 files changed

+40
-21
lines changed

2 files changed

+40
-21
lines changed

src/Tests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ public void polynomialRegressionTest1() throws FileNotFoundException {
265265
}
266266

267267
PolynomialRegression polynomialRegression = new PolynomialRegression(points, 2);
268-
polynomialRegression.computeCoefficients();
268+
//polynomialRegression.computeCoefficients();
269269
double[][] coefficients = polynomialRegression.getCoefficients();
270270

271271
assertTrue(StatisticUtils.isApproxEqual(coefficients[0][0], -1216.143887));
@@ -289,7 +289,7 @@ public void polynomialRegressionTest2() throws FileNotFoundException {
289289
}
290290

291291
PolynomialRegression polynomialRegression = new PolynomialRegression(points, 2);
292-
polynomialRegression.computeCoefficients();
292+
//polynomialRegression.computeCoefficients();
293293
double[][] coefficients = polynomialRegression.getCoefficients();
294294

295295
//coefficient real data has been rounded in the website that supplied the data, hence the big epsilon

src/polynomialRegression/PolynomialRegression.java

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,25 @@ public PolynomialRegression(List<Point> points, int polynomialDegree) {
4242
throw new IllegalArgumentException("Polynomial degree should be >= 0");
4343
}
4444
this.polynomialDegree = polynomialDegree;
45+
computeCoefficients();
46+
}
47+
48+
/**
49+
* Constructor of a polynomial regression.
50+
* @param points the training data
51+
* @param polynomialDegree the desired degree of the polynomial regression
52+
* @param computeCoeff should coefficients be computed
53+
*/
54+
private PolynomialRegression(List<Point> points, int polynomialDegree, boolean computeCoeff) {
55+
this.points = points;
56+
if (polynomialDegree < 0) {
57+
throw new IllegalArgumentException("Polynomial degree should be >= 0");
58+
}
59+
60+
this.polynomialDegree = polynomialDegree;
61+
if (computeCoeff) {
62+
computeCoefficients();
63+
}
4564
}
4665

4766
/**
@@ -155,7 +174,7 @@ public double getPrediction(double value) {
155174
* @return the integer representing the optimal degree for the polynomial regression for the trained data
156175
* @throws InterruptedException
157176
*/
158-
public int getOptimalPolynomialDegreeWithTestData(List<Point> testData, boolean terminalOutput)
177+
public static int getOptimalPolynomialDegreeWithTestData(List<Point> trainingData, List<Point> testData, boolean terminalOutput)
159178
throws InterruptedException {
160179

161180
long startTime = System.nanoTime();
@@ -187,7 +206,7 @@ public int getOptimalPolynomialDegreeWithTestData(List<Point> testData, boolean
187206
for (int i = (sequence.length * threadIndex) / threadNum;
188207
i < (sequence.length * (threadIndex + 1)) / threadNum; i++) {
189208

190-
PolynomialRegression regression = new PolynomialRegression(getPoints(), sequence[i]);
209+
PolynomialRegression regression = new PolynomialRegression(trainingData, sequence[i], true);
191210
double error = regression.getTestDataRootMeanSquareError(testData);
192211
if (terminalOutput) {
193212
System.out.println("Thread: " + threadIndex + ", Degree: " +sequence[i] + ", Error: " + error);
@@ -226,6 +245,18 @@ public int getOptimalPolynomialDegreeWithTestData(List<Point> testData, boolean
226245
return threadPolyDegrees[minimalDegreeIndex];
227246
}
228247

248+
/**
249+
* Multithreaded implementation of a method based on Root Mean Square Error (RMSE) comparison to obtain the optimal polynomial
250+
* degree that minimises the RMSE error, and therefore improves the accuracy of the trained data, given test data.
251+
* By default, it prints to terminal the errors, polynomial degrees, the thread ids and the elapsed time.
252+
* @param testData the data we want to optimise the regression for
253+
* @return the integer representing the optimal degree for the polynomial regression for the trained data
254+
* @throws InterruptedException
255+
*/
256+
public int getOptimalPolynomialDegreeWithTestData(List<Point> trainingData, List<Point> testData) throws InterruptedException {
257+
return getOptimalPolynomialDegreeWithTestData(trainingData, testData, true);
258+
}
259+
229260
/**
230261
* Returns the optimal polynomial regression (with minimised RMSE) for the supplied training data and test data. Also
231262
* prints the errors, poly-degrees and time elapsed in the computation. This also includes the thread ids, as this is
@@ -241,22 +272,10 @@ public int getOptimalPolynomialDegreeWithTestData(List<Point> testData, boolean
241272
public PolynomialRegression getOptimalPolynomialRegression(List<Point> trainingData, List<Point> testData,
242273
boolean terminalOutput) throws InterruptedException {
243274
PolynomialRegression plr = new PolynomialRegression(trainingData, 0);
244-
int optimalDegree = getOptimalPolynomialDegreeWithTestData(testData, terminalOutput);
275+
int optimalDegree = getOptimalPolynomialDegreeWithTestData(trainingData, testData, terminalOutput);
245276
return new PolynomialRegression(trainingData, optimalDegree);
246277
}
247278

248-
/**
249-
* Multithreaded implementation of a method based on Root Mean Square Error (RMSE) comparison to obtain the optimal polynomial
250-
* degree that minimises the RMSE error, and therefore improves the accuracy of the trained data, given test data.
251-
* By default, it prints to terminal the errors, polynomial degrees, the thread ids and the elapsed time.
252-
* @param testData the data we want to optimise the regression for
253-
* @return the integer representing the optimal degree for the polynomial regression for the trained data
254-
* @throws InterruptedException
255-
*/
256-
public int getOptimalPolynomialDegreeWithTestData(List<Point> testData) throws InterruptedException {
257-
return getOptimalPolynomialDegreeWithTestData(testData, true);
258-
}
259-
260279
/**
261280
* Splits an array into n smaller arrays containing the supplied size. Used in the distribute function that
262281
* gives each thread a number of tasks with a similar combined difficulty, so each thread works approximately
@@ -293,7 +312,7 @@ private static List<int[]> splitArray(int[] items, int maxSubArraySize) {
293312
* @param array the array containing the task ids by order of difficulty.
294313
* @param numThreads the number of threads we will be using
295314
*/
296-
private void distribute(int[] array, int numThreads) {
315+
private static void distribute(int[] array, int numThreads) {
297316

298317
List<int[]> list = splitArray(array, array.length / numThreads);
299318
int count = 0;
@@ -312,7 +331,7 @@ private void distribute(int[] array, int numThreads) {
312331
* @param doubles an array of doubles
313332
* @return the index of the minimum double in the array
314333
*/
315-
private int getIndexOfMinDouble(double[] doubles) {
334+
private static int getIndexOfMinDouble(double[] doubles) {
316335

317336
double current = Double.MAX_VALUE;
318337
int currentIndex = 0;
@@ -387,7 +406,7 @@ public static void main(String[] args) throws FileNotFoundException, Interrupted
387406
+ 0.7483924 * i
388407
+ Math.random());
389408

390-
if (count < 200) {
409+
if (count < 205) {
391410
points.add(point);
392411
} else {
393412
testData.add(point);
@@ -398,7 +417,7 @@ public static void main(String[] args) throws FileNotFoundException, Interrupted
398417
System.out.println("Points to analyse: " + testData.size());
399418

400419
PolynomialRegression regression = new PolynomialRegression(points, 0);
401-
System.out.println(regression.getOptimalPolynomialDegreeWithTestData(testData));
420+
System.out.println(regression.getOptimalPolynomialDegreeWithTestData(points, testData));
402421

403422
}
404423

0 commit comments

Comments
 (0)