megatron.core.tensor_parallel.random.CudaRNGStatesTracker#

class megatron.core.tensor_parallel.random.CudaRNGStatesTracker#

Bases: object

Tracker for the cuda RNG states.

Using the add method, a cuda rng state is initialized based on the input seed and is assigned to name. Later, by forking the rng state, we can perform operations and return to our starting cuda state.

add(name, seed)#

Track the rng state.

fork(name='model-parallel-rng')#

Fork the cuda rng state, perform operations, and exit with the original state.

get_states()#

Get rng states. Copy the dictionary so we have direct pointers to the states, not just a pointer to the dictionary.

reset()#

Set to the initial state (no tracker).

set_states(states)#

Set the rng states. For efficiency purposes, we do not check the size of seed for compatibility.