megatron.training.pretrain#
- megatron.training.pretrain(args, train_valid_test_dataset_provider, model_provider_func, model_type: ModelType, forward_step_func, process_non_loss_data_func=None, collate_fn=None)#
Main training program.
- This function will run the followings in the order provided:
initialize Megatron.
setup model, optimizer and lr schedule using the model_provider_func.
call train_val_test_data_provider to get train/val/test datasets.
train the modle using the forward_step_func.
- Parameters:
train_valid_test_dataset_provider – a function that takes the size of train/valid/test dataset and returns train, valid, test datasets.
model_provider_func – a function that returns a vanilla version of the model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
model_type – an enum that specifies the type of model being trained.
forward_step_func – a function that takes a data iterator and model, and returns a loss scalar with a dictionary with key:values being the info we would like to monitor during training, for example lm-loss: value. We also require that this function add batch generator to the timers class.
process_non_loss_data_func – a function to post process outputs of the network. It can be used for dumping output tensors (e.g images) to tensorboard. It takes collected data`(list of tensors), `current iteration index and tensorboard writer as arguments.
extra_args_provider – a function that takes a parser and adds arguments to it. It is used for programs to add their own arguments.
args_defaults – a dictionary from argument-name to argument-value. It to set already parse arguments.