[Tensor] add cpu group to ddp (#1200)

This commit is contained in:
Jiarui Fang
2022-07-05 14:58:28 +08:00
committed by GitHub
parent f7878f465c
commit b5f25eb32a
4 changed files with 30 additions and 26 deletions

View File

@@ -54,14 +54,11 @@ class ColoDDP(torch.nn.Module):
module (torch.nn.Module): Module to apply DDP.
process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses.
If it's None, the default data parallel group will be used. Defaults to None.
cpu_process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses for those parameters on CPU.
If it's None, the default CPU data parallel group will be used. Defaults to None.
"""
def __init__(self,
module: torch.nn.Module,
process_group: ColoProcessGroup,
cpu_process_group: Optional[dist.ProcessGroup] = None,
bucket_cap_mb: int = 25,
rebuild_bucket: bool = True) -> None:
assert not isinstance(module, ColoDDP)
@@ -70,8 +67,9 @@ class ColoDDP(torch.nn.Module):
self.comm_stream: torch.cuda.Stream = torch.cuda.Stream()
assert process_group
self.process_group = process_group.dp_process_group()
self.dp_world_size = self.process_group.size()
self.process_group = process_group
self.dp_world_size = self.process_group.dp_world_size()
self.reducer = Reducer(bucket_cap_mb)
self.rebuild_bucket = rebuild_bucket
for p in module.parameters():
@@ -112,7 +110,7 @@ class ColoDDP(torch.nn.Module):
self.comm_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.comm_stream):
self.reducer.all_reduce_async(grad,
group=self.process_group,
group=self.process_group.dp_process_group(),
callback_fn=partial(self._save_grad, p))
grad.record_stream(self.comm_stream)
else:
@@ -121,8 +119,8 @@ class ColoDDP(torch.nn.Module):
else:
#TODO(jiaruifang) fixme
raise NotImplementedError
dist.all_reduce(grad, group=self.cpu_process_group)
self.process_group.set_cpu_groups()
dist.all_reduce(grad, group=self.process_group.cpu_dp_process_group())
return grad
@staticmethod