mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -19,12 +19,14 @@ def _check_tesseract_env_var(tesseract_dim: int, tesseract_dep: int):
|
||||
env_tesseract_dep = env.tesseract_dep
|
||||
|
||||
if env_tesseract_dim and env_tesseract_dep:
|
||||
assert int(env_tesseract_dim) == tesseract_dim, \
|
||||
'TESSERACT_DIM has been set in the current environment and ' \
|
||||
'does not match with the value passed to this initialized'
|
||||
assert int(env_tesseract_dep) == tesseract_dep, \
|
||||
'TESSERACT_DEP has been set in the current environment and ' \
|
||||
'does not match with the value passed to this initialized'
|
||||
assert int(env_tesseract_dim) == tesseract_dim, (
|
||||
"TESSERACT_DIM has been set in the current environment and "
|
||||
"does not match with the value passed to this initialized"
|
||||
)
|
||||
assert int(env_tesseract_dep) == tesseract_dep, (
|
||||
"TESSERACT_DEP has been set in the current environment and "
|
||||
"does not match with the value passed to this initialized"
|
||||
)
|
||||
else:
|
||||
env.tesseract_dim = tesseract_dim
|
||||
env.tesseract_dep = tesseract_dep
|
||||
@@ -50,8 +52,9 @@ class Initializer_2p5D_ROW(ProcessGroupInitializer):
|
||||
self.num_group = self.world_size // self.tensor_parallel_size
|
||||
self.tesseract_dep = tesseract_dep
|
||||
self.tesseract_dim = tesseract_dim
|
||||
assert self.tensor_parallel_size == self.tesseract_dim ** 2 * self.tesseract_dep, \
|
||||
"Tensor parallel size should be depth * dim ** 2 in 2.5D parallel"
|
||||
assert (
|
||||
self.tensor_parallel_size == self.tesseract_dim**2 * self.tesseract_dep
|
||||
), "Tensor parallel size should be depth * dim ** 2 in 2.5D parallel"
|
||||
|
||||
def init_dist_group(self):
|
||||
"""Initialize 2.5D tensor row parallel groups, and assign local_ranks and groups to each gpu.
|
||||
@@ -75,7 +78,7 @@ class Initializer_2p5D_ROW(ProcessGroupInitializer):
|
||||
for i in range(self.tesseract_dim)
|
||||
]
|
||||
group = dist.new_group(ranks)
|
||||
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
|
||||
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
|
||||
|
||||
if self.rank in ranks:
|
||||
local_rank = ranks.index(self.rank)
|
||||
@@ -129,7 +132,7 @@ class Initializer_2p5D_Col(ProcessGroupInitializer):
|
||||
for j in range(self.tesseract_dim)
|
||||
]
|
||||
group = dist.new_group(ranks)
|
||||
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
|
||||
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
|
||||
|
||||
if self.rank in ranks:
|
||||
local_rank = ranks.index(self.rank)
|
||||
@@ -183,7 +186,7 @@ class Initializer_2p5D_Dep(ProcessGroupInitializer):
|
||||
for k in range(self.tesseract_dep)
|
||||
]
|
||||
group = dist.new_group(ranks)
|
||||
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
|
||||
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
|
||||
|
||||
if self.rank in ranks:
|
||||
local_rank = ranks.index(self.rank)
|
||||
@@ -238,7 +241,7 @@ class Initializer_2p5D_XZ(ProcessGroupInitializer):
|
||||
for j in range(self.tesseract_dim)
|
||||
]
|
||||
group = dist.new_group(ranks)
|
||||
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
|
||||
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
|
||||
|
||||
if self.rank in ranks:
|
||||
local_rank = ranks.index(self.rank)
|
||||
@@ -265,16 +268,25 @@ class Initializer_2p5D(ProcessGroupInitializer):
|
||||
depth (int): The depth of 2.5d parallel.
|
||||
"""
|
||||
|
||||
def __init__(self, rank: int, world_size: int, config: Config, data_parallel_size: int, pipeline_parallel_size: int,
|
||||
tensor_parallel_size: int, depth: int):
|
||||
def __init__(
|
||||
self,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
config: Config,
|
||||
data_parallel_size: int,
|
||||
pipeline_parallel_size: int,
|
||||
tensor_parallel_size: int,
|
||||
depth: int,
|
||||
):
|
||||
args = (rank, world_size, config, data_parallel_size, pipeline_parallel_size, tensor_parallel_size)
|
||||
super().__init__(*args)
|
||||
self.num_group = self.world_size // self.tensor_parallel_size
|
||||
self.tesseract_dim = int(math.sqrt(self.tensor_parallel_size / depth))
|
||||
self.tesseract_dep = depth
|
||||
|
||||
assert self.tensor_parallel_size == self.tesseract_dim ** 2 * self.tesseract_dep, \
|
||||
"2.5D tesseract dim should equal to (tensor parallel size / tesseract dep) ^ 0.5"
|
||||
assert (
|
||||
self.tensor_parallel_size == self.tesseract_dim**2 * self.tesseract_dep
|
||||
), "2.5D tesseract dim should equal to (tensor parallel size / tesseract dep) ^ 0.5"
|
||||
_check_tesseract_env_var(self.tesseract_dim, self.tesseract_dep)
|
||||
|
||||
self.col_initializer = Initializer_2p5D_Col(self.tesseract_dim, self.tesseract_dep, *args)
|
||||
@@ -293,6 +305,6 @@ class Initializer_2p5D(ProcessGroupInitializer):
|
||||
self.col_initializer.init_dist_group(),
|
||||
self.row_initializer.init_dist_group(),
|
||||
self.dep_initializer.init_dist_group(),
|
||||
self.xz_initializer.init_dist_group()
|
||||
self.xz_initializer.init_dist_group(),
|
||||
]
|
||||
return parallel_setting
|
||||
|
Reference in New Issue
Block a user