megatron.core.tensor_parallel.layers.RowParallelLinear#

class megatron.core.tensor_parallel.layers.RowParallelLinear(input_size, output_size, *, bias=True, input_is_parallel=False, init_method=<function xavier_normal_>, stride=1, keep_master_weight_for_test=False, skip_bias_add=False, params_dtype=torch.float32, use_cpu_initialization=False, perform_initialization=True, gradient_accumulation_fusion=False, sequence_parallel_enabled: bool = False, world_size: int | None = None)#

Bases: Module

Linear layer with row parallelism.

The linear layer is defined as Y = XA + b. A is parallelized along its first dimension and X along its second dimension as:

A_1 |
. |
A = | . | X = [X_1, …, X_p]
. |
A_p | - -
Parameters:
  • input_size – first dimension of matrix A.

  • output_size – second dimension of matrix A.

Keyword Arguments:
  • bias – If true, add bias. Note that bias is not parallelized.

  • input_is_parallel – If true, we assume that the input is already split across the GPUs and we do not split again.

  • init_method – method to initialize weights. Note that bias is always set to zero.

  • stride – For the strided linear layers.

  • keep_master_weight_for_test – This was added for testing and should be set to False. It returns the master weights used for initialization.

  • skip_bias_add – This was added to enable performance optimization where bias can be fused with other elementwise operations. We skip adding bias but instead return it.

  • params_dtype

  • use_cpu_initialization

  • perform_initialization

  • gradient_accumulation_fusion

  • sequence_parallel_enabled

forward(input_)#
Parameters:

input – 3D tensor whose order of dimension is [sequence, batch, hidden]

Returns:

  • output

  • bias