mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 20:54:35 +00:00
[misc] refactor launch API and tensor constructor (#5666)
* [misc] remove config arg from initialize * [misc] remove old tensor contrusctor * [plugin] add npu support for ddp * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [devops] fix doc test ci * [test] fix test launch * [doc] update launch doc --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -56,7 +56,7 @@ class Worker:
|
||||
# initialize a ray collective group, otherwise colossalai distributed env won't be built successfully
|
||||
collective.init_collective_group(world_size, rank, "nccl", "default")
|
||||
# initialize and set distributed environment
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
ray_serve_logger.info(f"Worker with rank {rank} (world size {world_size}) setting up..")
|
||||
log_cuda_info("Worker.setup")
|
||||
|
||||
|
@@ -42,7 +42,7 @@ class CaiInferEngine:
|
||||
import colossalai
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
colossalai.launch_from_torch(config={})
|
||||
colossalai.launch_from_torch()
|
||||
|
||||
model = LlamaForCausalLM.from_pretrained("your_path_to_model")
|
||||
tokenizer = LlamaTokenizer.from_pretrained("/home/lczyh/share/models/llama-7b-hf")
|
||||
|
@@ -36,7 +36,7 @@ from colossalai.inference.pipeline.policies import LlamaModelInferPolicy
|
||||
import colossalai
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
colossalai.launch_from_torch(config={})
|
||||
colossalai.launch_from_torch()
|
||||
|
||||
model = LlamaForCausalLM.from_pretrained("/path/to/model")
|
||||
tokenizer = LlamaTokenizer.from_pretrained("/path/to/model")
|
||||
@@ -57,27 +57,27 @@ We conducted multiple benchmark tests to evaluate the performance. We compared t
|
||||
### Llama Throughput (tokens/s) | input length=1024, output length=128
|
||||
|
||||
#### A10 7b, fp16
|
||||
| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(8) | 32(8) | 32(16)|
|
||||
| :---: | :---: | :---: | :---: | :---: | :---: | :---:|
|
||||
| Pipeline Inference | 40.35 | 77.1 | 139.03 | 232.7 | 257.81 | OOM |
|
||||
| Hugging Face | 41.43 | 65.30 | 91.93 | 114.62 | OOM| OOM |
|
||||
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(8) | 32(16) |
|
||||
|:----------------------------:|:-----:|:-----:|:------:|:------:|:------:|:------:|
|
||||
| Pipeline Inference | 40.35 | 77.1 | 139.03 | 232.7 | 257.81 | OOM |
|
||||
| Hugging Face | 41.43 | 65.30 | 91.93 | 114.62 | OOM | OOM |
|
||||
|
||||
#### A10 13b, fp16
|
||||
| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(4) |
|
||||
| :---: | :---: | :---: | :---: | :---: |
|
||||
| Pipeline Inference | 25.39 | 47.09 | 83.7 | 89.46 |
|
||||
| Hugging Face | 23.48 | 37.59 | 53.44 | OOM |
|
||||
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(4) |
|
||||
|:----------------------------:|:-----:|:-----:|:-----:|:-----:|
|
||||
| Pipeline Inference | 25.39 | 47.09 | 83.7 | 89.46 |
|
||||
| Hugging Face | 23.48 | 37.59 | 53.44 | OOM |
|
||||
|
||||
|
||||
#### A800 7b, fp16
|
||||
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) |
|
||||
| :---: | :---: | :---: | :---: | :---: | :---: |
|
||||
| Pipeline Inference| 57.97 | 110.13 | 213.33 | 389.86 | 670.12 |
|
||||
| Hugging Face | 42.44 | 76.5 | 151.97 | 212.88 | 256.13 |
|
||||
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) |
|
||||
|:----------------------------:|:-----:|:------:|:------:|:------:|:------:|
|
||||
| Pipeline Inference | 57.97 | 110.13 | 213.33 | 389.86 | 670.12 |
|
||||
| Hugging Face | 42.44 | 76.5 | 151.97 | 212.88 | 256.13 |
|
||||
|
||||
|
||||
#### A800 13b, fp16
|
||||
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) |
|
||||
| :---: | :---: | :---: | :---: | :---: | :---: |
|
||||
| Pipeline Inference | 41.78 | 94.18 | 172.67| 310.75| 470.15 |
|
||||
| Hugging Face | 36.57 | 68.4 | 105.81 | 139.51 | 166.34 |
|
||||
| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) |
|
||||
|:----------------------------:|:-----:|:-----:|:------:|:------:|:------:|
|
||||
| Pipeline Inference | 41.78 | 94.18 | 172.67 | 310.75 | 470.15 |
|
||||
| Hugging Face | 36.57 | 68.4 | 105.81 | 139.51 | 166.34 |
|
||||
|
@@ -12,7 +12,7 @@ from colossalai.inference.pipeline.policies import LlamaModelInferPolicy
|
||||
GIGABYTE = 1024**3
|
||||
MEGABYTE = 1024 * 1024
|
||||
|
||||
colossalai.launch_from_torch(config={})
|
||||
colossalai.launch_from_torch()
|
||||
|
||||
|
||||
def data_gen(batch_size: int = 4, seq_len: int = 512):
|
||||
|
@@ -56,7 +56,7 @@ class Worker:
|
||||
# initialize a ray collective group, otherwise colossalai distributed env won't be built successfully
|
||||
collective.init_collective_group(world_size, rank, "nccl", "default")
|
||||
# initialize and set distributed environment
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
ray_serve_logger.info(f"Worker with rank {rank} (world size {world_size}) setting up..")
|
||||
log_cuda_info("Worker.setup")
|
||||
|
||||
|
@@ -98,7 +98,7 @@ class ColossalInferenceHandler(BaseHandler, ABC):
|
||||
self.model.cuda()
|
||||
self.model.eval()
|
||||
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host=host, port=port, backend="nccl")
|
||||
colossalai.launch(rank=rank, world_size=world_size, host=host, port=port, backend="nccl")
|
||||
logger.info("Initializing TPInferEngine ...")
|
||||
shard_config = ShardConfig(
|
||||
enable_tensor_parallelism=True if self.tp_size > 1 else False, extra_kwargs={"inference_only": True}
|
||||
|
Reference in New Issue
Block a user