![]() |
Implementation for sampling strategies (go/decoding-tf-nlp).
tfm.nlp.ops.SamplingModule( symbols_to_logits_fn, vocab_size: int, max_decode_length: int, eos_id: int, padded_decode: bool, length_normalization_fn: Optional[Callable[[int, tf.DType], float]] = None, top_k=0, top_p=1.0, sample_temperature=0.0, enable_greedy: bool = True, dtype: tf.DType = tf.float32, decoding_name: Optional[str] = None, extra_cache_output: bool = False )
Methods
generate
generate( initial_ids: tf.Tensor, initial_cache: Dict[str, tf.Tensor], initial_log_probs: Optional[tf.Tensor] = None ) -> Output
Implements the decoding strategy (beam_search or sampling).
Args | |
---|---|
initial_ids | initial ids to pass into the symbols_to_logits_fn. int tensor with shape [batch_size, 1] |
initial_cache | dictionary for caching model outputs from previous step. |
initial_log_probs | Optionally initial log probs if there is a prefix sequence we want to start to decode from. |
Returns | |
---|---|
Tuple of tensors representing finished_sequence: shape [batch, max_seq_length] finished_scores: [batch] first_cache: The cache after init token |
inf
inf()
Returns a value close to infinity, but is still finite in dtype
.
This is useful to get a very large value that is still zero when multiplied by zero. The floating-point "Inf" value is NaN when multiplied by zero.
Returns | |
---|---|
A very large value. |