mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -41,7 +41,7 @@ def _reduce(input_, pg: ProcessGroup):
|
||||
# skip if only one rank involved
|
||||
if pg.tp_world_size() == 1:
|
||||
return input_
|
||||
assert input_.device.type == 'cuda'
|
||||
assert input_.device.type == "cuda"
|
||||
group = pg.tp_process_group()
|
||||
dist.all_reduce(input_, group=group)
|
||||
|
||||
@@ -56,9 +56,10 @@ def _split(input_, pg: ProcessGroup, dim=-1):
|
||||
|
||||
# Split along last dimension.
|
||||
dim_size = input_.size(dim)
|
||||
assert dim_size % world_size == 0, \
|
||||
f'The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), ' \
|
||||
f'cannot split tensor evenly'
|
||||
assert dim_size % world_size == 0, (
|
||||
f"The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), "
|
||||
f"cannot split tensor evenly"
|
||||
)
|
||||
|
||||
tensor_list = torch.split(input_, dim_size // world_size, dim=dim)
|
||||
rank = pg.tp_local_rank()
|
||||
@@ -77,7 +78,7 @@ def _gather(input_, pg: ProcessGroup, dim=-1):
|
||||
rank = pg.tp_local_rank()
|
||||
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
tensor_list[rank] = input_
|
||||
assert input_.device.type == 'cuda'
|
||||
assert input_.device.type == "cuda"
|
||||
group = pg.tp_process_group()
|
||||
torch.distributed.all_gather(tensor_list, input_, group=group)
|
||||
|
||||
@@ -203,7 +204,7 @@ def _all_to_all(x: torch.Tensor, pg: ProcessGroup, scatter_dim: int, gather_dim:
|
||||
return x
|
||||
|
||||
# TODO: enabling mpi backend to support CPU all_to_all
|
||||
assert x.device.type == 'cuda', f"Currently, the collective function dual_all_to_all only supports nccl backend"
|
||||
assert x.device.type == "cuda", f"Currently, the collective function dual_all_to_all only supports nccl backend"
|
||||
|
||||
shapes = list(x.size())
|
||||
shapes[scatter_dim] = shapes[scatter_dim] // world_size
|
||||
@@ -216,7 +217,6 @@ def _all_to_all(x: torch.Tensor, pg: ProcessGroup, scatter_dim: int, gather_dim:
|
||||
|
||||
|
||||
class _DualAllToAll(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, pg, scatter_dim, gather_dim):
|
||||
ctx.scatter_dim = scatter_dim
|
||||
@@ -236,16 +236,14 @@ def dual_all_to_all(x, pg, scatter_dim: int, gather_dim: int):
|
||||
# table wise embedding shard
|
||||
|
||||
|
||||
def _all_to_all_for_tablewise(x: torch.Tensor,
|
||||
pg: ProcessGroup,
|
||||
scatter_strides: List[int],
|
||||
gather_strides: List[int],
|
||||
forward=True) -> torch.Tensor:
|
||||
def _all_to_all_for_tablewise(
|
||||
x: torch.Tensor, pg: ProcessGroup, scatter_strides: List[int], gather_strides: List[int], forward=True
|
||||
) -> torch.Tensor:
|
||||
world_size = pg.tp_world_size()
|
||||
rank = pg.tp_local_rank()
|
||||
if world_size == 1:
|
||||
return x
|
||||
assert x.device.type == 'cuda', f"Currently, the collective function dual_all_to_all only supports nccl backend"
|
||||
assert x.device.type == "cuda", f"Currently, the collective function dual_all_to_all only supports nccl backend"
|
||||
if forward:
|
||||
scatter_list = list(x.split(scatter_strides, 0))
|
||||
gather_list = [
|
||||
@@ -266,7 +264,6 @@ def _all_to_all_for_tablewise(x: torch.Tensor,
|
||||
|
||||
|
||||
class _DualAllToAllForTablewise(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, pg, scatter_strides, gather_strides):
|
||||
ctx.pg = pg
|
||||
@@ -276,8 +273,12 @@ class _DualAllToAllForTablewise(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad):
|
||||
return _all_to_all_for_tablewise(grad, ctx.pg, ctx.gather_strides, ctx.scatter_strides,
|
||||
forward=False), None, None, None
|
||||
return (
|
||||
_all_to_all_for_tablewise(grad, ctx.pg, ctx.gather_strides, ctx.scatter_strides, forward=False),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def dual_all_to_all_tablewise(x, pg, scatter_strides, gather_strides):
|
||||
|
Reference in New Issue
Block a user