megatron.training.pretrain#

megatron.training.pretrain(args, train_valid_test_dataset_provider, model_provider_func, model_type: ModelType, forward_step_func, process_non_loss_data_func=None, collate_fn=None)#

Main training program.

This function will run the followings in the order provided:
  1. initialize Megatron.

  2. setup model, optimizer and lr schedule using the model_provider_func.

  3. call train_val_test_data_provider to get train/val/test datasets.

  4. train the modle using the forward_step_func.

Parameters:
  • train_valid_test_dataset_provider – a function that takes the size of train/valid/test dataset and returns train, valid, test datasets.

  • model_provider_func – a function that returns a vanilla version of the model. By vanilla we mean a simple model on cpu with no fp16 or ddp.

  • model_type – an enum that specifies the type of model being trained.

  • forward_step_func – a function that takes a data iterator and model, and returns a loss scalar with a dictionary with key:values being the info we would like to monitor during training, for example lm-loss: value. We also require that this function add batch generator to the timers class.

  • process_non_loss_data_func – a function to post process outputs of the network. It can be used for dumping output tensors (e.g images) to tensorboard. It takes collected data`(list of tensors), `current iteration index and tensorboard writer as arguments.

  • extra_args_provider – a function that takes a parser and adds arguments to it. It is used for programs to add their own arguments.

  • args_defaults – a dictionary from argument-name to argument-value. It to set already parse arguments.