mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
[misc] refactor launch API and tensor constructor (#5666)
* [misc] remove config arg from initialize * [misc] remove old tensor contrusctor * [plugin] add npu support for ddp * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [devops] fix doc test ci * [test] fix test launch * [doc] update launch doc --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -2,20 +2,15 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Dict, Union
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.context import Config
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import set_seed
|
||||
|
||||
|
||||
def launch(
|
||||
config: Union[str, Path, Config, Dict],
|
||||
rank: int,
|
||||
world_size: int,
|
||||
host: str,
|
||||
@@ -44,8 +39,6 @@ def launch(
|
||||
Raises:
|
||||
Exception: Raise exception when config type is wrong
|
||||
"""
|
||||
if rank == 0:
|
||||
warnings.warn("`config` is deprecated and will be removed soon.")
|
||||
|
||||
cur_accelerator = get_accelerator()
|
||||
|
||||
@@ -68,7 +61,6 @@ def launch(
|
||||
|
||||
|
||||
def launch_from_slurm(
|
||||
config: Union[str, Path, Config, Dict],
|
||||
host: str,
|
||||
port: int,
|
||||
backend: str = "nccl",
|
||||
@@ -95,7 +87,6 @@ def launch_from_slurm(
|
||||
)
|
||||
|
||||
launch(
|
||||
config=config,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host=host,
|
||||
@@ -107,7 +98,6 @@ def launch_from_slurm(
|
||||
|
||||
|
||||
def launch_from_openmpi(
|
||||
config: Union[str, Path, Config, Dict],
|
||||
host: str,
|
||||
port: int,
|
||||
backend: str = "nccl",
|
||||
@@ -135,7 +125,6 @@ def launch_from_openmpi(
|
||||
)
|
||||
|
||||
launch(
|
||||
config=config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
@@ -147,9 +136,7 @@ def launch_from_openmpi(
|
||||
)
|
||||
|
||||
|
||||
def launch_from_torch(
|
||||
config: Union[str, Path, Config, Dict], backend: str = "nccl", seed: int = 1024, verbose: bool = True
|
||||
):
|
||||
def launch_from_torch(backend: str = "nccl", seed: int = 1024, verbose: bool = True):
|
||||
"""A wrapper for colossalai.launch for torchrun or torch.distributed.launch by reading rank and world size
|
||||
from the environment variables set by PyTorch
|
||||
|
||||
@@ -171,7 +158,6 @@ def launch_from_torch(
|
||||
)
|
||||
|
||||
launch(
|
||||
config=config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
|
Reference in New Issue
Block a user