[FAW] parallel FreqAwareEmbedding (#1424)

This commit is contained in:
Jiarui Fang
2022-08-10 13:44:30 +08:00
committed by GitHub
parent 0d212183c4
commit cb98cf5558
4 changed files with 272 additions and 2 deletions

View File

@@ -195,3 +195,39 @@ def split_forward_gather_backward(input_, process_group, dim):
def gather_forward_split_backward(input_, process_group, dim):
return _GatherForwardSplitBackward.apply(input_, process_group, dim)
def _all_to_all(x: torch.Tensor, pg: ProcessGroup, scatter_dim: int, gather_dim: int) -> torch.Tensor:
world_size = pg.tp_world_size()
if world_size == 1:
return x
# TODO: enabling mpi backend to support CPU all_to_all
assert x.device.type == 'cuda', f"Currently, the collective function dual_all_to_all only supports nccl backend"
shapes = list(x.size())
shapes[scatter_dim] = shapes[scatter_dim] // world_size
scatter_list = [each.contiguous() for each in torch.tensor_split(x, world_size, scatter_dim)]
gather_list = [torch.empty(*shapes, dtype=x.dtype, device=x.device) for _ in range(world_size)]
torch.distributed.all_to_all(gather_list, scatter_list, group=pg.tp_process_group())
return torch.cat(gather_list, dim=gather_dim).contiguous()
class _DualAllToAll(torch.autograd.Function):
@staticmethod
def forward(ctx, x, pg, scatter_dim, gather_dim):
ctx.scatter_dim = scatter_dim
ctx.gather_dim = gather_dim
ctx.pg = pg
return _all_to_all(x, pg, scatter_dim, gather_dim)
@staticmethod
def backward(ctx, grad):
return _all_to_all(grad, ctx.pg, ctx.gather_dim, ctx.scatter_dim), None, None, None
def dual_all_to_all(x, pg, scatter_dim: int, gather_dim: int):
return _DualAllToAll.apply(x, pg, scatter_dim, gather_dim)