megatron.checkpointing#

Description

Input/output checkpointing.

Functions

check_checkpoint_args(checkpoint_args)

Ensure fixed arguments for a model are the same for the input arguments and the one retrieved from checkpoint.

ensure_directory_exists(filename)

Build filename's path if it does not already exists.

find_checkpoint_rank_0(checkpoints_path, ...)

Finds the checkpoint for rank 0 without knowing if we are using pipeline parallelism or not.

fix_query_key_value_ordering(model, ...)

Fix up query/key/value matrix ordering if checkpoint version is smaller than 2.0

get_checkpoint_name(checkpoints_path, iteration)

Determine the directory name for this rank's checkpoint.

get_checkpoint_names(checkpoints_path, ...)

Determine the directory name for this rank's checkpoint.

get_checkpoint_tracker_filename(checkpoints_path)

Tracker file rescords the latest chckpoint during training to restart from.

get_checkpoint_version()

get_rng_state()

collect rng state across data parallel ranks

load_args_from_checkpoint(args[, load_arg])

Set required arguments from the checkpoint specified in the arguments.

load_biencoder_checkpoint(model[, ...])

selectively load retrieval models for indexing/retrieving from saved checkpoints

load_checkpoint(model, optimizer, ...[, ...])

Load a model checkpoint and return the iteration. strict (bool): whether to strictly enforce that the keys in state_dict of the checkpoint match the names of parameters and buffers in model.

read_metadata(tracker_filename)

save_checkpoint(iteration, model, optimizer, ...)

Save a model checkpoint.

set_checkpoint_version(value)