This commit is contained in:
csunny 2023-04-30 22:24:06 +08:00
parent a9f8477f3f
commit 86909045fc
2 changed files with 20 additions and 19 deletions

View File

@ -89,8 +89,8 @@ class Conversation:
conv_one_shot = Conversation(
system="A chat between a curious human and an artificial intelligence assistant, who very familiar with database related knowledge. "
"The assistant gives helpful, detailed, professional and polite answers to the human's questions. ",
system="A chat between a curious human and an artificial intelligence assistant."
"The assistant gives helpful, detailed and polite answers to the human's questions. ",
roles=("Human", "Assistant"),
messages=(
(
@ -123,8 +123,8 @@ conv_one_shot = Conversation(
)
conv_vicuna_v1 = Conversation(
system = "A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge. "
"The assistant gives helpful, detailed, professional and polite answers to the user's questions. ",
system = "A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed and polite answers to the user's questions. ",
roles=("USER", "ASSISTANT"),
messages=(),
offset=0,

View File

@ -22,8 +22,14 @@ model_semaphore = None
# ml = ModerLoader(model_path=model_path)
# model, tokenizer = ml.loader(load_8bit=isload_8bit, debug=isdebug)
model, tokenizer = load_model(model_path=model_path, device=DEVICE, num_gpus=1, load_8bit=True, debug=False)
class ModelWorker:
def __init__(self):
pass
# TODO
app = FastAPI()
class PromptRequest(BaseModel):
@ -71,24 +77,19 @@ def generate_stream_gate(params):
@app.post("/generate_stream")
def api_generate_stream(request: StreamRequest):
async def api_generate_stream(request: Request):
global model_semaphore, global_counter
global_counter += 1
params = {
"prompt": request.prompt,
"model": request.model,
"temperature": request.temperature,
"max_new_tokens": request.max_new_tokens,
"stop": request.stop
}
params = await request.json()
print(model, tokenizer, params, DEVICE)
# if model_semaphore is None:
# model_semaphore = asyncio.Semaphore(LIMIT_MODEL_CONCURRENCY)
if model_semaphore is None:
model_semaphore = asyncio.Semaphore(LIMIT_MODEL_CONCURRENCY)
await model_semaphore.acquire()
generator = generate_stream_gate(params)
# background_tasks = BackgroundTasks()
# background_tasks.add_task(release_model_semaphore)
return StreamingResponse(generator)
background_tasks = BackgroundTasks()
background_tasks.add_task(release_model_semaphore)
return StreamingResponse(generator, background=background_tasks)
@app.post("/generate")
def generate(prompt_request: PromptRequest):