mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -7,13 +7,12 @@ from colossalai.logging import get_dist_logger
|
||||
|
||||
|
||||
class PyTorchProcessGroupDict(metaclass=SingletonMeta):
|
||||
|
||||
def __init__(self):
|
||||
# distributed settings
|
||||
# use this dict to record all Pytorch ProcessGroups
|
||||
self.dict = {}
|
||||
# set a distributed logger
|
||||
self.logger = get_dist_logger('ProcessGroup')
|
||||
self.logger = get_dist_logger("ProcessGroup")
|
||||
|
||||
def log_pg_init(self, rank_list: List[int], backend: str):
|
||||
str_list = ["Pytorch ProcessGroup Init:"]
|
||||
@@ -21,9 +20,8 @@ class PyTorchProcessGroupDict(metaclass=SingletonMeta):
|
||||
str_list.append(f"ranks: {rank_list}")
|
||||
self.logger.info("\n\t".join(str_list), ranks=[0])
|
||||
|
||||
def get(self, rank_list: List[int], backend: str = 'nccl'):
|
||||
"""Reuse Pytorch ProcessGroup when such a group is initialized
|
||||
"""
|
||||
def get(self, rank_list: List[int], backend: str = "nccl"):
|
||||
"""Reuse Pytorch ProcessGroup when such a group is initialized"""
|
||||
# we need to convert the passed list to a tuple
|
||||
# since List is unhashable
|
||||
processgroup_key = (backend, tuple(rank_list))
|
||||
@@ -51,11 +49,13 @@ class ProcessGroup:
|
||||
dp_degree: Optional[int], data parallelism degree. How many processes are inside a dp process group. . default None means len(ranks).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
rank: Optional[int] = None,
|
||||
ranks: Optional[List[int]] = None,
|
||||
tp_degree: Optional[int] = None,
|
||||
dp_degree: Optional[int] = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
rank: Optional[int] = None,
|
||||
ranks: Optional[List[int]] = None,
|
||||
tp_degree: Optional[int] = None,
|
||||
dp_degree: Optional[int] = None,
|
||||
) -> None:
|
||||
if not torch.distributed.is_initialized():
|
||||
self.is_init = False
|
||||
return
|
||||
@@ -64,13 +64,13 @@ class ProcessGroup:
|
||||
|
||||
self._rank = torch.distributed.get_rank()
|
||||
if rank is not None:
|
||||
assert self._rank == rank # make sure that the global rank is correct
|
||||
assert self._rank == rank # make sure that the global rank is correct
|
||||
|
||||
if ranks is None:
|
||||
self._rank_list = list(range(torch.distributed.get_world_size()))
|
||||
else:
|
||||
self._rank_list = ranks
|
||||
self._rank_list.sort() # ensure that the list is in order
|
||||
self._rank_list.sort() # ensure that the list is in order
|
||||
|
||||
self._world_size = len(self._rank_list)
|
||||
|
||||
@@ -79,31 +79,36 @@ class ProcessGroup:
|
||||
self._tp_degree = 1
|
||||
elif dp_degree and not tp_degree:
|
||||
self._dp_degree = dp_degree
|
||||
assert self._world_size % self._dp_degree == 0, f"DP degree {dp_degree} should be divisible by {self._world_size} hen DP degree is None"
|
||||
assert (
|
||||
self._world_size % self._dp_degree == 0
|
||||
), f"DP degree {dp_degree} should be divisible by {self._world_size} hen DP degree is None"
|
||||
self._tp_degree = self._world_size // dp_degree
|
||||
elif not dp_degree and tp_degree:
|
||||
self._tp_degree = tp_degree
|
||||
assert self._world_size % self._tp_degree == 0, f"TP degree {tp_degree} should be divisible by {self._world_size} when DP degree is None"
|
||||
assert (
|
||||
self._world_size % self._tp_degree == 0
|
||||
), f"TP degree {tp_degree} should be divisible by {self._world_size} when DP degree is None"
|
||||
self._dp_degree = self._world_size // tp_degree
|
||||
else:
|
||||
self._dp_degree = dp_degree
|
||||
self._tp_degree = tp_degree
|
||||
assert self._dp_degree * self._tp_degree == self._world_size, \
|
||||
f"the world size {self._world_size} should equals to the product of DP degree {self._dp_degree}" \
|
||||
assert self._dp_degree * self._tp_degree == self._world_size, (
|
||||
f"the world size {self._world_size} should equals to the product of DP degree {self._dp_degree}"
|
||||
f"and TP degree {self._tp_degree}"
|
||||
)
|
||||
|
||||
self._tp_rank_list = None
|
||||
self._dp_rank_list = None
|
||||
|
||||
for i in range(self._dp_degree):
|
||||
i_tp_list = [self._rank_list[i * self._tp_degree + j] for j in range(self._tp_degree)]
|
||||
PYTORCHPGDICT_.get(i_tp_list, 'nccl')
|
||||
PYTORCHPGDICT_.get(i_tp_list, "nccl")
|
||||
if self._rank in i_tp_list:
|
||||
self._tp_rank_list = i_tp_list
|
||||
|
||||
for j in range(self._tp_degree):
|
||||
j_dp_list = [self._rank_list[i * self._tp_degree + j] for i in range(self._dp_degree)]
|
||||
PYTORCHPGDICT_.get(j_dp_list, 'nccl')
|
||||
PYTORCHPGDICT_.get(j_dp_list, "nccl")
|
||||
if self._rank in j_dp_list:
|
||||
self._dp_rank_list = j_dp_list
|
||||
|
||||
@@ -119,11 +124,11 @@ class ProcessGroup:
|
||||
|
||||
for i in range(self._dp_degree):
|
||||
i_tp_list = [self._rank_list[i * self._tp_degree + j] for j in range(self._tp_degree)]
|
||||
PYTORCHPGDICT_.get(i_tp_list, 'gloo')
|
||||
PYTORCHPGDICT_.get(i_tp_list, "gloo")
|
||||
|
||||
for j in range(self._tp_degree):
|
||||
j_dp_list = [self._rank_list[i * self._tp_degree + j] for i in range(self._dp_degree)]
|
||||
PYTORCHPGDICT_.get(j_dp_list, 'gloo')
|
||||
PYTORCHPGDICT_.get(j_dp_list, "gloo")
|
||||
|
||||
self._has_cpu_groups = True
|
||||
|
||||
@@ -145,7 +150,7 @@ class ProcessGroup:
|
||||
else:
|
||||
return "ProcessGroup not initialized"
|
||||
|
||||
def __eq__(self, obj: 'ProcessGroup') -> bool:
|
||||
def __eq__(self, obj: "ProcessGroup") -> bool:
|
||||
if not isinstance(obj, ProcessGroup):
|
||||
return False
|
||||
if self._rank != obj._rank:
|
||||
@@ -260,7 +265,7 @@ class ProcessGroup:
|
||||
Returns:
|
||||
`torch._C._distributed_c10d.ProcessGroup`: the pytorch DP process group.
|
||||
"""
|
||||
return PYTORCHPGDICT_.get(self._dp_rank_list, 'nccl')
|
||||
return PYTORCHPGDICT_.get(self._dp_rank_list, "nccl")
|
||||
|
||||
def tp_process_group(self):
|
||||
"""tp_process_group
|
||||
@@ -270,7 +275,7 @@ class ProcessGroup:
|
||||
Returns:
|
||||
`torch._C._distributed_c10d.ProcessGroup`: the pytorch TP process group.
|
||||
"""
|
||||
return PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl')
|
||||
return PYTORCHPGDICT_.get(self._tp_rank_list, "nccl")
|
||||
|
||||
def cpu_dp_process_group(self):
|
||||
"""cpu_dp_process_group
|
||||
@@ -283,7 +288,7 @@ class ProcessGroup:
|
||||
`torch._C._distributed_c10d.ProcessGroup`: the pytorch DP process group.
|
||||
"""
|
||||
assert self._has_cpu_groups
|
||||
return PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo')
|
||||
return PYTORCHPGDICT_.get(self._dp_rank_list, "gloo")
|
||||
|
||||
def cpu_tp_process_group(self):
|
||||
"""cpu_tp_process_group
|
||||
@@ -296,7 +301,7 @@ class ProcessGroup:
|
||||
`torch._C._distributed_c10d.ProcessGroup`: the pytorch TP process group.
|
||||
"""
|
||||
assert self._has_cpu_groups
|
||||
return PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo')
|
||||
return PYTORCHPGDICT_.get(self._tp_rank_list, "gloo")
|
||||
|
||||
def get_ranks_in_dp(self) -> List[int]:
|
||||
"""get_ranks_in_dp
|
||||
|
Reference in New Issue
Block a user