mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[Inference] Fix API server, test and example (#5712)
* fix api server * fix generation config * fix api server * fix comments * fix infer hanging bug * resolve comments, change backend to free port
This commit is contained in:
@@ -20,10 +20,12 @@ from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.config import InferenceConfig
|
||||
from colossalai.inference.server.chat_service import ChatServing
|
||||
from colossalai.inference.server.completion_service import CompletionServing
|
||||
from colossalai.inference.server.utils import id_generator
|
||||
from colossalai.inference.utils import find_available_ports
|
||||
|
||||
from colossalai.inference.core.async_engine import AsyncInferenceEngine, InferenceEngine # noqa
|
||||
|
||||
@@ -54,8 +56,9 @@ async def generate(request: Request) -> Response:
|
||||
"""
|
||||
request_dict = await request.json()
|
||||
prompt = request_dict.pop("prompt")
|
||||
stream = request_dict.pop("stream", "false").lower()
|
||||
|
||||
stream = request_dict.pop("stream", "false")
|
||||
if isinstance(stream, str):
|
||||
stream = stream.lower()
|
||||
request_id = id_generator()
|
||||
generation_config = get_generation_config(request_dict)
|
||||
results = engine.generate(request_id, prompt, generation_config=generation_config)
|
||||
@@ -66,7 +69,7 @@ async def generate(request: Request) -> Response:
|
||||
ret = {"text": request_output[len(prompt) :]}
|
||||
yield (json.dumps(ret) + "\0").encode("utf-8")
|
||||
|
||||
if stream == "true":
|
||||
if stream == "true" or stream == True:
|
||||
return StreamingResponse(stream_results())
|
||||
|
||||
# Non-streaming case
|
||||
@@ -86,12 +89,14 @@ async def generate(request: Request) -> Response:
|
||||
@app.post("/completion")
|
||||
async def create_completion(request: Request):
|
||||
request_dict = await request.json()
|
||||
stream = request_dict.pop("stream", "false").lower()
|
||||
stream = request_dict.pop("stream", "false")
|
||||
if isinstance(stream, str):
|
||||
stream = stream.lower()
|
||||
generation_config = get_generation_config(request_dict)
|
||||
result = await completion_serving.create_completion(request, generation_config)
|
||||
|
||||
ret = {"request_id": result.request_id, "text": result.output}
|
||||
if stream == "true":
|
||||
if stream == "true" or stream == True:
|
||||
return StreamingResponse(content=json.dumps(ret) + "\0", media_type="text/event-stream")
|
||||
else:
|
||||
return JSONResponse(content=ret)
|
||||
@@ -101,10 +106,12 @@ async def create_completion(request: Request):
|
||||
async def create_chat(request: Request):
|
||||
request_dict = await request.json()
|
||||
|
||||
stream = request_dict.get("stream", "false").lower()
|
||||
stream = request_dict.get("stream", "false")
|
||||
if isinstance(stream, str):
|
||||
stream = stream.lower()
|
||||
generation_config = get_generation_config(request_dict)
|
||||
message = await chat_serving.create_chat(request, generation_config)
|
||||
if stream == "true":
|
||||
if stream == "true" or stream == True:
|
||||
return StreamingResponse(content=message, media_type="text/event-stream")
|
||||
else:
|
||||
ret = {"role": message.role, "text": message.content}
|
||||
@@ -115,27 +122,29 @@ def get_generation_config(request):
|
||||
generation_config = async_engine.engine.generation_config
|
||||
for arg in request:
|
||||
if hasattr(generation_config, arg):
|
||||
generation_config[arg] = request[arg]
|
||||
setattr(generation_config, arg, request[arg])
|
||||
return generation_config
|
||||
|
||||
|
||||
def add_engine_config(parser):
|
||||
parser.add_argument("--model", type=str, default="llama2-7b", help="name or path of the huggingface model to use")
|
||||
|
||||
parser.add_argument(
|
||||
"--max-model-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="model context length. If unspecified, " "will be automatically derived from the model.",
|
||||
"-m", "--model", type=str, default="llama2-7b", help="name or path of the huggingface model to use"
|
||||
)
|
||||
# Parallel arguments
|
||||
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1, help="number of tensor parallel replicas")
|
||||
# Parallel arguments not supported now
|
||||
|
||||
# KV cache arguments
|
||||
parser.add_argument("--block-size", type=int, default=16, choices=[8, 16, 32], help="token block size")
|
||||
|
||||
parser.add_argument("--max_batch_size", type=int, default=8, help="maximum number of batch size")
|
||||
|
||||
parser.add_argument("-i", "--max_input_len", type=int, default=128, help="max input length")
|
||||
|
||||
parser.add_argument("-o", "--max_output_len", type=int, default=128, help="max output length")
|
||||
|
||||
parser.add_argument("-d", "--dtype", type=str, default="fp16", help="Data type", choices=["fp16", "fp32", "bf16"])
|
||||
|
||||
parser.add_argument("--use_cuda_kernel", action="store_true", help="Use CUDA kernel, use Triton by default")
|
||||
|
||||
# generation arguments
|
||||
parser.add_argument(
|
||||
"--prompt_template",
|
||||
@@ -150,7 +159,7 @@ def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Colossal-Inference API server.")
|
||||
|
||||
parser.add_argument("--host", type=str, default="127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--port", type=int, default=8000, help="port of FastAPI server.")
|
||||
parser.add_argument("--ssl-keyfile", type=str, default=None)
|
||||
parser.add_argument("--ssl-certfile", type=str, default=None)
|
||||
parser.add_argument(
|
||||
@@ -164,6 +173,7 @@ def parse_args():
|
||||
"specified, the model name will be the same as "
|
||||
"the huggingface name.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--chat-template",
|
||||
type=str,
|
||||
@@ -184,13 +194,21 @@ def parse_args():
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
inference_config = InferenceConfig.from_dict(vars(args))
|
||||
model = AutoModelForCausalLM.from_pretrained(args.model)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||||
colossalai_backend_port = find_available_ports(1)[0]
|
||||
colossalai.launch(
|
||||
rank=0,
|
||||
world_size=1,
|
||||
host=args.host,
|
||||
port=colossalai_backend_port,
|
||||
backend="nccl",
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(args.model)
|
||||
async_engine = AsyncInferenceEngine(
|
||||
start_engine_loop=True, model=model, tokenizer=tokenizer, inference_config=inference_config
|
||||
start_engine_loop=True, model_or_path=model, tokenizer=tokenizer, inference_config=inference_config
|
||||
)
|
||||
engine = async_engine.engine
|
||||
completion_serving = CompletionServing(async_engine, served_model=model.__class__.__name__)
|
||||
completion_serving = CompletionServing(async_engine, model.__class__.__name__)
|
||||
chat_serving = ChatServing(
|
||||
async_engine,
|
||||
served_model=model.__class__.__name__,
|
||||
|
Reference in New Issue
Block a user