megatron.text_generation.sampling#

Description

Sampling utilities. Part of this code is inspired by:

Functions

modify_logits_for_top_k_filtering(logits, top_k)

Set the logits for none top-k values to -inf.

modify_logits_for_top_p_filtering(logits, top_p)

Set the logits for none top-p values to -inf.

sample(logits[, top_k, top_p, temperature, ...])

Sample and generate a token. Note: logits has the dimension [b, v] where b is the batch size and v is the vocabulary size. If vocab_size is provided, we will make sure the sample that is generated is in [0, vocab-size). This will avoid out of vocabulary generations due to padding.