View source on GitHub |
Implementation for sampling strategies (go/decoding-tf-nlp).
tfm.nlp.ops.SamplingModule( symbols_to_logits_fn, vocab_size: int, max_decode_length: int, eos_id: int, padded_decode: bool, length_normalization_fn: Optional[Callable[[int, tf.DType], float]] = None, top_k=0, top_p=1.0, sample_temperature=0.0, enable_greedy: bool = True, dtype: tf.DType = tf.float32, decoding_name: Optional[str] = None, extra_cache_output: bool = False ) Methods
generate
generate( initial_ids: tf.Tensor, initial_cache: Dict[str, tf.Tensor], initial_log_probs: Optional[tf.Tensor] = None ) -> Output Implements the decoding strategy (beam_search or sampling).
| Args | |
|---|---|
initial_ids | initial ids to pass into the symbols_to_logits_fn. int tensor with shape [batch_size, 1] |
initial_cache | dictionary for caching model outputs from previous step. |
initial_log_probs | Optionally initial log probs if there is a prefix sequence we want to start to decode from. |
| Returns | |
|---|---|
| Tuple of tensors representing finished_sequence: shape [batch, max_seq_length] finished_scores: [batch] first_cache: The cache after init token |
inf
inf() Returns a value close to infinity, but is still finite in dtype.
This is useful to get a very large value that is still zero when multiplied by zero. The floating-point "Inf" value is NaN when multiplied by zero.
| Returns | |
|---|---|
| A very large value. |
View source on GitHub