From 86909045fc043855e40df3d5104d51f15bdd31ba Mon Sep 17 00:00:00 2001 From: csunny Date: Sun, 30 Apr 2023 22:24:06 +0800 Subject: [PATCH] test --- pilot/conversation.py | 8 ++++---- pilot/server/vicuna_server.py | 31 ++++++++++++++++--------------- 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/pilot/conversation.py b/pilot/conversation.py index 4db5d9548..1ee142762 100644 --- a/pilot/conversation.py +++ b/pilot/conversation.py @@ -89,8 +89,8 @@ class Conversation: conv_one_shot = Conversation( - system="A chat between a curious human and an artificial intelligence assistant, who very familiar with database related knowledge. " - "The assistant gives helpful, detailed, professional and polite answers to the human's questions. ", + system="A chat between a curious human and an artificial intelligence assistant." + "The assistant gives helpful, detailed and polite answers to the human's questions. ", roles=("Human", "Assistant"), messages=( ( @@ -123,8 +123,8 @@ conv_one_shot = Conversation( ) conv_vicuna_v1 = Conversation( - system = "A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge. " - "The assistant gives helpful, detailed, professional and polite answers to the user's questions. ", + system = "A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed and polite answers to the user's questions. ", roles=("USER", "ASSISTANT"), messages=(), offset=0, diff --git a/pilot/server/vicuna_server.py b/pilot/server/vicuna_server.py index 6c18f9cc6..3012d7ef1 100644 --- a/pilot/server/vicuna_server.py +++ b/pilot/server/vicuna_server.py @@ -22,8 +22,14 @@ model_semaphore = None # ml = ModerLoader(model_path=model_path) # model, tokenizer = ml.loader(load_8bit=isload_8bit, debug=isdebug) - model, tokenizer = load_model(model_path=model_path, device=DEVICE, num_gpus=1, load_8bit=True, debug=False) + +class ModelWorker: + def __init__(self): + pass + + # TODO + app = FastAPI() class PromptRequest(BaseModel): @@ -71,24 +77,19 @@ def generate_stream_gate(params): @app.post("/generate_stream") -def api_generate_stream(request: StreamRequest): +async def api_generate_stream(request: Request): global model_semaphore, global_counter global_counter += 1 - params = { - "prompt": request.prompt, - "model": request.model, - "temperature": request.temperature, - "max_new_tokens": request.max_new_tokens, - "stop": request.stop - } + params = await request.json() print(model, tokenizer, params, DEVICE) - # if model_semaphore is None: - # model_semaphore = asyncio.Semaphore(LIMIT_MODEL_CONCURRENCY) - + if model_semaphore is None: + model_semaphore = asyncio.Semaphore(LIMIT_MODEL_CONCURRENCY) + await model_semaphore.acquire() + generator = generate_stream_gate(params) - # background_tasks = BackgroundTasks() - # background_tasks.add_task(release_model_semaphore) - return StreamingResponse(generator) + background_tasks = BackgroundTasks() + background_tasks.add_task(release_model_semaphore) + return StreamingResponse(generator, background=background_tasks) @app.post("/generate") def generate(prompt_request: PromptRequest):