mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[zero] solve hang
This commit is contained in:
@@ -137,7 +137,7 @@ class ProcessGroupMesh:
|
||||
assert mode in ["raise", "wrap", "clip"]
|
||||
return int(np.ravel_multi_index(coord, shape, mode))
|
||||
|
||||
def get_group(self, ranks_in_group: List[int], backend: Optional[str] = None) -> ProcessGroup:
|
||||
def _get_group(self, ranks_in_group: List[int], backend: Optional[str] = None) -> ProcessGroup:
|
||||
"""Get the process group with the given ranks. It the process group doesn't exist, it will be created.
|
||||
|
||||
Args:
|
||||
@@ -240,7 +240,7 @@ class ProcessGroupMesh:
|
||||
for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
|
||||
coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis)
|
||||
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
|
||||
group = self.get_group(ranks_in_group, backend=backend)
|
||||
group = self._get_group(ranks_in_group, backend=backend)
|
||||
if self._rank in ranks_in_group:
|
||||
target_group = group
|
||||
return target_group
|
||||
|
Reference in New Issue
Block a user