[Tensor] 1d row embedding (#1075)

* Add CPU 1d row embedding

* polish
This commit is contained in:
Ziyue Jiang
2022-06-08 12:04:59 +08:00
committed by GitHub
parent d66ffb4df4
commit 0653c63eaa
2 changed files with 12 additions and 10 deletions

View File

@@ -32,7 +32,8 @@ def _reduce(input_, parallel_mode):
# skip if only one rank involved
if gpc.get_world_size(parallel_mode) == 1:
return input_
dist.all_reduce(input_, group=gpc.get_group(parallel_mode))
group = gpc.get_cpu_group(parallel_mode) if input_.device.type == "cpu" else gpc.get_group(parallel_mode)
dist.all_reduce(input_, group=group)
return input_
@@ -66,7 +67,8 @@ def _gather(input_, parallel_mode, dim=-1):
rank = gpc.get_local_rank(parallel_mode)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=gpc.get_group(parallel_mode))
group = gpc.get_cpu_group(parallel_mode) if input_.device.type == "cpu" else gpc.get_group(parallel_mode)
torch.distributed.all_gather(tensor_list, input_, group=group)
# concat
output = torch.cat(tensor_list, dim=dim).contiguous()