megatron.model.biencoder_model.BiEncoderModel#

class megatron.model.biencoder_model.BiEncoderModel(num_tokentypes=1, parallel_output=True, only_query_model=False, only_context_model=False, biencoder_shared_query_context_model=False, pre_process=True, post_process=True, model_type=None)#

Bases: MegatronModule

Bert-based module for Biencoder model.

static embed_text(model, tokens, attention_mask, token_types)#

Embed a batch of tokens using the model

forward(query_tokens, query_attention_mask, query_types, context_tokens, context_attention_mask, context_types)#

Run a forward pass for each of the models and return the respective embeddings.

init_state_dict_from_bert()#

Initialize the state from a pretrained BERT model on iteration zero of ICT pretraining

load_state_dict(state_dict, strict=True)#

Load the state dicts of each of the models

set_input_tensor(input_tensor)#

See megatron.model.transformer.set_input_tensor()

state_dict_for_save_checkpoint(prefix='', keep_vars=False)#

Save dict with state dicts of each of the models.