1
1
package javax .visrec .ml .classification ;
2
2
3
- import javax .visrec .ml .classification .ImageClassifier ;
3
+ import javax .visrec .ImageFactory ;
4
+ import javax .visrec .ml .model .ModelProvider ;
4
5
import javax .visrec .spi .ServiceProvider ;
5
6
import java .awt .image .BufferedImage ;
6
7
import java .io .File ;
9
10
import java .util .Map ;
10
11
import java .util .Objects ;
11
12
import java .util .Optional ;
12
- import javax .visrec .ImageFactory ;
13
- import javax .visrec .ml .model .ModelProvider ;
14
13
15
14
/**
16
15
* Skeleton abstract class to make it easier to implement image classifier.
17
16
* It provides implementation of Classifier interface for images, along with
18
17
* image factory for specific type of images.
19
18
* This class solves the problem of using various implementation of images and machine learning models in Java,
20
19
* and provides standard Classifier API for clients.
21
- *
20
+ * <p>
22
21
* By default the type of key in the Map the {@link ImageClassifier} is {@code String}
23
22
*
24
- * @author Zoran Sevarac
25
- *
23
+ * @param <IMAGE_CLASS> class to classify
26
24
* @param <MODEL_CLASS> class of machine learning model
25
+ * @author Zoran Sevarac
26
+ * @since 1.0
27
27
*/
28
- public abstract class AbstractImageClassifier <IMAGE_CLASS , MODEL_CLASS > implements ImageClassifier <IMAGE_CLASS >, ModelProvider { // could also implement binary classifier
28
+ public abstract class AbstractImageClassifier <IMAGE_CLASS , MODEL_CLASS > implements ImageClassifier <IMAGE_CLASS >, ModelProvider < MODEL_CLASS > {
29
29
30
- private ImageFactory <IMAGE_CLASS > imageFactory ; // image factory impl for the specified image class
31
- private MODEL_CLASS model ; // the model could be injected from machine learning container?
30
+ private final ImageFactory <IMAGE_CLASS > imageFactory ;
31
+ private MODEL_CLASS model ;
32
32
33
- private float threshold =0.0f ; // this should ba a part of every classifier
33
+ // TODO: this should ba a part of every classifier
34
+ private float threshold = 0.0f ;
34
35
35
36
protected AbstractImageClassifier (final Class <IMAGE_CLASS > imgCls , final MODEL_CLASS model ) {
36
37
final Optional <ImageFactory <IMAGE_CLASS >> optionalImageFactory = ServiceProvider .current ()
@@ -48,27 +49,27 @@ public ImageFactory<IMAGE_CLASS> getImageFactory() {
48
49
}
49
50
50
51
@ Override
51
- public Map <String , Float > classify (File file ) {
52
+ public Map <String , Float > classify (File file ) throws ClassificationException {
52
53
IMAGE_CLASS image ;
53
54
try {
54
55
image = imageFactory .getImage (file );
55
56
return classify (image );
56
57
} catch (IOException e ) {
57
- throw new RuntimeException ( "Couldn't transform input into a BufferedImage" , e );
58
+ throw new ClassificationException ( "Failed to transform input into a BufferedImage" , e );
58
59
}
59
60
}
60
61
61
62
@ Override
62
- public Map <String , Float > classify (InputStream inputStream ) {
63
+ public Map <String , Float > classify (InputStream inputStream ) throws ClassificationException {
63
64
IMAGE_CLASS image ;
64
65
try {
65
66
image = imageFactory .getImage (inputStream );
66
67
return classify (image );
67
68
} catch (IOException e ) {
68
- throw new RuntimeException ("Couldn't transform input into a BufferedImage" , e );
69
- }
69
+ throw new RuntimeException ("Failed to transform input into a BufferedImage" , e );
70
+ }
70
71
}
71
-
72
+
72
73
// todo: provide get top 1, 3, 5 results; sort and get
73
74
74
75
@ Override
@@ -77,7 +78,7 @@ public MODEL_CLASS getModel() {
77
78
}
78
79
79
80
protected final void setModel (MODEL_CLASS model ) {
80
- this .model = Objects .requireNonNull (model , "Model cannot bu null!" );
81
+ this .model = Objects .requireNonNull (model , "Model cannot bu null!" );
81
82
}
82
83
83
84
public float getThreshold () {
0 commit comments