megatron.optimizer.optimizer.MixedPrecisionOptimizer#
- class megatron.optimizer.optimizer.MixedPrecisionOptimizer(optimizer, clip_grad, log_num_zeros_in_grad, params_have_main_grad, use_contiguous_buffers_in_local_ddp, fp16, bf16, params_dtype, grad_scaler, models)#
Bases:
MegatronOptimizer
Base class for both the float-16 and the distributed optimizer.
- Parameters:
optimizer – base optimizer such as Adam or SGD
clip_grad – clip gradeints with this global L2 norm. Note that clipping is ignored if clip_grad == 0
log_num_zeros_in_grad – return number of zeros in the gradients.
params_have_main_grad – flag indicating if parameters have a main_grad field. If this is set, we are assuming that the model parameters are store in the main_grad field instead of the typical grad field. This happens for the DDP cases where there is a continuous buffer holding the gradients. For example for bfloat16, we want to do gradient accumulation and all-reduces in float32 and as a result we store those gradients in the main_grad. Note that main grad is not necessarily in float32.
use_contiguous_buffers_in_local_ddp – if true, the local DDP model is using a contiguous buffer to hold the model grads.
fp16 – if true, the model is running in fp16.
bf16 – if true, the model is running in bfloat16.
params_dtype – used by distributed optimizer.
grad_scaler – used for scaling gradients. Note that this can be None. This case happens when bf16 = True and we don’t use any loss scale. Note that for bf16 = True, we can have a constnat gradient scaler. Also for bf16 = False, we always require a grad scaler.
models – list of models (i.e., the virtual pipelining models). This is used by the distributed optimizer for mapping parameters.
- get_loss_scale()#
The output should be a cuda tensor of size 1.
- reload_model_params()#
Refreshes any internal state from the current model parameters. Call whenever the parameters are changed outside of the optimizer. For example, when we load a model from a checkpoint without loading the optimizer, the model parameters are updated but for fp16 optimizer with main parameters, the main parameters need to also be updated.