Files
DB-GPT/pilot/server/llmserver.py
2023-08-15 19:00:08 +08:00

213 lines
6.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import asyncio
import json
import os
import sys
from typing import List
import uvicorn
from fastapi import BackgroundTasks, FastAPI, Request
from fastapi.responses import StreamingResponse
# from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
global_counter = 0
model_semaphore = None
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
from pilot.configs.config import Config
from pilot.configs.model_config import *
from pilot.model.llm_out.vicuna_base_llm import get_embeddings
from pilot.model.loader import ModelLoader, _get_model_real_path
from pilot.server.chat_adapter import get_llm_chat_adapter
from pilot.scene.base_message import ModelMessage
CFG = Config()
class ModelWorker:
def __init__(self, model_path, model_name, device):
if model_path.endswith("/"):
model_path = model_path[:-1]
model_path = _get_model_real_path(model_name, model_path)
# self.model_name = model_name or model_path.split("/")[-1]
self.device = device
print(
f"Loading {model_name} LLM ModelServer in {device} from model path {model_path}! Please Wait......"
)
self.ml: ModelLoader = ModelLoader(model_path=model_path, model_name=model_name)
self.model, self.tokenizer = self.ml.loader(
load_8bit=CFG.IS_LOAD_8BIT,
load_4bit=CFG.IS_LOAD_4BIT,
debug=ISDEBUG,
max_gpu_memory=CFG.MAX_GPU_MEMORY,
)
if not isinstance(self.model, str):
if hasattr(self.model, "config") and hasattr(
self.model.config, "max_sequence_length"
):
self.context_len = self.model.config.max_sequence_length
elif hasattr(self.model, "config") and hasattr(
self.model.config, "max_position_embeddings"
):
self.context_len = self.model.config.max_position_embeddings
else:
self.context_len = 2048
self.llm_chat_adapter = get_llm_chat_adapter(model_name, model_path)
self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func(
model_path
)
def start_check(self):
print("LLM Model Loading Success")
def get_queue_length(self):
if (
model_semaphore is None
or model_semaphore._value is None
or model_semaphore._waiters is None
):
return 0
else:
(
CFG.LIMIT_MODEL_CONCURRENCY
- model_semaphore._value
+ len(model_semaphore._waiters)
)
def generate_stream_gate(self, params):
try:
# params adaptation
params, model_context = self.llm_chat_adapter.model_adaptation(
params, self.ml.model_path, prompt_template=self.ml.prompt_template
)
for output in self.generate_stream_func(
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
):
# Please do not open the output in production!
# The gpt4all thread shares stdout with the parent process,
# and opening it may affect the frontend output.
print("output: ", output)
# return some model context to dgt-server
ret = {"text": output, "error_code": 0, "model_context": model_context}
yield json.dumps(ret).encode() + b"\0"
except torch.cuda.CudaError:
ret = {"text": "**GPU OutOfMemory, Please Refresh.**", "error_code": 0}
yield json.dumps(ret).encode() + b"\0"
except Exception as e:
ret = {
"text": f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
"error_code": 0,
}
yield json.dumps(ret).encode() + b"\0"
def get_embeddings(self, prompt):
return get_embeddings(self.model, self.tokenizer, prompt)
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
worker = ModelWorker(model_path=model_path, model_name=CFG.LLM_MODEL, device=DEVICE)
app = FastAPI()
# from pilot.openapi.knowledge.knowledge_controller import router
#
# app.include_router(router)
#
# origins = [
# "http://localhost",
# "http://localhost:8000",
# "http://localhost:3000",
# ]
#
# app.add_middleware(
# CORSMiddleware,
# allow_origins=origins,
# allow_credentials=True,
# allow_methods=["*"],
# allow_headers=["*"],
# )
class PromptRequest(BaseModel):
messages: List[ModelMessage]
prompt: str
temperature: float
max_new_tokens: int
model: str
stop: str = None
echo: bool = True
class StreamRequest(BaseModel):
model: str
prompt: str
temperature: float
max_new_tokens: int
stop: str
class EmbeddingRequest(BaseModel):
prompt: str
def release_model_semaphore():
model_semaphore.release()
@app.post("/generate_stream")
async def api_generate_stream(request: Request):
global model_semaphore, global_counter
global_counter += 1
params = await request.json()
if model_semaphore is None:
model_semaphore = asyncio.Semaphore(CFG.LIMIT_MODEL_CONCURRENCY)
await model_semaphore.acquire()
generator = worker.generate_stream_gate(params)
background_tasks = BackgroundTasks()
background_tasks.add_task(release_model_semaphore)
return StreamingResponse(generator, background=background_tasks)
@app.post("/generate")
def generate(prompt_request: PromptRequest) -> str:
params = {
"messages": prompt_request.messages,
"prompt": prompt_request.prompt,
"temperature": prompt_request.temperature,
"max_new_tokens": prompt_request.max_new_tokens,
"stop": prompt_request.stop,
"echo": prompt_request.echo,
}
rsp_str = ""
output = worker.generate_stream_gate(params)
for rsp in output:
# rsp = rsp.decode("utf-8")
rsp = rsp.replace(b"\0", b"")
rsp_str = rsp.decode()
return rsp_str
@app.post("/embedding")
def embeddings(prompt_request: EmbeddingRequest):
params = {"prompt": prompt_request.prompt}
print("Received prompt: ", params["prompt"])
output = worker.get_embeddings(params["prompt"])
return {"response": [float(x) for x in output]}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=CFG.MODEL_PORT, log_level="info")