[model checkpoint] add gloo groups for cpu tensor communication (#589)

This commit is contained in:
アマデウス
2022-04-01 10:15:52 +08:00
committed by GitHub
parent 54e688b623
commit 297b8baae2
10 changed files with 98 additions and 23 deletions

View File

@@ -34,6 +34,7 @@ class ParallelContext(metaclass=SingletonMeta):
self._local_ranks = dict()
self._world_sizes = dict()
self._groups = dict()
self._cpu_groups = dict()
self._ranks_in_group = dict()
# load config from file
@@ -290,6 +291,32 @@ class ParallelContext(metaclass=SingletonMeta):
self._check_parallel_mode(parallel_mode)
self._groups[parallel_mode] = group
def get_cpu_group(self, parallel_mode: ParallelMode):
"""Returns the Gloo group of the current device for `parallel_mode`.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`
:return: The group of the current device for `parallel_mode`
:rtype: torch.distributed.ProcessGroup
"""
self._check_parallel_mode(parallel_mode)
return self._cpu_groups[parallel_mode]
def add_cpu_group(self, parallel_mode: ParallelMode, group: dist.ProcessGroup):
"""Adds the Gloo group of the current device for `parallel_mode`.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:param group: The group to be added
:type group: torch.distributed.ProcessGroup
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`
"""
self._check_parallel_mode(parallel_mode)
self._cpu_groups[parallel_mode] = group
def get_ranks_in_group(self, parallel_mode: ParallelMode):
"""Returns the rank of the current device for `parallel_mode` in the group.
@@ -335,13 +362,16 @@ class ParallelContext(metaclass=SingletonMeta):
dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method)
# None will give the default global process group for pytorch dist operations
self._register_dist(rank, world_size, None, list(range(world_size)), ParallelMode.GLOBAL)
ranks = list(range(world_size))
cpu_group = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else None
self._register_dist(rank, world_size, None, cpu_group, ranks, ParallelMode.GLOBAL)
self.add_global_rank(ParallelMode.GLOBAL, rank)
def _register_dist(self, local_rank, world_size, process_group, ranks_in_group, mode):
def _register_dist(self, local_rank, world_size, process_group, cpu_group, ranks_in_group, mode):
self.add_local_rank(mode, local_rank)
self.add_world_size(mode, world_size)
self.add_group(mode, process_group)
self.add_cpu_group(mode, cpu_group)
self.add_ranks_in_group(mode, ranks_in_group)
def check_sanity(self):