mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-24 19:08:15 +00:00
test
This commit is contained in:
@@ -22,8 +22,14 @@ model_semaphore = None
|
||||
|
||||
# ml = ModerLoader(model_path=model_path)
|
||||
# model, tokenizer = ml.loader(load_8bit=isload_8bit, debug=isdebug)
|
||||
|
||||
model, tokenizer = load_model(model_path=model_path, device=DEVICE, num_gpus=1, load_8bit=True, debug=False)
|
||||
|
||||
class ModelWorker:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
# TODO
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
class PromptRequest(BaseModel):
|
||||
@@ -71,24 +77,19 @@ def generate_stream_gate(params):
|
||||
|
||||
|
||||
@app.post("/generate_stream")
|
||||
def api_generate_stream(request: StreamRequest):
|
||||
async def api_generate_stream(request: Request):
|
||||
global model_semaphore, global_counter
|
||||
global_counter += 1
|
||||
params = {
|
||||
"prompt": request.prompt,
|
||||
"model": request.model,
|
||||
"temperature": request.temperature,
|
||||
"max_new_tokens": request.max_new_tokens,
|
||||
"stop": request.stop
|
||||
}
|
||||
params = await request.json()
|
||||
print(model, tokenizer, params, DEVICE)
|
||||
# if model_semaphore is None:
|
||||
# model_semaphore = asyncio.Semaphore(LIMIT_MODEL_CONCURRENCY)
|
||||
|
||||
if model_semaphore is None:
|
||||
model_semaphore = asyncio.Semaphore(LIMIT_MODEL_CONCURRENCY)
|
||||
await model_semaphore.acquire()
|
||||
|
||||
generator = generate_stream_gate(params)
|
||||
# background_tasks = BackgroundTasks()
|
||||
# background_tasks.add_task(release_model_semaphore)
|
||||
return StreamingResponse(generator)
|
||||
background_tasks = BackgroundTasks()
|
||||
background_tasks.add_task(release_model_semaphore)
|
||||
return StreamingResponse(generator, background=background_tasks)
|
||||
|
||||
@app.post("/generate")
|
||||
def generate(prompt_request: PromptRequest):
|
||||
|
Reference in New Issue
Block a user