This commit is contained in:
csunny 2023-04-30 22:24:06 +08:00
parent a9f8477f3f
commit 86909045fc
2 changed files with 20 additions and 19 deletions

View File

@ -89,8 +89,8 @@ class Conversation:
conv_one_shot = Conversation( conv_one_shot = Conversation(
system="A chat between a curious human and an artificial intelligence assistant, who very familiar with database related knowledge. " system="A chat between a curious human and an artificial intelligence assistant."
"The assistant gives helpful, detailed, professional and polite answers to the human's questions. ", "The assistant gives helpful, detailed and polite answers to the human's questions. ",
roles=("Human", "Assistant"), roles=("Human", "Assistant"),
messages=( messages=(
( (
@ -123,8 +123,8 @@ conv_one_shot = Conversation(
) )
conv_vicuna_v1 = Conversation( conv_vicuna_v1 = Conversation(
system = "A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge. " system = "A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, professional and polite answers to the user's questions. ", "The assistant gives helpful, detailed and polite answers to the user's questions. ",
roles=("USER", "ASSISTANT"), roles=("USER", "ASSISTANT"),
messages=(), messages=(),
offset=0, offset=0,

View File

@ -22,8 +22,14 @@ model_semaphore = None
# ml = ModerLoader(model_path=model_path) # ml = ModerLoader(model_path=model_path)
# model, tokenizer = ml.loader(load_8bit=isload_8bit, debug=isdebug) # model, tokenizer = ml.loader(load_8bit=isload_8bit, debug=isdebug)
model, tokenizer = load_model(model_path=model_path, device=DEVICE, num_gpus=1, load_8bit=True, debug=False) model, tokenizer = load_model(model_path=model_path, device=DEVICE, num_gpus=1, load_8bit=True, debug=False)
class ModelWorker:
def __init__(self):
pass
# TODO
app = FastAPI() app = FastAPI()
class PromptRequest(BaseModel): class PromptRequest(BaseModel):
@ -71,24 +77,19 @@ def generate_stream_gate(params):
@app.post("/generate_stream") @app.post("/generate_stream")
def api_generate_stream(request: StreamRequest): async def api_generate_stream(request: Request):
global model_semaphore, global_counter global model_semaphore, global_counter
global_counter += 1 global_counter += 1
params = { params = await request.json()
"prompt": request.prompt,
"model": request.model,
"temperature": request.temperature,
"max_new_tokens": request.max_new_tokens,
"stop": request.stop
}
print(model, tokenizer, params, DEVICE) print(model, tokenizer, params, DEVICE)
# if model_semaphore is None: if model_semaphore is None:
# model_semaphore = asyncio.Semaphore(LIMIT_MODEL_CONCURRENCY) model_semaphore = asyncio.Semaphore(LIMIT_MODEL_CONCURRENCY)
await model_semaphore.acquire()
generator = generate_stream_gate(params) generator = generate_stream_gate(params)
# background_tasks = BackgroundTasks() background_tasks = BackgroundTasks()
# background_tasks.add_task(release_model_semaphore) background_tasks.add_task(release_model_semaphore)
return StreamingResponse(generator) return StreamingResponse(generator, background=background_tasks)
@app.post("/generate") @app.post("/generate")
def generate(prompt_request: PromptRequest): def generate(prompt_request: PromptRequest):