[Inference] Fix API server, test and example (#5712)

* fix api server

* fix generation config

* fix api server

* fix comments

* fix infer hanging bug

* resolve comments, change backend to free port
This commit is contained in:
Jianghai 2024-05-15 15:47:31 +08:00 committed by GitHub
parent 74c47921fa
commit f47f2fbb24
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 73 additions and 32 deletions

View File

@ -4,6 +4,7 @@ from functools import partial
from typing import AsyncIterator, Dict, Iterable, List, Optional, Set, Tuple, Type from typing import AsyncIterator, Dict, Iterable, List, Optional, Set, Tuple, Type
from colossalai.inference.core.engine import InferenceEngine from colossalai.inference.core.engine import InferenceEngine
from colossalai.inference.sampler import search_tokens
# CLI logger # CLI logger
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
@ -168,26 +169,44 @@ class _AsyncInferenceEngine(InferenceEngine):
generated results. generated results.
""" """
batch = self.request_handler.schedule() batch = self.request_handler.schedule()
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
if input_meta_data.use_cuda_graph:
model_executable = self.graph_runners[input_meta_data.batch_size]
else:
model_executable = self.model
# Use run_in_executor to asyncally run the sync method model.forward(). # Use run_in_executor to asyncally run the sync method model.forward().
logits = await loop.run_in_executor( logits = await loop.run_in_executor(
None, None,
self.model, model_executable,
batch, input_token_ids,
output_tensor,
input_meta_data,
self.k_cache, self.k_cache,
self.v_cache, self.v_cache,
) )
if self.inference_config.pad_input: if self.inference_config.pad_input:
logits = logits[:, -1, :] logits = logits[:, -1, :]
self.request_handler.search_tokens(self.generation_config, logits) next_tokens = search_tokens(
self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids
)
self.request_handler.append_next_tokens(next_tokens)
finished_sequences = self.request_handler.update() finished_sequences = self.request_handler.update()
for sequence in finished_sequences: for sequence in finished_sequences:
sequence.output = self.tokenizer.decode(sequence.output_token_id) sequence.output = self.tokenizer.decode(sequence.output_token_id)
return finished_sequences, self.request_handler.total_requests_in_batch_bucket() > 0 return finished_sequences, not self.request_handler.running_list.is_empty()
def add_single_request(self, request_id: int, prompt: str, prompt_token_ids, generation_config=None):
prompts = [prompt]
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
self.add_request(request_ids=request_id, prompts=prompts, prompts_token_ids=prompt_token_ids, **gen_config_dict)
class AsyncInferenceEngine: class AsyncInferenceEngine:
@ -240,7 +259,6 @@ class AsyncInferenceEngine:
for new_request in new_requests: for new_request in new_requests:
self.engine.add_single_request(**new_request) self.engine.add_single_request(**new_request)
newly_finished_seqs, has_running_requests = await self.engine.async_step() newly_finished_seqs, has_running_requests = await self.engine.async_step()
for seq in newly_finished_seqs: for seq in newly_finished_seqs:
self._request_tracer.process_finished_request(seq) self._request_tracer.process_finished_request(seq)
@ -273,6 +291,7 @@ class AsyncInferenceEngine:
request_id: int, request_id: int,
prompt: Optional[str], prompt: Optional[str],
prompt_token_ids: Optional[List[int]] = None, prompt_token_ids: Optional[List[int]] = None,
generation_config=None,
) -> RequstStream: ) -> RequstStream:
""" """
Add a request to the background tracker(waiting queue), start the background loop if needed. Add a request to the background tracker(waiting queue), start the background loop if needed.
@ -286,6 +305,7 @@ class AsyncInferenceEngine:
request_id, request_id,
prompt=prompt, prompt=prompt,
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
generation_config=generation_config,
) )
return stream return stream
@ -294,13 +314,16 @@ class AsyncInferenceEngine:
request_id: int, request_id: int,
prompt: Optional[str], prompt: Optional[str],
prompt_token_ids: Optional[List[int]] = None, prompt_token_ids: Optional[List[int]] = None,
generation_config=None,
) -> AsyncIterator[str]: ) -> AsyncIterator[str]:
""" """
Generate output from a request. It receives the request from http server, adds it into the Generate output from a request. It receives the request from http server, adds it into the
waitting queue of Async Engine and streams the output sequence. waitting queue of Async Engine and streams the output sequence.
""" """
try: try:
stream = await self.add_request(request_id, prompt, prompt_token_ids=prompt_token_ids) stream = await self.add_request(
request_id, prompt, prompt_token_ids=prompt_token_ids, generation_config=generation_config
)
return await stream.get_result() return await stream.get_result()
except (Exception, asyncio.CancelledError) as e: except (Exception, asyncio.CancelledError) as e:

View File

@ -154,7 +154,6 @@ class InferenceEngine:
else: else:
model_type = "nopadding_" + self.model_config.model_type model_type = "nopadding_" + self.model_config.model_type
model_policy = model_policy_map[model_type]() model_policy = model_policy_map[model_type]()
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size) pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
tp_group = pg_mesh.get_group_along_axis(TP_AXIS) tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
@ -589,7 +588,7 @@ class InferenceEngine:
def add_request( def add_request(
self, self,
request_ids: Union[List[int], int] = None, request_ids: Union[List[int], int] = None,
prompts: List[str] = None, prompts: Union[List[str], str] = None,
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None, prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
**kwargs, **kwargs,
) -> None: ) -> None:

View File

@ -20,10 +20,12 @@ from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse from fastapi.responses import JSONResponse, Response, StreamingResponse
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
import colossalai
from colossalai.inference.config import InferenceConfig from colossalai.inference.config import InferenceConfig
from colossalai.inference.server.chat_service import ChatServing from colossalai.inference.server.chat_service import ChatServing
from colossalai.inference.server.completion_service import CompletionServing from colossalai.inference.server.completion_service import CompletionServing
from colossalai.inference.server.utils import id_generator from colossalai.inference.server.utils import id_generator
from colossalai.inference.utils import find_available_ports
from colossalai.inference.core.async_engine import AsyncInferenceEngine, InferenceEngine # noqa from colossalai.inference.core.async_engine import AsyncInferenceEngine, InferenceEngine # noqa
@ -54,8 +56,9 @@ async def generate(request: Request) -> Response:
""" """
request_dict = await request.json() request_dict = await request.json()
prompt = request_dict.pop("prompt") prompt = request_dict.pop("prompt")
stream = request_dict.pop("stream", "false").lower() stream = request_dict.pop("stream", "false")
if isinstance(stream, str):
stream = stream.lower()
request_id = id_generator() request_id = id_generator()
generation_config = get_generation_config(request_dict) generation_config = get_generation_config(request_dict)
results = engine.generate(request_id, prompt, generation_config=generation_config) results = engine.generate(request_id, prompt, generation_config=generation_config)
@ -66,7 +69,7 @@ async def generate(request: Request) -> Response:
ret = {"text": request_output[len(prompt) :]} ret = {"text": request_output[len(prompt) :]}
yield (json.dumps(ret) + "\0").encode("utf-8") yield (json.dumps(ret) + "\0").encode("utf-8")
if stream == "true": if stream == "true" or stream == True:
return StreamingResponse(stream_results()) return StreamingResponse(stream_results())
# Non-streaming case # Non-streaming case
@ -86,12 +89,14 @@ async def generate(request: Request) -> Response:
@app.post("/completion") @app.post("/completion")
async def create_completion(request: Request): async def create_completion(request: Request):
request_dict = await request.json() request_dict = await request.json()
stream = request_dict.pop("stream", "false").lower() stream = request_dict.pop("stream", "false")
if isinstance(stream, str):
stream = stream.lower()
generation_config = get_generation_config(request_dict) generation_config = get_generation_config(request_dict)
result = await completion_serving.create_completion(request, generation_config) result = await completion_serving.create_completion(request, generation_config)
ret = {"request_id": result.request_id, "text": result.output} ret = {"request_id": result.request_id, "text": result.output}
if stream == "true": if stream == "true" or stream == True:
return StreamingResponse(content=json.dumps(ret) + "\0", media_type="text/event-stream") return StreamingResponse(content=json.dumps(ret) + "\0", media_type="text/event-stream")
else: else:
return JSONResponse(content=ret) return JSONResponse(content=ret)
@ -101,10 +106,12 @@ async def create_completion(request: Request):
async def create_chat(request: Request): async def create_chat(request: Request):
request_dict = await request.json() request_dict = await request.json()
stream = request_dict.get("stream", "false").lower() stream = request_dict.get("stream", "false")
if isinstance(stream, str):
stream = stream.lower()
generation_config = get_generation_config(request_dict) generation_config = get_generation_config(request_dict)
message = await chat_serving.create_chat(request, generation_config) message = await chat_serving.create_chat(request, generation_config)
if stream == "true": if stream == "true" or stream == True:
return StreamingResponse(content=message, media_type="text/event-stream") return StreamingResponse(content=message, media_type="text/event-stream")
else: else:
ret = {"role": message.role, "text": message.content} ret = {"role": message.role, "text": message.content}
@ -115,27 +122,29 @@ def get_generation_config(request):
generation_config = async_engine.engine.generation_config generation_config = async_engine.engine.generation_config
for arg in request: for arg in request:
if hasattr(generation_config, arg): if hasattr(generation_config, arg):
generation_config[arg] = request[arg] setattr(generation_config, arg, request[arg])
return generation_config return generation_config
def add_engine_config(parser): def add_engine_config(parser):
parser.add_argument("--model", type=str, default="llama2-7b", help="name or path of the huggingface model to use")
parser.add_argument( parser.add_argument(
"--max-model-len", "-m", "--model", type=str, default="llama2-7b", help="name or path of the huggingface model to use"
type=int,
default=None,
help="model context length. If unspecified, " "will be automatically derived from the model.",
) )
# Parallel arguments # Parallel arguments not supported now
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1, help="number of tensor parallel replicas")
# KV cache arguments # KV cache arguments
parser.add_argument("--block-size", type=int, default=16, choices=[8, 16, 32], help="token block size") parser.add_argument("--block-size", type=int, default=16, choices=[8, 16, 32], help="token block size")
parser.add_argument("--max_batch_size", type=int, default=8, help="maximum number of batch size") parser.add_argument("--max_batch_size", type=int, default=8, help="maximum number of batch size")
parser.add_argument("-i", "--max_input_len", type=int, default=128, help="max input length")
parser.add_argument("-o", "--max_output_len", type=int, default=128, help="max output length")
parser.add_argument("-d", "--dtype", type=str, default="fp16", help="Data type", choices=["fp16", "fp32", "bf16"])
parser.add_argument("--use_cuda_kernel", action="store_true", help="Use CUDA kernel, use Triton by default")
# generation arguments # generation arguments
parser.add_argument( parser.add_argument(
"--prompt_template", "--prompt_template",
@ -150,7 +159,7 @@ def parse_args():
parser = argparse.ArgumentParser(description="Colossal-Inference API server.") parser = argparse.ArgumentParser(description="Colossal-Inference API server.")
parser.add_argument("--host", type=str, default="127.0.0.1") parser.add_argument("--host", type=str, default="127.0.0.1")
parser.add_argument("--port", type=int, default=8000) parser.add_argument("--port", type=int, default=8000, help="port of FastAPI server.")
parser.add_argument("--ssl-keyfile", type=str, default=None) parser.add_argument("--ssl-keyfile", type=str, default=None)
parser.add_argument("--ssl-certfile", type=str, default=None) parser.add_argument("--ssl-certfile", type=str, default=None)
parser.add_argument( parser.add_argument(
@ -164,6 +173,7 @@ def parse_args():
"specified, the model name will be the same as " "specified, the model name will be the same as "
"the huggingface name.", "the huggingface name.",
) )
parser.add_argument( parser.add_argument(
"--chat-template", "--chat-template",
type=str, type=str,
@ -184,13 +194,21 @@ def parse_args():
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
inference_config = InferenceConfig.from_dict(vars(args)) inference_config = InferenceConfig.from_dict(vars(args))
model = AutoModelForCausalLM.from_pretrained(args.model)
tokenizer = AutoTokenizer.from_pretrained(args.model) tokenizer = AutoTokenizer.from_pretrained(args.model)
colossalai_backend_port = find_available_ports(1)[0]
colossalai.launch(
rank=0,
world_size=1,
host=args.host,
port=colossalai_backend_port,
backend="nccl",
)
model = AutoModelForCausalLM.from_pretrained(args.model)
async_engine = AsyncInferenceEngine( async_engine = AsyncInferenceEngine(
start_engine_loop=True, model=model, tokenizer=tokenizer, inference_config=inference_config start_engine_loop=True, model_or_path=model, tokenizer=tokenizer, inference_config=inference_config
) )
engine = async_engine.engine engine = async_engine.engine
completion_serving = CompletionServing(async_engine, served_model=model.__class__.__name__) completion_serving = CompletionServing(async_engine, model.__class__.__name__)
chat_serving = ChatServing( chat_serving = ChatServing(
async_engine, async_engine,
served_model=model.__class__.__name__, served_model=model.__class__.__name__,

View File

@ -23,7 +23,7 @@ class CompletionServing:
# it is not a intuitive way # it is not a intuitive way
self.engine.engine.generation_config = generation_config self.engine.engine.generation_config = generation_config
result_generator = self.engine.generate(request_id, prompt=prompt) result_generator = self.engine.generate(request_id, prompt=prompt, generation_config=generation_config)
if await request.is_disconnected(): if await request.is_disconnected():
# Abort the request if the client disconnects. # Abort the request if the client disconnects.

View File

@ -6,8 +6,9 @@
model_path=${1:-"lmsys/vicuna-7b-v1.3"} model_path=${1:-"lmsys/vicuna-7b-v1.3"}
chat_template="{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}" chat_template="{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}"
echo "Model Path: $model_path" echo "Model Path: $model_path"
echo "Chat Tempelate" "${chat_template}"
echo "Starting server..." echo "Starting server..."
python -m colossalai.inference.server.api_server --model $model_path --chat-template $chat_template & python -m colossalai.inference.server.api_server --model $model_path --chat-template "${chat_template}" &
SERVER_PID=$! SERVER_PID=$!
# waiting time # waiting time
@ -17,9 +18,9 @@ sleep 60
echo "Starting Locust..." echo "Starting Locust..."
echo "The test will automatically begin, you can turn to http://0.0.0.0:8089 for more information." echo "The test will automatically begin, you can turn to http://0.0.0.0:8089 for more information."
echo "Test completion api first" echo "Test completion api first"
locust -f locustfile.py -t 300 --tags online-generation --host http://127.0.0.1:8000 --autostart --users 100 --stop-timeout 10 locust -f locustfile.py -t 300 --tags online-generation --host http://127.0.0.1:8000 --autostart --users 300 --stop-timeout 10
echo "Test chat api" echo "Test chat api"
locust -f locustfile.py -t 300 --tags online-chat --host http://127.0.0.1:8000 --autostart --users 100 --stop-timeout 10 locust -f locustfile.py -t 300 --tags online-chat --host http://127.0.0.1:8000 --autostart --users 300 --stop-timeout 10
# kill Server # kill Server
echo "Stopping server..." echo "Stopping server..."
kill $SERVER_PID kill $SERVER_PID