Skip to content

Commit c60d942

Browse files
Hyeri1eeilayaperumalg
authored andcommitted
fix(tokenizer): use Base64 encoding for binary data token estimation
Signed-off-by: Hyeri1ee <haerizian10@gmail.com>
1 parent 921305b commit c60d942

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

spring-ai-commons/src/main/java/org/springframework/ai/tokenizer/JTokkitTokenCountEstimator.java

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
package org.springframework.ai.tokenizer;
1818

19+
import java.util.Base64;
20+
1921
import com.knuddels.jtokkit.Encodings;
2022
import com.knuddels.jtokkit.api.Encoding;
2123
import com.knuddels.jtokkit.api.EncodingType;
@@ -34,43 +36,52 @@
3436
*/
3537
public class JTokkitTokenCountEstimator implements TokenCountEstimator {
3638

39+
/**
40+
* The JTokkit encoding instance used for token counting.
41+
*/
3742
private final Encoding estimator;
3843

44+
/**
45+
* Creates a new JTokkitTokenCountEstimator with default CL100K_BASE encoding.
46+
*/
3947
public JTokkitTokenCountEstimator() {
4048
this(EncodingType.CL100K_BASE);
4149
}
4250

43-
public JTokkitTokenCountEstimator(EncodingType tokenEncodingType) {
51+
/**
52+
* Creates a new JTokkitTokenCountEstimator with the specified encoding type.
53+
* @param tokenEncodingType the encoding type to use for token counting
54+
*/
55+
public JTokkitTokenCountEstimator(final EncodingType tokenEncodingType) {
4456
this.estimator = Encodings.newLazyEncodingRegistry().getEncoding(tokenEncodingType);
4557
}
4658

4759
@Override
48-
public int estimate(String text) {
60+
public int estimate(final String text) {
4961
if (text == null) {
5062
return 0;
5163
}
5264
return this.estimator.countTokens(text);
5365
}
5466

5567
@Override
56-
public int estimate(MediaContent content) {
68+
public int estimate(final MediaContent content) {
5769
int tokenCount = 0;
5870

5971
if (content.getText() != null) {
6072
tokenCount += this.estimate(content.getText());
6173
}
6274

6375
if (!CollectionUtils.isEmpty(content.getMedia())) {
64-
6576
for (Media media : content.getMedia()) {
66-
6777
tokenCount += this.estimate(media.getMimeType().toString());
6878

6979
if (media.getData() instanceof String textData) {
7080
tokenCount += this.estimate(textData);
7181
}
7282
else if (media.getData() instanceof byte[] binaryData) {
73-
tokenCount += binaryData.length; // This is likely incorrect.
83+
String base64 = Base64.getEncoder().encodeToString(binaryData);
84+
tokenCount += this.estimate(base64);
7485
}
7586
}
7687
}
@@ -79,7 +90,7 @@ else if (media.getData() instanceof byte[] binaryData) {
7990
}
8091

8192
@Override
82-
public int estimate(Iterable<MediaContent> contents) {
93+
public int estimate(final Iterable<MediaContent> contents) {
8394
int totalSize = 0;
8495
for (MediaContent mediaContent : contents) {
8596
totalSize += this.estimate(mediaContent);

0 commit comments

Comments
 (0)