megatron.model.biencoder_model.BiEncoderModel#
- class megatron.model.biencoder_model.BiEncoderModel(num_tokentypes=1, parallel_output=True, only_query_model=False, only_context_model=False, biencoder_shared_query_context_model=False, pre_process=True, post_process=True, model_type=None)#
Bases:
MegatronModule
Bert-based module for Biencoder model.
- static embed_text(model, tokens, attention_mask, token_types)#
Embed a batch of tokens using the model
- forward(query_tokens, query_attention_mask, query_types, context_tokens, context_attention_mask, context_types)#
Run a forward pass for each of the models and return the respective embeddings.
- init_state_dict_from_bert()#
Initialize the state from a pretrained BERT model on iteration zero of ICT pretraining
- load_state_dict(state_dict, strict=True)#
Load the state dicts of each of the models
- set_input_tensor(input_tensor)#
See megatron.model.transformer.set_input_tensor()
- state_dict_for_save_checkpoint(prefix='', keep_vars=False)#
Save dict with state dicts of each of the models.