mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-23 10:33:46 +00:00
WEB API independent
This commit is contained in:
parent
caa1a41065
commit
8e93833321
1
.gitignore
vendored
1
.gitignore
vendored
@ -7,6 +7,7 @@ __pycache__/
|
||||
*.so
|
||||
|
||||
message/
|
||||
static/
|
||||
|
||||
.env
|
||||
.idea
|
||||
|
@ -2,7 +2,7 @@ import uuid
|
||||
import json
|
||||
import asyncio
|
||||
import time
|
||||
from fastapi import APIRouter, Request, Body, status, HTTPException, Response
|
||||
from fastapi import APIRouter, Request, Body, status, HTTPException, Response, BackgroundTasks
|
||||
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.responses import StreamingResponse
|
||||
@ -35,6 +35,7 @@ knowledge_service = KnowledgeService()
|
||||
model_semaphore = None
|
||||
global_counter = 0
|
||||
|
||||
|
||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||
message = ""
|
||||
for error in exc.errors():
|
||||
@ -102,7 +103,7 @@ async def dialogue_scenes():
|
||||
|
||||
@router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo])
|
||||
async def dialogue_new(
|
||||
chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None
|
||||
chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None
|
||||
):
|
||||
unique_id = uuid.uuid1()
|
||||
return Result.succ(ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode))
|
||||
@ -176,6 +177,11 @@ async def dialogue_history_messages(con_uid: str):
|
||||
@router.post("/v1/chat/completions")
|
||||
async def chat_completions(dialogue: ConversationVo = Body()):
|
||||
print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}")
|
||||
if not dialogue.chat_mode:
|
||||
dialogue.chat_mode = ChatScene.ChatNormal.value
|
||||
if not dialogue.conv_uid:
|
||||
dialogue.conv_uid = str(uuid.uuid1())
|
||||
|
||||
global model_semaphore, global_counter
|
||||
global_counter += 1
|
||||
if model_semaphore is None:
|
||||
@ -204,30 +210,51 @@ async def chat_completions(dialogue: ConversationVo = Body()):
|
||||
chat_param.update({"knowledge_space": dialogue.select_param})
|
||||
|
||||
chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param)
|
||||
background_tasks = BackgroundTasks()
|
||||
background_tasks.add_task(release_model_semaphore)
|
||||
headers = {
|
||||
# "Content-Type": "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
# "Transfer-Encoding": "chunked",
|
||||
}
|
||||
|
||||
if not chat.prompt_template.stream_out:
|
||||
return chat.nostream_call()
|
||||
return StreamingResponse(no_stream_generator(chat), headers=headers, media_type='text/event-stream',
|
||||
background=background_tasks)
|
||||
else:
|
||||
background_tasks = BackgroundTasks()
|
||||
background_tasks.add_task(release_model_semaphore)
|
||||
return StreamingResponse(stream_generator(chat), background=background_tasks)
|
||||
return StreamingResponse(stream_generator(chat), headers=headers, media_type='text/plain',
|
||||
background=background_tasks)
|
||||
|
||||
|
||||
def release_model_semaphore():
|
||||
model_semaphore.release()
|
||||
|
||||
def stream_generator(chat):
|
||||
|
||||
async def no_stream_generator(chat):
|
||||
msg = chat.nostream_call()
|
||||
msg = msg.replace("\n", "\\n")
|
||||
yield f"data: {msg}\n\n"
|
||||
|
||||
async def stream_generator(chat):
|
||||
model_response = chat.stream_call()
|
||||
if not CFG.NEW_SERVER_MODE:
|
||||
for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||
if chunk:
|
||||
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
|
||||
chat.current_message.add_ai_message(msg)
|
||||
yield msg
|
||||
msg = msg.replace("\n", "\\n")
|
||||
yield f"data:{msg}\n\n"
|
||||
await asyncio.sleep(0.1)
|
||||
else:
|
||||
for chunk in model_response:
|
||||
if chunk:
|
||||
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
|
||||
chat.current_message.add_ai_message(msg)
|
||||
yield msg
|
||||
|
||||
msg = msg.replace("\n", "\\n")
|
||||
yield f"data:{msg}\n\n"
|
||||
await asyncio.sleep(0.1)
|
||||
chat.memory.append(chat.current_message)
|
||||
|
||||
|
||||
|
@ -6,7 +6,7 @@ from pilot.common.schema import ExampleType
|
||||
|
||||
|
||||
class ExampleSelector(BaseModel, ABC):
|
||||
examples: List[List]
|
||||
examples_record: List[List]
|
||||
use_example: bool = False
|
||||
type: str = ExampleType.ONE_SHOT.value
|
||||
|
||||
@ -16,17 +16,13 @@ class ExampleSelector(BaseModel, ABC):
|
||||
else:
|
||||
return self.__few_shot_context(count)
|
||||
|
||||
def __examples_text(self, used_examples):
|
||||
|
||||
|
||||
|
||||
def __few_shot_context(self, count: int = 2) -> List[List]:
|
||||
"""
|
||||
Use 2 or more examples, default 2
|
||||
Returns: example text
|
||||
"""
|
||||
if self.use_example:
|
||||
need_use = self.examples[:count]
|
||||
need_use = self.examples_record[:count]
|
||||
return need_use
|
||||
return None
|
||||
|
||||
@ -37,7 +33,7 @@ class ExampleSelector(BaseModel, ABC):
|
||||
|
||||
"""
|
||||
if self.use_example:
|
||||
need_use = self.examples[:1]
|
||||
need_use = self.examples_record[:1]
|
||||
return need_use
|
||||
|
||||
return None
|
||||
|
@ -46,7 +46,10 @@ class PromptTemplate(BaseModel, ABC):
|
||||
output_parser: BaseOutputParser = None
|
||||
""""""
|
||||
sep: str = SeparatorStyle.SINGLE.value
|
||||
example: ExampleSelector = None
|
||||
|
||||
example_selector: ExampleSelector = None
|
||||
|
||||
need_historical_messages: bool = False
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
@ -52,7 +52,7 @@ class BaseChat(ABC):
|
||||
temperature: float = 0.6
|
||||
max_new_tokens: int = 1024
|
||||
# By default, keep the last two rounds of conversation records as the context
|
||||
chat_retention_rounds: int = 2
|
||||
chat_retention_rounds: int = 1
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@ -79,8 +79,6 @@ class BaseChat(ABC):
|
||||
self.history_message: List[OnceConversation] = self.memory.messages()
|
||||
self.current_message: OnceConversation = OnceConversation(chat_mode.value)
|
||||
self.current_tokens_used: int = 0
|
||||
### load chat_session_id's chat historys
|
||||
self._load_history(self.chat_session_id)
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@ -107,18 +105,9 @@ class BaseChat(ABC):
|
||||
self.current_message.start_date = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
# TODO
|
||||
self.current_message.tokens = 0
|
||||
current_prompt = None
|
||||
|
||||
if self.prompt_template.template:
|
||||
current_prompt = self.prompt_template.format(**input_values)
|
||||
|
||||
### 构建当前对话, 是否安第一次对话prompt构造? 是否考虑切换库
|
||||
if self.history_message:
|
||||
## TODO 带历史对话记录的场景需要确定切换库后怎么处理
|
||||
logger.info(
|
||||
f"There are already {len(self.history_message)} rounds of conversations!"
|
||||
)
|
||||
if current_prompt:
|
||||
self.current_message.add_system_message(current_prompt)
|
||||
|
||||
payload = {
|
||||
@ -155,7 +144,7 @@ class BaseChat(ABC):
|
||||
self.current_message.add_view_message(
|
||||
f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
|
||||
)
|
||||
### 对话记录存储
|
||||
### store current conversation
|
||||
self.memory.append(self.current_message)
|
||||
|
||||
def nostream_call(self):
|
||||
@ -165,7 +154,6 @@ class BaseChat(ABC):
|
||||
try:
|
||||
rsp_str = ""
|
||||
if not CFG.NEW_SERVER_MODE:
|
||||
### 走非流式的模型服务接口
|
||||
rsp_str = requests.post(
|
||||
urljoin(CFG.MODEL_SERVER, "generate"),
|
||||
headers=headers,
|
||||
@ -212,7 +200,7 @@ class BaseChat(ABC):
|
||||
self.current_message.add_view_message(
|
||||
f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
|
||||
)
|
||||
### 对话记录存储
|
||||
### store dialogue
|
||||
self.memory.append(self.current_message)
|
||||
return self.current_ai_response()
|
||||
|
||||
@ -224,68 +212,99 @@ class BaseChat(ABC):
|
||||
|
||||
def generate_llm_text(self) -> str:
|
||||
text = ""
|
||||
### Load scene setting or character definition
|
||||
if self.prompt_template.template_define:
|
||||
text = self.prompt_template.template_define + self.prompt_template.sep
|
||||
text += self.prompt_template.template_define + self.prompt_template.sep
|
||||
### Load prompt
|
||||
text += self.__load_system_message()
|
||||
|
||||
### 处理历史信息
|
||||
if len(self.history_message) > self.chat_retention_rounds:
|
||||
### 使用历史信息的第一轮和最后n轮数据合并成历史对话记录, 做上下文提示时,用户展示消息需要过滤掉
|
||||
for first_message in self.history_message[0]['messages']:
|
||||
if not first_message['type'] in [ViewMessage.type, SystemMessage.type]:
|
||||
text += (
|
||||
first_message['type']
|
||||
+ ":"
|
||||
+ first_message['data']['content']
|
||||
+ self.prompt_template.sep
|
||||
)
|
||||
### Load examples
|
||||
text += self.__load_example_messages()
|
||||
|
||||
index = self.chat_retention_rounds - 1
|
||||
for round_conv in self.history_message[-index:]:
|
||||
### Load History
|
||||
text += self.__load_histroy_messages()
|
||||
|
||||
### Load User Input
|
||||
text += self.__load_user_message()
|
||||
return text
|
||||
|
||||
def __load_system_message(self):
|
||||
system_convs = self.current_message.get_system_conv()
|
||||
system_text = ""
|
||||
for system_conv in system_convs:
|
||||
system_text += system_conv.type + ":" + system_conv.content + self.prompt_template.sep
|
||||
return system_text
|
||||
|
||||
def __load_user_message(self):
|
||||
user_conv = self.current_message.get_user_conv()
|
||||
if user_conv:
|
||||
return user_conv.type + ":" + user_conv.content + self.prompt_template.sep
|
||||
else:
|
||||
raise ValueError("Hi! What do you want to talk about?")
|
||||
|
||||
def __load_example_messages(self):
|
||||
example_text = ""
|
||||
if self.prompt_template.example_selector:
|
||||
for round_conv in self.prompt_template.example_selector.examples():
|
||||
for round_message in round_conv['messages']:
|
||||
if not isinstance(round_message, ViewMessage):
|
||||
text += (
|
||||
if not round_message['type'] in [SystemMessage.type, ViewMessage.type]:
|
||||
example_text += (
|
||||
round_message['type']
|
||||
+ ":"
|
||||
+ round_message['data']['content']
|
||||
+ self.prompt_template.sep
|
||||
)
|
||||
return example_text
|
||||
|
||||
else:
|
||||
### 直接历史记录拼接
|
||||
for conversation in self.history_message:
|
||||
for message in conversation['messages']:
|
||||
if not isinstance(message, ViewMessage):
|
||||
text += (
|
||||
message['type']
|
||||
def __load_histroy_messages(self):
|
||||
history_text = ""
|
||||
if self.prompt_template.need_historical_messages:
|
||||
if self.history_message:
|
||||
logger.info(
|
||||
f"There are already {len(self.history_message)} rounds of conversations! Will use {self.chat_retention_rounds} rounds of content as history!"
|
||||
)
|
||||
if len(self.history_message) > self.chat_retention_rounds:
|
||||
for first_message in self.history_message[0]['messages']:
|
||||
if not first_message['type'] in [ViewMessage.type, SystemMessage.type]:
|
||||
history_text += (
|
||||
first_message['type']
|
||||
+ ":"
|
||||
+ message['data']['content']
|
||||
+ first_message['data']['content']
|
||||
+ self.prompt_template.sep
|
||||
)
|
||||
### current conversation
|
||||
|
||||
for now_message in self.current_message.messages:
|
||||
text += (
|
||||
now_message.type + ":" + now_message.content + self.prompt_template.sep
|
||||
)
|
||||
index = self.chat_retention_rounds - 1
|
||||
for round_conv in self.history_message[-index:]:
|
||||
for round_message in round_conv['messages']:
|
||||
if not round_message['type'] in [SystemMessage.type, ViewMessage.type]:
|
||||
history_text += (
|
||||
round_message['type']
|
||||
+ ":"
|
||||
+ round_message['data']['content']
|
||||
+ self.prompt_template.sep
|
||||
)
|
||||
|
||||
return text
|
||||
else:
|
||||
### user all history
|
||||
for conversation in self.history_message:
|
||||
for message in conversation['messages']:
|
||||
### histroy message not have promot and view info
|
||||
if not message['type'] in [SystemMessage.type, ViewMessage.type]:
|
||||
history_text += (
|
||||
message['type']
|
||||
+ ":"
|
||||
+ message['data']['content']
|
||||
+ self.prompt_template.sep
|
||||
)
|
||||
|
||||
return history_text
|
||||
|
||||
# 暂时为了兼容前端
|
||||
def current_ai_response(self) -> str:
|
||||
for message in self.current_message.messages:
|
||||
if message.type == "view":
|
||||
return message.content
|
||||
return None
|
||||
|
||||
def _load_history(self, session_id: str) -> List[OnceConversation]:
|
||||
"""
|
||||
load chat history by session_id
|
||||
Args:
|
||||
session_id:
|
||||
Returns:
|
||||
"""
|
||||
return self.memory.messages()
|
||||
|
||||
def generate(self, p) -> str:
|
||||
"""
|
||||
generate context for LLM input
|
||||
|
@ -6,4 +6,4 @@ EXAMPLES = [
|
||||
[{"system": "123"},{"system":"xxx"},{"human":"xxx"},{"assistant":"xxx"}]
|
||||
]
|
||||
|
||||
example = ExampleSelector(examples=EXAMPLES, use_example=True)
|
||||
plugin_example = ExampleSelector(examples_record=EXAMPLES, use_example=True)
|
||||
|
@ -5,7 +5,7 @@ from pilot.scene.base import ChatScene
|
||||
from pilot.common.schema import SeparatorStyle, ExampleType
|
||||
|
||||
from pilot.scene.chat_execution.out_parser import PluginChatOutputParser
|
||||
from pilot.scene.chat_execution.example import example
|
||||
from pilot.scene.chat_execution.example import plugin_example
|
||||
|
||||
CFG = Config()
|
||||
|
||||
@ -49,7 +49,7 @@ prompt = PromptTemplate(
|
||||
output_parser=PluginChatOutputParser(
|
||||
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_STREAM_OUT
|
||||
),
|
||||
example=example,
|
||||
example_selector=plugin_example,
|
||||
)
|
||||
|
||||
CFG.prompt_templates.update({prompt.template_scene: prompt})
|
||||
|
@ -85,12 +85,18 @@ class OnceConversation:
|
||||
self.messages.clear()
|
||||
self.session_id = None
|
||||
|
||||
def get_user_message(self):
|
||||
for once in self.messages:
|
||||
if isinstance(once, HumanMessage):
|
||||
return once.content
|
||||
return ""
|
||||
def get_user_conv(self):
|
||||
for message in self.messages:
|
||||
if isinstance(message, HumanMessage):
|
||||
return message
|
||||
return None
|
||||
|
||||
def get_system_conv(self):
|
||||
system_convs = []
|
||||
for message in self.messages:
|
||||
if isinstance(message, SystemMessage):
|
||||
system_convs.append(message)
|
||||
return system_convs
|
||||
|
||||
def _conversation_to_dic(once: OnceConversation) -> dict:
|
||||
start_str: str = ""
|
||||
|
@ -18,6 +18,7 @@ from pilot.utils import build_logger
|
||||
|
||||
from pilot.server.webserver_base import server_init
|
||||
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi import FastAPI, applications
|
||||
from fastapi.openapi.docs import get_swagger_ui_html
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
@ -25,6 +26,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler
|
||||
|
||||
static_file_path = os.path.join(os.getcwd(), "server/static")
|
||||
|
||||
CFG = Config()
|
||||
logger = build_logger("webserver", LOGDIR + "webserver.log")
|
||||
@ -52,7 +54,9 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# app.mount("static", StaticFiles(directory="static"), name="static")
|
||||
app.mount("/static", StaticFiles(directory=static_file_path), name="static")
|
||||
app.add_route("/test", "static/test.html")
|
||||
|
||||
app.include_router(api_v1)
|
||||
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
||||
|
||||
|
19
pilot/server/static/test.html
Normal file
19
pilot/server/static/test.html
Normal file
@ -0,0 +1,19 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>Streaming Demo</title>
|
||||
<script src="https://code.jquery.com/jquery-3.6.0.min.js"></script>
|
||||
</head>
|
||||
<body>
|
||||
<div id="output"></div>
|
||||
<script>
|
||||
$(document).ready(function() {
|
||||
var source = new EventSource("/v1/chat/completions");
|
||||
source.onmessage = function(event) {
|
||||
$("#output").append(event.data);
|
||||
}
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
Loading…
Reference in New Issue
Block a user