ColossalAI/colossalai/legacy/inference/serving/ray_serve/Colossal_Inference_rayserve.py
Xu Kai fd6482ad8c
[inference] Refactor inference architecture (#5057)
* [inference] support only TP (#4998)

* support only tp

* enable tp

* add support for bloom (#5008)

* [refactor] refactor gptq and smoothquant llama (#5012)

* refactor gptq and smoothquant llama

* fix import error

* fix linear import torch-int

* fix smoothquant llama import error

* fix import accelerate error

* fix bug

* fix import smooth cuda

* fix smoothcuda

* [Inference Refactor] Merge chatglm2 with pp and tp (#5023)

merge chatglm with pp and tp

* [Refactor] remove useless inference code (#5022)

* remove useless code

* fix quant model

* fix test import bug

* mv original inference legacy

* fix chatglm2

* [Refactor] refactor policy search and quant type controlling in inference (#5035)

* [Refactor] refactor policy search and quant type controling in inference

* [inference] update readme (#5051)

* update readme

* update readme

* fix architecture

* fix table

* fix table

* [inference] udpate example (#5053)

* udpate example

* fix run.sh

* fix rebase bug

* fix some errors

* update readme

* add some features

* update interface

* update readme

* update benchmark

* add requirements-infer

---------

Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
2023-11-19 21:05:05 +08:00

154 lines
5.7 KiB
Python

import logging
import os
from typing import Any, List, Union
import ray
import ray.util.collective as collective
import starlette
import torch
from pydantic import BaseModel
from ray import serve
from ray.serve import Application
from transformers import AutoModelForCausalLM, AutoTokenizer
import colossalai
from colossalai.inference.tensor_parallel.engine import TPInferEngine
from colossalai.shardformer import ShardConfig
from colossalai.testing import free_port
ray_serve_logger = logging.getLogger("ray.serve")
class GenConfigArgs(BaseModel):
"""Config for generation"""
path: str
tp_size: int = 2
max_batch_size: int = 4
max_input_len: int = 128
max_output_len: int = 32
def log_cuda_info(scope_name: str):
ray_serve_logger.info(f" {scope_name}: ray.get_gpu_ids(): {ray.get_gpu_ids()}")
ray_serve_logger.info(
f" {scope_name}: CUDA_VISIBLE_DEVICES: {os.getenv('CUDA_VISIBLE_DEVICES', 'NO DEVICES FOUND!')}"
)
if torch.cuda.is_available():
ray_serve_logger.info(
f" {scope_name}: cuda current_device: {torch.cuda.current_device()}, cuda device count: {torch.cuda.device_count()}"
)
else:
ray_serve_logger.info(f" {scope_name}: cuda is not available!")
@ray.remote(num_gpus=1)
class Worker:
def __init__(self, model_path: str, tp_size: int, max_batch_size: int, max_input_len: int, max_output_len: int):
log_cuda_info("Worker.init")
self.tp_size = tp_size
self.model_path = model_path
self.max_batch_size = max_batch_size
self.max_input_len = max_input_len
self.max_output_len = max_output_len
def setup(self, world_size, rank, port):
# initialize a ray collective group, otherwise colossalai distributed env won't be built successfully
collective.init_collective_group(world_size, rank, "nccl", "default")
# initialize and set distributed environment
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
ray_serve_logger.info(f"Worker with rank {rank} (world size {world_size}) setting up..")
log_cuda_info("Worker.setup")
# Load model
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = AutoModelForCausalLM.from_pretrained(
self.model_path, pad_token_id=self.tokenizer.pad_token_id, torch_dtype=torch.float16
)
shard_config = ShardConfig(
enable_tensor_parallelism=True if world_size > 1 else False, extra_kwargs={"inference_only": True}
)
self.infer_engine = TPInferEngine(
self.model, shard_config, self.max_batch_size, self.max_input_len, self.max_output_len
)
self.generate_kwargs = dict(max_new_tokens=self.max_output_len, do_sample=False)
return True
def generate(self, text: Union[str, List[str]]) -> str:
input_tokens = self.tokenizer.batch_encode_plus(text, return_tensors="pt", padding=True)
ray_serve_logger.info(f"text: {text},\ninput_tokens: {input_tokens}")
model_output = self.infer_engine.generate(input_tokens, **self.generate_kwargs)
ray_serve_logger.info(f"model_output.shape: {model_output.shape}")
text_output = []
for i in range(len(model_output)):
text_output.append(self.tokenizer.decode(model_output[i]))
ray_serve_logger.info(f"output: {text_output}")
return text_output
@serve.deployment(
ray_actor_options={"num_cpus": 1, "num_gpus": 0},
max_concurrent_queries=5,
autoscaling_config={
"target_num_ongoing_requests_per_replica": 1,
"min_replicas": 1,
"initial_replicas": 1,
"max_replicas": 1,
},
)
class Driver:
def __init__(self, config: GenConfigArgs):
log_cuda_info("Driver:init")
model_path = config.path
tp_size = config.tp_size
self.num_workers = tp_size
self.workers = []
init_rets = []
# Just grab a free port on localhost
# NOTE workers in this communication group listen to the same port
available_port = free_port()
for i in range(self.num_workers):
worker_name = "worker_idx_{}".format(i)
w = Worker.options(name=worker_name).remote(
model_path, self.num_workers, config.max_batch_size, config.max_input_len, config.max_output_len
)
self.workers.append(w)
init_rets.append(w.setup.remote(self.num_workers, i, available_port))
_options = {
"group_name": "default_driver",
"world_size": self.num_workers,
"ranks": [i for i in range(self.num_workers)],
"backend": "nccl",
}
collective.create_collective_group(self.workers, **_options)
_ = ray.get(init_rets)
# set batch wait delay in seconds and maximum number of sequences in a batch
@serve.batch(batch_wait_timeout_s=0.8, max_batch_size=4)
async def batch_generate(self, requests: List[str]):
ray_serve_logger.info(f"Driver.batch_generate: requests length: {len(requests)}\n requests: {requests}")
results = ray.get([w.generate.remote(requests) for w in self.workers])
text_res = results[0] # get any one of the copies
return text_res
async def __call__(self, request: starlette.requests.Request) -> Any:
return await self.batch_generate(request.query_params["text"])
def app(args: GenConfigArgs) -> Application:
print(args)
if args.path is None or not os.path.exists(args.path):
raise ValueError("Model path not provided or invalid path!")
return Driver.options(name="Colossal-Inference-Driver").bind(config=args)