@@ -49,7 +49,13 @@ public final class UnigramTokenizer extends Tokenizer {
49
49
private final CharTermAttribute termAtt = addAttribute (CharTermAttribute .class );
50
50
private final OffsetAttribute offsetAtt = addAttribute (OffsetAttribute .class );
51
51
52
- static UnigramTokenizer build (List <String > neverSplit , List <String > dictionary , double [] scores , String unknownToken ) {
52
+ static UnigramTokenizer build (
53
+ List <String > neverSplit ,
54
+ List <String > dictionary ,
55
+ double [] scores ,
56
+ String unknownToken ,
57
+ boolean byteFallback
58
+ ) {
53
59
if (dictionary .isEmpty ()) {
54
60
throw new IllegalArgumentException ("vocab empty" );
55
61
}
@@ -84,7 +90,8 @@ static UnigramTokenizer build(List<String> neverSplit, List<String> dictionary,
84
90
Optional .ofNullable (tokenToId .get (new BytesRef (unknownToken )))
85
91
.orElseThrow (
86
92
() -> new IllegalArgumentException ("provided vocabulary does not contain the unknown token of [" + unknownToken + "]" )
87
- )
93
+ ),
94
+ byteFallback
88
95
);
89
96
}
90
97
@@ -94,7 +101,7 @@ static UnigramTokenizer build(List<String> neverSplit, List<String> dictionary,
94
101
95
102
private final double minScore ;
96
103
// This may be configurable in the future
97
- private final boolean fuseUnk = true ;
104
+ private boolean fuseUnk = true ;
98
105
private final double [] vocabScores ;
99
106
private final CharTrie neverSplit ;
100
107
private final CharArraySet neverSplitHash ;
@@ -104,6 +111,7 @@ static UnigramTokenizer build(List<String> neverSplit, List<String> dictionary,
104
111
// This is a buffer that is reused per token for decoding the normalized char-sequence into utf-8 bytes
105
112
// It's usage is NOT thread safe
106
113
private byte [] normalizedByteBuffer = new byte [128 ];
114
+ private boolean byteFallback = false ; // If true, decompose unknown pieces into UTF-8 byte pieces
107
115
108
116
public UnigramTokenizer (
109
117
double minScore ,
@@ -127,6 +135,31 @@ public UnigramTokenizer(
127
135
this .whitespaceTokenizer = new SimpleWhitespaceTokenizer ();
128
136
}
129
137
138
+ public UnigramTokenizer (
139
+ double minScore ,
140
+ double [] vocabScores ,
141
+ CharTrie neverSplit ,
142
+ CharArraySet neverSplitHash ,
143
+ Map <BytesRef , Integer > vocabToId ,
144
+ BytesTrie vocabTrie ,
145
+ int unknownTokenId ,
146
+ boolean byteFallback
147
+ ) {
148
+ super ();
149
+ this .tokens = new LinkedList <>();
150
+ this .tokenizedValues = new ArrayList <>();
151
+ this .minScore = minScore ;
152
+ this .neverSplit = neverSplit ;
153
+ this .neverSplitHash = neverSplitHash ;
154
+ this .vocabToId = vocabToId ;
155
+ this .vocabTrie = vocabTrie ;
156
+ this .unknownTokenId = unknownTokenId ;
157
+ this .vocabScores = vocabScores ;
158
+ this .whitespaceTokenizer = new SimpleWhitespaceTokenizer ();
159
+ this .byteFallback = byteFallback ;
160
+ this .fuseUnk = byteFallback == false ;
161
+ }
162
+
130
163
List <DelimitedToken .Encoded > getTokenizedValues () {
131
164
return tokenizedValues ;
132
165
}
@@ -231,6 +264,22 @@ public boolean incrementToken() throws IOException {
231
264
return false ;
232
265
}
233
266
267
+ private int [] decomposeBytePieces (CharSequence maybeTokenized ) {
268
+ assert this .byteFallback ;
269
+
270
+ byte [] bytes = maybeTokenized .toString ().getBytes (StandardCharsets .UTF_8 );
271
+ int [] pieces = new int [bytes .length ];
272
+ for (int i = 0 ; i < bytes .length ; i ++) {
273
+ BytesRef decomposedToken = new BytesRef (String .format ("<0x%02X>" , bytes [i ]));
274
+ Integer piece = vocabToId .get (decomposedToken );
275
+ if (piece == null ) {
276
+ piece = unknownTokenId ;
277
+ }
278
+ pieces [i ] = piece ;
279
+ }
280
+ return pieces ;
281
+ }
282
+
234
283
/**
235
284
* This algorithm does the following:
236
285
*
@@ -309,7 +358,21 @@ List<DelimitedToken.Encoded> tokenize(CharSequence inputSequence, IntToIntFuncti
309
358
while (endsAtBytes > 0 ) {
310
359
BestPathNode node = bestPathNodes [endsAtBytes ];
311
360
int startsAtBytes = node .startsAtBytePos ;
312
- if (node .id == unknownTokenId && fuseUnk ) {
361
+ if (node .id == unknownTokenId && byteFallback ) {
362
+ CharSequence multiByteSequence = inputSequence .subSequence (node .startsAtCharPos , endsAtChars );
363
+ byte [] bytes = multiByteSequence .toString ().getBytes (StandardCharsets .UTF_8 );
364
+ int [] pieces = decomposeBytePieces (multiByteSequence );
365
+ for (int i = pieces .length - 1 ; i >= 0 ; i --) {
366
+ results .add (
367
+ new DelimitedToken .Encoded (
368
+ String .format ("<0x%02X>" , bytes [i ]),
369
+ pieces [i ],
370
+ offsetCorrection .apply (node .startsAtCharPos ),
371
+ offsetCorrection .apply (startsAtBytes + i )
372
+ )
373
+ );
374
+ }
375
+ } else if (node .id == unknownTokenId && fuseUnk ) {
313
376
unknownTokens .add (
314
377
new DelimitedToken .Encoded (
315
378
new String (normalizedByteBuffer , startsAtBytes , endsAtBytes - startsAtBytes , StandardCharsets .UTF_8 ),
0 commit comments