Skip to content

Commit 64d8dc6

Browse files
authored
[Fix][Tokenizer] Fix failure in decoding tokens for ByteLevel BPE (#2649)
This PR fixes the issue where the tokenizer would fail in decoding tokens for ByteLevel BPE when the token is not recognized by ByteLevel. E.g. in decoding, ``` "hello" -> "hello" (recognized by ByteLevel) "Ġthere" -> " there" (recognized by ByteLevel) "\n" -> not recognized by ByteLevel "\u203c" -> not recognized by ByteLevel ``` This PR adds the logic that in decoding, when the token is not recognized by ByteLevel, the original token will be returned. Then ``` "hello" -> "hello" (recognized by ByteLevel) "Ġthere" -> " there" (recognized by ByteLevel) "\n" -> "\n" (not recognized by ByteLevel) "\u203c" -> "\u203c" (not recognized by ByteLevel) ``` This behavior is align to huggingface tokenizers.
1 parent 16a79ab commit 64d8dc6

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

cpp/tokenizers/tokenizers.cc

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ inline std::string SpaceReplacerDecoder(const std::string& token) {
375375
inline std::string ByteLevelDecoder(const std::string& token) {
376376
// clang-format off
377377
// The inverse map of bytes_to_unicode. -1 means there is no mapping to this unicode.
378-
static const std::array<int, 324> unicode_to_byte_map = {
378+
static const std::array<int, 324> char_to_byte_map = {
379379
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
380380
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45,
381381
46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68,
@@ -396,20 +396,20 @@ inline std::string ByteLevelDecoder(const std::string& token) {
396396
// clang-format on
397397

398398
auto unicode_codepoints = ParseUTF8(token.c_str(), UTF8ErrorPolicy::kReturnInvalid);
399-
ICHECK(unicode_codepoints.size() != 1 || unicode_codepoints[0] != kInvalidUTF8);
399+
if (unicode_codepoints.size() == 1 && unicode_codepoints[0] == kInvalidUTF8) {
400+
return token;
401+
}
402+
400403
std::string decoded;
401404

402405
for (auto unicode_codepoint : unicode_codepoints) {
403-
ICHECK(unicode_codepoint >= 0 &&
404-
unicode_codepoint < static_cast<int>(unicode_to_byte_map.size()));
405-
int byte = unicode_to_byte_map[unicode_codepoint];
406-
if (byte == -1) {
407-
// If there is no mapping, add the codepoint itself to the result string
408-
// Some tokenizer like Phi-2 have raw tokens like \t\t
409-
decoded += static_cast<char>(unicode_codepoint);
410-
} else {
411-
decoded += static_cast<char>(byte);
406+
ICHECK(unicode_codepoint >= 0);
407+
if (unicode_codepoint >= static_cast<int>(char_to_byte_map.size()) ||
408+
char_to_byte_map[unicode_codepoint] == -1) {
409+
// If there is no mapping, return the original token
410+
return token;
412411
}
412+
decoded += static_cast<char>(char_to_byte_map[unicode_codepoint]);
413413
}
414414
return decoded;
415415
}

0 commit comments

Comments
 (0)