[hotfix] Fix the bug where process groups were not being properly released. (#4940)

* Fix the bug where process groups were not being properly released.

* test

* Revert "test"

This reverts commit 479900c139.
This commit is contained in:
littsk
2023-10-31 14:47:30 +08:00
committed by GitHub
parent 4f0234f236
commit be82b5d4ca
2 changed files with 69 additions and 2 deletions

View File

@@ -4,6 +4,7 @@ from dataclasses import dataclass
from typing import Dict, List, Tuple
import torch
import torch.distributed as dist
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.tensor.d_tensor.comm_spec import *
@@ -438,11 +439,58 @@ class LayoutConverter(metaclass=SingletonMeta):
MAX_TRANSFORM_STEPS = 20
total_steps = 0
transform_path = []
comm_action_sequence = []
comm_action_sequence: List[CommSpec] = []
spec_pairs = (str(source_spec.sharding_sequence), str(target_spec.sharding_sequence))
if spec_pairs in self.cached_solution:
return self.cached_solution[spec_pairs]
# Solution Cache hit
def _group_alive_check(cached_comm_action_sequence):
r"""
Check if the process groups required for sharding have been deleted by torch.distributed.destroy_process_group method.
If not deleted, return True; otherwise, return False.
Args:
cached_comm_action_sequence (List[CommSpec]): A list of communication specifications representing actions.
Returns:
bool: True if all process groups are still registered, False if at least one has been deleted.
Raises:
RuntimeError: If there is an error while checking the status of a process group.
"""
# Collect all process groups used in communication actions from the cached sequence
used_process_groups = [
pg for comm_spec in cached_comm_action_sequence for pg in comm_spec.process_group_dict.values()
]
# Check if each process group is still alive
for process_group in used_process_groups:
try:
dist.get_rank(process_group)
except RuntimeError as e:
# If the group is not registered, it means it has been deleted
if str(e) == (
f"Group {process_group} is not registered, please create group with torch.distributed.new_group API"
):
return False
elif str(e) == "The given group does not exist":
return False
else:
# Re-raise the exception if it's not related to group deletion
raise e
# All process groups are alive
return True
cached_transform_path, cached_comm_action_sequence = self.cached_solution[spec_pairs]
if _group_alive_check(cached_comm_action_sequence):
# If all process groups have not been deleted, the cache is valid
return cached_transform_path, cached_comm_action_sequence
else:
# If at least one process group has been deleted, the cache is invalid, so delete it
del self.cached_solution[spec_pairs]
# We do nothing if the sharding spec is all the same.
if source_spec.spec_diff(target_spec) == 0: