Skip to content

Commit 2593093

Browse files
committed
Add a prioriProbs method and test case.
1 parent f4e170c commit 2593093

File tree

4 files changed

+68
-3
lines changed

4 files changed

+68
-3
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
.idea/
22
Recomenda-Engine.iml
3-
target/
3+
target/
4+
knn-java-library.iml

src/main/java/com/github/felipexw/classifier/bayes/MultinomialNaiveBayesClassifier.java

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@
44
import com.github.felipexw.classifier.CrossValidateClassifier;
55
import com.github.felipexw.types.LabeledTrainingInstance;
66
import com.github.felipexw.types.PredictedInstance;
7+
import java.util.HashMap;
78
import java.util.List;
9+
import java.util.Map;
810

911
/**
1012
* Created by felipe.appio on 29/08/2016.
1113
*/
12-
public class MultinomialNaiveBayesClassifier implements Classifier, CrossValidateClassifier{
14+
public class MultinomialNaiveBayesClassifier extends NaiveBayes
15+
implements Classifier, CrossValidateClassifier {
1316

1417
@Override public void train(List<LabeledTrainingInstance> instances) {
1518

@@ -26,4 +29,19 @@ public class MultinomialNaiveBayesClassifier implements Classifier, CrossValida
2629
@Override public void train(List<LabeledTrainingInstance> instances, int k) {
2730

2831
}
32+
33+
@Override
34+
public Map<String, Integer> calculateAPrioriProbs(List<LabeledTrainingInstance> instanceList) {
35+
Map<String, Integer> probs = new HashMap<>();
36+
37+
for (LabeledTrainingInstance instance : instanceList) {
38+
if (!probs.containsKey(instance.getLabel())) {
39+
probs.put(instance.getLabel(), 1);
40+
} else {
41+
probs.put(instance.getLabel(), probs.get(instance.getLabel()) + 1);
42+
}
43+
}
44+
45+
return probs;
46+
}
2947
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
package com.github.felipexw.classifier.bayes;
2+
3+
import com.github.felipexw.types.LabeledTrainingInstance;
4+
import java.util.List;
5+
import java.util.Map;
6+
7+
/**
8+
* Created by felipe.appio on 29/08/2016.
9+
*/
10+
public abstract class NaiveBayes {
11+
public abstract Map<String, Integer> calculateAPrioriProbs (List<LabeledTrainingInstance> instanceList);
12+
}
Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,45 @@
1+
import com.github.felipexw.classifier.Classifier;
2+
import com.github.felipexw.classifier.bayes.MultinomialNaiveBayesClassifier;
3+
import com.github.felipexw.classifier.bayes.NaiveBayes;
4+
import com.github.felipexw.types.LabeledTrainingInstance;
5+
import java.util.ArrayList;
6+
import java.util.Arrays;
7+
import java.util.List;
8+
import java.util.Map;
9+
import org.junit.Before;
110
import org.junit.Test;
211

12+
import static com.google.common.truth.Truth.*;
13+
314
/**
415
* Created by felipe.appio on 29/08/2016.
516
*/
617
public class MultinomialNaiveBayesClassifierTest {
718

19+
private NaiveBayes naiveBayesClassifier;
20+
21+
@Before
22+
public void setUp() {
23+
naiveBayesClassifier = new MultinomialNaiveBayesClassifier();
24+
}
25+
826
@Test
9-
public void it_should_fail(){
27+
public void it_should_calculate_a_priori_probs() {
28+
String negativeLabel = "negative";
29+
String positiveLabel = "positive";
30+
31+
List<LabeledTrainingInstance> training =
32+
Arrays.asList(new LabeledTrainingInstance(new double[] {2}, negativeLabel),
33+
new LabeledTrainingInstance(new double[] {2}, negativeLabel),
34+
new LabeledTrainingInstance(new double[] {2}, positiveLabel));
35+
36+
Map<String, Integer> probs = naiveBayesClassifier.calculateAPrioriProbs(training);
37+
38+
assertThat(probs.get(negativeLabel))
39+
.isEqualTo(2);
40+
assertThat(probs.get(positiveLabel))
41+
.isEqualTo(1);
1042
}
43+
44+
1145
}

0 commit comments

Comments
 (0)