mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 21:40:02 +00:00
[FAW] parallel FreqAwareEmbedding (#1424)
This commit is contained in:
@@ -195,3 +195,39 @@ def split_forward_gather_backward(input_, process_group, dim):
|
||||
|
||||
def gather_forward_split_backward(input_, process_group, dim):
|
||||
return _GatherForwardSplitBackward.apply(input_, process_group, dim)
|
||||
|
||||
|
||||
def _all_to_all(x: torch.Tensor, pg: ProcessGroup, scatter_dim: int, gather_dim: int) -> torch.Tensor:
|
||||
world_size = pg.tp_world_size()
|
||||
if world_size == 1:
|
||||
return x
|
||||
|
||||
# TODO: enabling mpi backend to support CPU all_to_all
|
||||
assert x.device.type == 'cuda', f"Currently, the collective function dual_all_to_all only supports nccl backend"
|
||||
|
||||
shapes = list(x.size())
|
||||
shapes[scatter_dim] = shapes[scatter_dim] // world_size
|
||||
|
||||
scatter_list = [each.contiguous() for each in torch.tensor_split(x, world_size, scatter_dim)]
|
||||
gather_list = [torch.empty(*shapes, dtype=x.dtype, device=x.device) for _ in range(world_size)]
|
||||
torch.distributed.all_to_all(gather_list, scatter_list, group=pg.tp_process_group())
|
||||
|
||||
return torch.cat(gather_list, dim=gather_dim).contiguous()
|
||||
|
||||
|
||||
class _DualAllToAll(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, pg, scatter_dim, gather_dim):
|
||||
ctx.scatter_dim = scatter_dim
|
||||
ctx.gather_dim = gather_dim
|
||||
ctx.pg = pg
|
||||
return _all_to_all(x, pg, scatter_dim, gather_dim)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad):
|
||||
return _all_to_all(grad, ctx.pg, ctx.gather_dim, ctx.scatter_dim), None, None, None
|
||||
|
||||
|
||||
def dual_all_to_all(x, pg, scatter_dim: int, gather_dim: int):
|
||||
return _DualAllToAll.apply(x, pg, scatter_dim, gather_dim)
|
||||
|
Reference in New Issue
Block a user