[hotfix] ZeroDDP use new process group (#1333)

* process group supports getting ranks in group

* chunk mgr receives a process group

* update unit test

* fix unit tests
This commit is contained in:
ver217
2022-07-18 14:14:52 +08:00
committed by GitHub
parent 11d1436a67
commit 0c51ff2c13
9 changed files with 49 additions and 43 deletions

View File

@@ -118,7 +118,7 @@ class ColoDDP(torch.nn.Module):
return empty_grad
else:
#TODO(jiaruifang) fixme
# TODO(jiaruifang) fixme
self.process_group.set_cpu_groups()
dist.all_reduce(grad, group=self.process_group.cpu_dp_process_group())
return grad
@@ -191,11 +191,8 @@ class ZeroDDP(ColoDDP):
For more details, see the API reference of ``GeminiManager``.
"""
def __init__(self,
module: torch.nn.Module,
gemini_manager: GeminiManager,
process_group: Optional[ColoProcessGroup] = None) -> None:
super().__init__(module.half(), process_group=process_group)
def __init__(self, module: torch.nn.Module, gemini_manager: GeminiManager) -> None:
super().__init__(module.half(), process_group=gemini_manager.chunk_manager.process_group)
self.gemini_manager = gemini_manager
self.chunk_manager = gemini_manager.chunk_manager
self.param_op_hook = ZeROHookV2(gemini_manager)