Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions tests/v1/sample/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,46 @@ def _create_allowed_token_ids(
return mask


def _create_bad_words_token_ids(
batch_size: int, vocab_size: int,
bad_words_lengths: list[tuple[int]]) -> dict[int, list[list[int]]]:
bad_words_token_ids = {}
for batch_idx in range(batch_size):
token_ids_single_batch = []
for bad_words_length in bad_words_lengths:
token_ids = np.random.choice(vocab_size,
size=bad_words_length,
replace=True).tolist()
token_ids_single_batch.append(token_ids)
bad_words_token_ids[batch_idx] = token_ids_single_batch
return bad_words_token_ids


def _update_output_token_ids_for_bad_words(metadata: SamplingMetadata,
vocab_size: int) -> list[list[int]]:
bad_words_last_tokens = []
for batch_idx in range(len(metadata.bad_words_token_ids)):
bad_words_token_ids = metadata.bad_words_token_ids[batch_idx]
output_token_ids = metadata.output_token_ids[batch_idx]
bad_words_last_token: list[int] = []
for i, bad_word_token_ids in enumerate(bad_words_token_ids):
if len(bad_word_token_ids) == 1:
# Single token id always affects logits
bad_words_last_token.append(bad_word_token_ids[0])
else:
prefix_length = len(bad_word_token_ids) - 1
has_bad_words = np.random.choice([True, False])
if has_bad_words:
output_token_ids[-prefix_length:] = bad_word_token_ids[:-1]
bad_words_last_token.append(bad_word_token_ids[-1])
break # Maximum one update to output_token_ids
else: # Make sure no accidental match to bad words
output_token_ids[-1] = (bad_word_token_ids[-2] +
1) % vocab_size
bad_words_last_tokens.append(bad_words_last_token)
return bad_words_last_tokens


def _create_default_sampling_metadata(
num_output_tokens: int,
batch_size: int,
Expand Down Expand Up @@ -112,6 +152,7 @@ def _create_default_sampling_metadata(
min_tokens={},
logit_bias=[None] * batch_size,
allowed_token_ids_mask=None,
bad_words_token_ids={},
)
return fake_sampling_metadata

Expand Down Expand Up @@ -467,3 +508,34 @@ def test_sampler_allowed_token_ids(device: str, batch_size: int,
"inf"), f"{batch_idx}, {token_id}"
else:
assert logits_for_req[token_id] != -float("inf")


@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("bad_words_lengths", [(1, ), (1, 3), (2, 2)])
def test_sampler_bad_words(device: str, batch_size: int,
bad_words_lengths: list[tuple[int]]):
"""
Test to verify that when the bad words restriction is present, tokens
are penalized based on their match with the bad words.
"""
torch.set_default_device(device)
# Create fake logits where each token is assigned the same
# logit value.
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
sampling_metadata = _create_default_sampling_metadata(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
sampling_metadata.bad_words_token_ids = _create_bad_words_token_ids(
batch_size, VOCAB_SIZE, bad_words_lengths)
bad_words_last_tokens = _update_output_token_ids_for_bad_words(
sampling_metadata, VOCAB_SIZE)
sampler = Sampler()
logits = sampler.apply_bad_words(fake_logits, sampling_metadata)
logits = logits.cpu()
for batch_idx in range(batch_size):
logits_for_req = logits[batch_idx]
for token_id in range(VOCAB_SIZE):
if token_id in bad_words_last_tokens[batch_idx]:
assert logits_for_req[token_id] == -float("inf")
else:
assert logits_for_req[token_id] != -float("inf")
6 changes: 6 additions & 0 deletions tests/v1/worker/test_gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def _construct_expected_sampling_metadata(
VOCAB_SIZE,
dtype=torch.bool,
device=device)
bad_words_token_ids = {}
for req in reqs:
if req.req_id not in req_ids_retained:
continue
Expand All @@ -123,6 +124,8 @@ def _construct_expected_sampling_metadata(
if req.sampling_params.allowed_token_ids:
allowed_token_ids_mask[index_in_input_batch][
req.sampling_params.allowed_token_ids] = True
bad_words_token_ids[
index_in_input_batch] = req.sampling_params.bad_words_token_ids

return SamplingMetadata(
temperature=torch.tensor(temperature, dtype=torch.float,
Expand Down Expand Up @@ -159,6 +162,7 @@ def _construct_expected_sampling_metadata(
and all(x == 1 for x in repetition_penalties)),
logit_bias=logit_bias,
allowed_token_ids_mask=allowed_token_ids_mask,
bad_words_token_ids=bad_words_token_ids,
)


Expand Down Expand Up @@ -284,6 +288,8 @@ def same(t1: Optional[torch.Tensor], t2: Optional[torch.Tensor]) -> bool:
assert torch.allclose(
expected_sampling_metadata.allowed_token_ids_mask,
sampling_metadata.allowed_token_ids_mask)
assert expected_sampling_metadata.bad_words_token_ids == \
sampling_metadata.bad_words_token_ids


@pytest.mark.parametrize("device", CUDA_DEVICES)
Expand Down
52 changes: 51 additions & 1 deletion vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

from vllm.logger import init_logger
from vllm.logits_process import LogitsProcessor
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer

logger = init_logger(__name__)

Expand Down Expand Up @@ -199,7 +201,6 @@ class SamplingParams(
seed: Optional[int] = None
stop: Optional[Union[str, list[str]]] = None
stop_token_ids: Optional[list[int]] = None
bad_words: Optional[list[str]] = None
ignore_eos: bool = False
max_tokens: Optional[int] = 16
min_tokens: int = 0
Expand Down Expand Up @@ -228,6 +229,10 @@ class SamplingParams(
logit_bias: Optional[dict[int, float]] = None
allowed_token_ids: Optional[list[int]] = None

# Fields used for bad words
bad_words: Optional[list[str]] = None
_bad_words_token_ids: list[list[int]] = msgspec.field(default_factory=list)

@staticmethod
def from_optional(
n: Optional[int] = 1,
Expand Down Expand Up @@ -458,6 +463,46 @@ def update_from_generation_config(
eos_ids.update(self.stop_token_ids)
self.stop_token_ids = list(eos_ids)

def update_from_tokenizer(self, tokenizer: AnyTokenizer) -> None:
if self.bad_words is None:
return
for bad_word in self.bad_words:
# To prohibit words both at the beginning
# and in the middle of text
# (related to add_prefix_space tokenizer parameter)
for add_prefix_space in [False, True]:
prefix = " " if add_prefix_space else ""
prompt = prefix + bad_word.lstrip()

if isinstance(tokenizer, MistralTokenizer):
# Mistral tokenizers should not add special tokens
prompt_token_ids = tokenizer.encode(text=prompt)
else:
prompt_token_ids = tokenizer.encode(
text=prompt, add_special_tokens=False)

# If no space at the beginning
# or if prefix space produces a new word token
if (not add_prefix_space) or (
add_prefix_space and prompt_token_ids[0]
!= self._bad_words_token_ids[-1][0]
and len(prompt_token_ids) == len(
self._bad_words_token_ids[-1])):
self._bad_words_token_ids.append(prompt_token_ids)

invalid_token_ids = [
token_id for bad_words_token_ids in self._bad_words_token_ids
for token_id in bad_words_token_ids
if token_id < 0 or token_id > tokenizer.max_token_id
]
if len(invalid_token_ids) > 0:
raise ValueError(
f"The model vocabulary size is {tokenizer.max_token_id+1},"
f" but the following tokens"
f" were specified as bad: {invalid_token_ids}."
f" All token id values should be integers satisfying:"
f" 0 <= token_id <= {tokenizer.max_token_id}.")

@cached_property
def sampling_type(self) -> SamplingType:
if self.temperature < _SAMPLING_EPS:
Expand All @@ -470,6 +515,11 @@ def sampling_type(self) -> SamplingType:
def all_stop_token_ids(self) -> set[int]:
return self._all_stop_token_ids

@property
def bad_words_token_ids(self) -> list[list[int]]:
# For internal use only. Backward compatibility not guaranteed
return self._bad_words_token_ids

def clone(self) -> "SamplingParams":
"""Deep copy, but maybe not the LogitsProcessor objects.

Expand Down
5 changes: 2 additions & 3 deletions vllm/v1/engine/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,6 @@ def _validate_supported_sampling_params(
# Best of not yet supported.
if params.best_of is not None and params.best_of > 1:
raise ValueError("VLLM V1 does not yet support best_of.")
# Bad words not yet supported.
if params.bad_words:
raise ValueError("VLLM V1 does not yet support bad_words.")
# Logits processors not supported.
if params.logits_processors:
raise ValueError("VLLM V1 does not support per request "
Expand Down Expand Up @@ -209,6 +206,8 @@ def process_inputs(
sampling_params = params.clone()
sampling_params.update_from_generation_config(
self.generation_config_fields, eos_token_id)
sampling_params.update_from_tokenizer(
self.tokenizer.get_lora_tokenizer(lora_request))

# Multimodal related.
# Compute MM hashes (if enabled)
Expand Down
3 changes: 3 additions & 0 deletions vllm/v1/sample/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,6 @@ class SamplingMetadata:
# `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size,
# vocab size).
allowed_token_ids_mask: Optional[torch.Tensor]

# req_index -> bad_words_token_ids
bad_words_token_ids: dict[int, list[list[int]]]
38 changes: 38 additions & 0 deletions vllm/v1/sample/ops/bad_words.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# SPDX-License-Identifier: Apache-2.0

import torch

_SMALLEST_LOGIT = float("-inf")


def _apply_bad_words_single_batch(
logits: torch.Tensor,
bad_words_token_ids: list[list[int]],
past_tokens_ids: list[int],
) -> None:
for bad_word_ids in bad_words_token_ids:
if len(bad_word_ids) > len(past_tokens_ids) + 1:
continue

prefix_length = len(bad_word_ids) - 1
last_token_id = bad_word_ids[-1]
if prefix_length > 0:
actual_prefix = past_tokens_ids[-prefix_length:]
else:
actual_prefix = []
expected_prefix = bad_word_ids[:prefix_length]

assert len(actual_prefix) == len(expected_prefix)

if actual_prefix == expected_prefix:
logits[last_token_id] = _SMALLEST_LOGIT


def apply_bad_words(
logits: torch.Tensor,
bad_words_token_ids: dict[int, list[list[int]]],
past_tokens_ids: list[list[int]],
) -> None:
for i in range(logits.shape[0]):
_apply_bad_words_single_batch(logits[i], bad_words_token_ids[i],
past_tokens_ids[i])
16 changes: 16 additions & 0 deletions vllm/v1/sample/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from vllm.v1.outputs import LogprobsTensors, SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.bad_words import apply_bad_words
from vllm.v1.sample.ops.penalties import (apply_all_penalties,
apply_min_token_penalties)
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
Expand Down Expand Up @@ -38,6 +39,8 @@ def forward(
logits = logits.to(torch.float32)
# Apply allowed token ids.
logits = self.apply_allowed_token_ids(logits, sampling_metadata)
# Apply bad words exclusion.
logits = self.apply_bad_words(logits, sampling_metadata)
# Apply logits bias.
logits = self.apply_logits_bias(logits, sampling_metadata)
# Apply penalties (e.g., min_tokens, freq_penalties).
Expand Down Expand Up @@ -237,3 +240,16 @@ def apply_allowed_token_ids(
logits.masked_fill_(sampling_metadata.allowed_token_ids_mask,
float("-inf"))
return logits

def apply_bad_words(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
if sampling_metadata.bad_words_token_ids:
apply_bad_words(
logits,
sampling_metadata.bad_words_token_ids,
sampling_metadata.output_token_ids,
)
return logits
Loading