mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -36,17 +36,16 @@ from ._operation import (
|
||||
from .parallel_module import ParallelModule
|
||||
from .utils import create_randomizer_with_offset
|
||||
|
||||
__all__ = ['FusedLinear1D_Col', 'FusedLinear1D_Row', 'GPT2FusedLinearConv1D_Col', 'GPT2FusedLinearConv1D_Row']
|
||||
__all__ = ["FusedLinear1D_Col", "FusedLinear1D_Row", "GPT2FusedLinearConv1D_Col", "GPT2FusedLinearConv1D_Row"]
|
||||
|
||||
# ====================================
|
||||
# For GPT Only
|
||||
# ====================================
|
||||
|
||||
|
||||
def split_fused_qkv_in_gpt2_style(qkv: torch.Tensor,
|
||||
n_fused: int,
|
||||
process_group: ProcessGroup,
|
||||
is_transposed: bool = False):
|
||||
def split_fused_qkv_in_gpt2_style(
|
||||
qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup, is_transposed: bool = False
|
||||
):
|
||||
"""
|
||||
The fused qkv tensor looks like [Q1, Q2, K1, K2, V1, V2], this function will split them into [Q1, K1, V1] and [Q2, K2, V2].
|
||||
|
||||
@@ -85,10 +84,9 @@ def split_fused_qkv_in_gpt2_style(qkv: torch.Tensor,
|
||||
return weight_of_current_rank
|
||||
|
||||
|
||||
def gather_fused_qkv_in_gpt2_style(qkv: torch.Tensor,
|
||||
n_fused: int,
|
||||
process_group: ProcessGroup,
|
||||
is_transposed: bool = False):
|
||||
def gather_fused_qkv_in_gpt2_style(
|
||||
qkv: torch.Tensor, n_fused: int, process_group: ProcessGroup, is_transposed: bool = False
|
||||
):
|
||||
"""
|
||||
The splitted qkv tensor looks like [Q1, K1, V1] and [Q2, K2, V2], this function will gather them into [Q1, Q2, K1, K2, V1, V2].
|
||||
|
||||
@@ -167,23 +165,25 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
device: torch.device = None,
|
||||
process_group: ProcessGroup = None,
|
||||
async_communication: bool = False,
|
||||
gather_output: bool = False,
|
||||
seq_parallel: bool = False,
|
||||
overlap: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
n_fused: int = 3,
|
||||
weight: Optional[Parameter] = None,
|
||||
bias_: Optional[Parameter] = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
device: torch.device = None,
|
||||
process_group: ProcessGroup = None,
|
||||
async_communication: bool = False,
|
||||
gather_output: bool = False,
|
||||
seq_parallel: bool = False,
|
||||
overlap: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
n_fused: int = 3,
|
||||
weight: Optional[Parameter] = None,
|
||||
bias_: Optional[Parameter] = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Keep input parameters
|
||||
@@ -199,7 +199,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||
self.async_communication = async_communication
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError('cannot skip bias addition if bias is None')
|
||||
raise ValueError("cannot skip bias addition if bias is None")
|
||||
|
||||
# offset the seed with randomizer index and rank
|
||||
seed = torch.random.initial_seed()
|
||||
@@ -207,14 +207,14 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||
|
||||
# sanity check
|
||||
if weight is not None:
|
||||
assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None'
|
||||
assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None"
|
||||
else:
|
||||
assert bias_ is None, 'bias_ must be None if weight is None'
|
||||
assert bias_ is None, "bias_ must be None if weight is None"
|
||||
|
||||
# Parameters.
|
||||
if weight is None:
|
||||
# Initialize weight.
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
self.weight = Parameter(torch.empty(self.in_features, self.out_features, **factory_kwargs))
|
||||
else:
|
||||
weight.data = weight.data.to(device=device, dtype=dtype)
|
||||
@@ -249,8 +249,9 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
|
||||
**kwargs) -> ParallelModule:
|
||||
def from_native_module(
|
||||
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||
) -> ParallelModule:
|
||||
r"""
|
||||
Convert a huggingface layer `Conv1D` in gpt2 to a parallelized linear layer.
|
||||
|
||||
@@ -268,8 +269,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||
|
||||
# ensure only one process group is passed
|
||||
if isinstance(process_group, (list, tuple)):
|
||||
assert len(process_group) == 1, \
|
||||
f'Expected only one process group, got {len(process_group)}.'
|
||||
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
|
||||
process_group = process_group[0]
|
||||
|
||||
tp_size = dist.get_world_size(process_group)
|
||||
@@ -278,17 +278,20 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||
|
||||
if out_features % tp_size != 0:
|
||||
raise ValueError(
|
||||
f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!")
|
||||
f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!"
|
||||
)
|
||||
|
||||
linear_1d = GPT2FusedLinearConv1D_Col(in_features=in_features,
|
||||
out_features=out_features,
|
||||
bias=bias,
|
||||
device=device,
|
||||
process_group=process_group,
|
||||
weight=module.weight,
|
||||
bias_=module.bias,
|
||||
*args,
|
||||
**kwargs)
|
||||
linear_1d = GPT2FusedLinearConv1D_Col(
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
bias=bias,
|
||||
device=device,
|
||||
process_group=process_group,
|
||||
weight=module.weight,
|
||||
bias_=module.bias,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return linear_1d
|
||||
|
||||
@@ -300,22 +303,26 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
assert input_.shape[-1] == self.weight.shape[0], \
|
||||
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1])
|
||||
assert (
|
||||
input_.shape[-1] == self.weight.shape[0]
|
||||
), "Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.".format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1]
|
||||
)
|
||||
|
||||
# Matrix multiply.
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
|
||||
if self.seq_parallel:
|
||||
input_parallel = input_
|
||||
output_parallel = matmul_gather_forward_reducescatter_backward(input_parallel, self.weight, bias,
|
||||
self.process_group, True, 1, self.overlap)
|
||||
output_parallel = matmul_gather_forward_reducescatter_backward(
|
||||
input_parallel, self.weight, bias, self.process_group, True, 1, self.overlap
|
||||
)
|
||||
else:
|
||||
# Set up backprop all-reduce.
|
||||
input_parallel = reduce_backward(input_, self.process_group)
|
||||
output_parallel = matmul_with_async_comm(input_parallel, self.weight, bias, self.process_group,
|
||||
self.async_communication)
|
||||
output_parallel = matmul_with_async_comm(
|
||||
input_parallel, self.weight, bias, self.process_group, self.async_communication
|
||||
)
|
||||
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
@@ -330,7 +337,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
||||
|
||||
|
||||
class GPT2FusedLinearConv1D_Row(ParallelModule):
|
||||
r""" Linear layer with row parallelism.
|
||||
r"""Linear layer with row parallelism.
|
||||
This layer is used to fit `Conv1D` layer (Fused QKV) in gpt2 of huggingface.
|
||||
|
||||
Args:
|
||||
@@ -351,21 +358,23 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
|
||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
device: torch.device = None,
|
||||
process_group: ProcessGroup = None,
|
||||
seq_parallel: bool = False,
|
||||
parallel_input: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
weight: Optional[Parameter] = None,
|
||||
bias_: Optional[Parameter] = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
stream_chunk_num: int = 1):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
device: torch.device = None,
|
||||
process_group: ProcessGroup = None,
|
||||
seq_parallel: bool = False,
|
||||
parallel_input: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
weight: Optional[Parameter] = None,
|
||||
bias_: Optional[Parameter] = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
stream_chunk_num: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.stream_chunk_num = stream_chunk_num
|
||||
@@ -380,7 +389,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
|
||||
self.num_partitions = dist.get_world_size(self.process_group)
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError('cannot skip bias addition if bias is None')
|
||||
raise ValueError("cannot skip bias addition if bias is None")
|
||||
|
||||
# offset the seed with randomizer index and rank
|
||||
seed = torch.random.initial_seed()
|
||||
@@ -391,14 +400,14 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
|
||||
|
||||
# sanity check
|
||||
if weight is not None:
|
||||
assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None'
|
||||
assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None"
|
||||
else:
|
||||
assert bias_ is None, 'bias_ must be None if weight is None'
|
||||
assert bias_ is None, "bias_ must be None if weight is None"
|
||||
|
||||
# Parameters.
|
||||
if weight is None:
|
||||
# Initialize weight.
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
self.weight = Parameter(torch.empty(self.in_features, self.out_features, **factory_kwargs))
|
||||
else:
|
||||
weight.data = weight.data.to(device=device, dtype=dtype)
|
||||
@@ -424,8 +433,9 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
|
||||
**kwargs) -> ParallelModule:
|
||||
def from_native_module(
|
||||
module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||
) -> ParallelModule:
|
||||
r"""
|
||||
Convert a native PyTorch linear layer to a parallelized linear layer.
|
||||
"""
|
||||
@@ -438,8 +448,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
|
||||
|
||||
# ensure only one process group is passed
|
||||
if isinstance(process_group, (list, tuple)):
|
||||
assert len(process_group) == 1, \
|
||||
f'Expected only one process group, got {len(process_group)}.'
|
||||
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
|
||||
process_group = process_group[0]
|
||||
|
||||
tp_size = dist.get_world_size(process_group)
|
||||
@@ -448,17 +457,20 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
|
||||
|
||||
if in_features % tp_size != 0:
|
||||
raise ValueError(
|
||||
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!")
|
||||
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!"
|
||||
)
|
||||
|
||||
linear_1d = GPT2FusedLinearConv1D_Row(in_features=in_features,
|
||||
out_features=out_features,
|
||||
bias=bias,
|
||||
device=device,
|
||||
process_group=process_group,
|
||||
weight=module.weight,
|
||||
bias_=module.bias,
|
||||
*args,
|
||||
**kwargs)
|
||||
linear_1d = GPT2FusedLinearConv1D_Row(
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
bias=bias,
|
||||
device=device,
|
||||
process_group=process_group,
|
||||
weight=module.weight,
|
||||
bias_=module.bias,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return linear_1d
|
||||
|
||||
@@ -485,14 +497,18 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
# Set up backprop all-reduce.
|
||||
if self.parallel_input:
|
||||
assert input_.shape[-1] == self.weight.shape[0], \
|
||||
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[0])
|
||||
assert (
|
||||
input_.shape[-1] == self.weight.shape[0]
|
||||
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[0]
|
||||
)
|
||||
input_ = input_
|
||||
else:
|
||||
assert divide(input_.shape[-1], self.num_partitions) == self.weight.shape[0], \
|
||||
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[0] * self.num_partitions)
|
||||
assert (
|
||||
divide(input_.shape[-1], self.num_partitions) == self.weight.shape[0]
|
||||
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[0] * self.num_partitions
|
||||
)
|
||||
input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group)
|
||||
|
||||
if self.stream_chunk_num > 1:
|
||||
@@ -503,9 +519,9 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
|
||||
handle_list = []
|
||||
for i in range(self.stream_chunk_num):
|
||||
output_parallel_list[i] = torch.matmul(input_, self.weight_list[i])
|
||||
handle = torch.distributed.all_reduce(output_parallel_list[i],
|
||||
group=self.process_group,
|
||||
async_op=True)
|
||||
handle = torch.distributed.all_reduce(
|
||||
output_parallel_list[i], group=self.process_group, async_op=True
|
||||
)
|
||||
handle_list.append(handle)
|
||||
# output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D)
|
||||
for handle in handle_list:
|
||||
@@ -559,21 +575,23 @@ class FusedLinear1D_Col(ParallelModule):
|
||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
device: torch.device = None,
|
||||
process_group: ProcessGroup = None,
|
||||
async_communication: bool = False,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
n_fused: int = 3,
|
||||
weight: Optional[Parameter] = None,
|
||||
bias_: Optional[Parameter] = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
device: torch.device = None,
|
||||
process_group: ProcessGroup = None,
|
||||
async_communication: bool = False,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
n_fused: int = 3,
|
||||
weight: Optional[Parameter] = None,
|
||||
bias_: Optional[Parameter] = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
):
|
||||
super().__init__()
|
||||
# Keep input parameters
|
||||
self.in_features = in_features
|
||||
@@ -586,7 +604,7 @@ class FusedLinear1D_Col(ParallelModule):
|
||||
self.async_communication = async_communication
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError('cannot skip bias addition if bias is None')
|
||||
raise ValueError("cannot skip bias addition if bias is None")
|
||||
|
||||
# offset the seed with randomizer index and rank
|
||||
seed = torch.random.initial_seed()
|
||||
@@ -594,14 +612,14 @@ class FusedLinear1D_Col(ParallelModule):
|
||||
|
||||
# sanity check
|
||||
if weight is not None:
|
||||
assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None'
|
||||
assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None"
|
||||
else:
|
||||
assert bias_ is None, 'bias_ must be None if weight is None'
|
||||
assert bias_ is None, "bias_ must be None if weight is None"
|
||||
|
||||
# Parameters.
|
||||
if weight is None:
|
||||
# Initialize weight.
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
|
||||
else:
|
||||
weight.data = weight.data.to(device=device, dtype=dtype)
|
||||
@@ -636,8 +654,9 @@ class FusedLinear1D_Col(ParallelModule):
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], n_fused: int,
|
||||
*args, **kwargs) -> ParallelModule:
|
||||
def from_native_module(
|
||||
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], n_fused: int, *args, **kwargs
|
||||
) -> ParallelModule:
|
||||
r"""
|
||||
Convert a fused `torch.nn.linear` layer to a parallelized linear layer.
|
||||
|
||||
@@ -654,19 +673,20 @@ class FusedLinear1D_Col(ParallelModule):
|
||||
|
||||
# ensure only one process group is passed
|
||||
if isinstance(process_group, (list, tuple)):
|
||||
assert len(process_group) == 1, \
|
||||
f'Expected only one process group, got {len(process_group)}.'
|
||||
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
|
||||
process_group = process_group[0]
|
||||
|
||||
linear_1d = FusedLinear1D_Col(in_features=in_features,
|
||||
out_features=out_features,
|
||||
bias=bias,
|
||||
device=device,
|
||||
process_group=process_group,
|
||||
weight=module.weight,
|
||||
bias_=module.bias,
|
||||
*args,
|
||||
**kwargs)
|
||||
linear_1d = FusedLinear1D_Col(
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
bias=bias,
|
||||
device=device,
|
||||
process_group=process_group,
|
||||
weight=module.weight,
|
||||
bias_=module.bias,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# # TODO: copy the sharded weights
|
||||
# with torch.no_grad():
|
||||
@@ -693,9 +713,11 @@ class FusedLinear1D_Col(ParallelModule):
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
assert input_.shape[-1] == self.weight.shape[-1], \
|
||||
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1])
|
||||
assert (
|
||||
input_.shape[-1] == self.weight.shape[-1]
|
||||
), "Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.".format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1]
|
||||
)
|
||||
# Set up backprop all-reduce.
|
||||
# input_parallel = reduce_backward(input_, self.process_group)
|
||||
input_parallel = input_
|
||||
|
Reference in New Issue
Block a user