16
16
17
17
package org .springframework .ai .tokenizer ;
18
18
19
+ import java .util .Base64 ;
20
+
19
21
import com .knuddels .jtokkit .Encodings ;
20
22
import com .knuddels .jtokkit .api .Encoding ;
21
23
import com .knuddels .jtokkit .api .EncodingType ;
34
36
*/
35
37
public class JTokkitTokenCountEstimator implements TokenCountEstimator {
36
38
39
+ /**
40
+ * The JTokkit encoding instance used for token counting.
41
+ */
37
42
private final Encoding estimator ;
38
43
44
+ /**
45
+ * Creates a new JTokkitTokenCountEstimator with default CL100K_BASE encoding.
46
+ */
39
47
public JTokkitTokenCountEstimator () {
40
48
this (EncodingType .CL100K_BASE );
41
49
}
42
50
43
- public JTokkitTokenCountEstimator (EncodingType tokenEncodingType ) {
51
+ /**
52
+ * Creates a new JTokkitTokenCountEstimator with the specified encoding type.
53
+ * @param tokenEncodingType the encoding type to use for token counting
54
+ */
55
+ public JTokkitTokenCountEstimator (final EncodingType tokenEncodingType ) {
44
56
this .estimator = Encodings .newLazyEncodingRegistry ().getEncoding (tokenEncodingType );
45
57
}
46
58
47
59
@ Override
48
- public int estimate (String text ) {
60
+ public int estimate (final String text ) {
49
61
if (text == null ) {
50
62
return 0 ;
51
63
}
52
64
return this .estimator .countTokens (text );
53
65
}
54
66
55
67
@ Override
56
- public int estimate (MediaContent content ) {
68
+ public int estimate (final MediaContent content ) {
57
69
int tokenCount = 0 ;
58
70
59
71
if (content .getText () != null ) {
60
72
tokenCount += this .estimate (content .getText ());
61
73
}
62
74
63
75
if (!CollectionUtils .isEmpty (content .getMedia ())) {
64
-
65
76
for (Media media : content .getMedia ()) {
66
-
67
77
tokenCount += this .estimate (media .getMimeType ().toString ());
68
78
69
79
if (media .getData () instanceof String textData ) {
70
80
tokenCount += this .estimate (textData );
71
81
}
72
82
else if (media .getData () instanceof byte [] binaryData ) {
73
- tokenCount += binaryData .length ; // This is likely incorrect.
83
+ String base64 = Base64 .getEncoder ().encodeToString (binaryData );
84
+ tokenCount += this .estimate (base64 );
74
85
}
75
86
}
76
87
}
@@ -79,7 +90,7 @@ else if (media.getData() instanceof byte[] binaryData) {
79
90
}
80
91
81
92
@ Override
82
- public int estimate (Iterable <MediaContent > contents ) {
93
+ public int estimate (final Iterable <MediaContent > contents ) {
83
94
int totalSize = 0 ;
84
95
for (MediaContent mediaContent : contents ) {
85
96
totalSize += this .estimate (mediaContent );
0 commit comments