mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 05:20:33 +00:00
[Online Server] Chat Api for streaming and not streaming response (#5470)
* fix bugs * fix bugs * fix api server * fix api server * add chat api and test * del request.n
This commit is contained in:
@@ -11,7 +11,6 @@ Doc:
|
||||
-d '{"prompt":"hello, who are you? ","stream":"False"}'`
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
@@ -21,16 +20,20 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from colossalai.inference.config import InferenceConfig
|
||||
from colossalai.inference.server.chat_service import ChatServing
|
||||
from colossalai.inference.server.completion_service import CompletionServing
|
||||
from colossalai.inference.server.utils import id_generator
|
||||
|
||||
from colossalai.inference.core.async_engine import AsyncInferenceEngine, InferenceEngine # noqa
|
||||
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds.
|
||||
app = FastAPI()
|
||||
engine = None
|
||||
supported_models_dict = {"Llama_Models": ("llama2-7b",)}
|
||||
prompt_template_choices = ["llama", "vicuna"]
|
||||
async_engine = None
|
||||
chat_serving = None
|
||||
completion_serving = None
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@app.get("/v0/models")
|
||||
@@ -49,7 +52,7 @@ async def generate(request: Request) -> Response:
|
||||
"""
|
||||
request_dict = await request.json()
|
||||
prompt = request_dict.pop("prompt")
|
||||
stream = request_dict.pop("stream", None)
|
||||
stream = request_dict.pop("stream", "false").lower()
|
||||
|
||||
request_id = id_generator()
|
||||
generation_config = get_generation_config(request_dict)
|
||||
@@ -61,7 +64,7 @@ async def generate(request: Request) -> Response:
|
||||
ret = {"text": request_output[len(prompt) :]}
|
||||
yield (json.dumps(ret) + "\0").encode("utf-8")
|
||||
|
||||
if stream:
|
||||
if stream == "true":
|
||||
return StreamingResponse(stream_results())
|
||||
|
||||
# Non-streaming case
|
||||
@@ -81,17 +84,31 @@ async def generate(request: Request) -> Response:
|
||||
@app.post("/v1/completion")
|
||||
async def create_completion(request: Request):
|
||||
request_dict = await request.json()
|
||||
stream = request_dict.pop("stream", False)
|
||||
stream = request_dict.pop("stream", "false").lower()
|
||||
generation_config = get_generation_config(request_dict)
|
||||
result = await completion_serving.create_completion(request, generation_config)
|
||||
|
||||
ret = {"request_id": result.request_id, "text": result.output}
|
||||
if stream:
|
||||
if stream == "true":
|
||||
return StreamingResponse(content=json.dumps(ret) + "\0", media_type="text/event-stream")
|
||||
else:
|
||||
return JSONResponse(content=ret)
|
||||
|
||||
|
||||
@app.post("/v1/chat")
|
||||
async def create_chat(request: Request):
|
||||
request_dict = await request.json()
|
||||
|
||||
stream = request_dict.get("stream", "false").lower()
|
||||
generation_config = get_generation_config(request_dict)
|
||||
message = await chat_serving.create_chat(request, generation_config)
|
||||
if stream == "true":
|
||||
return StreamingResponse(content=message, media_type="text/event-stream")
|
||||
else:
|
||||
ret = {"role": message.role, "text": message.content}
|
||||
return ret
|
||||
|
||||
|
||||
def get_generation_config(request):
|
||||
generation_config = async_engine.engine.generation_config
|
||||
for arg in request:
|
||||
@@ -175,6 +192,18 @@ def parse_args():
|
||||
"specified, the model name will be the same as "
|
||||
"the huggingface name.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chat-template",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The file path to the chat template, " "or the template in single-line form " "for the specified model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--response-role",
|
||||
type=str,
|
||||
default="assistant",
|
||||
help="The role name to return if " "`request.add_generation_prompt=true`.",
|
||||
)
|
||||
parser = add_engine_config(parser)
|
||||
|
||||
return parser.parse_args()
|
||||
@@ -182,7 +211,6 @@ def parse_args():
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
inference_config = InferenceConfig.from_dict(vars(args))
|
||||
model = AutoModelForCausalLM.from_pretrained(args.model)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||||
@@ -191,10 +219,16 @@ if __name__ == "__main__":
|
||||
)
|
||||
engine = async_engine.engine
|
||||
completion_serving = CompletionServing(async_engine, served_model=model.__class__.__name__)
|
||||
|
||||
chat_serving = ChatServing(
|
||||
async_engine,
|
||||
served_model=model.__class__.__name__,
|
||||
tokenizer=tokenizer,
|
||||
response_role=args.response_role,
|
||||
chat_template=args.chat_template,
|
||||
)
|
||||
app.root_path = args.root_path
|
||||
uvicorn.run(
|
||||
app,
|
||||
app=app,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level="debug",
|
||||
|
Reference in New Issue
Block a user