@@ -42,6 +42,25 @@ public PolynomialRegression(List<Point> points, int polynomialDegree) {
4242 throw new IllegalArgumentException ("Polynomial degree should be >= 0" );
4343 }
4444 this .polynomialDegree = polynomialDegree ;
45+ computeCoefficients ();
46+ }
47+
48+ /**
49+ * Constructor of a polynomial regression.
50+ * @param points the training data
51+ * @param polynomialDegree the desired degree of the polynomial regression
52+ * @param computeCoeff should coefficients be computed
53+ */
54+ private PolynomialRegression (List <Point > points , int polynomialDegree , boolean computeCoeff ) {
55+ this .points = points ;
56+ if (polynomialDegree < 0 ) {
57+ throw new IllegalArgumentException ("Polynomial degree should be >= 0" );
58+ }
59+
60+ this .polynomialDegree = polynomialDegree ;
61+ if (computeCoeff ) {
62+ computeCoefficients ();
63+ }
4564 }
4665
4766 /**
@@ -155,7 +174,7 @@ public double getPrediction(double value) {
155174 * @return the integer representing the optimal degree for the polynomial regression for the trained data
156175 * @throws InterruptedException
157176 */
158- public int getOptimalPolynomialDegreeWithTestData (List <Point > testData , boolean terminalOutput )
177+ public static int getOptimalPolynomialDegreeWithTestData (List < Point > trainingData , List <Point > testData , boolean terminalOutput )
159178 throws InterruptedException {
160179
161180 long startTime = System .nanoTime ();
@@ -187,7 +206,7 @@ public int getOptimalPolynomialDegreeWithTestData(List<Point> testData, boolean
187206 for (int i = (sequence .length * threadIndex ) / threadNum ;
188207 i < (sequence .length * (threadIndex + 1 )) / threadNum ; i ++) {
189208
190- PolynomialRegression regression = new PolynomialRegression (getPoints () , sequence [i ]);
209+ PolynomialRegression regression = new PolynomialRegression (trainingData , sequence [i ], true );
191210 double error = regression .getTestDataRootMeanSquareError (testData );
192211 if (terminalOutput ) {
193212 System .out .println ("Thread: " + threadIndex + ", Degree: " +sequence [i ] + ", Error: " + error );
@@ -226,6 +245,18 @@ public int getOptimalPolynomialDegreeWithTestData(List<Point> testData, boolean
226245 return threadPolyDegrees [minimalDegreeIndex ];
227246 }
228247
248+ /**
249+ * Multithreaded implementation of a method based on Root Mean Square Error (RMSE) comparison to obtain the optimal polynomial
250+ * degree that minimises the RMSE error, and therefore improves the accuracy of the trained data, given test data.
251+ * By default, it prints to terminal the errors, polynomial degrees, the thread ids and the elapsed time.
252+ * @param testData the data we want to optimise the regression for
253+ * @return the integer representing the optimal degree for the polynomial regression for the trained data
254+ * @throws InterruptedException
255+ */
256+ public int getOptimalPolynomialDegreeWithTestData (List <Point > trainingData , List <Point > testData ) throws InterruptedException {
257+ return getOptimalPolynomialDegreeWithTestData (trainingData , testData , true );
258+ }
259+
229260 /**
230261 * Returns the optimal polynomial regression (with minimised RMSE) for the supplied training data and test data. Also
231262 * prints the errors, poly-degrees and time elapsed in the computation. This also includes the thread ids, as this is
@@ -241,22 +272,10 @@ public int getOptimalPolynomialDegreeWithTestData(List<Point> testData, boolean
241272 public PolynomialRegression getOptimalPolynomialRegression (List <Point > trainingData , List <Point > testData ,
242273 boolean terminalOutput ) throws InterruptedException {
243274 PolynomialRegression plr = new PolynomialRegression (trainingData , 0 );
244- int optimalDegree = getOptimalPolynomialDegreeWithTestData (testData , terminalOutput );
275+ int optimalDegree = getOptimalPolynomialDegreeWithTestData (trainingData , testData , terminalOutput );
245276 return new PolynomialRegression (trainingData , optimalDegree );
246277 }
247278
248- /**
249- * Multithreaded implementation of a method based on Root Mean Square Error (RMSE) comparison to obtain the optimal polynomial
250- * degree that minimises the RMSE error, and therefore improves the accuracy of the trained data, given test data.
251- * By default, it prints to terminal the errors, polynomial degrees, the thread ids and the elapsed time.
252- * @param testData the data we want to optimise the regression for
253- * @return the integer representing the optimal degree for the polynomial regression for the trained data
254- * @throws InterruptedException
255- */
256- public int getOptimalPolynomialDegreeWithTestData (List <Point > testData ) throws InterruptedException {
257- return getOptimalPolynomialDegreeWithTestData (testData , true );
258- }
259-
260279 /**
261280 * Splits an array into n smaller arrays containing the supplied size. Used in the distribute function that
262281 * gives each thread a number of tasks with a similar combined difficulty, so each thread works approximately
@@ -293,7 +312,7 @@ private static List<int[]> splitArray(int[] items, int maxSubArraySize) {
293312 * @param array the array containing the task ids by order of difficulty.
294313 * @param numThreads the number of threads we will be using
295314 */
296- private void distribute (int [] array , int numThreads ) {
315+ private static void distribute (int [] array , int numThreads ) {
297316
298317 List <int []> list = splitArray (array , array .length / numThreads );
299318 int count = 0 ;
@@ -312,7 +331,7 @@ private void distribute(int[] array, int numThreads) {
312331 * @param doubles an array of doubles
313332 * @return the index of the minimum double in the array
314333 */
315- private int getIndexOfMinDouble (double [] doubles ) {
334+ private static int getIndexOfMinDouble (double [] doubles ) {
316335
317336 double current = Double .MAX_VALUE ;
318337 int currentIndex = 0 ;
@@ -387,7 +406,7 @@ public static void main(String[] args) throws FileNotFoundException, Interrupted
387406 + 0.7483924 * i
388407 + Math .random ());
389408
390- if (count < 200 ) {
409+ if (count < 205 ) {
391410 points .add (point );
392411 } else {
393412 testData .add (point );
@@ -398,7 +417,7 @@ public static void main(String[] args) throws FileNotFoundException, Interrupted
398417 System .out .println ("Points to analyse: " + testData .size ());
399418
400419 PolynomialRegression regression = new PolynomialRegression (points , 0 );
401- System .out .println (regression .getOptimalPolynomialDegreeWithTestData (testData ));
420+ System .out .println (regression .getOptimalPolynomialDegreeWithTestData (points , testData ));
402421
403422 }
404423
0 commit comments