Skip to content

Commit d2e347e

Browse files
author
Max Hniebergall
committed
Add support for byte_fallback which is enabled for DeBERTa
byte_fallback decomposes unknown tokens into multiple tokens each of one byte if those bytes are in the vocabulary.
1 parent cdface2 commit d2e347e

File tree

5 files changed

+92
-10
lines changed

5 files changed

+92
-10
lines changed

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/DebertaV2Tokenizer.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ protected Reader initReader(String fieldName, Reader reader) {
291291

292292
@Override
293293
protected TokenStreamComponents createComponents(String fieldName) {
294-
this.innerTokenizer = UnigramTokenizer.build(neverSplit, vocabulary, scores, unknownToken);
294+
this.innerTokenizer = UnigramTokenizer.build(neverSplit, vocabulary, scores, unknownToken, true);
295295
return new TokenStreamComponents(this.innerTokenizer);
296296
}
297297

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/UnigramTokenizer.java

Lines changed: 67 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,13 @@ public final class UnigramTokenizer extends Tokenizer {
4949
private final CharTermAttribute termAtt = addAttribute(CharTermAttribute.class);
5050
private final OffsetAttribute offsetAtt = addAttribute(OffsetAttribute.class);
5151

52-
static UnigramTokenizer build(List<String> neverSplit, List<String> dictionary, double[] scores, String unknownToken) {
52+
static UnigramTokenizer build(
53+
List<String> neverSplit,
54+
List<String> dictionary,
55+
double[] scores,
56+
String unknownToken,
57+
boolean byteFallback
58+
) {
5359
if (dictionary.isEmpty()) {
5460
throw new IllegalArgumentException("vocab empty");
5561
}
@@ -84,7 +90,8 @@ static UnigramTokenizer build(List<String> neverSplit, List<String> dictionary,
8490
Optional.ofNullable(tokenToId.get(new BytesRef(unknownToken)))
8591
.orElseThrow(
8692
() -> new IllegalArgumentException("provided vocabulary does not contain the unknown token of [" + unknownToken + "]")
87-
)
93+
),
94+
byteFallback
8895
);
8996
}
9097

@@ -94,7 +101,7 @@ static UnigramTokenizer build(List<String> neverSplit, List<String> dictionary,
94101

95102
private final double minScore;
96103
// This may be configurable in the future
97-
private final boolean fuseUnk = true;
104+
private boolean fuseUnk = true;
98105
private final double[] vocabScores;
99106
private final CharTrie neverSplit;
100107
private final CharArraySet neverSplitHash;
@@ -104,6 +111,7 @@ static UnigramTokenizer build(List<String> neverSplit, List<String> dictionary,
104111
// This is a buffer that is reused per token for decoding the normalized char-sequence into utf-8 bytes
105112
// It's usage is NOT thread safe
106113
private byte[] normalizedByteBuffer = new byte[128];
114+
private boolean byteFallback = false; // If true, decompose unknown pieces into UTF-8 byte pieces
107115

108116
public UnigramTokenizer(
109117
double minScore,
@@ -127,6 +135,31 @@ public UnigramTokenizer(
127135
this.whitespaceTokenizer = new SimpleWhitespaceTokenizer();
128136
}
129137

138+
public UnigramTokenizer(
139+
double minScore,
140+
double[] vocabScores,
141+
CharTrie neverSplit,
142+
CharArraySet neverSplitHash,
143+
Map<BytesRef, Integer> vocabToId,
144+
BytesTrie vocabTrie,
145+
int unknownTokenId,
146+
boolean byteFallback
147+
) {
148+
super();
149+
this.tokens = new LinkedList<>();
150+
this.tokenizedValues = new ArrayList<>();
151+
this.minScore = minScore;
152+
this.neverSplit = neverSplit;
153+
this.neverSplitHash = neverSplitHash;
154+
this.vocabToId = vocabToId;
155+
this.vocabTrie = vocabTrie;
156+
this.unknownTokenId = unknownTokenId;
157+
this.vocabScores = vocabScores;
158+
this.whitespaceTokenizer = new SimpleWhitespaceTokenizer();
159+
this.byteFallback = byteFallback;
160+
this.fuseUnk = byteFallback == false;
161+
}
162+
130163
List<DelimitedToken.Encoded> getTokenizedValues() {
131164
return tokenizedValues;
132165
}
@@ -231,6 +264,22 @@ public boolean incrementToken() throws IOException {
231264
return false;
232265
}
233266

267+
private int[] decomposeBytePieces(CharSequence maybeTokenized) {
268+
assert this.byteFallback;
269+
270+
byte[] bytes = maybeTokenized.toString().getBytes(StandardCharsets.UTF_8);
271+
int[] pieces = new int[bytes.length];
272+
for (int i = 0; i < bytes.length; i++) {
273+
BytesRef decomposedToken = new BytesRef(String.format("<0x%02X>", bytes[i]));
274+
Integer piece = vocabToId.get(decomposedToken);
275+
if (piece == null) {
276+
piece = unknownTokenId;
277+
}
278+
pieces[i] = piece;
279+
}
280+
return pieces;
281+
}
282+
234283
/**
235284
* This algorithm does the following:
236285
*
@@ -309,7 +358,21 @@ List<DelimitedToken.Encoded> tokenize(CharSequence inputSequence, IntToIntFuncti
309358
while (endsAtBytes > 0) {
310359
BestPathNode node = bestPathNodes[endsAtBytes];
311360
int startsAtBytes = node.startsAtBytePos;
312-
if (node.id == unknownTokenId && fuseUnk) {
361+
if (node.id == unknownTokenId && byteFallback) {
362+
CharSequence multiByteSequence = inputSequence.subSequence(node.startsAtCharPos, endsAtChars);
363+
byte[] bytes = multiByteSequence.toString().getBytes(StandardCharsets.UTF_8);
364+
int[] pieces = decomposeBytePieces(multiByteSequence);
365+
for (int i = pieces.length - 1; i >= 0; i--) {
366+
results.add(
367+
new DelimitedToken.Encoded(
368+
String.format("<0x%02X>", bytes[i]),
369+
pieces[i],
370+
offsetCorrection.apply(node.startsAtCharPos),
371+
offsetCorrection.apply(startsAtBytes + i)
372+
)
373+
);
374+
}
375+
} else if (node.id == unknownTokenId && fuseUnk) {
313376
unknownTokens.add(
314377
new DelimitedToken.Encoded(
315378
new String(normalizedByteBuffer, startsAtBytes, endsAtBytes - startsAtBytes, StandardCharsets.UTF_8),

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/XLMRobertaTokenizer.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ protected Reader initReader(String fieldName, Reader reader) {
284284

285285
@Override
286286
protected TokenStreamComponents createComponents(String fieldName) {
287-
this.innerTokenizer = UnigramTokenizer.build(neverSplit, vocabulary, scores, unknownToken);
287+
this.innerTokenizer = UnigramTokenizer.build(neverSplit, vocabulary, scores, unknownToken, false);
288288
return new TokenStreamComponents(this.innerTokenizer);
289289
}
290290

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/DebertaV2TokenizerTests.java

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ public class DebertaV2TokenizerTests extends ESTestCase {
4343
"▁😀",
4444
"▁🇸🇴",
4545
MASK_TOKEN,
46-
"."
46+
".",
47+
"<0xC2>",
48+
"<0xAD>"
4749
);
4850
private static final List<Double> TEST_CASE_SCORES = List.of(
4951
0.0,
@@ -65,7 +67,9 @@ public class DebertaV2TokenizerTests extends ESTestCase {
6567
-10.230172157287598,
6668
-9.451579093933105,
6769
0.0,
68-
-3.0
70+
-3.0,
71+
1.0,
72+
2.0
6973
);
7074

7175
private List<String> tokenStrings(List<? extends DelimitedToken> tokens) {
@@ -96,14 +100,29 @@ public void testSurrogatePair() throws IOException {
96100
new DebertaV2Tokenization(false, false, null, Tokenization.Truncate.NONE, -1)
97101
).build()
98102
) {
99-
TokenizationResult.Tokens tokenization = tokenizer.tokenize("😀", Tokenization.Truncate.NONE, -1, 0, null).get(0);
103+
TokenizationResult.Tokens tokenization = tokenizer.tokenize(
104+
"Elastic" + "\u00AD" + "search 😀" + "\u00AD" + " fun",
105+
Tokenization.Truncate.NONE,
106+
-1,
107+
0,
108+
null
109+
).get(0);
110+
assertArrayEquals(new int[] { 4, 5, 20, 21, 6, 16, 20, 21, 8 }, tokenization.tokenIds());
111+
System.out.println(tokenization.tokens().get(0));
112+
assertThat(
113+
tokenStrings(tokenization.tokens().get(0)),
114+
contains("▁Ela", "stic", "<0xC2>", "<0xAD>", "search", "▁\uD83D\uDE00", "<0xC2>", "<0xAD>", "▁fun")
115+
);
116+
117+
tokenization = tokenizer.tokenize("😀", Tokenization.Truncate.NONE, -1, 0, null).get(0);
100118
assertThat(tokenStrings(tokenization.tokens().get(0)), contains("▁\uD83D\uDE00"));
101119

102120
tokenization = tokenizer.tokenize("Elasticsearch 😀", Tokenization.Truncate.NONE, -1, 0, null).get(0);
103121
assertThat(tokenStrings(tokenization.tokens().get(0)), contains("▁Ela", "stic", "search", "▁\uD83D\uDE00"));
104122

105123
tokenization = tokenizer.tokenize("Elasticsearch 😀 fun", Tokenization.Truncate.NONE, -1, 0, null).get(0);
106124
assertThat(tokenStrings(tokenization.tokens().get(0)), contains("▁Ela", "stic", "search", "▁\uD83D\uDE00", "▁fun"));
125+
107126
}
108127
}
109128

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/UnigramTokenizerTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ protected Reader initReader(String fieldName, Reader reader) {
153153

154154
@Override
155155
protected TokenStreamComponents createComponents(String fieldName) {
156-
UnigramTokenizer tokenizer = UnigramTokenizer.build(NEVER_SPLIT, dictionary, scores, unknownToken);
156+
UnigramTokenizer tokenizer = UnigramTokenizer.build(NEVER_SPLIT, dictionary, scores, unknownToken, false);
157157
return new TokenStreamComponents(tokenizer);
158158
}
159159
}

0 commit comments

Comments
 (0)