mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-01 03:45:27 +00:00
[Inference] Fix API server, test and example (#5712)
* fix api server * fix generation config * fix api server * fix comments * fix infer hanging bug * resolve comments, change backend to free port
This commit is contained in:
parent
74c47921fa
commit
f47f2fbb24
@ -4,6 +4,7 @@ from functools import partial
|
|||||||
from typing import AsyncIterator, Dict, Iterable, List, Optional, Set, Tuple, Type
|
from typing import AsyncIterator, Dict, Iterable, List, Optional, Set, Tuple, Type
|
||||||
|
|
||||||
from colossalai.inference.core.engine import InferenceEngine
|
from colossalai.inference.core.engine import InferenceEngine
|
||||||
|
from colossalai.inference.sampler import search_tokens
|
||||||
|
|
||||||
# CLI logger
|
# CLI logger
|
||||||
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||||
@ -168,26 +169,44 @@ class _AsyncInferenceEngine(InferenceEngine):
|
|||||||
generated results.
|
generated results.
|
||||||
"""
|
"""
|
||||||
batch = self.request_handler.schedule()
|
batch = self.request_handler.schedule()
|
||||||
|
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
|
||||||
|
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
|
if input_meta_data.use_cuda_graph:
|
||||||
|
model_executable = self.graph_runners[input_meta_data.batch_size]
|
||||||
|
else:
|
||||||
|
model_executable = self.model
|
||||||
|
|
||||||
# Use run_in_executor to asyncally run the sync method model.forward().
|
# Use run_in_executor to asyncally run the sync method model.forward().
|
||||||
logits = await loop.run_in_executor(
|
logits = await loop.run_in_executor(
|
||||||
None,
|
None,
|
||||||
self.model,
|
model_executable,
|
||||||
batch,
|
input_token_ids,
|
||||||
|
output_tensor,
|
||||||
|
input_meta_data,
|
||||||
self.k_cache,
|
self.k_cache,
|
||||||
self.v_cache,
|
self.v_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.inference_config.pad_input:
|
if self.inference_config.pad_input:
|
||||||
logits = logits[:, -1, :]
|
logits = logits[:, -1, :]
|
||||||
self.request_handler.search_tokens(self.generation_config, logits)
|
next_tokens = search_tokens(
|
||||||
|
self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
self.request_handler.append_next_tokens(next_tokens)
|
||||||
finished_sequences = self.request_handler.update()
|
finished_sequences = self.request_handler.update()
|
||||||
|
|
||||||
for sequence in finished_sequences:
|
for sequence in finished_sequences:
|
||||||
sequence.output = self.tokenizer.decode(sequence.output_token_id)
|
sequence.output = self.tokenizer.decode(sequence.output_token_id)
|
||||||
|
|
||||||
return finished_sequences, self.request_handler.total_requests_in_batch_bucket() > 0
|
return finished_sequences, not self.request_handler.running_list.is_empty()
|
||||||
|
|
||||||
|
def add_single_request(self, request_id: int, prompt: str, prompt_token_ids, generation_config=None):
|
||||||
|
prompts = [prompt]
|
||||||
|
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
|
||||||
|
self.add_request(request_ids=request_id, prompts=prompts, prompts_token_ids=prompt_token_ids, **gen_config_dict)
|
||||||
|
|
||||||
|
|
||||||
class AsyncInferenceEngine:
|
class AsyncInferenceEngine:
|
||||||
@ -240,7 +259,6 @@ class AsyncInferenceEngine:
|
|||||||
for new_request in new_requests:
|
for new_request in new_requests:
|
||||||
self.engine.add_single_request(**new_request)
|
self.engine.add_single_request(**new_request)
|
||||||
newly_finished_seqs, has_running_requests = await self.engine.async_step()
|
newly_finished_seqs, has_running_requests = await self.engine.async_step()
|
||||||
|
|
||||||
for seq in newly_finished_seqs:
|
for seq in newly_finished_seqs:
|
||||||
self._request_tracer.process_finished_request(seq)
|
self._request_tracer.process_finished_request(seq)
|
||||||
|
|
||||||
@ -273,6 +291,7 @@ class AsyncInferenceEngine:
|
|||||||
request_id: int,
|
request_id: int,
|
||||||
prompt: Optional[str],
|
prompt: Optional[str],
|
||||||
prompt_token_ids: Optional[List[int]] = None,
|
prompt_token_ids: Optional[List[int]] = None,
|
||||||
|
generation_config=None,
|
||||||
) -> RequstStream:
|
) -> RequstStream:
|
||||||
"""
|
"""
|
||||||
Add a request to the background tracker(waiting queue), start the background loop if needed.
|
Add a request to the background tracker(waiting queue), start the background loop if needed.
|
||||||
@ -286,6 +305,7 @@ class AsyncInferenceEngine:
|
|||||||
request_id,
|
request_id,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
prompt_token_ids=prompt_token_ids,
|
prompt_token_ids=prompt_token_ids,
|
||||||
|
generation_config=generation_config,
|
||||||
)
|
)
|
||||||
return stream
|
return stream
|
||||||
|
|
||||||
@ -294,13 +314,16 @@ class AsyncInferenceEngine:
|
|||||||
request_id: int,
|
request_id: int,
|
||||||
prompt: Optional[str],
|
prompt: Optional[str],
|
||||||
prompt_token_ids: Optional[List[int]] = None,
|
prompt_token_ids: Optional[List[int]] = None,
|
||||||
|
generation_config=None,
|
||||||
) -> AsyncIterator[str]:
|
) -> AsyncIterator[str]:
|
||||||
"""
|
"""
|
||||||
Generate output from a request. It receives the request from http server, adds it into the
|
Generate output from a request. It receives the request from http server, adds it into the
|
||||||
waitting queue of Async Engine and streams the output sequence.
|
waitting queue of Async Engine and streams the output sequence.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
stream = await self.add_request(request_id, prompt, prompt_token_ids=prompt_token_ids)
|
stream = await self.add_request(
|
||||||
|
request_id, prompt, prompt_token_ids=prompt_token_ids, generation_config=generation_config
|
||||||
|
)
|
||||||
return await stream.get_result()
|
return await stream.get_result()
|
||||||
|
|
||||||
except (Exception, asyncio.CancelledError) as e:
|
except (Exception, asyncio.CancelledError) as e:
|
||||||
|
@ -154,7 +154,6 @@ class InferenceEngine:
|
|||||||
else:
|
else:
|
||||||
model_type = "nopadding_" + self.model_config.model_type
|
model_type = "nopadding_" + self.model_config.model_type
|
||||||
model_policy = model_policy_map[model_type]()
|
model_policy = model_policy_map[model_type]()
|
||||||
|
|
||||||
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
|
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
|
||||||
tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
|
tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
|
||||||
|
|
||||||
@ -589,7 +588,7 @@ class InferenceEngine:
|
|||||||
def add_request(
|
def add_request(
|
||||||
self,
|
self,
|
||||||
request_ids: Union[List[int], int] = None,
|
request_ids: Union[List[int], int] = None,
|
||||||
prompts: List[str] = None,
|
prompts: Union[List[str], str] = None,
|
||||||
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -20,10 +20,12 @@ from fastapi import FastAPI, Request
|
|||||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
import colossalai
|
||||||
from colossalai.inference.config import InferenceConfig
|
from colossalai.inference.config import InferenceConfig
|
||||||
from colossalai.inference.server.chat_service import ChatServing
|
from colossalai.inference.server.chat_service import ChatServing
|
||||||
from colossalai.inference.server.completion_service import CompletionServing
|
from colossalai.inference.server.completion_service import CompletionServing
|
||||||
from colossalai.inference.server.utils import id_generator
|
from colossalai.inference.server.utils import id_generator
|
||||||
|
from colossalai.inference.utils import find_available_ports
|
||||||
|
|
||||||
from colossalai.inference.core.async_engine import AsyncInferenceEngine, InferenceEngine # noqa
|
from colossalai.inference.core.async_engine import AsyncInferenceEngine, InferenceEngine # noqa
|
||||||
|
|
||||||
@ -54,8 +56,9 @@ async def generate(request: Request) -> Response:
|
|||||||
"""
|
"""
|
||||||
request_dict = await request.json()
|
request_dict = await request.json()
|
||||||
prompt = request_dict.pop("prompt")
|
prompt = request_dict.pop("prompt")
|
||||||
stream = request_dict.pop("stream", "false").lower()
|
stream = request_dict.pop("stream", "false")
|
||||||
|
if isinstance(stream, str):
|
||||||
|
stream = stream.lower()
|
||||||
request_id = id_generator()
|
request_id = id_generator()
|
||||||
generation_config = get_generation_config(request_dict)
|
generation_config = get_generation_config(request_dict)
|
||||||
results = engine.generate(request_id, prompt, generation_config=generation_config)
|
results = engine.generate(request_id, prompt, generation_config=generation_config)
|
||||||
@ -66,7 +69,7 @@ async def generate(request: Request) -> Response:
|
|||||||
ret = {"text": request_output[len(prompt) :]}
|
ret = {"text": request_output[len(prompt) :]}
|
||||||
yield (json.dumps(ret) + "\0").encode("utf-8")
|
yield (json.dumps(ret) + "\0").encode("utf-8")
|
||||||
|
|
||||||
if stream == "true":
|
if stream == "true" or stream == True:
|
||||||
return StreamingResponse(stream_results())
|
return StreamingResponse(stream_results())
|
||||||
|
|
||||||
# Non-streaming case
|
# Non-streaming case
|
||||||
@ -86,12 +89,14 @@ async def generate(request: Request) -> Response:
|
|||||||
@app.post("/completion")
|
@app.post("/completion")
|
||||||
async def create_completion(request: Request):
|
async def create_completion(request: Request):
|
||||||
request_dict = await request.json()
|
request_dict = await request.json()
|
||||||
stream = request_dict.pop("stream", "false").lower()
|
stream = request_dict.pop("stream", "false")
|
||||||
|
if isinstance(stream, str):
|
||||||
|
stream = stream.lower()
|
||||||
generation_config = get_generation_config(request_dict)
|
generation_config = get_generation_config(request_dict)
|
||||||
result = await completion_serving.create_completion(request, generation_config)
|
result = await completion_serving.create_completion(request, generation_config)
|
||||||
|
|
||||||
ret = {"request_id": result.request_id, "text": result.output}
|
ret = {"request_id": result.request_id, "text": result.output}
|
||||||
if stream == "true":
|
if stream == "true" or stream == True:
|
||||||
return StreamingResponse(content=json.dumps(ret) + "\0", media_type="text/event-stream")
|
return StreamingResponse(content=json.dumps(ret) + "\0", media_type="text/event-stream")
|
||||||
else:
|
else:
|
||||||
return JSONResponse(content=ret)
|
return JSONResponse(content=ret)
|
||||||
@ -101,10 +106,12 @@ async def create_completion(request: Request):
|
|||||||
async def create_chat(request: Request):
|
async def create_chat(request: Request):
|
||||||
request_dict = await request.json()
|
request_dict = await request.json()
|
||||||
|
|
||||||
stream = request_dict.get("stream", "false").lower()
|
stream = request_dict.get("stream", "false")
|
||||||
|
if isinstance(stream, str):
|
||||||
|
stream = stream.lower()
|
||||||
generation_config = get_generation_config(request_dict)
|
generation_config = get_generation_config(request_dict)
|
||||||
message = await chat_serving.create_chat(request, generation_config)
|
message = await chat_serving.create_chat(request, generation_config)
|
||||||
if stream == "true":
|
if stream == "true" or stream == True:
|
||||||
return StreamingResponse(content=message, media_type="text/event-stream")
|
return StreamingResponse(content=message, media_type="text/event-stream")
|
||||||
else:
|
else:
|
||||||
ret = {"role": message.role, "text": message.content}
|
ret = {"role": message.role, "text": message.content}
|
||||||
@ -115,27 +122,29 @@ def get_generation_config(request):
|
|||||||
generation_config = async_engine.engine.generation_config
|
generation_config = async_engine.engine.generation_config
|
||||||
for arg in request:
|
for arg in request:
|
||||||
if hasattr(generation_config, arg):
|
if hasattr(generation_config, arg):
|
||||||
generation_config[arg] = request[arg]
|
setattr(generation_config, arg, request[arg])
|
||||||
return generation_config
|
return generation_config
|
||||||
|
|
||||||
|
|
||||||
def add_engine_config(parser):
|
def add_engine_config(parser):
|
||||||
parser.add_argument("--model", type=str, default="llama2-7b", help="name or path of the huggingface model to use")
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-model-len",
|
"-m", "--model", type=str, default="llama2-7b", help="name or path of the huggingface model to use"
|
||||||
type=int,
|
|
||||||
default=None,
|
|
||||||
help="model context length. If unspecified, " "will be automatically derived from the model.",
|
|
||||||
)
|
)
|
||||||
# Parallel arguments
|
# Parallel arguments not supported now
|
||||||
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1, help="number of tensor parallel replicas")
|
|
||||||
|
|
||||||
# KV cache arguments
|
# KV cache arguments
|
||||||
parser.add_argument("--block-size", type=int, default=16, choices=[8, 16, 32], help="token block size")
|
parser.add_argument("--block-size", type=int, default=16, choices=[8, 16, 32], help="token block size")
|
||||||
|
|
||||||
parser.add_argument("--max_batch_size", type=int, default=8, help="maximum number of batch size")
|
parser.add_argument("--max_batch_size", type=int, default=8, help="maximum number of batch size")
|
||||||
|
|
||||||
|
parser.add_argument("-i", "--max_input_len", type=int, default=128, help="max input length")
|
||||||
|
|
||||||
|
parser.add_argument("-o", "--max_output_len", type=int, default=128, help="max output length")
|
||||||
|
|
||||||
|
parser.add_argument("-d", "--dtype", type=str, default="fp16", help="Data type", choices=["fp16", "fp32", "bf16"])
|
||||||
|
|
||||||
|
parser.add_argument("--use_cuda_kernel", action="store_true", help="Use CUDA kernel, use Triton by default")
|
||||||
|
|
||||||
# generation arguments
|
# generation arguments
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--prompt_template",
|
"--prompt_template",
|
||||||
@ -150,7 +159,7 @@ def parse_args():
|
|||||||
parser = argparse.ArgumentParser(description="Colossal-Inference API server.")
|
parser = argparse.ArgumentParser(description="Colossal-Inference API server.")
|
||||||
|
|
||||||
parser.add_argument("--host", type=str, default="127.0.0.1")
|
parser.add_argument("--host", type=str, default="127.0.0.1")
|
||||||
parser.add_argument("--port", type=int, default=8000)
|
parser.add_argument("--port", type=int, default=8000, help="port of FastAPI server.")
|
||||||
parser.add_argument("--ssl-keyfile", type=str, default=None)
|
parser.add_argument("--ssl-keyfile", type=str, default=None)
|
||||||
parser.add_argument("--ssl-certfile", type=str, default=None)
|
parser.add_argument("--ssl-certfile", type=str, default=None)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -164,6 +173,7 @@ def parse_args():
|
|||||||
"specified, the model name will be the same as "
|
"specified, the model name will be the same as "
|
||||||
"the huggingface name.",
|
"the huggingface name.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--chat-template",
|
"--chat-template",
|
||||||
type=str,
|
type=str,
|
||||||
@ -184,13 +194,21 @@ def parse_args():
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
inference_config = InferenceConfig.from_dict(vars(args))
|
inference_config = InferenceConfig.from_dict(vars(args))
|
||||||
model = AutoModelForCausalLM.from_pretrained(args.model)
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||||||
|
colossalai_backend_port = find_available_ports(1)[0]
|
||||||
|
colossalai.launch(
|
||||||
|
rank=0,
|
||||||
|
world_size=1,
|
||||||
|
host=args.host,
|
||||||
|
port=colossalai_backend_port,
|
||||||
|
backend="nccl",
|
||||||
|
)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(args.model)
|
||||||
async_engine = AsyncInferenceEngine(
|
async_engine = AsyncInferenceEngine(
|
||||||
start_engine_loop=True, model=model, tokenizer=tokenizer, inference_config=inference_config
|
start_engine_loop=True, model_or_path=model, tokenizer=tokenizer, inference_config=inference_config
|
||||||
)
|
)
|
||||||
engine = async_engine.engine
|
engine = async_engine.engine
|
||||||
completion_serving = CompletionServing(async_engine, served_model=model.__class__.__name__)
|
completion_serving = CompletionServing(async_engine, model.__class__.__name__)
|
||||||
chat_serving = ChatServing(
|
chat_serving = ChatServing(
|
||||||
async_engine,
|
async_engine,
|
||||||
served_model=model.__class__.__name__,
|
served_model=model.__class__.__name__,
|
||||||
|
@ -23,7 +23,7 @@ class CompletionServing:
|
|||||||
|
|
||||||
# it is not a intuitive way
|
# it is not a intuitive way
|
||||||
self.engine.engine.generation_config = generation_config
|
self.engine.engine.generation_config = generation_config
|
||||||
result_generator = self.engine.generate(request_id, prompt=prompt)
|
result_generator = self.engine.generate(request_id, prompt=prompt, generation_config=generation_config)
|
||||||
|
|
||||||
if await request.is_disconnected():
|
if await request.is_disconnected():
|
||||||
# Abort the request if the client disconnects.
|
# Abort the request if the client disconnects.
|
||||||
|
@ -6,8 +6,9 @@
|
|||||||
model_path=${1:-"lmsys/vicuna-7b-v1.3"}
|
model_path=${1:-"lmsys/vicuna-7b-v1.3"}
|
||||||
chat_template="{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}"
|
chat_template="{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}"
|
||||||
echo "Model Path: $model_path"
|
echo "Model Path: $model_path"
|
||||||
|
echo "Chat Tempelate" "${chat_template}"
|
||||||
echo "Starting server..."
|
echo "Starting server..."
|
||||||
python -m colossalai.inference.server.api_server --model $model_path --chat-template $chat_template &
|
python -m colossalai.inference.server.api_server --model $model_path --chat-template "${chat_template}" &
|
||||||
SERVER_PID=$!
|
SERVER_PID=$!
|
||||||
|
|
||||||
# waiting time
|
# waiting time
|
||||||
@ -17,9 +18,9 @@ sleep 60
|
|||||||
echo "Starting Locust..."
|
echo "Starting Locust..."
|
||||||
echo "The test will automatically begin, you can turn to http://0.0.0.0:8089 for more information."
|
echo "The test will automatically begin, you can turn to http://0.0.0.0:8089 for more information."
|
||||||
echo "Test completion api first"
|
echo "Test completion api first"
|
||||||
locust -f locustfile.py -t 300 --tags online-generation --host http://127.0.0.1:8000 --autostart --users 100 --stop-timeout 10
|
locust -f locustfile.py -t 300 --tags online-generation --host http://127.0.0.1:8000 --autostart --users 300 --stop-timeout 10
|
||||||
echo "Test chat api"
|
echo "Test chat api"
|
||||||
locust -f locustfile.py -t 300 --tags online-chat --host http://127.0.0.1:8000 --autostart --users 100 --stop-timeout 10
|
locust -f locustfile.py -t 300 --tags online-chat --host http://127.0.0.1:8000 --autostart --users 300 --stop-timeout 10
|
||||||
# kill Server
|
# kill Server
|
||||||
echo "Stopping server..."
|
echo "Stopping server..."
|
||||||
kill $SERVER_PID
|
kill $SERVER_PID
|
||||||
|
Loading…
Reference in New Issue
Block a user