mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-06 08:04:16 +00:00
[Tensor] add cpu group to ddp (#1200)
This commit is contained in:
@@ -54,14 +54,11 @@ class ColoDDP(torch.nn.Module):
|
||||
module (torch.nn.Module): Module to apply DDP.
|
||||
process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses.
|
||||
If it's None, the default data parallel group will be used. Defaults to None.
|
||||
cpu_process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses for those parameters on CPU.
|
||||
If it's None, the default CPU data parallel group will be used. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
module: torch.nn.Module,
|
||||
process_group: ColoProcessGroup,
|
||||
cpu_process_group: Optional[dist.ProcessGroup] = None,
|
||||
bucket_cap_mb: int = 25,
|
||||
rebuild_bucket: bool = True) -> None:
|
||||
assert not isinstance(module, ColoDDP)
|
||||
@@ -70,8 +67,9 @@ class ColoDDP(torch.nn.Module):
|
||||
self.comm_stream: torch.cuda.Stream = torch.cuda.Stream()
|
||||
assert process_group
|
||||
|
||||
self.process_group = process_group.dp_process_group()
|
||||
self.dp_world_size = self.process_group.size()
|
||||
self.process_group = process_group
|
||||
self.dp_world_size = self.process_group.dp_world_size()
|
||||
|
||||
self.reducer = Reducer(bucket_cap_mb)
|
||||
self.rebuild_bucket = rebuild_bucket
|
||||
for p in module.parameters():
|
||||
@@ -112,7 +110,7 @@ class ColoDDP(torch.nn.Module):
|
||||
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(self.comm_stream):
|
||||
self.reducer.all_reduce_async(grad,
|
||||
group=self.process_group,
|
||||
group=self.process_group.dp_process_group(),
|
||||
callback_fn=partial(self._save_grad, p))
|
||||
grad.record_stream(self.comm_stream)
|
||||
else:
|
||||
@@ -121,8 +119,8 @@ class ColoDDP(torch.nn.Module):
|
||||
|
||||
else:
|
||||
#TODO(jiaruifang) fixme
|
||||
raise NotImplementedError
|
||||
dist.all_reduce(grad, group=self.cpu_process_group)
|
||||
self.process_group.set_cpu_groups()
|
||||
dist.all_reduce(grad, group=self.process_group.cpu_dp_process_group())
|
||||
return grad
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user