[chore] handle non member group

This commit is contained in:
hxwang 2024-07-05 07:03:45 +00:00 committed by Hongxin Liu
parent a249e71946
commit 0fad23c691

View File

@ -7,6 +7,7 @@ from typing import Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import GroupMember
def prod(nums: List[int]) -> int: def prod(nums: List[int]) -> int:
@ -47,7 +48,7 @@ class ProcessGroupMesh:
self._shape = size self._shape = size
self._rank = dist.get_rank() self._rank = dist.get_rank()
self._coord = ProcessGroupMesh.unravel(self._rank, self._shape) self._coord = ProcessGroupMesh.unravel(self._rank, self._shape)
self._ranks_to_group: Dict[Tuple[int, ...], ProcessGroup] = {} self._ranks_to_group: Dict[Tuple[int, ...], Union[ProcessGroup, GroupMember]] = {}
self._group_to_ranks: Dict[ProcessGroup, Tuple[int, ...]] = {} self._group_to_ranks: Dict[ProcessGroup, Tuple[int, ...]] = {}
def destroy_mesh_process_groups(self): def destroy_mesh_process_groups(self):
@ -150,7 +151,8 @@ class ProcessGroupMesh:
if tuple(ranks_in_group) not in self._ranks_to_group: if tuple(ranks_in_group) not in self._ranks_to_group:
group = dist.new_group(ranks_in_group, backend=backend) group = dist.new_group(ranks_in_group, backend=backend)
self._ranks_to_group[tuple(ranks_in_group)] = group self._ranks_to_group[tuple(ranks_in_group)] = group
self._group_to_ranks[group] = tuple(ranks_in_group) if group is not GroupMember.NON_GROUP_MEMBER:
self._group_to_ranks[group] = tuple(ranks_in_group)
return self._ranks_to_group[tuple(ranks_in_group)] return self._ranks_to_group[tuple(ranks_in_group)]
def get_ranks_in_group(self, group: ProcessGroup) -> List[int]: def get_ranks_in_group(self, group: ProcessGroup) -> List[int]: