megatron.optimizer.distrib_optimizer.DistributedOptimizer#

class megatron.optimizer.distrib_optimizer.DistributedOptimizer(optimizer, clip_grad, log_num_zeros_in_grad, params_have_main_grad, use_contiguous_buffers_in_local_ddp, fp16, bf16, params_dtype, grad_scaler, models)#

Bases: MixedPrecisionOptimizer

Distributed optimizer, for all data types (fp16, bf16, and fp32).

Parameters:
  • optimizer – base optimizer such as Adam or SGD

  • clip_grad – clip gradeints with this global L2 norm. Note that clipping is ignored if clip_grad == 0

  • log_num_zeros_in_grad – return number of zeros in the gradients.

  • params_have_main_grad – flag indicating if parameters have a main_grad field. If this is set, we are assuming that the model parameters are store in the main_grad field instead of the typical grad field. This happens for the DDP cases where there is a continuous buffer holding the gradients. For example for bfloat16, we want to do gradient accumulation and all-reduces in float32 and as a result we store those gradients in the main_grad. Note that main grad is not necessarily in float32.

  • use_contiguous_buffers_in_local_ddp – if true, the local DDP model is using a contiguous buffer to hold the model grads.

  • fp16 – if true, the model is running in fp16.

  • bf16 – if true, the model is running in bfloat16.

  • grad_scaler – used for scaling gradients. Note that this can be None. This case happens when bf16 = True and we don’t use any loss scale. Note that for bf16 = True, we can have a constnat gradient scaler. Also for bf16 = False, we always require a grad scaler.

  • models – list of models (i.e., the virtual pipelining models). This is used by the distributed optimizer for mapping parameters.

classmethod build_model_and_main_param_groups(model_gbuf_ranges, param_gbuf_map, opt_group_ranges)#

Create main parameter groups needed for the optimizer step.

These groups encompass both: 1) groups used by this class, for reducing/gather, and 2) groups used by the inner optimizer for the parameter update. Given that the conceptual grad buffer partitioning (created in earlier method) doesn’t respect parameter boundaries, the optimizer operates on shards of the model parameters, rather than the full parameters.

classmethod build_model_gbuf_param_range_map(model, dtype, gbuf_world_range)#

Build mapping from param reference to grad buffer shard ranges.

This method builds a mapping from parameter references to grad buffer shard ranges, specific to each data-parallel (DP) rank’s set of ‘owned’ parameters. Each grad buffer (padded to be an even multiple of DP-world-size) is conceptually divided into DP-world-size contiguous regions, where each DP rank ‘owns’ a contiguous regions. Ownership in this sense means DP rank is responsible for reducing the relevant subset of grads, and updating the relevant subset of params.

This conceptual partitioning of the grad buffer does NOT respect parameter boundaries, and as such it is assumed that each created range references a shard (or subset) of the full parameter. It is easiest to think of each DP rank as operating (i.e., reducing, gathering) purely on views into the grad buffer, for all model-to- main & main-to-model operations.

This method creates three ranges: - The param’s range within the entire grad buffer (i.e., world index). - The param’s range within the DP rank’s local view of the grad buffer. - The param’s range within itself (i.e., its shard).

classmethod build_model_gbuf_range(model, dtype)#

Build mapping between params and their grad buffers.

This method does the initial setup for the method above. This setup includes determining the shard ranges into the DDP’s grad buffer for each data-parallel (DP) rank. Each DP rank keeps range info for all other DP ranks, for the purpose of creating args for reduce-scatter and all-gather.

classmethod build_model_gbuf_range_map(model)#

Create param-to-grad-buffer mappings, for grad buffer data types within a specific virtual model.

classmethod build_model_param_gbuf_map(model_gbuf_ranges)#

Create a reverse of the model_gbuf_ranges, for referencing in opposite direction.

classmethod build_optimizer_group_ranges(param_groups, model_gbuf_ranges)#

Create optimizer groups.

Given the set of parameter shard ranges that are owned by the current data-parallel (DP) rank, gather the set of parameters that will be used (in the method below) to create the current DP’s optimizer groups.

gather_model_params(args, timers)#

All-gather updated model params.

The DDP’s param buffer is used for the all-gather, and thus no tensors are dynamically allocated. After the all-gather, the params can be copied from the param buffer to the param.

static get_model_buffer_dp_views(model_buffers)#

Get shard views of each of the DDP’s param/grad buffers.

In this nested list, the top level is grouped by the virtual model index and the buffer’s data type. The sub-level is a list of shards of that buffer, where each shard in the list represents a contiguous view of the buffer, that is owned by a data-parallel rank. The shard boundary does not respect parameter boundaries, and so the elements of some parameters are split across data parallel ranks.

Additionally, return references to the entire buffers, for use in reduce_scatter_tensor and _all_gather_base.

get_model_parallel_group()#

With the distributed optimizer, the model parallel group is the entire world.

get_model_param_range_map(param)#

Given a model param, get the index sub-range of the param that this data-parallel rank owns.

load_state_dict(state_dict)#

Load the state dict.

reduce_model_grads(args, timers)#

Reduce-scatter model grads.

The DDP’s grad buffer is used for the reduce-scatter, and thus no tensors are dynamically allocated.

Note: this is a different order of reduction, versus the non- distributed optimizer, which reduces: 1) layernorm grads, 2) all grads, 3) embedding grads.

state_dict()#

The state dict must contain the fp32-from-float16 shards.

zero_grad(set_to_none=True)#

Zero grads.

We only need to zero the model related parameters, i.e., model_float16_groups & model_fp32_groups. We additionally zero the remaining groups as a memory optimization to reduce fragmentation; in the case of set_to_none==True, the space used by this field can be safely deallocated at this point.