@@ -329,38 +329,40 @@ void Algorithm::train_perceptrons_BPG(Eigen::MatrixXd &weights) {
329329 }
330330
331331 std::vector<int > misclassified_elements; // Row indexes of the criterion function
332+ misclassified_elements.push_back (1 ); // Add a value to start the loop
332333 Eigen::MatrixXd criterion_function (input_data->getNbClasses (), input_data->getNbTrainingElements ());
333334 criterion_function.setZero ();
334335
335- c = 200 ;
336+ c = 200 ; // Safety counter: stop if still misclassified elements anyway
336337 /* Iterate while there are misclassified elements and counter not equal to zero */
337338 while (!misclassified_elements.empty () && c--) {
338- /* Empty the misclassified elements */
339- misclassified_elements.clear ();
340-
341339/* Update the criterion function */
342340for (int i = 0 ; i < input_data->getNbTrainingElements (); i++)
343341 criterion_function.col (i) = outputVectors.col (i).cwiseProduct (weights.transpose () * augmented_data.col (i));
344342
343+ int n = 0 ;
345344for (int i = 0 ; i < input_data->getNbClasses (); i++) {
345+ /* Empty the misclassified elements */
346+ misclassified_elements.clear ();
347+
346348 /* Find all misclassified elements */
347- for (int j = 0 ; j < criterion_function.rows (); j++)
348- for (int k = 0 ; j < criterion_function.cols (); j++)
349- if (criterion_function (j, k) < 0 ) // Misclassified !
350- misclassified_elements.push_back (k); // Add the row index to the list
349+ for (int k = 0 ; k < criterion_function.cols (); k++)
350+ if (criterion_function (n, k) < 0 ) // Misclassified !
351+ misclassified_elements.push_back (k); // Add the col index to the list
351352
352- /* Calculate the gradiant and update the weights */
353- Eigen::VectorXd gradiant ;
354- gradiant .setZero ();
353+ /* Calculate the gradient and update the weights */
354+ Eigen::VectorXd gradient (weights. rows ()) ;
355+ gradient .setZero ();
355356
356- int k = 0 ;
357357 for (auto const &misclassified_element : misclassified_elements)
358- gradiant += misclassified_element * outputVectors. row (k++ );
358+ gradient += augmented_data. col ( misclassified_element) * outputVectors (n, misclassified_element );
359359
360- gradiant *= LEARNING_RATE;
360+ gradient *= LEARNING_RATE;
361361
362362 /* Update the weights */
363- weights.col (i) += gradiant;
363+ weights.col (n) += gradient;
364+
365+ n++; // Next class index (don't use i because of the offset of ORL
364366}
365367 }
366368}
0 commit comments