[misc] refactor launch API and tensor constructor (#5666)

* [misc] remove config arg from initialize

* [misc] remove old tensor contrusctor

* [plugin] add npu support for ddp

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [devops] fix doc test ci

* [test] fix test launch

* [doc] update launch doc

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Hongxin Liu
2024-04-29 10:40:11 +08:00
committed by GitHub
parent 91fa553775
commit 7f8b16635b
223 changed files with 294 additions and 403 deletions

View File

@@ -6,7 +6,7 @@ from colossalai.testing import spawn
def check_device_mesh_manager(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
device_mesh_manager = DeviceMeshManager()
# TODO(ver217): this test is strictly relies on hardware, temporary skip it
# device_mesh_info_auto = DeviceMeshInfo(physical_ids=[0, 1, 2, 3],)

View File

@@ -6,57 +6,6 @@ from colossalai.cluster import ProcessGroupMesh
from colossalai.testing import spawn
def check_process_group_mesh_with_gpc():
from colossalai.legacy.context import ParallelMode
from colossalai.legacy.core import global_context as gpc
DP_DIM, PP_DIM, TP_DIM = 0, 1, 2
pg_mesh = ProcessGroupMesh(1, 2, 2)
# check world size
assert gpc.get_world_size(ParallelMode.TENSOR) == pg_mesh.size(
TP_DIM
), f"{gpc.get_world_size(ParallelMode.TENSOR)} != {pg_mesh.size(TP_DIM)}"
assert gpc.get_world_size(ParallelMode.PIPELINE) == pg_mesh.size(PP_DIM)
assert gpc.get_world_size(ParallelMode.DATA) == pg_mesh.size(DP_DIM)
# check locak rank (coordinate)
assert gpc.get_local_rank(ParallelMode.TENSOR) == pg_mesh.coordinate(
TP_DIM
), f"{gpc.get_local_rank(ParallelMode.TENSOR)} != {pg_mesh.coordinate(TP_DIM)}"
assert gpc.get_local_rank(ParallelMode.PIPELINE) == pg_mesh.coordinate(PP_DIM)
assert gpc.get_local_rank(ParallelMode.DATA) == pg_mesh.coordinate(DP_DIM)
# check ranks in group
tp_group = pg_mesh.get_group_along_axis(TP_DIM)
assert gpc.get_ranks_in_group(ParallelMode.TENSOR) == pg_mesh.get_ranks_in_group(tp_group)
pp_group = pg_mesh.get_group_along_axis(PP_DIM)
assert gpc.get_ranks_in_group(ParallelMode.PIPELINE) == pg_mesh.get_ranks_in_group(pp_group)
dp_group = pg_mesh.get_group_along_axis(DP_DIM)
assert gpc.get_ranks_in_group(ParallelMode.DATA) == pg_mesh.get_ranks_in_group(dp_group)
# check prev rank
coord = pg_mesh.coordinate()
if not gpc.is_first_rank(ParallelMode.TENSOR):
assert coord[TP_DIM] != 0
prev_coord = coord[:TP_DIM] + (coord[TP_DIM] - 1,) + coord[TP_DIM + 1 :]
assert gpc.get_prev_global_rank(ParallelMode.TENSOR) == pg_mesh.ravel(prev_coord, pg_mesh.shape)
if not gpc.is_first_rank(ParallelMode.PIPELINE):
assert coord[PP_DIM] != 0
prev_coord = coord[:PP_DIM] + (coord[PP_DIM] - 1,) + coord[PP_DIM + 1 :]
assert gpc.get_prev_global_rank(ParallelMode.PIPELINE) == pg_mesh.ravel(prev_coord, pg_mesh.shape)
# check next rank
if not gpc.is_last_rank(ParallelMode.TENSOR):
assert coord[TP_DIM] != pg_mesh.size(TP_DIM) - 1
next_coord = coord[:TP_DIM] + (coord[TP_DIM] + 1,) + coord[TP_DIM + 1 :]
assert gpc.get_next_global_rank(ParallelMode.TENSOR) == pg_mesh.ravel(next_coord, pg_mesh.shape)
if not gpc.is_last_rank(ParallelMode.PIPELINE):
assert coord[PP_DIM] != pg_mesh.size(PP_DIM) - 1
next_coord = coord[:PP_DIM] + (coord[PP_DIM] + 1,) + coord[PP_DIM + 1 :]
assert gpc.get_next_global_rank(ParallelMode.PIPELINE) == pg_mesh.ravel(next_coord, pg_mesh.shape)
def check_process_group_mesh_with_cases():
DP_DIM, PP_DIM, TP_DIM = 0, 1, 2
DP_SIZE, PP_SIZE, TP_SIZE = 1, 2, 2
@@ -177,14 +126,11 @@ def check_process_group_mesh_with_cases():
def run_dist(rank, world_size, port):
colossalai.launch(
config=dict(parallel=dict(data=1, pipeline=2, tensor=dict(mode="1d", size=2))),
rank=rank,
world_size=world_size,
port=port,
host="localhost",
)
# TODO(ver217): this function should be removed when gpc is removed
# check_process_group_mesh_with_gpc()
check_process_group_mesh_with_cases()