Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/111852.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 111852
summary: Add DeBERTa-V2/V3 tokenizer
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenizationUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.DebertaV2Tokenization;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.DebertaV2TokenizationUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfigUpdate;
Expand Down Expand Up @@ -547,6 +549,13 @@ public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
(p, c) -> XLMRobertaTokenization.fromXContent(p, (boolean) c)
)
);
namedXContent.add(
new NamedXContentRegistry.Entry(
Tokenization.class,
new ParseField(DebertaV2Tokenization.NAME),
(p, c) -> DebertaV2Tokenization.fromXContent(p, (boolean) c)
)
);

namedXContent.add(
new NamedXContentRegistry.Entry(
Expand Down Expand Up @@ -583,6 +592,13 @@ public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
(p, c) -> XLMRobertaTokenizationUpdate.fromXContent(p)
)
);
namedXContent.add(
new NamedXContentRegistry.Entry(
TokenizationUpdate.class,
DebertaV2TokenizationUpdate.NAME,
(p, c) -> DebertaV2TokenizationUpdate.fromXContent(p)
)
);

return namedXContent;
}
Expand Down Expand Up @@ -791,6 +807,7 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
);
namedWriteables.add(new NamedWriteableRegistry.Entry(Tokenization.class, RobertaTokenization.NAME, RobertaTokenization::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(Tokenization.class, XLMRobertaTokenization.NAME, XLMRobertaTokenization::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(Tokenization.class, DebertaV2Tokenization.NAME, DebertaV2Tokenization::new));

namedWriteables.add(
new NamedWriteableRegistry.Entry(
Expand Down Expand Up @@ -827,6 +844,9 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
XLMRobertaTokenizationUpdate::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(TokenizationUpdate.class, DebertaV2Tokenization.NAME, DebertaV2TokenizationUpdate::new)
);

return namedWriteables;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.ml.inference.trainedmodel;

import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;

import java.io.IOException;

public class DebertaV2Tokenization extends Tokenization {

public static final String NAME = "deberta_v2";
public static final String MASK_TOKEN = "[MASK]";

public static ConstructingObjectParser<DebertaV2Tokenization, Void> createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<DebertaV2Tokenization, Void> parser = new ConstructingObjectParser<>(
NAME,
ignoreUnknownFields,
a -> new DebertaV2Tokenization(
(Boolean) a[0],
(Boolean) a[1],
(Integer) a[2],
a[3] == null ? null : Truncate.fromString((String) a[3]),
(Integer) a[4]
)
);
declareCommonFields(parser);
return parser;
}

private static final ConstructingObjectParser<DebertaV2Tokenization, Void> LENIENT_PARSER = createParser(true);
private static final ConstructingObjectParser<DebertaV2Tokenization, Void> STRICT_PARSER = createParser(false);

public static DebertaV2Tokenization fromXContent(XContentParser parser, boolean lenient) {
return lenient ? LENIENT_PARSER.apply(parser, null) : STRICT_PARSER.apply(parser, null);
}

public DebertaV2Tokenization(
Boolean doLowerCase,
Boolean withSpecialTokens,
Integer maxSequenceLength,
Truncate truncate,
Integer span
) {
super(doLowerCase, withSpecialTokens, maxSequenceLength, truncate, span);
}

public DebertaV2Tokenization(StreamInput in) throws IOException {
super(in);
}

@Override
Tokenization buildWindowingTokenization(int updatedMaxSeqLength, int updatedSpan) {
return new DebertaV2Tokenization(doLowerCase, withSpecialTokens, updatedMaxSeqLength, truncate, updatedSpan);
}

@Override
public String getMaskToken() {
return MASK_TOKEN;
}

@Override
XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException {
return builder;
}

@Override
public String getWriteableName() {
return NAME;
}

@Override
public String getName() {
return NAME;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.ml.inference.trainedmodel;

import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.Optional;

public class DebertaV2TokenizationUpdate extends AbstractTokenizationUpdate {
public static final ParseField NAME = new ParseField(DebertaV2Tokenization.NAME);

public static ConstructingObjectParser<DebertaV2TokenizationUpdate, Void> PARSER = new ConstructingObjectParser<>(
"deberta_v2_tokenization_update",
a -> new DebertaV2TokenizationUpdate(a[0] == null ? null : Tokenization.Truncate.fromString((String) a[0]), (Integer) a[1])
);

static {
declareCommonParserFields(PARSER);
}

public static DebertaV2TokenizationUpdate fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

public DebertaV2TokenizationUpdate(@Nullable Tokenization.Truncate truncate, @Nullable Integer span) {
super(truncate, span);
}

public DebertaV2TokenizationUpdate(StreamInput in) throws IOException {
super(in);
}

@Override
public Tokenization apply(Tokenization originalConfig) {
if (originalConfig instanceof DebertaV2Tokenization debertaV2Tokenization) {
if (isNoop()) {
return debertaV2Tokenization;
}

Tokenization.validateSpanAndTruncate(getTruncate(), getSpan());

if (getTruncate() != null && getTruncate().isInCompatibleWithSpan() == false) {
// When truncate value is incompatible with span wipe out
// the existing span setting to avoid an invalid combination of settings.
// This avoids the user have to set span to the special unset value
return new DebertaV2Tokenization(
debertaV2Tokenization.doLowerCase(),
debertaV2Tokenization.withSpecialTokens(),
debertaV2Tokenization.maxSequenceLength(),
getTruncate(),
null
);
}

return new DebertaV2Tokenization(
debertaV2Tokenization.doLowerCase(),
debertaV2Tokenization.withSpecialTokens(),
debertaV2Tokenization.maxSequenceLength(),
Optional.ofNullable(this.getTruncate()).orElse(originalConfig.getTruncate()),
Optional.ofNullable(this.getSpan()).orElse(originalConfig.getSpan())
);
}
throw ExceptionsHelper.badRequestException(
"Tokenization config of type [{}] can not be updated with a request of type [{}]",
originalConfig.getName(),
getName()
);
}

@Override
public String getWriteableName() {
return NAME.getPreferredName();
}

@Override
public String getName() {
return NAME.getPreferredName();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ public static TokenizationUpdate tokenizationFromMap(Map<String, Object> map) {
RobertaTokenizationUpdate.NAME.getPreferredName(),
RobertaTokenizationUpdate::new,
XLMRobertaTokenizationUpdate.NAME.getPreferredName(),
XLMRobertaTokenizationUpdate::new
XLMRobertaTokenizationUpdate::new,
DebertaV2Tokenization.NAME,
DebertaV2TokenizationUpdate::new
);

Map<String, Object> tokenizationConfig = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ public enum Truncate {
public boolean isInCompatibleWithSpan() {
return false;
}
};
},
BALANCED;

public boolean isInCompatibleWithSpan() {
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ public void testTokenizationFromMap() {
);
assertThat(
e.getMessage(),
containsString("unknown tokenization type expecting one of [bert, bert_ja, mpnet, roberta, xlm_roberta] got [not_bert]")
containsString(
"unknown tokenization type expecting one of [bert, bert_ja, deberta_v2, mpnet, roberta, xlm_roberta] got [not_bert]"
)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,6 @@ boolean isWithSpecialTokens() {
return withSpecialTokens;
}

@Override
int defaultSpanForChunking(int maxWindowSize) {
return (maxWindowSize - numExtraTokensForSingleSequence()) / 2;
}

@Override
int getNumExtraTokensForSeqPair() {
return 3;
Expand Down
Loading
Loading