mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 20:54:35 +00:00
[misc] refactor launch API and tensor constructor (#5666)
* [misc] remove config arg from initialize * [misc] remove old tensor contrusctor * [plugin] add npu support for ddp * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [devops] fix doc test ci * [test] fix test launch * [doc] update launch doc --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -126,7 +126,7 @@ class AMPOptimizer(OptimizerWrapper):
|
||||
return self.grad_scaler.scale.item()
|
||||
|
||||
def zero_grad(self, *args, **kwargs):
|
||||
self.module.overflow_counter = torch.cuda.IntTensor([0])
|
||||
self.module.overflow_counter = torch.tensor([0], dtype=torch.int, device=get_accelerator().get_current_device())
|
||||
return self.optim.zero_grad(set_to_none=True)
|
||||
|
||||
def step(self, *args, **kwargs):
|
||||
|
@@ -4,7 +4,7 @@ from typing import Optional, Set
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.utils import _cast_float
|
||||
from colossalai.utils import _cast_float, get_current_device
|
||||
from colossalai.utils.common import free_storage
|
||||
|
||||
from .region_manager import RegionManager
|
||||
@@ -25,7 +25,7 @@ class BaseOffloadModule:
|
||||
self.model = model
|
||||
self.region_manager = region_manager
|
||||
self.grad_hook_list = []
|
||||
self.overflow_counter = torch.cuda.IntTensor([0])
|
||||
self.overflow_counter = torch.tensor([0], dtype=torch.int, device=get_current_device())
|
||||
|
||||
self.grad_offload_stream = torch.cuda.current_stream() if is_sync else GlobalRuntimeInfo.d2h_stream
|
||||
|
||||
|
@@ -10,6 +10,7 @@ from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .dp_plugin_base import DPPluginBase
|
||||
|
||||
@@ -203,7 +204,7 @@ class TorchDDPPlugin(DPPluginBase):
|
||||
return True
|
||||
|
||||
def supported_devices(self) -> List[str]:
|
||||
return ["cuda"]
|
||||
return ["cuda", "npu"]
|
||||
|
||||
def configure(
|
||||
self,
|
||||
@@ -214,7 +215,7 @@ class TorchDDPPlugin(DPPluginBase):
|
||||
lr_scheduler: Optional[LRScheduler] = None,
|
||||
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||
# cast model to cuda
|
||||
model = model.cuda()
|
||||
model = model.to(get_current_device())
|
||||
|
||||
# convert model to sync bn
|
||||
model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None)
|
||||
|
@@ -114,7 +114,7 @@ import colossalai
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
#launch distributed environment
|
||||
colossalai.launch_from_torch(config={})
|
||||
colossalai.launch_from_torch()
|
||||
|
||||
# load original model and tokenizer
|
||||
model = LlamaForCausalLM.from_pretrained("/path/to/model")
|
||||
|
@@ -2,20 +2,15 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Dict, Union
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.context import Config
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import set_seed
|
||||
|
||||
|
||||
def launch(
|
||||
config: Union[str, Path, Config, Dict],
|
||||
rank: int,
|
||||
world_size: int,
|
||||
host: str,
|
||||
@@ -44,8 +39,6 @@ def launch(
|
||||
Raises:
|
||||
Exception: Raise exception when config type is wrong
|
||||
"""
|
||||
if rank == 0:
|
||||
warnings.warn("`config` is deprecated and will be removed soon.")
|
||||
|
||||
cur_accelerator = get_accelerator()
|
||||
|
||||
@@ -68,7 +61,6 @@ def launch(
|
||||
|
||||
|
||||
def launch_from_slurm(
|
||||
config: Union[str, Path, Config, Dict],
|
||||
host: str,
|
||||
port: int,
|
||||
backend: str = "nccl",
|
||||
@@ -95,7 +87,6 @@ def launch_from_slurm(
|
||||
)
|
||||
|
||||
launch(
|
||||
config=config,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host=host,
|
||||
@@ -107,7 +98,6 @@ def launch_from_slurm(
|
||||
|
||||
|
||||
def launch_from_openmpi(
|
||||
config: Union[str, Path, Config, Dict],
|
||||
host: str,
|
||||
port: int,
|
||||
backend: str = "nccl",
|
||||
@@ -135,7 +125,6 @@ def launch_from_openmpi(
|
||||
)
|
||||
|
||||
launch(
|
||||
config=config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
@@ -147,9 +136,7 @@ def launch_from_openmpi(
|
||||
)
|
||||
|
||||
|
||||
def launch_from_torch(
|
||||
config: Union[str, Path, Config, Dict], backend: str = "nccl", seed: int = 1024, verbose: bool = True
|
||||
):
|
||||
def launch_from_torch(backend: str = "nccl", seed: int = 1024, verbose: bool = True):
|
||||
"""A wrapper for colossalai.launch for torchrun or torch.distributed.launch by reading rank and world size
|
||||
from the environment variables set by PyTorch
|
||||
|
||||
@@ -171,7 +158,6 @@ def launch_from_torch(
|
||||
)
|
||||
|
||||
launch(
|
||||
config=config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
|
@@ -56,7 +56,7 @@ class Worker:
|
||||
# initialize a ray collective group, otherwise colossalai distributed env won't be built successfully
|
||||
collective.init_collective_group(world_size, rank, "nccl", "default")
|
||||
# initialize and set distributed environment
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
ray_serve_logger.info(f"Worker with rank {rank} (world size {world_size}) setting up..")
|
||||
log_cuda_info("Worker.setup")
|
||||
|
||||
|
@@ -42,7 +42,7 @@ class CaiInferEngine:
|
||||
import colossalai
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
colossalai.launch_from_torch(config={})
|
||||
colossalai.launch_from_torch()
|
||||
|
||||
model = LlamaForCausalLM.from_pretrained("your_path_to_model")
|
||||
tokenizer = LlamaTokenizer.from_pretrained("/home/lczyh/share/models/llama-7b-hf")
|
||||
|
@@ -36,7 +36,7 @@ from colossalai.inference.pipeline.policies import LlamaModelInferPolicy
|
||||
import colossalai
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
colossalai.launch_from_torch(config={})
|
||||
colossalai.launch_from_torch()
|
||||
|
||||
model = LlamaForCausalLM.from_pretrained("/path/to/model")
|
||||
tokenizer = LlamaTokenizer.from_pretrained("/path/to/model")
|
||||
@@ -57,27 +57,27 @@ We conducted multiple benchmark tests to evaluate the performance. We compared t
|
||||
### Llama Throughput (tokens/s) | input length=1024, output length=128
|
||||
|
||||
#### A10 7b, fp16
|
||||
| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(8) | 32(8) | 32(16)|
|
||||
| :---: | :---: | :---: | :---: | :---: | :---: | :---:|
|
||||
| Pipeline Inference | 40.35 | 77.1 | 139.03 | 232.7 | 257.81 | OOM |
|
||||
| Hugging Face | 41.43 | 65.30 | 91.93 | 114.62 | OOM| OOM |
|
||||
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(8) | 32(16) |
|
||||
|:----------------------------:|:-----:|:-----:|:------:|:------:|:------:|:------:|
|
||||
| Pipeline Inference | 40.35 | 77.1 | 139.03 | 232.7 | 257.81 | OOM |
|
||||
| Hugging Face | 41.43 | 65.30 | 91.93 | 114.62 | OOM | OOM |
|
||||
|
||||
#### A10 13b, fp16
|
||||
| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(4) |
|
||||
| :---: | :---: | :---: | :---: | :---: |
|
||||
| Pipeline Inference | 25.39 | 47.09 | 83.7 | 89.46 |
|
||||
| Hugging Face | 23.48 | 37.59 | 53.44 | OOM |
|
||||
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(4) |
|
||||
|:----------------------------:|:-----:|:-----:|:-----:|:-----:|
|
||||
| Pipeline Inference | 25.39 | 47.09 | 83.7 | 89.46 |
|
||||
| Hugging Face | 23.48 | 37.59 | 53.44 | OOM |
|
||||
|
||||
|
||||
#### A800 7b, fp16
|
||||
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) |
|
||||
| :---: | :---: | :---: | :---: | :---: | :---: |
|
||||
| Pipeline Inference| 57.97 | 110.13 | 213.33 | 389.86 | 670.12 |
|
||||
| Hugging Face | 42.44 | 76.5 | 151.97 | 212.88 | 256.13 |
|
||||
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) |
|
||||
|:----------------------------:|:-----:|:------:|:------:|:------:|:------:|
|
||||
| Pipeline Inference | 57.97 | 110.13 | 213.33 | 389.86 | 670.12 |
|
||||
| Hugging Face | 42.44 | 76.5 | 151.97 | 212.88 | 256.13 |
|
||||
|
||||
|
||||
#### A800 13b, fp16
|
||||
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) |
|
||||
| :---: | :---: | :---: | :---: | :---: | :---: |
|
||||
| Pipeline Inference | 41.78 | 94.18 | 172.67| 310.75| 470.15 |
|
||||
| Hugging Face | 36.57 | 68.4 | 105.81 | 139.51 | 166.34 |
|
||||
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) |
|
||||
|:----------------------------:|:-----:|:-----:|:------:|:------:|:------:|
|
||||
| Pipeline Inference | 41.78 | 94.18 | 172.67 | 310.75 | 470.15 |
|
||||
| Hugging Face | 36.57 | 68.4 | 105.81 | 139.51 | 166.34 |
|
||||
|
@@ -12,7 +12,7 @@ from colossalai.inference.pipeline.policies import LlamaModelInferPolicy
|
||||
GIGABYTE = 1024**3
|
||||
MEGABYTE = 1024 * 1024
|
||||
|
||||
colossalai.launch_from_torch(config={})
|
||||
colossalai.launch_from_torch()
|
||||
|
||||
|
||||
def data_gen(batch_size: int = 4, seq_len: int = 512):
|
||||
|
@@ -56,7 +56,7 @@ class Worker:
|
||||
# initialize a ray collective group, otherwise colossalai distributed env won't be built successfully
|
||||
collective.init_collective_group(world_size, rank, "nccl", "default")
|
||||
# initialize and set distributed environment
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
ray_serve_logger.info(f"Worker with rank {rank} (world size {world_size}) setting up..")
|
||||
log_cuda_info("Worker.setup")
|
||||
|
||||
|
@@ -98,7 +98,7 @@ class ColossalInferenceHandler(BaseHandler, ABC):
|
||||
self.model.cuda()
|
||||
self.model.eval()
|
||||
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host=host, port=port, backend="nccl")
|
||||
colossalai.launch(rank=rank, world_size=world_size, host=host, port=port, backend="nccl")
|
||||
logger.info("Initializing TPInferEngine ...")
|
||||
shard_config = ShardConfig(
|
||||
enable_tensor_parallelism=True if self.tp_size > 1 else False, extra_kwargs={"inference_only": True}
|
||||
|
@@ -114,7 +114,7 @@ def run_worker(rank, args, master_func):
|
||||
port = args.master_port
|
||||
backend = "nccl" if device == "cuda" else "gloo"
|
||||
|
||||
launch(dict(), rank, world_size, host, int(port), backend, verbose=False)
|
||||
launch(rank, world_size, host, int(port), backend, verbose=False)
|
||||
ppg.set_global_info(
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
|
@@ -8,7 +8,7 @@ Licensed under the MIT License.
|
||||
"""
|
||||
import torch
|
||||
|
||||
from colossalai.utils import multi_tensor_applier
|
||||
from colossalai.utils import get_current_device, multi_tensor_applier
|
||||
|
||||
|
||||
class FusedAdam(torch.optim.Optimizer):
|
||||
@@ -75,7 +75,7 @@ class FusedAdam(torch.optim.Optimizer):
|
||||
fused_optim = FusedOptimizerLoader().load()
|
||||
|
||||
# Skip buffer
|
||||
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=get_current_device())
|
||||
self.multi_tensor_adam = fused_optim.multi_tensor_adam
|
||||
else:
|
||||
raise RuntimeError("FusedAdam requires cuda extensions")
|
||||
|
@@ -3,7 +3,7 @@ from typing import Any, Optional
|
||||
import torch
|
||||
|
||||
from colossalai.kernel.kernel_loader import FusedOptimizerLoader
|
||||
from colossalai.utils import multi_tensor_applier
|
||||
from colossalai.utils import get_current_device, multi_tensor_applier
|
||||
|
||||
from .cpu_adam import CPUAdam
|
||||
|
||||
@@ -87,7 +87,7 @@ class HybridAdam(CPUAdam):
|
||||
if torch.cuda.is_available():
|
||||
fused_optim = FusedOptimizerLoader().load()
|
||||
self.gpu_adam_op = fused_optim.multi_tensor_adam
|
||||
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=get_current_device())
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None, div_scale: float = -1):
|
||||
|
@@ -38,7 +38,7 @@ from transformers import BertForMaskedLM
|
||||
import colossalai
|
||||
|
||||
# launch colossalai
|
||||
colossalai.launch_from_torch(config={})
|
||||
colossalai.launch_from_torch()
|
||||
|
||||
# create model
|
||||
config = BertConfig.from_pretrained('bert-base-uncased')
|
||||
|
@@ -28,7 +28,7 @@ def to_device(x: Any, device: torch.device) -> Any:
|
||||
|
||||
|
||||
def train(args):
|
||||
colossalai.launch_from_torch(config={}, seed=42)
|
||||
colossalai.launch_from_torch(seed=42)
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# prepare for data and dataset
|
||||
|
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Shardformer Benchmark
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import transformers
|
||||
@@ -84,5 +85,5 @@ def bench_shardformer(BATCH, N_CTX, provider, model_func, dtype=torch.float32, d
|
||||
# start benchmark, command:
|
||||
# torchrun --standalone --nproc_per_node=2 performance_benchmark.py
|
||||
if __name__ == "__main__":
|
||||
colossalai.launch_from_torch({})
|
||||
colossalai.launch_from_torch()
|
||||
bench_shardformer.run(save_path=".", print_data=dist.get_rank() == 0)
|
||||
|
@@ -26,7 +26,7 @@ class ShardFormer:
|
||||
import colossalai
|
||||
import torch
|
||||
|
||||
colossalai.launch_from_torch(config={})
|
||||
colossalai.launch_from_torch()
|
||||
|
||||
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
|
||||
shard_config = ShardConfig()
|
||||
|
@@ -69,7 +69,7 @@ import colossalai
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.d_tensor import DTensor, ShardingSpec
|
||||
|
||||
colossalai.launch_from_torch(config={})
|
||||
colossalai.launch_from_torch()
|
||||
|
||||
# define your device mesh
|
||||
# assume you have 4 GPUs
|
||||
|
Reference in New Issue
Block a user