Top-k Sampling

Here we first pick the top-k tokens from the distribution of logits, and then sample from them.

Here's an experiment that uses these sampling techniques.

15import torch 16 17from labml_nn.sampling import Sampler

Top-k Sampler

20class TopKSampler(Sampler):
  • k is the number of tokens to pick
  • sampler is the sampler to use for the top-k tokens

sampler can be any sampler that takes a logits tensor as input and returns a token tensor; e.g. `TemperatureSampler'.

24 def __init__(self, k: int, sampler: Sampler):
32 self.k = k 33 self.sampler = sampler

Sample from logits

35 def __call__(self, logits: torch.Tensor):

New logits filled with ; i.e. zero probability

40 zeros = logits.new_ones(logits.shape) * float('-inf')

Pick the largest logits and their indices

42 values, indices = torch.topk(logits, self.k, dim=-1)

Set the values of the top-k selected indices to actual logits. Logits of other tokens remain

45 zeros.scatter_(-1, indices, values)

Sample from the top-k logits with the specified sampler.

48 return self.sampler(zeros)