mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-23 10:30:03 +00:00
[hotfix] ZeroDDP use new process group (#1333)
* process group supports getting ranks in group * chunk mgr receives a process group * update unit test * fix unit tests
This commit is contained in:
@@ -118,7 +118,7 @@ class ColoDDP(torch.nn.Module):
|
||||
return empty_grad
|
||||
|
||||
else:
|
||||
#TODO(jiaruifang) fixme
|
||||
# TODO(jiaruifang) fixme
|
||||
self.process_group.set_cpu_groups()
|
||||
dist.all_reduce(grad, group=self.process_group.cpu_dp_process_group())
|
||||
return grad
|
||||
@@ -191,11 +191,8 @@ class ZeroDDP(ColoDDP):
|
||||
For more details, see the API reference of ``GeminiManager``.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
module: torch.nn.Module,
|
||||
gemini_manager: GeminiManager,
|
||||
process_group: Optional[ColoProcessGroup] = None) -> None:
|
||||
super().__init__(module.half(), process_group=process_group)
|
||||
def __init__(self, module: torch.nn.Module, gemini_manager: GeminiManager) -> None:
|
||||
super().__init__(module.half(), process_group=gemini_manager.chunk_manager.process_group)
|
||||
self.gemini_manager = gemini_manager
|
||||
self.chunk_manager = gemini_manager.chunk_manager
|
||||
self.param_op_hook = ZeROHookV2(gemini_manager)
|
||||
|
Reference in New Issue
Block a user