mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[Online Server] Chat Api for streaming and not streaming response (#5470)
* fix bugs * fix bugs * fix api server * fix api server * add chat api and test * del request.n
This commit is contained in:
@@ -11,7 +11,6 @@ Doc:
|
||||
-d '{"prompt":"hello, who are you? ","stream":"False"}'`
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
@@ -21,16 +20,20 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from colossalai.inference.config import InferenceConfig
|
||||
from colossalai.inference.server.chat_service import ChatServing
|
||||
from colossalai.inference.server.completion_service import CompletionServing
|
||||
from colossalai.inference.server.utils import id_generator
|
||||
|
||||
from colossalai.inference.core.async_engine import AsyncInferenceEngine, InferenceEngine # noqa
|
||||
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds.
|
||||
app = FastAPI()
|
||||
engine = None
|
||||
supported_models_dict = {"Llama_Models": ("llama2-7b",)}
|
||||
prompt_template_choices = ["llama", "vicuna"]
|
||||
async_engine = None
|
||||
chat_serving = None
|
||||
completion_serving = None
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@app.get("/v0/models")
|
||||
@@ -49,7 +52,7 @@ async def generate(request: Request) -> Response:
|
||||
"""
|
||||
request_dict = await request.json()
|
||||
prompt = request_dict.pop("prompt")
|
||||
stream = request_dict.pop("stream", None)
|
||||
stream = request_dict.pop("stream", "false").lower()
|
||||
|
||||
request_id = id_generator()
|
||||
generation_config = get_generation_config(request_dict)
|
||||
@@ -61,7 +64,7 @@ async def generate(request: Request) -> Response:
|
||||
ret = {"text": request_output[len(prompt) :]}
|
||||
yield (json.dumps(ret) + "\0").encode("utf-8")
|
||||
|
||||
if stream:
|
||||
if stream == "true":
|
||||
return StreamingResponse(stream_results())
|
||||
|
||||
# Non-streaming case
|
||||
@@ -81,17 +84,31 @@ async def generate(request: Request) -> Response:
|
||||
@app.post("/v1/completion")
|
||||
async def create_completion(request: Request):
|
||||
request_dict = await request.json()
|
||||
stream = request_dict.pop("stream", False)
|
||||
stream = request_dict.pop("stream", "false").lower()
|
||||
generation_config = get_generation_config(request_dict)
|
||||
result = await completion_serving.create_completion(request, generation_config)
|
||||
|
||||
ret = {"request_id": result.request_id, "text": result.output}
|
||||
if stream:
|
||||
if stream == "true":
|
||||
return StreamingResponse(content=json.dumps(ret) + "\0", media_type="text/event-stream")
|
||||
else:
|
||||
return JSONResponse(content=ret)
|
||||
|
||||
|
||||
@app.post("/v1/chat")
|
||||
async def create_chat(request: Request):
|
||||
request_dict = await request.json()
|
||||
|
||||
stream = request_dict.get("stream", "false").lower()
|
||||
generation_config = get_generation_config(request_dict)
|
||||
message = await chat_serving.create_chat(request, generation_config)
|
||||
if stream == "true":
|
||||
return StreamingResponse(content=message, media_type="text/event-stream")
|
||||
else:
|
||||
ret = {"role": message.role, "text": message.content}
|
||||
return ret
|
||||
|
||||
|
||||
def get_generation_config(request):
|
||||
generation_config = async_engine.engine.generation_config
|
||||
for arg in request:
|
||||
@@ -175,6 +192,18 @@ def parse_args():
|
||||
"specified, the model name will be the same as "
|
||||
"the huggingface name.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chat-template",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The file path to the chat template, " "or the template in single-line form " "for the specified model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--response-role",
|
||||
type=str,
|
||||
default="assistant",
|
||||
help="The role name to return if " "`request.add_generation_prompt=true`.",
|
||||
)
|
||||
parser = add_engine_config(parser)
|
||||
|
||||
return parser.parse_args()
|
||||
@@ -182,7 +211,6 @@ def parse_args():
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
inference_config = InferenceConfig.from_dict(vars(args))
|
||||
model = AutoModelForCausalLM.from_pretrained(args.model)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||||
@@ -191,10 +219,16 @@ if __name__ == "__main__":
|
||||
)
|
||||
engine = async_engine.engine
|
||||
completion_serving = CompletionServing(async_engine, served_model=model.__class__.__name__)
|
||||
|
||||
chat_serving = ChatServing(
|
||||
async_engine,
|
||||
served_model=model.__class__.__name__,
|
||||
tokenizer=tokenizer,
|
||||
response_role=args.response_role,
|
||||
chat_template=args.chat_template,
|
||||
)
|
||||
app.root_path = args.root_path
|
||||
uvicorn.run(
|
||||
app,
|
||||
app=app,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level="debug",
|
||||
|
142
colossalai/inference/server/chat_service.py
Normal file
142
colossalai/inference/server/chat_service.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import asyncio
|
||||
import codecs
|
||||
import logging
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from colossalai.inference.core.async_engine import AsyncInferenceEngine
|
||||
|
||||
from .utils import ChatCompletionResponseStreamChoice, ChatMessage, DeltaMessage, id_generator
|
||||
|
||||
logger = logging.getLogger("colossalai-inference")
|
||||
|
||||
|
||||
class ChatServing:
|
||||
def __init__(
|
||||
self, engine: AsyncInferenceEngine, served_model: str, tokenizer, response_role: str, chat_template=None
|
||||
):
|
||||
self.engine = engine
|
||||
self.served_model = served_model
|
||||
self.tokenizer = tokenizer
|
||||
self.response_role = response_role
|
||||
self._load_chat_template(chat_template)
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
async def create_chat(self, request: Request, generation_config):
|
||||
request_dict = await request.json()
|
||||
messages = request_dict["messages"]
|
||||
stream = request_dict.pop("stream", "false").lower()
|
||||
add_generation_prompt = request_dict.pop("add_generation_prompt", False)
|
||||
request_id = id_generator()
|
||||
try:
|
||||
prompt = self.tokenizer.apply_chat_template(
|
||||
conversation=messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error in applying chat template from request: {str(e)}")
|
||||
|
||||
# it is not a intuitive way
|
||||
self.engine.engine.generation_config = generation_config
|
||||
result_generator = self.engine.generate(request_id, prompt=prompt)
|
||||
|
||||
if stream == "true":
|
||||
return self.chat_completion_stream_generator(request, request_dict, result_generator, request_id)
|
||||
else:
|
||||
return await self.chat_completion_full_generator(request, request_dict, result_generator, request_id)
|
||||
|
||||
async def chat_completion_stream_generator(self, request, request_dict, result_generator, request_id: int):
|
||||
# Send first response for each request.n (index) with the role
|
||||
role = self.get_chat_request_role(request, request_dict)
|
||||
n = request_dict.get("n", 1)
|
||||
echo = request_dict.get("echo", "false").lower()
|
||||
for i in range(n):
|
||||
choice_data = ChatCompletionResponseStreamChoice(index=i, message=DeltaMessage(role=role))
|
||||
data = choice_data.model_dump_json(exclude_unset=True)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
# Send response to echo the input portion of the last message
|
||||
if echo == "true":
|
||||
last_msg_content = ""
|
||||
if (
|
||||
request_dict["messages"]
|
||||
and isinstance(request_dict["messages"], list)
|
||||
and request_dict["messages"][-1].get("content")
|
||||
and request_dict["messages"][-1].get("role") == role
|
||||
):
|
||||
last_msg_content = request_dict["messages"][-1]["content"]
|
||||
if last_msg_content:
|
||||
for i in range(n):
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i, message=DeltaMessage(content=last_msg_content)
|
||||
)
|
||||
data = choice_data.model_dump_json(exclude_unset=True)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
result = await result_generator
|
||||
choice_data = DeltaMessage(content=result.output)
|
||||
data = choice_data.model_dump_json(exclude_unset=True, exclude_none=True)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
# Send the final done message after all response.n are finished
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def chat_completion_full_generator(
|
||||
self,
|
||||
request: Request,
|
||||
request_dict: dict,
|
||||
result_generator,
|
||||
request_id,
|
||||
):
|
||||
if await request.is_disconnected():
|
||||
# Abort the request if the client disconnects.
|
||||
await self.engine.abort(request_id)
|
||||
return {"error_msg": "Client disconnected"}
|
||||
|
||||
result = await result_generator
|
||||
assert result is not None
|
||||
role = self.get_chat_request_role(request, request_dict)
|
||||
choice_data = ChatMessage(role=role, content=result.output)
|
||||
echo = request_dict.get("echo", "false").lower()
|
||||
|
||||
if echo == "true":
|
||||
last_msg_content = ""
|
||||
if (
|
||||
request.messages
|
||||
and isinstance(request.messages, list)
|
||||
and request.messages[-1].get("content")
|
||||
and request.messages[-1].get("role") == role
|
||||
):
|
||||
last_msg_content = request.messages[-1]["content"]
|
||||
|
||||
full_message = last_msg_content + choice_data.content
|
||||
choice_data.content = full_message
|
||||
|
||||
return choice_data
|
||||
|
||||
def get_chat_request_role(self, request: Request, request_dict: dict) -> str:
|
||||
add_generation_prompt = request_dict.get("add_generation_prompt", False)
|
||||
if add_generation_prompt:
|
||||
return self.response_role
|
||||
else:
|
||||
return request_dict["messages"][-1]["role"]
|
||||
|
||||
def _load_chat_template(self, chat_template):
|
||||
if chat_template is not None:
|
||||
try:
|
||||
with open(chat_template, "r") as f:
|
||||
self.tokenizer.chat_template = f.read()
|
||||
except OSError:
|
||||
# If opening a file fails, set chat template to be args to
|
||||
# ensure we decode so our escape are interpreted correctly
|
||||
self.tokenizer.chat_template = codecs.decode(chat_template, "unicode_escape")
|
||||
|
||||
logger.info(f"Using supplied chat template:\n{self.tokenizer.chat_template}")
|
||||
elif self.tokenizer.chat_template is not None:
|
||||
logger.info(f"Using default chat template:\n{self.tokenizer.chat_template}")
|
||||
else:
|
||||
logger.warning("No chat template provided. Chat API will not work.")
|
@@ -1,3 +1,8 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
# make it singleton
|
||||
class NumericIDGenerator:
|
||||
_instance = None
|
||||
@@ -14,3 +19,18 @@ class NumericIDGenerator:
|
||||
|
||||
|
||||
id_generator = NumericIDGenerator()
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
content: Any
|
||||
|
||||
|
||||
class DeltaMessage(BaseModel):
|
||||
role: Optional[str] = None
|
||||
content: Optional[Any] = None
|
||||
|
||||
|
||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
index: int
|
||||
message: DeltaMessage
|
||||
|
@@ -165,12 +165,13 @@ class Sequence:
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"(request_id={self.request_id}, "
|
||||
f"prompt={self.prompt}, "
|
||||
f"output_token_id={self.output_token_id},"
|
||||
f"status={self.status.name}, "
|
||||
f"sample_params={self.sample_params}, "
|
||||
f"input_len={self.input_len},"
|
||||
f"output_len={self.output_len})"
|
||||
f"prompt={self.prompt},\n"
|
||||
f"output_token_id={self.output_token_id},\n"
|
||||
f"output={self.output},\n"
|
||||
f"status={self.status.name},\n"
|
||||
f"sample_params={self.sample_params},\n"
|
||||
f"input_len={self.input_len},\n"
|
||||
f"output_len={self.output_len})\n"
|
||||
)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user