mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-28 11:45:23 +00:00
* sequence parallel optimization * validate sequence parallel in llama (code to be polished) * shardformer api writing * integrate sequence parallel in ShardFormer * fix pp bugs and sp bugs for LlaMa model * integrating ring-based sequence parallelism into ShardFormer * [sequence parallelism]: Add fused megatron function * integrating ring-based sequence parallelism into ShardFormer --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> * fix bugs when useing sp and flashattention together * fix operation function name * support flash attention for ulysses-style sp * clarify sp process group * fix compatibility bugs in moe plugin * fix fused linear bugs * fix linear layer test * support gpt model all-to-all sp * modify shard data dimension (meant to be dim=-1) * support megtron-style sp and distributed attn for llama model * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * finish sp mode 3 support for gpt * using all_to_all_single when batch size is 1 * support mode 2 sp in gpt2 (#5) * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * refactor ring implementation * support mode 2 sp in gpt2 * polish code * enable distributed attn mask when using sp mode 2 and 3 in llama * automatically enable flash attn when using sp mode 2 and 3 in llama * inplace attn mask * add zero2 support for sequence parallel * polish code * fix bugs * fix gemini checkpoint io * loose tensor checking atol and rtol * add comment * fix llama layernorm grad * fix zero grad * fix zero grad * fix conflict * update split and gather auto grad func * sequence parallel: inside text split (#6) * polish code (part 1) * polish code (part 2) * polish code (part 2.5) * polish code (part 3) * sequence parallel: inside text split * miscellaneous minor fixes * polish code * fix ulysses style ZeRO * sequence parallel: inside text split * miscellaneous minor fixes * disaggregate sp group and dp group for sp * fix llama and gpt sp * polish code * move ulysses grad sync to ddp (#9) * remove zero_stage and unbind the grad sync for alltoall sp * add 2d group creation test * move ulysses grad sync to ddp * add 2d group creation test * remove useless code * change shard config not to enable sp when enable_all_optimizations * add sp warnings for several model * remove useless code --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
198 lines
7.2 KiB
Python
198 lines
7.2 KiB
Python
import pytest
|
|
import torch.distributed as dist
|
|
|
|
import colossalai
|
|
from colossalai.cluster import ProcessGroupMesh
|
|
from colossalai.testing import spawn
|
|
|
|
|
|
def check_process_group_mesh_with_gpc():
|
|
from colossalai.legacy.context import ParallelMode
|
|
from colossalai.legacy.core import global_context as gpc
|
|
|
|
DP_DIM, PP_DIM, TP_DIM = 0, 1, 2
|
|
pg_mesh = ProcessGroupMesh(1, 2, 2)
|
|
|
|
# check world size
|
|
assert gpc.get_world_size(ParallelMode.TENSOR) == pg_mesh.size(
|
|
TP_DIM
|
|
), f"{gpc.get_world_size(ParallelMode.TENSOR)} != {pg_mesh.size(TP_DIM)}"
|
|
assert gpc.get_world_size(ParallelMode.PIPELINE) == pg_mesh.size(PP_DIM)
|
|
assert gpc.get_world_size(ParallelMode.DATA) == pg_mesh.size(DP_DIM)
|
|
|
|
# check locak rank (coordinate)
|
|
assert gpc.get_local_rank(ParallelMode.TENSOR) == pg_mesh.coordinate(
|
|
TP_DIM
|
|
), f"{gpc.get_local_rank(ParallelMode.TENSOR)} != {pg_mesh.coordinate(TP_DIM)}"
|
|
assert gpc.get_local_rank(ParallelMode.PIPELINE) == pg_mesh.coordinate(PP_DIM)
|
|
assert gpc.get_local_rank(ParallelMode.DATA) == pg_mesh.coordinate(DP_DIM)
|
|
|
|
# check ranks in group
|
|
tp_group = pg_mesh.get_group_along_axis(TP_DIM)
|
|
assert gpc.get_ranks_in_group(ParallelMode.TENSOR) == pg_mesh.get_ranks_in_group(tp_group)
|
|
pp_group = pg_mesh.get_group_along_axis(PP_DIM)
|
|
assert gpc.get_ranks_in_group(ParallelMode.PIPELINE) == pg_mesh.get_ranks_in_group(pp_group)
|
|
dp_group = pg_mesh.get_group_along_axis(DP_DIM)
|
|
assert gpc.get_ranks_in_group(ParallelMode.DATA) == pg_mesh.get_ranks_in_group(dp_group)
|
|
|
|
# check prev rank
|
|
coord = pg_mesh.coordinate()
|
|
if not gpc.is_first_rank(ParallelMode.TENSOR):
|
|
assert coord[TP_DIM] != 0
|
|
prev_coord = coord[:TP_DIM] + (coord[TP_DIM] - 1,) + coord[TP_DIM + 1 :]
|
|
assert gpc.get_prev_global_rank(ParallelMode.TENSOR) == pg_mesh.ravel(prev_coord, pg_mesh.shape)
|
|
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
|
assert coord[PP_DIM] != 0
|
|
prev_coord = coord[:PP_DIM] + (coord[PP_DIM] - 1,) + coord[PP_DIM + 1 :]
|
|
assert gpc.get_prev_global_rank(ParallelMode.PIPELINE) == pg_mesh.ravel(prev_coord, pg_mesh.shape)
|
|
|
|
# check next rank
|
|
if not gpc.is_last_rank(ParallelMode.TENSOR):
|
|
assert coord[TP_DIM] != pg_mesh.size(TP_DIM) - 1
|
|
next_coord = coord[:TP_DIM] + (coord[TP_DIM] + 1,) + coord[TP_DIM + 1 :]
|
|
assert gpc.get_next_global_rank(ParallelMode.TENSOR) == pg_mesh.ravel(next_coord, pg_mesh.shape)
|
|
if not gpc.is_last_rank(ParallelMode.PIPELINE):
|
|
assert coord[PP_DIM] != pg_mesh.size(PP_DIM) - 1
|
|
next_coord = coord[:PP_DIM] + (coord[PP_DIM] + 1,) + coord[PP_DIM + 1 :]
|
|
assert gpc.get_next_global_rank(ParallelMode.PIPELINE) == pg_mesh.ravel(next_coord, pg_mesh.shape)
|
|
|
|
|
|
def check_process_group_mesh_with_cases():
|
|
DP_DIM, PP_DIM, TP_DIM = 0, 1, 2
|
|
DP_SIZE, PP_SIZE, TP_SIZE = 1, 2, 2
|
|
RANK_TO_COORDINATE = {
|
|
0: (0, 0, 0),
|
|
1: (0, 0, 1),
|
|
2: (0, 1, 0),
|
|
3: (0, 1, 1),
|
|
}
|
|
TP_RANKS_IN_GROUP = {
|
|
0: [0, 1],
|
|
1: [0, 1],
|
|
2: [2, 3],
|
|
3: [2, 3],
|
|
}
|
|
PP_RANKS_IN_GROUP = {
|
|
0: [0, 2],
|
|
1: [1, 3],
|
|
2: [0, 2],
|
|
3: [1, 3],
|
|
}
|
|
DP_RANKS_IN_GROUP = {
|
|
0: [0],
|
|
1: [1],
|
|
2: [2],
|
|
3: [3],
|
|
}
|
|
TPxPP_RANKS_IN_GROUP = {
|
|
0: [0, 1, 2, 3],
|
|
1: [0, 1, 2, 3],
|
|
2: [0, 1, 2, 3],
|
|
3: [0, 1, 2, 3],
|
|
}
|
|
DPxTP_RANKS_IN_GROUP = {
|
|
0: [0, 1],
|
|
1: [0, 1],
|
|
2: [2, 3],
|
|
3: [2, 3],
|
|
}
|
|
TPxPP_PARTIAL_INDICES = {
|
|
0: [[0, 1], [0]],
|
|
1: [[1], [0, 1]],
|
|
2: [[0], [0, 1]],
|
|
3: [[0, 1], [1]],
|
|
}
|
|
TPxPP_RANKS_IN_GROUP_PARTIAL = {
|
|
0: [0, 1],
|
|
1: [1, 3],
|
|
2: [0, 2],
|
|
3: [2, 3],
|
|
}
|
|
|
|
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE, TP_SIZE)
|
|
|
|
rank = dist.get_rank()
|
|
assert rank == pg_mesh.rank
|
|
|
|
# check world size
|
|
assert pg_mesh.size(TP_DIM) == 2
|
|
assert pg_mesh.size(PP_DIM) == 2
|
|
assert pg_mesh.size(DP_DIM) == 1
|
|
|
|
# check coordinate
|
|
assert pg_mesh.coordinate(TP_DIM) == RANK_TO_COORDINATE[rank][TP_DIM]
|
|
assert pg_mesh.coordinate(PP_DIM) == RANK_TO_COORDINATE[rank][PP_DIM]
|
|
assert pg_mesh.coordinate(DP_DIM) == RANK_TO_COORDINATE[rank][DP_DIM]
|
|
|
|
# check ranks in group
|
|
tp_group = pg_mesh.get_group_along_axis(TP_DIM)
|
|
assert pg_mesh.get_ranks_in_group(tp_group) == TP_RANKS_IN_GROUP[rank]
|
|
pp_group = pg_mesh.get_group_along_axis(PP_DIM)
|
|
assert pg_mesh.get_ranks_in_group(pp_group) == PP_RANKS_IN_GROUP[rank]
|
|
dp_group = pg_mesh.get_group_along_axis(DP_DIM)
|
|
assert pg_mesh.get_ranks_in_group(dp_group) == DP_RANKS_IN_GROUP[rank]
|
|
dpxtp_group = pg_mesh.create_group_along_axis([DP_DIM, TP_DIM])
|
|
assert pg_mesh.get_ranks_in_group(dpxtp_group) == DPxTP_RANKS_IN_GROUP[rank]
|
|
tpxpp_group = pg_mesh.create_group_along_axis([TP_DIM, PP_DIM])
|
|
assert pg_mesh.get_ranks_in_group(tpxpp_group) == TPxPP_RANKS_IN_GROUP[rank]
|
|
tpxpp_group_partial = pg_mesh.create_group_along_axis([TP_DIM, PP_DIM], TPxPP_PARTIAL_INDICES[rank])
|
|
assert pg_mesh.get_ranks_in_group(tpxpp_group_partial) == TPxPP_RANKS_IN_GROUP_PARTIAL[rank]
|
|
|
|
# check prev rank
|
|
if RANK_TO_COORDINATE[rank][TP_DIM] != 0:
|
|
prev_coord = (
|
|
RANK_TO_COORDINATE[rank][:TP_DIM]
|
|
+ (RANK_TO_COORDINATE[rank][TP_DIM] - 1,)
|
|
+ RANK_TO_COORDINATE[rank][TP_DIM + 1 :]
|
|
)
|
|
prev_rank = TP_RANKS_IN_GROUP[rank][TP_RANKS_IN_GROUP[rank].index(rank) - 1]
|
|
assert pg_mesh.ravel(prev_coord, pg_mesh.shape) == prev_rank
|
|
if RANK_TO_COORDINATE[rank][PP_DIM] != 0:
|
|
prev_coord = (
|
|
RANK_TO_COORDINATE[rank][:PP_DIM]
|
|
+ (RANK_TO_COORDINATE[rank][PP_DIM] - 1,)
|
|
+ RANK_TO_COORDINATE[rank][PP_DIM + 1 :]
|
|
)
|
|
prev_rank = PP_RANKS_IN_GROUP[rank][PP_RANKS_IN_GROUP[rank].index(rank) - 1]
|
|
assert pg_mesh.ravel(prev_coord, pg_mesh.shape) == prev_rank
|
|
|
|
# check next rank
|
|
if RANK_TO_COORDINATE[rank][TP_DIM] != TP_SIZE - 1:
|
|
next_coord = (
|
|
RANK_TO_COORDINATE[rank][:TP_DIM]
|
|
+ (RANK_TO_COORDINATE[rank][TP_DIM] + 1,)
|
|
+ RANK_TO_COORDINATE[rank][TP_DIM + 1 :]
|
|
)
|
|
next_rank = TP_RANKS_IN_GROUP[rank][TP_RANKS_IN_GROUP[rank].index(rank) + 1]
|
|
assert pg_mesh.ravel(next_coord, pg_mesh.shape) == next_rank
|
|
if RANK_TO_COORDINATE[rank][PP_DIM] != PP_SIZE - 1:
|
|
next_coord = (
|
|
RANK_TO_COORDINATE[rank][:PP_DIM]
|
|
+ (RANK_TO_COORDINATE[rank][PP_DIM] + 1,)
|
|
+ RANK_TO_COORDINATE[rank][PP_DIM + 1 :]
|
|
)
|
|
next_rank = PP_RANKS_IN_GROUP[rank][PP_RANKS_IN_GROUP[rank].index(rank) + 1]
|
|
assert pg_mesh.ravel(next_coord, pg_mesh.shape) == next_rank
|
|
|
|
|
|
def run_dist(rank, world_size, port):
|
|
colossalai.launch(
|
|
config=dict(parallel=dict(data=1, pipeline=2, tensor=dict(mode="1d", size=2))),
|
|
rank=rank,
|
|
world_size=world_size,
|
|
port=port,
|
|
host="localhost",
|
|
)
|
|
# TODO(ver217): this function should be removed when gpc is removed
|
|
# check_process_group_mesh_with_gpc()
|
|
check_process_group_mesh_with_cases()
|
|
|
|
|
|
@pytest.mark.dist
|
|
def test_process_group_mesh():
|
|
spawn(run_dist, 4)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_process_group_mesh()
|