mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-29 21:49:54 +00:00
[model checkpoint] add gloo groups for cpu tensor communication (#589)
This commit is contained in:
@@ -34,6 +34,7 @@ class ParallelContext(metaclass=SingletonMeta):
|
||||
self._local_ranks = dict()
|
||||
self._world_sizes = dict()
|
||||
self._groups = dict()
|
||||
self._cpu_groups = dict()
|
||||
self._ranks_in_group = dict()
|
||||
|
||||
# load config from file
|
||||
@@ -290,6 +291,32 @@ class ParallelContext(metaclass=SingletonMeta):
|
||||
self._check_parallel_mode(parallel_mode)
|
||||
self._groups[parallel_mode] = group
|
||||
|
||||
def get_cpu_group(self, parallel_mode: ParallelMode):
|
||||
"""Returns the Gloo group of the current device for `parallel_mode`.
|
||||
|
||||
:param parallel_mode: The chosen parallel mode
|
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
|
||||
of :class:`colossalai.context.ParallelMode`
|
||||
:return: The group of the current device for `parallel_mode`
|
||||
:rtype: torch.distributed.ProcessGroup
|
||||
"""
|
||||
self._check_parallel_mode(parallel_mode)
|
||||
return self._cpu_groups[parallel_mode]
|
||||
|
||||
def add_cpu_group(self, parallel_mode: ParallelMode, group: dist.ProcessGroup):
|
||||
"""Adds the Gloo group of the current device for `parallel_mode`.
|
||||
|
||||
:param parallel_mode: The chosen parallel mode
|
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
||||
:param group: The group to be added
|
||||
:type group: torch.distributed.ProcessGroup
|
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
|
||||
of :class:`colossalai.context.ParallelMode`
|
||||
"""
|
||||
self._check_parallel_mode(parallel_mode)
|
||||
self._cpu_groups[parallel_mode] = group
|
||||
|
||||
def get_ranks_in_group(self, parallel_mode: ParallelMode):
|
||||
"""Returns the rank of the current device for `parallel_mode` in the group.
|
||||
|
||||
@@ -335,13 +362,16 @@ class ParallelContext(metaclass=SingletonMeta):
|
||||
dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method)
|
||||
|
||||
# None will give the default global process group for pytorch dist operations
|
||||
self._register_dist(rank, world_size, None, list(range(world_size)), ParallelMode.GLOBAL)
|
||||
ranks = list(range(world_size))
|
||||
cpu_group = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else None
|
||||
self._register_dist(rank, world_size, None, cpu_group, ranks, ParallelMode.GLOBAL)
|
||||
self.add_global_rank(ParallelMode.GLOBAL, rank)
|
||||
|
||||
def _register_dist(self, local_rank, world_size, process_group, ranks_in_group, mode):
|
||||
def _register_dist(self, local_rank, world_size, process_group, cpu_group, ranks_in_group, mode):
|
||||
self.add_local_rank(mode, local_rank)
|
||||
self.add_world_size(mode, world_size)
|
||||
self.add_group(mode, process_group)
|
||||
self.add_cpu_group(mode, cpu_group)
|
||||
self.add_ranks_in_group(mode, ranks_in_group)
|
||||
|
||||
def check_sanity(self):
|
||||
|
||||
Reference in New Issue
Block a user