Files
DB-GPT/pilot/server/llmserver.py
2023-05-21 16:05:53 +08:00

166 lines
4.8 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import uvicorn
import asyncio
import json
import sys
from typing import Optional, List
from fastapi import FastAPI, Request, BackgroundTasks
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
global_counter = 0
model_semaphore = None
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
from pilot.model.inference import generate_stream
from pilot.model.inference import generate_output, get_embeddings
from pilot.model.loader import ModelLoader
from pilot.configs.model_config import *
from pilot.configs.config import Config
from pilot.server.chat_adapter import get_llm_chat_adapter
CFG = Config()
class ModelWorker:
def __init__(self, model_path, model_name, device, num_gpus=1):
if model_path.endswith("/"):
model_path = model_path[:-1]
self.model_name = model_name or model_path.split("/")[-1]
self.device = device
self.ml = ModelLoader(model_path=model_path)
self.model, self.tokenizer = self.ml.loader(num_gpus, load_8bit=ISLOAD_8BIT, debug=ISDEBUG)
if hasattr(self.model.config, "max_sequence_length"):
self.context_len = self.model.config.max_sequence_length
elif hasattr(self.model.config, "max_position_embeddings"):
self.context_len = self.model.config.max_position_embeddings
else:
self.context_len = 2048
self.llm_chat_adapter = get_llm_chat_adapter(model_path)
self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func()
def get_queue_length(self):
if model_semaphore is None or model_semaphore._value is None or model_semaphore._waiters is None:
return 0
else:
CFG.LIMIT_MODEL_CONCURRENCY - model_semaphore._value + len(model_semaphore._waiters)
def generate_stream_gate(self, params):
try:
for output in self.generate_stream_func(
self.model,
self.tokenizer,
params,
DEVICE,
CFG.MAX_POSITION_EMBEDDINGS
):
print("output: ", output)
ret = {
"text": output,
"error_code": 0,
}
yield json.dumps(ret).encode() + b"\0"
except torch.cuda.CudaError:
ret = {
"text": "**GPU OutOfMemory, Please Refresh.**",
"error_code": 0
}
yield json.dumps(ret).encode() + b"\0"
def get_embeddings(self, prompt):
return get_embeddings(self.model, self.tokenizer, prompt)
app = FastAPI()
class PromptRequest(BaseModel):
prompt: str
temperature: float
max_new_tokens: int
model: str
stop: str = None
class StreamRequest(BaseModel):
model: str
prompt: str
temperature: float
max_new_tokens: int
stop: str
class EmbeddingRequest(BaseModel):
prompt: str
def release_model_semaphore():
model_semaphore.release()
@app.post("/generate_stream")
async def api_generate_stream(request: Request):
global model_semaphore, global_counter
global_counter += 1
params = await request.json()
if model_semaphore is None:
model_semaphore = asyncio.Semaphore(CFG.LIMIT_MODEL_CONCURRENCY)
await model_semaphore.acquire()
generator = worker.generate_stream_gate(params)
background_tasks = BackgroundTasks()
background_tasks.add_task(release_model_semaphore)
return StreamingResponse(generator, background=background_tasks)
@app.post("/generate")
def generate(prompt_request: PromptRequest):
params = {
"prompt": prompt_request.prompt,
"temperature": prompt_request.temperature,
"max_new_tokens": prompt_request.max_new_tokens,
"stop": prompt_request.stop
}
response = []
rsp_str = ""
output = worker.generate_stream_gate(params)
for rsp in output:
# rsp = rsp.decode("utf-8")
rsp_str = str(rsp, "utf-8")
print("[TEST: output]:", rsp_str)
response.append(rsp_str)
return {"response": rsp_str}
@app.post("/embedding")
def embeddings(prompt_request: EmbeddingRequest):
params = {"prompt": prompt_request.prompt}
print("Received prompt: ", params["prompt"])
output = worker.get_embeddings(params["prompt"])
return {"response": [float(x) for x in output]}
if __name__ == "__main__":
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
print(model_path, DEVICE)
worker = ModelWorker(
model_path=model_path,
model_name=CFG.LLM_MODEL,
device=DEVICE,
num_gpus=1
)
uvicorn.run(app, host="0.0.0.0", port=CFG.MODEL_PORT, log_level="info")