megatron.core.tensor_parallel.random#

Description

Classes

CheckpointFunction(*args, **kwargs)

This function is adapted from torch.utils.checkpoint with two main changes: 1) torch.cuda.set_rng_state is replaced with _set_cuda_rng_state 2) the states in the model parallel tracker are also properly tracked/set/reset.

CudaRNGStatesTracker()

Tracker for the cuda RNG states.

Functions

checkpoint(function, ...)

Checkpoint a model or part of the model.

get_cuda_rng_tracker()

Get cuda rng tracker.

model_parallel_cuda_manual_seed(seed)

Initialize model parallel cuda seed.