mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -33,7 +33,7 @@ from ._operation import (
|
||||
from .parallel_module import ParallelModule
|
||||
from .utils import create_randomizer_with_offset
|
||||
|
||||
__all__ = ['Linear1D_Col', 'Linear1D_Row']
|
||||
__all__ = ["Linear1D_Col", "Linear1D_Row"]
|
||||
|
||||
|
||||
class Linear1D_Col(ParallelModule):
|
||||
@@ -65,22 +65,24 @@ class Linear1D_Col(ParallelModule):
|
||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
device: torch.device = None,
|
||||
process_group: ProcessGroup = None,
|
||||
gather_output: bool = False,
|
||||
seq_parallel: bool = False,
|
||||
seq_parallel_dim: int = 1,
|
||||
overlap: torch.cuda.Stream = None,
|
||||
skip_bias_add: bool = False,
|
||||
weight: Optional[Parameter] = None,
|
||||
bias_: Optional[Parameter] = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
device: torch.device = None,
|
||||
process_group: ProcessGroup = None,
|
||||
gather_output: bool = False,
|
||||
seq_parallel: bool = False,
|
||||
seq_parallel_dim: int = 1,
|
||||
overlap: torch.cuda.Stream = None,
|
||||
skip_bias_add: bool = False,
|
||||
weight: Optional[Parameter] = None,
|
||||
bias_: Optional[Parameter] = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Keep input parameters
|
||||
@@ -95,7 +97,7 @@ class Linear1D_Col(ParallelModule):
|
||||
self.process_group = process_group
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError('cannot skip bias addition if bias is None')
|
||||
raise ValueError("cannot skip bias addition if bias is None")
|
||||
|
||||
# offset the seed with randomizer index and rank
|
||||
seed = torch.random.initial_seed()
|
||||
@@ -103,13 +105,13 @@ class Linear1D_Col(ParallelModule):
|
||||
|
||||
# sanity check
|
||||
if weight is not None:
|
||||
assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None'
|
||||
assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None"
|
||||
else:
|
||||
assert bias_ is None, 'bias_ must be None if weight is None'
|
||||
assert bias_ is None, "bias_ must be None if weight is None"
|
||||
|
||||
# Parameters.
|
||||
if weight is None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
|
||||
else:
|
||||
weight.data = weight.data.to(device=device, dtype=dtype)
|
||||
@@ -135,8 +137,9 @@ class Linear1D_Col(ParallelModule):
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
|
||||
**kwargs) -> ParallelModule:
|
||||
def from_native_module(
|
||||
module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||
) -> ParallelModule:
|
||||
r"""
|
||||
Convert a native PyTorch linear layer to a parallelized linear layer.
|
||||
"""
|
||||
@@ -149,8 +152,7 @@ class Linear1D_Col(ParallelModule):
|
||||
|
||||
# ensure only one process group is passed
|
||||
if isinstance(process_group, (list, tuple)):
|
||||
assert len(process_group) == 1, \
|
||||
f'Expected only one process group, got {len(process_group)}.'
|
||||
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
|
||||
process_group = process_group[0]
|
||||
|
||||
tp_size = dist.get_world_size(process_group)
|
||||
@@ -159,17 +161,20 @@ class Linear1D_Col(ParallelModule):
|
||||
|
||||
if out_features % tp_size != 0:
|
||||
raise ValueError(
|
||||
f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!")
|
||||
f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!"
|
||||
)
|
||||
|
||||
linear_1d = Linear1D_Col(in_features=in_features,
|
||||
out_features=out_features,
|
||||
bias=bias,
|
||||
device=device,
|
||||
process_group=process_group,
|
||||
weight=module.weight,
|
||||
bias_=module.bias,
|
||||
*args,
|
||||
**kwargs)
|
||||
linear_1d = Linear1D_Col(
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
bias=bias,
|
||||
device=device,
|
||||
process_group=process_group,
|
||||
weight=module.weight,
|
||||
bias_=module.bias,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return linear_1d
|
||||
|
||||
@@ -181,9 +186,11 @@ class Linear1D_Col(ParallelModule):
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
|
||||
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
assert input_.shape[-1] == self.weight.shape[-1], \
|
||||
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1])
|
||||
assert (
|
||||
input_.shape[-1] == self.weight.shape[-1]
|
||||
), "Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.".format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1]
|
||||
)
|
||||
|
||||
# Set up backprop all-reduce.
|
||||
input_parallel = input_
|
||||
@@ -191,9 +198,9 @@ class Linear1D_Col(ParallelModule):
|
||||
# Matrix multiply.
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
if self.seq_parallel:
|
||||
output_parallel = linear_gather_forward_reducescatter_backward(input_parallel, self.weight, bias,
|
||||
self.process_group, True,
|
||||
self.seq_parallel_dim, self.overlap)
|
||||
output_parallel = linear_gather_forward_reducescatter_backward(
|
||||
input_parallel, self.weight, bias, self.process_group, True, self.seq_parallel_dim, self.overlap
|
||||
)
|
||||
else:
|
||||
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
|
||||
|
||||
@@ -210,7 +217,7 @@ class Linear1D_Col(ParallelModule):
|
||||
|
||||
|
||||
class Linear1D_Row(ParallelModule):
|
||||
r""" Linear layer with row parallelism
|
||||
r"""Linear layer with row parallelism
|
||||
|
||||
Args:
|
||||
in_features (int): size of each input sample.
|
||||
@@ -231,22 +238,24 @@ class Linear1D_Row(ParallelModule):
|
||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
device: torch.device = None,
|
||||
process_group: ProcessGroup = None,
|
||||
seq_parallel: bool = False,
|
||||
seq_parallel_dim: int = 1,
|
||||
parallel_input: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
weight: Optional[Parameter] = None,
|
||||
bias_: Optional[Parameter] = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
stream_chunk_num: int = 1):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
device: torch.device = None,
|
||||
process_group: ProcessGroup = None,
|
||||
seq_parallel: bool = False,
|
||||
seq_parallel_dim: int = 1,
|
||||
parallel_input: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
weight: Optional[Parameter] = None,
|
||||
bias_: Optional[Parameter] = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
stream_chunk_num: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.stream_chunk_num = stream_chunk_num
|
||||
@@ -262,7 +271,7 @@ class Linear1D_Row(ParallelModule):
|
||||
self.num_partitions = dist.get_world_size(self.process_group)
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError('cannot skip bias addition if bias is None')
|
||||
raise ValueError("cannot skip bias addition if bias is None")
|
||||
|
||||
# offset the seed with randomizer index and rank
|
||||
seed = torch.random.initial_seed()
|
||||
@@ -270,14 +279,14 @@ class Linear1D_Row(ParallelModule):
|
||||
|
||||
# sanity check
|
||||
if weight is not None:
|
||||
assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None'
|
||||
assert not bias or bias_ is not None, "bias_ must be provided if bias is True when weight is not None"
|
||||
else:
|
||||
assert bias_ is None, 'bias_ must be None if weight is None'
|
||||
assert bias_ is None, "bias_ must be None if weight is None"
|
||||
|
||||
# Parameters.
|
||||
if weight is None:
|
||||
# Initialize weight.
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
|
||||
else:
|
||||
weight.data = weight.data.to(device=device, dtype=dtype)
|
||||
@@ -304,8 +313,9 @@ class Linear1D_Row(ParallelModule):
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
|
||||
**kwargs) -> ParallelModule:
|
||||
def from_native_module(
|
||||
module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||
) -> ParallelModule:
|
||||
r"""
|
||||
Convert a native PyTorch linear layer to a parallelized linear layer.
|
||||
"""
|
||||
@@ -318,8 +328,7 @@ class Linear1D_Row(ParallelModule):
|
||||
|
||||
# ensure only one process group is passed
|
||||
if isinstance(process_group, (list, tuple)):
|
||||
assert len(process_group) == 1, \
|
||||
f'Expected only one process group, got {len(process_group)}.'
|
||||
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
|
||||
process_group = process_group[0]
|
||||
|
||||
tp_size = dist.get_world_size(process_group)
|
||||
@@ -328,17 +337,20 @@ class Linear1D_Row(ParallelModule):
|
||||
|
||||
if in_features % tp_size != 0:
|
||||
raise ValueError(
|
||||
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!")
|
||||
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!"
|
||||
)
|
||||
|
||||
linear_1d = Linear1D_Row(in_features=in_features,
|
||||
out_features=out_features,
|
||||
bias=bias,
|
||||
device=device,
|
||||
process_group=process_group,
|
||||
weight=module.weight,
|
||||
bias_=module.bias,
|
||||
*args,
|
||||
**kwargs)
|
||||
linear_1d = Linear1D_Row(
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
bias=bias,
|
||||
device=device,
|
||||
process_group=process_group,
|
||||
weight=module.weight,
|
||||
bias_=module.bias,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return linear_1d
|
||||
|
||||
@@ -366,14 +378,18 @@ class Linear1D_Row(ParallelModule):
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
# Set up backprop all-reduce.
|
||||
if self.parallel_input:
|
||||
assert input_.shape[-1] == self.weight.shape[-1], \
|
||||
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1])
|
||||
assert (
|
||||
input_.shape[-1] == self.weight.shape[-1]
|
||||
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1]
|
||||
)
|
||||
input_ = input_
|
||||
else:
|
||||
assert divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1], \
|
||||
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions)
|
||||
assert (
|
||||
divide(input_.shape[-1], self.num_partitions) == self.weight.shape[-1]
|
||||
), "Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.".format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1] * self.num_partitions
|
||||
)
|
||||
input_ = split_forward_gather_backward(input_, dim=-1, process_group=self.process_group)
|
||||
|
||||
if self.stream_chunk_num > 1:
|
||||
@@ -384,9 +400,9 @@ class Linear1D_Row(ParallelModule):
|
||||
handle_list = []
|
||||
for i in range(self.stream_chunk_num):
|
||||
output_parallel_list[i] = F.linear(input_, self.weight_list[i])
|
||||
handle = torch.distributed.all_reduce(output_parallel_list[i],
|
||||
group=self.process_group,
|
||||
async_op=True)
|
||||
handle = torch.distributed.all_reduce(
|
||||
output_parallel_list[i], group=self.process_group, async_op=True
|
||||
)
|
||||
handle_list.append(handle)
|
||||
# output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D)
|
||||
for handle in handle_list:
|
||||
@@ -395,8 +411,9 @@ class Linear1D_Row(ParallelModule):
|
||||
else:
|
||||
output_parallel = F.linear(input_, self.weight)
|
||||
if self.seq_parallel:
|
||||
output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group,
|
||||
self.seq_parallel_dim)
|
||||
output = linear_reducescatter_forward_gather_backward(
|
||||
output_parallel, self.process_group, self.seq_parallel_dim
|
||||
)
|
||||
else:
|
||||
output = reduce_forward(output_parallel, self.process_group)
|
||||
|
||||
|
Reference in New Issue
Block a user