Skip to content

Commit 0c0311d

Browse files
committed
Fix perceptron BPG
1 parent 86a7bf8 commit 0c0311d

File tree

1 file changed

+17
-15
lines changed

1 file changed

+17
-15
lines changed

Algorithm.cpp

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -329,38 +329,40 @@ void Algorithm::train_perceptrons_BPG(Eigen::MatrixXd &weights) {
329329
}
330330

331331
std::vector<int> misclassified_elements; //Row indexes of the criterion function
332+
misclassified_elements.push_back(1); //Add a value to start the loop
332333
Eigen::MatrixXd criterion_function(input_data->getNbClasses(), input_data->getNbTrainingElements());
333334
criterion_function.setZero();
334335

335-
c = 200;
336+
c = 200; //Safety counter: stop if still misclassified elements anyway
336337
/* Iterate while there are misclassified elements and counter not equal to zero */
337338
while (!misclassified_elements.empty() && c--) {
338-
/* Empty the misclassified elements */
339-
misclassified_elements.clear();
340-
341339
/* Update the criterion function */
342340
for (int i = 0; i < input_data->getNbTrainingElements(); i++)
343341
criterion_function.col(i) = outputVectors.col(i).cwiseProduct(weights.transpose() * augmented_data.col(i));
344342

343+
int n = 0;
345344
for (int i = 0; i < input_data->getNbClasses(); i++) {
345+
/* Empty the misclassified elements */
346+
misclassified_elements.clear();
347+
346348
/* Find all misclassified elements */
347-
for (int j = 0; j < criterion_function.rows(); j++)
348-
for (int k = 0; j < criterion_function.cols(); j++)
349-
if (criterion_function(j, k) < 0) //Misclassified !
350-
misclassified_elements.push_back(k); //Add the row index to the list
349+
for (int k = 0; k < criterion_function.cols(); k++)
350+
if (criterion_function(n, k) < 0) //Misclassified !
351+
misclassified_elements.push_back(k); //Add the col index to the list
351352

352-
/* Calculate the gradiant and update the weights */
353-
Eigen::VectorXd gradiant;
354-
gradiant.setZero();
353+
/* Calculate the gradient and update the weights */
354+
Eigen::VectorXd gradient(weights.rows());
355+
gradient.setZero();
355356

356-
int k = 0;
357357
for (auto const &misclassified_element : misclassified_elements)
358-
gradiant += misclassified_element * outputVectors.row(k++);
358+
gradient += augmented_data.col(misclassified_element) * outputVectors(n, misclassified_element);
359359

360-
gradiant *= LEARNING_RATE;
360+
gradient *= LEARNING_RATE;
361361

362362
/* Update the weights */
363-
weights.col(i) += gradiant;
363+
weights.col(n) += gradient;
364+
365+
n++; //Next class index (don't use i because of the offset of ORL
364366
}
365367
}
366368
}

0 commit comments

Comments
 (0)