mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-27 20:39:41 +00:00
WEB API independent
This commit is contained in:
parent
1d3d6cb23c
commit
b2d2828b4e
@ -17,6 +17,8 @@ class Config(metaclass=Singleton):
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the Config class"""
|
||||
|
||||
self.NEW_SERVER_MODE = False
|
||||
|
||||
# Gradio language version: en, zh
|
||||
self.LANGUAGE = os.getenv("LANGUAGE", "en")
|
||||
self.WEB_SERVER_PORT = int(os.getenv("WEB_SERVER_PORT", 7860))
|
||||
|
@ -44,7 +44,7 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
||||
cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [conv_uid])
|
||||
content = cursor.fetchone()
|
||||
if content:
|
||||
return cursor.fetchone()[0]
|
||||
return content[0]
|
||||
else:
|
||||
return None
|
||||
def messages(self) -> List[OnceConversation]:
|
||||
@ -66,7 +66,7 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
|
||||
[json.dumps(conversations, ensure_ascii=False), self.chat_seesion_id])
|
||||
else:
|
||||
cursor.execute("INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)",
|
||||
[self.chat_seesion_id, "", json.dumps(conversations_to_dict(conversations), ensure_ascii=False)])
|
||||
[self.chat_seesion_id, "", json.dumps(conversations, ensure_ascii=False)])
|
||||
cursor.commit()
|
||||
self.connect.commit()
|
||||
|
||||
|
@ -39,7 +39,7 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
|
||||
elif "ai:" in message:
|
||||
history.append(
|
||||
{
|
||||
"role": "ai",
|
||||
"role": "assistant",
|
||||
"content": message.split("ai:")[1],
|
||||
}
|
||||
)
|
||||
@ -57,6 +57,7 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
|
||||
for m in temp_his:
|
||||
if m["role"] == "user":
|
||||
last_user_input = m
|
||||
break
|
||||
if last_user_input:
|
||||
history.remove(last_user_input)
|
||||
history.append(last_user_input)
|
||||
|
@ -2,7 +2,7 @@ import uuid
|
||||
import json
|
||||
import asyncio
|
||||
import time
|
||||
from fastapi import APIRouter, Request, Body, status, HTTPException, Response
|
||||
from fastapi import APIRouter, Request, Body, status, HTTPException, Response, BackgroundTasks
|
||||
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.responses import StreamingResponse
|
||||
@ -31,6 +31,8 @@ CHAT_FACTORY = ChatFactory()
|
||||
logger = build_logger("api_v1", LOGDIR + "api_v1.log")
|
||||
knowledge_service = KnowledgeService()
|
||||
|
||||
model_semaphore = None
|
||||
global_counter = 0
|
||||
|
||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||
message = ""
|
||||
@ -148,6 +150,11 @@ async def dialogue_history_messages(con_uid: str):
|
||||
@router.post('/v1/chat/completions')
|
||||
async def chat_completions(dialogue: ConversationVo = Body()):
|
||||
print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}")
|
||||
global model_semaphore, global_counter
|
||||
global_counter += 1
|
||||
if model_semaphore is None:
|
||||
model_semaphore = asyncio.Semaphore(CFG.LIMIT_MODEL_CONCURRENCY)
|
||||
await model_semaphore.acquire()
|
||||
|
||||
if not ChatScene.is_valid_mode(dialogue.chat_mode):
|
||||
raise StopAsyncIteration(Result.faild("Unsupported Chat Mode," + dialogue.chat_mode + "!"))
|
||||
@ -170,73 +177,31 @@ async def chat_completions(dialogue: ConversationVo = Body()):
|
||||
|
||||
chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param)
|
||||
if not chat.prompt_template.stream_out:
|
||||
return non_stream_response(chat)
|
||||
return chat.nostream_call()
|
||||
else:
|
||||
return StreamingResponse(stream_generator(chat), media_type="text/plain")
|
||||
|
||||
|
||||
def stream_test():
|
||||
for message in ["Hello", "world", "how", "are", "you"]:
|
||||
yield message
|
||||
# yield json.dumps(Result.succ(message).__dict__).encode("utf-8")
|
||||
background_tasks = BackgroundTasks()
|
||||
background_tasks.add_task(release_model_semaphore)
|
||||
return StreamingResponse(stream_generator(chat), background=background_tasks)
|
||||
|
||||
def release_model_semaphore():
|
||||
model_semaphore.release()
|
||||
|
||||
def stream_generator(chat):
|
||||
model_response = chat.stream_call()
|
||||
for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||
if chunk:
|
||||
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
|
||||
chat.current_message.add_ai_message(msg)
|
||||
yield msg
|
||||
# chat.current_message.add_ai_message(msg)
|
||||
# vo = MessageVo(role="view", context=msg, order=chat.current_message.chat_order)
|
||||
# json_text = json.dumps(vo.__dict__)
|
||||
# yield json_text.encode('utf-8')
|
||||
if not CFG.NEW_SERVER_MODE:
|
||||
for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||
if chunk:
|
||||
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
|
||||
chat.current_message.add_ai_message(msg)
|
||||
yield msg
|
||||
else:
|
||||
for chunk in model_response:
|
||||
if chunk:
|
||||
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
|
||||
chat.current_message.add_ai_message(msg)
|
||||
yield msg
|
||||
chat.memory.append(chat.current_message)
|
||||
|
||||
|
||||
def message2Vo(message: dict, order) -> MessageVo:
|
||||
# message.additional_kwargs['time_stamp'] if message.additional_kwargs["time_stamp"] else 0
|
||||
return MessageVo(role=message['type'], context=message['data']['content'], order=order)
|
||||
|
||||
|
||||
def non_stream_response(chat):
|
||||
logger.info("not stream out, wait model response!")
|
||||
return chat.nostream_call()
|
||||
|
||||
|
||||
@router.get('/v1/db/types', response_model=Result[str])
|
||||
async def db_types():
|
||||
return Result.succ(["mysql", "duckdb"])
|
||||
|
||||
|
||||
@router.get('/v1/db/list', response_model=Result[str])
|
||||
async def db_list():
|
||||
db = CFG.local_db
|
||||
dbs = db.get_database_list()
|
||||
return Result.succ(dbs)
|
||||
|
||||
|
||||
@router.get('/v1/knowledge/list')
|
||||
async def knowledge_list():
|
||||
return ["test1", "test2"]
|
||||
|
||||
|
||||
@router.post('/v1/knowledge/add')
|
||||
async def knowledge_add():
|
||||
return ["test1", "test2"]
|
||||
|
||||
|
||||
@router.post('/v1/knowledge/delete')
|
||||
async def knowledge_delete():
|
||||
return ["test1", "test2"]
|
||||
|
||||
|
||||
@router.get('/v1/knowledge/types')
|
||||
async def knowledge_types():
|
||||
return ["test1", "test2"]
|
||||
|
||||
|
||||
@router.get('/v1/knowledge/detail')
|
||||
async def knowledge_detail():
|
||||
return ["test1", "test2"]
|
||||
|
@ -47,6 +47,8 @@ class BaseOutputParser(ABC):
|
||||
return code
|
||||
|
||||
def parse_model_stream_resp_ex(self, chunk, skip_echo_len):
|
||||
if b"\0" in chunk:
|
||||
chunk = chunk.replace(b"\0", b"")
|
||||
data = json.loads(chunk.decode())
|
||||
|
||||
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
|
||||
@ -95,11 +97,8 @@ class BaseOutputParser(ABC):
|
||||
def parse_model_nostream_resp(self, response, sep: str):
|
||||
text = response.text.strip()
|
||||
text = text.rstrip()
|
||||
respObj = json.loads(text)
|
||||
|
||||
xx = respObj["response"]
|
||||
xx = xx.strip(b"\x00".decode())
|
||||
respObj_ex = json.loads(xx)
|
||||
text = text.strip(b"\x00".decode())
|
||||
respObj_ex = json.loads(text)
|
||||
if respObj_ex["error_code"] == 0:
|
||||
all_text = respObj_ex["text"]
|
||||
### 解析返回文本,获取AI回复部分
|
||||
@ -123,7 +122,7 @@ class BaseOutputParser(ABC):
|
||||
def __extract_json(slef, s):
|
||||
i = s.index("{")
|
||||
count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
|
||||
for j, c in enumerate(s[i + 1 :], start=i + 1):
|
||||
for j, c in enumerate(s[i + 1:], start=i + 1):
|
||||
if c == "}":
|
||||
count -= 1
|
||||
elif c == "{":
|
||||
@ -131,7 +130,7 @@ class BaseOutputParser(ABC):
|
||||
if count == 0:
|
||||
break
|
||||
assert count == 0 # 检查是否找到最后一个'}'
|
||||
return s[i : j + 1]
|
||||
return s[i: j + 1]
|
||||
|
||||
def parse_prompt_response(self, model_out_text) -> T:
|
||||
"""
|
||||
@ -148,9 +147,9 @@ class BaseOutputParser(ABC):
|
||||
# if "```" in cleaned_output:
|
||||
# cleaned_output, _ = cleaned_output.split("```")
|
||||
if cleaned_output.startswith("```json"):
|
||||
cleaned_output = cleaned_output[len("```json") :]
|
||||
cleaned_output = cleaned_output[len("```json"):]
|
||||
if cleaned_output.startswith("```"):
|
||||
cleaned_output = cleaned_output[len("```") :]
|
||||
cleaned_output = cleaned_output[len("```"):]
|
||||
if cleaned_output.endswith("```"):
|
||||
cleaned_output = cleaned_output[: -len("```")]
|
||||
cleaned_output = cleaned_output.strip()
|
||||
@ -159,9 +158,9 @@ class BaseOutputParser(ABC):
|
||||
cleaned_output = self.__extract_json(cleaned_output)
|
||||
cleaned_output = (
|
||||
cleaned_output.strip()
|
||||
.replace("\n", " ")
|
||||
.replace("\\n", " ")
|
||||
.replace("\\", " ")
|
||||
.replace("\n", " ")
|
||||
.replace("\\n", " ")
|
||||
.replace("\\", " ")
|
||||
)
|
||||
return cleaned_output
|
||||
|
||||
|
@ -15,6 +15,10 @@ class ExampleSelector(BaseModel, ABC):
|
||||
else:
|
||||
return self.__few_shot_context(count)
|
||||
|
||||
def __examples_text(self, used_examples):
|
||||
|
||||
|
||||
|
||||
def __few_shot_context(self, count: int = 2) -> List[List]:
|
||||
"""
|
||||
Use 2 or more examples, default 2
|
||||
|
@ -39,6 +39,7 @@ from pilot.scene.base_message import (
|
||||
ViewMessage,
|
||||
)
|
||||
from pilot.configs.config import Config
|
||||
from pilot.server.llmserver import worker
|
||||
|
||||
logger = build_logger("BaseChat", LOGDIR + "BaseChat.log")
|
||||
headers = {"User-Agent": "dbgpt Client"}
|
||||
@ -59,10 +60,10 @@ class BaseChat(ABC):
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chat_mode,
|
||||
chat_session_id,
|
||||
current_user_input,
|
||||
self,
|
||||
chat_mode,
|
||||
chat_session_id,
|
||||
current_user_input,
|
||||
):
|
||||
self.chat_session_id = chat_session_id
|
||||
self.chat_mode = chat_mode
|
||||
@ -95,7 +96,6 @@ class BaseChat(ABC):
|
||||
def generate_input_values(self):
|
||||
pass
|
||||
|
||||
|
||||
def do_action(self, prompt_response):
|
||||
return prompt_response
|
||||
|
||||
@ -138,24 +138,17 @@ class BaseChat(ABC):
|
||||
logger.info(f"Requert: \n{payload}")
|
||||
ai_response_text = ""
|
||||
try:
|
||||
show_info = ""
|
||||
response = requests.post(
|
||||
urljoin(CFG.MODEL_SERVER, "generate_stream"),
|
||||
headers=headers,
|
||||
json=payload,
|
||||
stream=True,
|
||||
timeout=120,
|
||||
)
|
||||
return response
|
||||
|
||||
# yield self.prompt_template.output_parser.parse_model_stream_resp(response, skip_echo_len)
|
||||
|
||||
# for resp_text_trunck in ai_response_text:
|
||||
# show_info = resp_text_trunck
|
||||
# yield resp_text_trunck + "▌"
|
||||
|
||||
self.current_message.add_ai_message(show_info)
|
||||
|
||||
if not CFG.NEW_SERVER_MODE:
|
||||
response = requests.post(
|
||||
urljoin(CFG.MODEL_SERVER, "generate_stream"),
|
||||
headers=headers,
|
||||
json=payload,
|
||||
stream=True,
|
||||
timeout=120,
|
||||
)
|
||||
return response
|
||||
else:
|
||||
return worker.generate_stream_gate(payload)
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
logger.error("model response parase faild!" + str(e))
|
||||
@ -170,39 +163,28 @@ class BaseChat(ABC):
|
||||
logger.info(f"Requert: \n{payload}")
|
||||
ai_response_text = ""
|
||||
try:
|
||||
### 走非流式的模型服务接口
|
||||
response = requests.post(
|
||||
urljoin(CFG.MODEL_SERVER, "generate"),
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=120,
|
||||
)
|
||||
rsp_str = ""
|
||||
if not CFG.NEW_SERVER_MODE:
|
||||
### 走非流式的模型服务接口
|
||||
rsp_str = requests.post(
|
||||
urljoin(CFG.MODEL_SERVER, "generate"),
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=120,
|
||||
)
|
||||
else:
|
||||
###TODO no stream mode need independent
|
||||
output = worker.generate_stream_gate(payload)
|
||||
for rsp in output:
|
||||
rsp_str = str(rsp, "utf-8")
|
||||
print("[TEST: output]:", rsp_str)
|
||||
|
||||
### output parse
|
||||
ai_response_text = (
|
||||
self.prompt_template.output_parser.parse_model_nostream_resp(
|
||||
response, self.prompt_template.sep
|
||||
)
|
||||
)
|
||||
|
||||
# ### MOCK
|
||||
# ai_response_text = """{
|
||||
# "thoughts": "可以从users表和tran_order表联合查询,按城市和订单数量进行分组统计,并使用柱状图展示。",
|
||||
# "reasoning": "为了分析用户在不同城市的分布情况,需要查询users表和tran_order表,使用LEFT JOIN将两个表联合起来。按照城市进行分组,统计每个城市的订单数量。使用柱状图展示可以直观地看出每个城市的订单数量,方便比较。",
|
||||
# "speak": "根据您的分析目标,我查询了用户表和订单表,统计了每个城市的订单数量,并生成了柱状图展示。",
|
||||
# "command": {
|
||||
# "name": "histogram-executor",
|
||||
# "args": {
|
||||
# "title": "订单城市分布柱状图",
|
||||
# "sql": "SELECT users.city, COUNT(tran_order.order_id) AS order_count FROM users LEFT JOIN tran_order ON users.user_name = tran_order.user_name GROUP BY users.city"
|
||||
# }
|
||||
# }
|
||||
# }"""
|
||||
|
||||
ai_response_text = self.prompt_template.output_parser.parse_model_nostream_resp(rsp_str,
|
||||
self.prompt_template.sep)
|
||||
### model result deal
|
||||
self.current_message.add_ai_message(ai_response_text)
|
||||
prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text)
|
||||
|
||||
|
||||
result = self.do_action(prompt_define_response)
|
||||
|
||||
if hasattr(prompt_define_response, "thoughts"):
|
||||
@ -248,41 +230,42 @@ class BaseChat(ABC):
|
||||
### 处理历史信息
|
||||
if len(self.history_message) > self.chat_retention_rounds:
|
||||
### 使用历史信息的第一轮和最后n轮数据合并成历史对话记录, 做上下文提示时,用户展示消息需要过滤掉
|
||||
for first_message in self.history_message[0].messages:
|
||||
if not isinstance(first_message, ViewMessage):
|
||||
for first_message in self.history_message[0]['messages']:
|
||||
if not first_message['type'] in [ViewMessage.type, SystemMessage.type]:
|
||||
text += (
|
||||
first_message.type
|
||||
+ ":"
|
||||
+ first_message.content
|
||||
+ self.prompt_template.sep
|
||||
first_message['type']
|
||||
+ ":"
|
||||
+ first_message['data']['content']
|
||||
+ self.prompt_template.sep
|
||||
)
|
||||
|
||||
index = self.chat_retention_rounds - 1
|
||||
for last_message in self.history_message[-index:].messages:
|
||||
if not isinstance(last_message, ViewMessage):
|
||||
text += (
|
||||
last_message.type
|
||||
+ ":"
|
||||
+ last_message.content
|
||||
+ self.prompt_template.sep
|
||||
)
|
||||
for round_conv in self.history_message[-index:]:
|
||||
for round_message in round_conv['messages']:
|
||||
if not isinstance(round_message, ViewMessage):
|
||||
text += (
|
||||
round_message['type']
|
||||
+ ":"
|
||||
+ round_message['data']['content']
|
||||
+ self.prompt_template.sep
|
||||
)
|
||||
|
||||
else:
|
||||
### 直接历史记录拼接
|
||||
for conversation in self.history_message:
|
||||
for message in conversation.messages:
|
||||
for message in conversation['messages']:
|
||||
if not isinstance(message, ViewMessage):
|
||||
text += (
|
||||
message.type
|
||||
+ ":"
|
||||
+ message.content
|
||||
+ self.prompt_template.sep
|
||||
message['type']
|
||||
+ ":"
|
||||
+ message['data']['content']
|
||||
+ self.prompt_template.sep
|
||||
)
|
||||
### current conversation
|
||||
|
||||
for now_message in self.current_message.messages:
|
||||
text += (
|
||||
now_message.type + ":" + now_message.content + self.prompt_template.sep
|
||||
now_message.type + ":" + now_message.content + self.prompt_template.sep
|
||||
)
|
||||
|
||||
return text
|
||||
|
@ -8,7 +8,7 @@ from pilot.common.schema import SeparatorStyle
|
||||
|
||||
CFG = Config()
|
||||
|
||||
PROMPT_SCENE_DEFINE = """You are an AI designed to answer human questions, please follow the prompts and conventions of the system's input for your answers"""
|
||||
PROMPT_SCENE_DEFINE = None
|
||||
|
||||
|
||||
_DEFAULT_TEMPLATE = """
|
||||
|
@ -2,8 +2,8 @@ from pilot.prompts.example_base import ExampleSelector
|
||||
|
||||
## Two examples are defined by default
|
||||
EXAMPLES = [
|
||||
[{"System": "123"},{"System":"xxx"},{"User":"xxx"},{"Assistant":"xxx"}],
|
||||
[{"System": "123"},{"System":"xxx"},{"User":"xxx"},{"Assistant":"xxx"}]
|
||||
[{"system": "123"},{"system":"xxx"},{"human":"xxx"},{"assistant":"xxx"}],
|
||||
[{"system": "123"},{"system":"xxx"},{"human":"xxx"},{"assistant":"xxx"}]
|
||||
]
|
||||
|
||||
example = ExampleSelector(examples=EXAMPLES, use_example=True)
|
||||
|
@ -9,10 +9,8 @@ from pilot.scene.chat_execution.example import example
|
||||
|
||||
CFG = Config()
|
||||
|
||||
# PROMPT_SCENE_DEFINE = """You are an AI designed to solve the user's goals with given commands, please follow the prompts and constraints of the system's input for your answers."""
|
||||
PROMPT_SCENE_DEFINE = "You are an AI designed to solve the user's goals with given commands, please follow the constraints of the system's input for your answers."
|
||||
|
||||
|
||||
_DEFAULT_TEMPLATE = """
|
||||
Goals:
|
||||
{input}
|
||||
|
@ -8,8 +8,7 @@ from pilot.common.schema import SeparatorStyle
|
||||
|
||||
from pilot.scene.chat_normal.out_parser import NormalChatOutputParser
|
||||
|
||||
PROMPT_SCENE_DEFINE = """A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge.
|
||||
The assistant gives helpful, detailed, professional and polite answers to the user's questions. """
|
||||
PROMPT_SCENE_DEFINE = None
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
@ -95,7 +95,7 @@ class OnceConversation:
|
||||
|
||||
def _conversation_to_dic(once: OnceConversation) -> dict:
|
||||
start_str: str = ""
|
||||
if once.start_date:
|
||||
if hasattr(once, 'start_date') and once.start_date:
|
||||
if isinstance(once.start_date, datetime):
|
||||
start_str = once.start_date.strftime("%Y-%m-%d %H:%M:%S")
|
||||
else:
|
||||
|
74
pilot/server/dbgpt_server.py
Normal file
74
pilot/server/dbgpt_server.py
Normal file
@ -0,0 +1,74 @@
|
||||
import traceback
|
||||
import os
|
||||
import shutil
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(ROOT_PATH)
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import (
|
||||
DATASETS_DIR,
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH,
|
||||
LLM_MODEL_CONFIG,
|
||||
LOGDIR,
|
||||
)
|
||||
from pilot.utils import build_logger
|
||||
|
||||
from pilot.server.webserver_base import server_init
|
||||
|
||||
from fastapi import FastAPI, applications
|
||||
from fastapi.openapi.docs import get_swagger_ui_html
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler
|
||||
|
||||
|
||||
CFG = Config()
|
||||
logger = build_logger("webserver", LOGDIR + "webserver.log")
|
||||
|
||||
|
||||
def swagger_monkey_patch(*args, **kwargs):
|
||||
return get_swagger_ui_html(
|
||||
*args, **kwargs,
|
||||
swagger_js_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js',
|
||||
swagger_css_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css'
|
||||
)
|
||||
|
||||
|
||||
applications.get_swagger_ui_html = swagger_monkey_patch
|
||||
|
||||
app = FastAPI()
|
||||
origins = ["*"]
|
||||
|
||||
# 添加跨域中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# app.mount("static", StaticFiles(directory="static"), name="static")
|
||||
app.include_router(api_v1)
|
||||
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_list_mode", type=str, default="once", choices=["once", "reload"])
|
||||
|
||||
# old version server config
|
||||
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||||
parser.add_argument("--port", type=int, default=CFG.WEB_SERVER_PORT)
|
||||
parser.add_argument("--concurrency-count", type=int, default=10)
|
||||
parser.add_argument("--share", default=False, action="store_true")
|
||||
|
||||
# init server config
|
||||
args = parser.parse_args()
|
||||
server_init(args)
|
||||
CFG.NEW_SERVER_MODE = True
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=5000)
|
@ -27,14 +27,13 @@ from pilot.server.chat_adapter import get_llm_chat_adapter
|
||||
CFG = Config()
|
||||
|
||||
|
||||
|
||||
class ModelWorker:
|
||||
def __init__(self, model_path, model_name, device, num_gpus=1):
|
||||
if model_path.endswith("/"):
|
||||
model_path = model_path[:-1]
|
||||
self.model_name = model_name or model_path.split("/")[-1]
|
||||
self.device = device
|
||||
|
||||
print(f"Loading {model_name} LLM ModelServer in {device}! Please Wait......")
|
||||
self.ml = ModelLoader(model_path=model_path)
|
||||
self.model, self.tokenizer = self.ml.loader(
|
||||
num_gpus, load_8bit=ISLOAD_8BIT, debug=ISDEBUG
|
||||
@ -42,11 +41,11 @@ class ModelWorker:
|
||||
|
||||
if not isinstance(self.model, str):
|
||||
if hasattr(self.model, "config") and hasattr(
|
||||
self.model.config, "max_sequence_length"
|
||||
self.model.config, "max_sequence_length"
|
||||
):
|
||||
self.context_len = self.model.config.max_sequence_length
|
||||
elif hasattr(self.model, "config") and hasattr(
|
||||
self.model.config, "max_position_embeddings"
|
||||
self.model.config, "max_position_embeddings"
|
||||
):
|
||||
self.context_len = self.model.config.max_position_embeddings
|
||||
|
||||
@ -56,29 +55,32 @@ class ModelWorker:
|
||||
self.llm_chat_adapter = get_llm_chat_adapter(model_path)
|
||||
self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func()
|
||||
|
||||
def start_check(self):
|
||||
print("LLM Model Loading Success!")
|
||||
|
||||
def get_queue_length(self):
|
||||
if (
|
||||
model_semaphore is None
|
||||
or model_semaphore._value is None
|
||||
or model_semaphore._waiters is None
|
||||
model_semaphore is None
|
||||
or model_semaphore._value is None
|
||||
or model_semaphore._waiters is None
|
||||
):
|
||||
return 0
|
||||
else:
|
||||
(
|
||||
CFG.LIMIT_MODEL_CONCURRENCY
|
||||
- model_semaphore._value
|
||||
+ len(model_semaphore._waiters)
|
||||
CFG.LIMIT_MODEL_CONCURRENCY
|
||||
- model_semaphore._value
|
||||
+ len(model_semaphore._waiters)
|
||||
)
|
||||
|
||||
def generate_stream_gate(self, params):
|
||||
try:
|
||||
for output in self.generate_stream_func(
|
||||
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
|
||||
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
|
||||
):
|
||||
# Please do not open the output in production!
|
||||
# The gpt4all thread shares stdout with the parent process,
|
||||
# and opening it may affect the frontend output.
|
||||
# print("output: ", output)
|
||||
print("output: ", output)
|
||||
ret = {
|
||||
"text": output,
|
||||
"error_code": 0,
|
||||
@ -106,6 +108,7 @@ worker = ModelWorker(
|
||||
|
||||
app = FastAPI()
|
||||
from pilot.openapi.knowledge.knowledge_controller import router
|
||||
|
||||
app.include_router(router)
|
||||
|
||||
origins = [
|
||||
@ -122,6 +125,7 @@ app.add_middleware(
|
||||
allow_headers=["*"]
|
||||
)
|
||||
|
||||
|
||||
class PromptRequest(BaseModel):
|
||||
prompt: str
|
||||
temperature: float
|
||||
@ -177,10 +181,9 @@ def generate(prompt_request: PromptRequest):
|
||||
for rsp in output:
|
||||
# rsp = rsp.decode("utf-8")
|
||||
rsp_str = str(rsp, "utf-8")
|
||||
print("[TEST: output]:", rsp_str)
|
||||
response.append(rsp_str)
|
||||
|
||||
return {"response": rsp_str}
|
||||
return rsp_str
|
||||
|
||||
|
||||
@app.post("/embedding")
|
||||
|
@ -39,6 +39,8 @@ def server_init(args):
|
||||
# init config
|
||||
cfg = Config()
|
||||
|
||||
from pilot.server.llmserver import worker
|
||||
worker.start_check()
|
||||
load_native_plugins(cfg)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
async_db_summery()
|
||||
|
Loading…
Reference in New Issue
Block a user