WEB API independent

This commit is contained in:
tuyang.yhj 2023-06-28 11:34:40 +08:00
parent 1d3d6cb23c
commit b2d2828b4e
15 changed files with 201 additions and 171 deletions

View File

@ -17,6 +17,8 @@ class Config(metaclass=Singleton):
def __init__(self) -> None:
"""Initialize the Config class"""
self.NEW_SERVER_MODE = False
# Gradio language version: en, zh
self.LANGUAGE = os.getenv("LANGUAGE", "en")
self.WEB_SERVER_PORT = int(os.getenv("WEB_SERVER_PORT", 7860))

View File

@ -44,7 +44,7 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [conv_uid])
content = cursor.fetchone()
if content:
return cursor.fetchone()[0]
return content[0]
else:
return None
def messages(self) -> List[OnceConversation]:
@ -66,7 +66,7 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
[json.dumps(conversations, ensure_ascii=False), self.chat_seesion_id])
else:
cursor.execute("INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)",
[self.chat_seesion_id, "", json.dumps(conversations_to_dict(conversations), ensure_ascii=False)])
[self.chat_seesion_id, "", json.dumps(conversations, ensure_ascii=False)])
cursor.commit()
self.connect.commit()

View File

@ -39,7 +39,7 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
elif "ai:" in message:
history.append(
{
"role": "ai",
"role": "assistant",
"content": message.split("ai:")[1],
}
)
@ -57,6 +57,7 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
for m in temp_his:
if m["role"] == "user":
last_user_input = m
break
if last_user_input:
history.remove(last_user_input)
history.append(last_user_input)

View File

@ -2,7 +2,7 @@ import uuid
import json
import asyncio
import time
from fastapi import APIRouter, Request, Body, status, HTTPException, Response
from fastapi import APIRouter, Request, Body, status, HTTPException, Response, BackgroundTasks
from fastapi.responses import JSONResponse
from fastapi.responses import StreamingResponse
@ -31,6 +31,8 @@ CHAT_FACTORY = ChatFactory()
logger = build_logger("api_v1", LOGDIR + "api_v1.log")
knowledge_service = KnowledgeService()
model_semaphore = None
global_counter = 0
async def validation_exception_handler(request: Request, exc: RequestValidationError):
message = ""
@ -148,6 +150,11 @@ async def dialogue_history_messages(con_uid: str):
@router.post('/v1/chat/completions')
async def chat_completions(dialogue: ConversationVo = Body()):
print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}")
global model_semaphore, global_counter
global_counter += 1
if model_semaphore is None:
model_semaphore = asyncio.Semaphore(CFG.LIMIT_MODEL_CONCURRENCY)
await model_semaphore.acquire()
if not ChatScene.is_valid_mode(dialogue.chat_mode):
raise StopAsyncIteration(Result.faild("Unsupported Chat Mode," + dialogue.chat_mode + "!"))
@ -170,73 +177,31 @@ async def chat_completions(dialogue: ConversationVo = Body()):
chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param)
if not chat.prompt_template.stream_out:
return non_stream_response(chat)
return chat.nostream_call()
else:
return StreamingResponse(stream_generator(chat), media_type="text/plain")
def stream_test():
for message in ["Hello", "world", "how", "are", "you"]:
yield message
# yield json.dumps(Result.succ(message).__dict__).encode("utf-8")
background_tasks = BackgroundTasks()
background_tasks.add_task(release_model_semaphore)
return StreamingResponse(stream_generator(chat), background=background_tasks)
def release_model_semaphore():
model_semaphore.release()
def stream_generator(chat):
model_response = chat.stream_call()
for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
chat.current_message.add_ai_message(msg)
yield msg
# chat.current_message.add_ai_message(msg)
# vo = MessageVo(role="view", context=msg, order=chat.current_message.chat_order)
# json_text = json.dumps(vo.__dict__)
# yield json_text.encode('utf-8')
if not CFG.NEW_SERVER_MODE:
for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
chat.current_message.add_ai_message(msg)
yield msg
else:
for chunk in model_response:
if chunk:
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
chat.current_message.add_ai_message(msg)
yield msg
chat.memory.append(chat.current_message)
def message2Vo(message: dict, order) -> MessageVo:
# message.additional_kwargs['time_stamp'] if message.additional_kwargs["time_stamp"] else 0
return MessageVo(role=message['type'], context=message['data']['content'], order=order)
def non_stream_response(chat):
logger.info("not stream out, wait model response!")
return chat.nostream_call()
@router.get('/v1/db/types', response_model=Result[str])
async def db_types():
return Result.succ(["mysql", "duckdb"])
@router.get('/v1/db/list', response_model=Result[str])
async def db_list():
db = CFG.local_db
dbs = db.get_database_list()
return Result.succ(dbs)
@router.get('/v1/knowledge/list')
async def knowledge_list():
return ["test1", "test2"]
@router.post('/v1/knowledge/add')
async def knowledge_add():
return ["test1", "test2"]
@router.post('/v1/knowledge/delete')
async def knowledge_delete():
return ["test1", "test2"]
@router.get('/v1/knowledge/types')
async def knowledge_types():
return ["test1", "test2"]
@router.get('/v1/knowledge/detail')
async def knowledge_detail():
return ["test1", "test2"]

View File

@ -47,6 +47,8 @@ class BaseOutputParser(ABC):
return code
def parse_model_stream_resp_ex(self, chunk, skip_echo_len):
if b"\0" in chunk:
chunk = chunk.replace(b"\0", b"")
data = json.loads(chunk.decode())
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
@ -95,11 +97,8 @@ class BaseOutputParser(ABC):
def parse_model_nostream_resp(self, response, sep: str):
text = response.text.strip()
text = text.rstrip()
respObj = json.loads(text)
xx = respObj["response"]
xx = xx.strip(b"\x00".decode())
respObj_ex = json.loads(xx)
text = text.strip(b"\x00".decode())
respObj_ex = json.loads(text)
if respObj_ex["error_code"] == 0:
all_text = respObj_ex["text"]
### 解析返回文本获取AI回复部分
@ -123,7 +122,7 @@ class BaseOutputParser(ABC):
def __extract_json(slef, s):
i = s.index("{")
count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
for j, c in enumerate(s[i + 1 :], start=i + 1):
for j, c in enumerate(s[i + 1:], start=i + 1):
if c == "}":
count -= 1
elif c == "{":
@ -131,7 +130,7 @@ class BaseOutputParser(ABC):
if count == 0:
break
assert count == 0 # 检查是否找到最后一个'}'
return s[i : j + 1]
return s[i: j + 1]
def parse_prompt_response(self, model_out_text) -> T:
"""
@ -148,9 +147,9 @@ class BaseOutputParser(ABC):
# if "```" in cleaned_output:
# cleaned_output, _ = cleaned_output.split("```")
if cleaned_output.startswith("```json"):
cleaned_output = cleaned_output[len("```json") :]
cleaned_output = cleaned_output[len("```json"):]
if cleaned_output.startswith("```"):
cleaned_output = cleaned_output[len("```") :]
cleaned_output = cleaned_output[len("```"):]
if cleaned_output.endswith("```"):
cleaned_output = cleaned_output[: -len("```")]
cleaned_output = cleaned_output.strip()
@ -159,9 +158,9 @@ class BaseOutputParser(ABC):
cleaned_output = self.__extract_json(cleaned_output)
cleaned_output = (
cleaned_output.strip()
.replace("\n", " ")
.replace("\\n", " ")
.replace("\\", " ")
.replace("\n", " ")
.replace("\\n", " ")
.replace("\\", " ")
)
return cleaned_output

View File

@ -15,6 +15,10 @@ class ExampleSelector(BaseModel, ABC):
else:
return self.__few_shot_context(count)
def __examples_text(self, used_examples):
def __few_shot_context(self, count: int = 2) -> List[List]:
"""
Use 2 or more examples, default 2

View File

@ -39,6 +39,7 @@ from pilot.scene.base_message import (
ViewMessage,
)
from pilot.configs.config import Config
from pilot.server.llmserver import worker
logger = build_logger("BaseChat", LOGDIR + "BaseChat.log")
headers = {"User-Agent": "dbgpt Client"}
@ -59,10 +60,10 @@ class BaseChat(ABC):
arbitrary_types_allowed = True
def __init__(
self,
chat_mode,
chat_session_id,
current_user_input,
self,
chat_mode,
chat_session_id,
current_user_input,
):
self.chat_session_id = chat_session_id
self.chat_mode = chat_mode
@ -95,7 +96,6 @@ class BaseChat(ABC):
def generate_input_values(self):
pass
def do_action(self, prompt_response):
return prompt_response
@ -138,24 +138,17 @@ class BaseChat(ABC):
logger.info(f"Requert: \n{payload}")
ai_response_text = ""
try:
show_info = ""
response = requests.post(
urljoin(CFG.MODEL_SERVER, "generate_stream"),
headers=headers,
json=payload,
stream=True,
timeout=120,
)
return response
# yield self.prompt_template.output_parser.parse_model_stream_resp(response, skip_echo_len)
# for resp_text_trunck in ai_response_text:
# show_info = resp_text_trunck
# yield resp_text_trunck + "▌"
self.current_message.add_ai_message(show_info)
if not CFG.NEW_SERVER_MODE:
response = requests.post(
urljoin(CFG.MODEL_SERVER, "generate_stream"),
headers=headers,
json=payload,
stream=True,
timeout=120,
)
return response
else:
return worker.generate_stream_gate(payload)
except Exception as e:
print(traceback.format_exc())
logger.error("model response parase faild" + str(e))
@ -170,39 +163,28 @@ class BaseChat(ABC):
logger.info(f"Requert: \n{payload}")
ai_response_text = ""
try:
### 走非流式的模型服务接口
response = requests.post(
urljoin(CFG.MODEL_SERVER, "generate"),
headers=headers,
json=payload,
timeout=120,
)
rsp_str = ""
if not CFG.NEW_SERVER_MODE:
### 走非流式的模型服务接口
rsp_str = requests.post(
urljoin(CFG.MODEL_SERVER, "generate"),
headers=headers,
json=payload,
timeout=120,
)
else:
###TODO no stream mode need independent
output = worker.generate_stream_gate(payload)
for rsp in output:
rsp_str = str(rsp, "utf-8")
print("[TEST: output]:", rsp_str)
### output parse
ai_response_text = (
self.prompt_template.output_parser.parse_model_nostream_resp(
response, self.prompt_template.sep
)
)
# ### MOCK
# ai_response_text = """{
# "thoughts": "可以从users表和tran_order表联合查询按城市和订单数量进行分组统计并使用柱状图展示。",
# "reasoning": "为了分析用户在不同城市的分布情况需要查询users表和tran_order表使用LEFT JOIN将两个表联合起来。按照城市进行分组统计每个城市的订单数量。使用柱状图展示可以直观地看出每个城市的订单数量方便比较。",
# "speak": "根据您的分析目标,我查询了用户表和订单表,统计了每个城市的订单数量,并生成了柱状图展示。",
# "command": {
# "name": "histogram-executor",
# "args": {
# "title": "订单城市分布柱状图",
# "sql": "SELECT users.city, COUNT(tran_order.order_id) AS order_count FROM users LEFT JOIN tran_order ON users.user_name = tran_order.user_name GROUP BY users.city"
# }
# }
# }"""
ai_response_text = self.prompt_template.output_parser.parse_model_nostream_resp(rsp_str,
self.prompt_template.sep)
### model result deal
self.current_message.add_ai_message(ai_response_text)
prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text)
result = self.do_action(prompt_define_response)
if hasattr(prompt_define_response, "thoughts"):
@ -248,41 +230,42 @@ class BaseChat(ABC):
### 处理历史信息
if len(self.history_message) > self.chat_retention_rounds:
### 使用历史信息的第一轮和最后n轮数据合并成历史对话记录, 做上下文提示时,用户展示消息需要过滤掉
for first_message in self.history_message[0].messages:
if not isinstance(first_message, ViewMessage):
for first_message in self.history_message[0]['messages']:
if not first_message['type'] in [ViewMessage.type, SystemMessage.type]:
text += (
first_message.type
+ ":"
+ first_message.content
+ self.prompt_template.sep
first_message['type']
+ ":"
+ first_message['data']['content']
+ self.prompt_template.sep
)
index = self.chat_retention_rounds - 1
for last_message in self.history_message[-index:].messages:
if not isinstance(last_message, ViewMessage):
text += (
last_message.type
+ ":"
+ last_message.content
+ self.prompt_template.sep
)
for round_conv in self.history_message[-index:]:
for round_message in round_conv['messages']:
if not isinstance(round_message, ViewMessage):
text += (
round_message['type']
+ ":"
+ round_message['data']['content']
+ self.prompt_template.sep
)
else:
### 直接历史记录拼接
for conversation in self.history_message:
for message in conversation.messages:
for message in conversation['messages']:
if not isinstance(message, ViewMessage):
text += (
message.type
+ ":"
+ message.content
+ self.prompt_template.sep
message['type']
+ ":"
+ message['data']['content']
+ self.prompt_template.sep
)
### current conversation
for now_message in self.current_message.messages:
text += (
now_message.type + ":" + now_message.content + self.prompt_template.sep
now_message.type + ":" + now_message.content + self.prompt_template.sep
)
return text

View File

@ -8,7 +8,7 @@ from pilot.common.schema import SeparatorStyle
CFG = Config()
PROMPT_SCENE_DEFINE = """You are an AI designed to answer human questions, please follow the prompts and conventions of the system's input for your answers"""
PROMPT_SCENE_DEFINE = None
_DEFAULT_TEMPLATE = """

View File

@ -2,8 +2,8 @@ from pilot.prompts.example_base import ExampleSelector
## Two examples are defined by default
EXAMPLES = [
[{"System": "123"},{"System":"xxx"},{"User":"xxx"},{"Assistant":"xxx"}],
[{"System": "123"},{"System":"xxx"},{"User":"xxx"},{"Assistant":"xxx"}]
[{"system": "123"},{"system":"xxx"},{"human":"xxx"},{"assistant":"xxx"}],
[{"system": "123"},{"system":"xxx"},{"human":"xxx"},{"assistant":"xxx"}]
]
example = ExampleSelector(examples=EXAMPLES, use_example=True)

View File

@ -9,10 +9,8 @@ from pilot.scene.chat_execution.example import example
CFG = Config()
# PROMPT_SCENE_DEFINE = """You are an AI designed to solve the user's goals with given commands, please follow the prompts and constraints of the system's input for your answers."""
PROMPT_SCENE_DEFINE = "You are an AI designed to solve the user's goals with given commands, please follow the constraints of the system's input for your answers."
_DEFAULT_TEMPLATE = """
Goals:
{input}

View File

@ -8,8 +8,7 @@ from pilot.common.schema import SeparatorStyle
from pilot.scene.chat_normal.out_parser import NormalChatOutputParser
PROMPT_SCENE_DEFINE = """A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge.
The assistant gives helpful, detailed, professional and polite answers to the user's questions. """
PROMPT_SCENE_DEFINE = None
CFG = Config()

View File

@ -95,7 +95,7 @@ class OnceConversation:
def _conversation_to_dic(once: OnceConversation) -> dict:
start_str: str = ""
if once.start_date:
if hasattr(once, 'start_date') and once.start_date:
if isinstance(once.start_date, datetime):
start_str = once.start_date.strftime("%Y-%m-%d %H:%M:%S")
else:

View File

@ -0,0 +1,74 @@
import traceback
import os
import shutil
import argparse
import sys
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
from pilot.configs.config import Config
from pilot.configs.model_config import (
DATASETS_DIR,
KNOWLEDGE_UPLOAD_ROOT_PATH,
LLM_MODEL_CONFIG,
LOGDIR,
)
from pilot.utils import build_logger
from pilot.server.webserver_base import server_init
from fastapi import FastAPI, applications
from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler
CFG = Config()
logger = build_logger("webserver", LOGDIR + "webserver.log")
def swagger_monkey_patch(*args, **kwargs):
return get_swagger_ui_html(
*args, **kwargs,
swagger_js_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js',
swagger_css_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css'
)
applications.get_swagger_ui_html = swagger_monkey_patch
app = FastAPI()
origins = ["*"]
# 添加跨域中间件
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
allow_headers=["*"],
)
# app.mount("static", StaticFiles(directory="static"), name="static")
app.include_router(api_v1)
app.add_exception_handler(RequestValidationError, validation_exception_handler)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_list_mode", type=str, default="once", choices=["once", "reload"])
# old version server config
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=CFG.WEB_SERVER_PORT)
parser.add_argument("--concurrency-count", type=int, default=10)
parser.add_argument("--share", default=False, action="store_true")
# init server config
args = parser.parse_args()
server_init(args)
CFG.NEW_SERVER_MODE = True
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=5000)

View File

@ -27,14 +27,13 @@ from pilot.server.chat_adapter import get_llm_chat_adapter
CFG = Config()
class ModelWorker:
def __init__(self, model_path, model_name, device, num_gpus=1):
if model_path.endswith("/"):
model_path = model_path[:-1]
self.model_name = model_name or model_path.split("/")[-1]
self.device = device
print(f"Loading {model_name} LLM ModelServer in {device}! Please Wait......")
self.ml = ModelLoader(model_path=model_path)
self.model, self.tokenizer = self.ml.loader(
num_gpus, load_8bit=ISLOAD_8BIT, debug=ISDEBUG
@ -42,11 +41,11 @@ class ModelWorker:
if not isinstance(self.model, str):
if hasattr(self.model, "config") and hasattr(
self.model.config, "max_sequence_length"
self.model.config, "max_sequence_length"
):
self.context_len = self.model.config.max_sequence_length
elif hasattr(self.model, "config") and hasattr(
self.model.config, "max_position_embeddings"
self.model.config, "max_position_embeddings"
):
self.context_len = self.model.config.max_position_embeddings
@ -56,29 +55,32 @@ class ModelWorker:
self.llm_chat_adapter = get_llm_chat_adapter(model_path)
self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func()
def start_check(self):
print("LLM Model Loading Success")
def get_queue_length(self):
if (
model_semaphore is None
or model_semaphore._value is None
or model_semaphore._waiters is None
model_semaphore is None
or model_semaphore._value is None
or model_semaphore._waiters is None
):
return 0
else:
(
CFG.LIMIT_MODEL_CONCURRENCY
- model_semaphore._value
+ len(model_semaphore._waiters)
CFG.LIMIT_MODEL_CONCURRENCY
- model_semaphore._value
+ len(model_semaphore._waiters)
)
def generate_stream_gate(self, params):
try:
for output in self.generate_stream_func(
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS
):
# Please do not open the output in production!
# The gpt4all thread shares stdout with the parent process,
# and opening it may affect the frontend output.
# print("output: ", output)
print("output: ", output)
ret = {
"text": output,
"error_code": 0,
@ -106,6 +108,7 @@ worker = ModelWorker(
app = FastAPI()
from pilot.openapi.knowledge.knowledge_controller import router
app.include_router(router)
origins = [
@ -122,6 +125,7 @@ app.add_middleware(
allow_headers=["*"]
)
class PromptRequest(BaseModel):
prompt: str
temperature: float
@ -177,10 +181,9 @@ def generate(prompt_request: PromptRequest):
for rsp in output:
# rsp = rsp.decode("utf-8")
rsp_str = str(rsp, "utf-8")
print("[TEST: output]:", rsp_str)
response.append(rsp_str)
return {"response": rsp_str}
return rsp_str
@app.post("/embedding")

View File

@ -39,6 +39,8 @@ def server_init(args):
# init config
cfg = Config()
from pilot.server.llmserver import worker
worker.start_check()
load_native_plugins(cfg)
signal.signal(signal.SIGINT, signal_handler)
async_db_summery()