WEB API independent

This commit is contained in:
tuyang.yhj 2023-06-29 09:55:43 +08:00
parent caa1a41065
commit 8e93833321
10 changed files with 157 additions and 82 deletions

1
.gitignore vendored
View File

@ -7,6 +7,7 @@ __pycache__/
*.so
message/
static/
.env
.idea

View File

@ -2,7 +2,7 @@ import uuid
import json
import asyncio
import time
from fastapi import APIRouter, Request, Body, status, HTTPException, Response
from fastapi import APIRouter, Request, Body, status, HTTPException, Response, BackgroundTasks
from fastapi.responses import JSONResponse
from fastapi.responses import StreamingResponse
@ -35,6 +35,7 @@ knowledge_service = KnowledgeService()
model_semaphore = None
global_counter = 0
async def validation_exception_handler(request: Request, exc: RequestValidationError):
message = ""
for error in exc.errors():
@ -176,6 +177,11 @@ async def dialogue_history_messages(con_uid: str):
@router.post("/v1/chat/completions")
async def chat_completions(dialogue: ConversationVo = Body()):
print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}")
if not dialogue.chat_mode:
dialogue.chat_mode = ChatScene.ChatNormal.value
if not dialogue.conv_uid:
dialogue.conv_uid = str(uuid.uuid1())
global model_semaphore, global_counter
global_counter += 1
if model_semaphore is None:
@ -204,30 +210,51 @@ async def chat_completions(dialogue: ConversationVo = Body()):
chat_param.update({"knowledge_space": dialogue.select_param})
chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param)
if not chat.prompt_template.stream_out:
return chat.nostream_call()
else:
background_tasks = BackgroundTasks()
background_tasks.add_task(release_model_semaphore)
return StreamingResponse(stream_generator(chat), background=background_tasks)
headers = {
# "Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
# "Transfer-Encoding": "chunked",
}
if not chat.prompt_template.stream_out:
return StreamingResponse(no_stream_generator(chat), headers=headers, media_type='text/event-stream',
background=background_tasks)
else:
return StreamingResponse(stream_generator(chat), headers=headers, media_type='text/plain',
background=background_tasks)
def release_model_semaphore():
model_semaphore.release()
def stream_generator(chat):
async def no_stream_generator(chat):
msg = chat.nostream_call()
msg = msg.replace("\n", "\\n")
yield f"data: {msg}\n\n"
async def stream_generator(chat):
model_response = chat.stream_call()
if not CFG.NEW_SERVER_MODE:
for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
chat.current_message.add_ai_message(msg)
yield msg
msg = msg.replace("\n", "\\n")
yield f"data:{msg}\n\n"
await asyncio.sleep(0.1)
else:
for chunk in model_response:
if chunk:
msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len)
chat.current_message.add_ai_message(msg)
yield msg
msg = msg.replace("\n", "\\n")
yield f"data:{msg}\n\n"
await asyncio.sleep(0.1)
chat.memory.append(chat.current_message)

View File

@ -6,7 +6,7 @@ from pilot.common.schema import ExampleType
class ExampleSelector(BaseModel, ABC):
examples: List[List]
examples_record: List[List]
use_example: bool = False
type: str = ExampleType.ONE_SHOT.value
@ -16,17 +16,13 @@ class ExampleSelector(BaseModel, ABC):
else:
return self.__few_shot_context(count)
def __examples_text(self, used_examples):
def __few_shot_context(self, count: int = 2) -> List[List]:
"""
Use 2 or more examples, default 2
Returns: example text
"""
if self.use_example:
need_use = self.examples[:count]
need_use = self.examples_record[:count]
return need_use
return None
@ -37,7 +33,7 @@ class ExampleSelector(BaseModel, ABC):
"""
if self.use_example:
need_use = self.examples[:1]
need_use = self.examples_record[:1]
return need_use
return None

View File

@ -46,7 +46,10 @@ class PromptTemplate(BaseModel, ABC):
output_parser: BaseOutputParser = None
""""""
sep: str = SeparatorStyle.SINGLE.value
example: ExampleSelector = None
example_selector: ExampleSelector = None
need_historical_messages: bool = False
class Config:
"""Configuration for this pydantic object."""

View File

@ -52,7 +52,7 @@ class BaseChat(ABC):
temperature: float = 0.6
max_new_tokens: int = 1024
# By default, keep the last two rounds of conversation records as the context
chat_retention_rounds: int = 2
chat_retention_rounds: int = 1
class Config:
"""Configuration for this pydantic object."""
@ -79,8 +79,6 @@ class BaseChat(ABC):
self.history_message: List[OnceConversation] = self.memory.messages()
self.current_message: OnceConversation = OnceConversation(chat_mode.value)
self.current_tokens_used: int = 0
### load chat_session_id's chat historys
self._load_history(self.chat_session_id)
class Config:
"""Configuration for this pydantic object."""
@ -107,18 +105,9 @@ class BaseChat(ABC):
self.current_message.start_date = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# TODO
self.current_message.tokens = 0
current_prompt = None
if self.prompt_template.template:
current_prompt = self.prompt_template.format(**input_values)
### 构建当前对话, 是否安第一次对话prompt构造 是否考虑切换库
if self.history_message:
## TODO 带历史对话记录的场景需要确定切换库后怎么处理
logger.info(
f"There are already {len(self.history_message)} rounds of conversations!"
)
if current_prompt:
self.current_message.add_system_message(current_prompt)
payload = {
@ -155,7 +144,7 @@ class BaseChat(ABC):
self.current_message.add_view_message(
f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
)
### 对话记录存储
### store current conversation
self.memory.append(self.current_message)
def nostream_call(self):
@ -165,7 +154,6 @@ class BaseChat(ABC):
try:
rsp_str = ""
if not CFG.NEW_SERVER_MODE:
### 走非流式的模型服务接口
rsp_str = requests.post(
urljoin(CFG.MODEL_SERVER, "generate"),
headers=headers,
@ -212,7 +200,7 @@ class BaseChat(ABC):
self.current_message.add_view_message(
f"""<span style=\"color:red\">ERROR!</span>{str(e)}\n {ai_response_text} """
)
### 对话记录存储
### store dialogue
self.memory.append(self.current_message)
return self.current_ai_response()
@ -224,15 +212,61 @@ class BaseChat(ABC):
def generate_llm_text(self) -> str:
text = ""
### Load scene setting or character definition
if self.prompt_template.template_define:
text = self.prompt_template.template_define + self.prompt_template.sep
text += self.prompt_template.template_define + self.prompt_template.sep
### Load prompt
text += self.__load_system_message()
### 处理历史信息
### Load examples
text += self.__load_example_messages()
### Load History
text += self.__load_histroy_messages()
### Load User Input
text += self.__load_user_message()
return text
def __load_system_message(self):
system_convs = self.current_message.get_system_conv()
system_text = ""
for system_conv in system_convs:
system_text += system_conv.type + ":" + system_conv.content + self.prompt_template.sep
return system_text
def __load_user_message(self):
user_conv = self.current_message.get_user_conv()
if user_conv:
return user_conv.type + ":" + user_conv.content + self.prompt_template.sep
else:
raise ValueError("Hi! What do you want to talk about")
def __load_example_messages(self):
example_text = ""
if self.prompt_template.example_selector:
for round_conv in self.prompt_template.example_selector.examples():
for round_message in round_conv['messages']:
if not round_message['type'] in [SystemMessage.type, ViewMessage.type]:
example_text += (
round_message['type']
+ ":"
+ round_message['data']['content']
+ self.prompt_template.sep
)
return example_text
def __load_histroy_messages(self):
history_text = ""
if self.prompt_template.need_historical_messages:
if self.history_message:
logger.info(
f"There are already {len(self.history_message)} rounds of conversations! Will use {self.chat_retention_rounds} rounds of content as history!"
)
if len(self.history_message) > self.chat_retention_rounds:
### 使用历史信息的第一轮和最后n轮数据合并成历史对话记录, 做上下文提示时,用户展示消息需要过滤掉
for first_message in self.history_message[0]['messages']:
if not first_message['type'] in [ViewMessage.type, SystemMessage.type]:
text += (
history_text += (
first_message['type']
+ ":"
+ first_message['data']['content']
@ -242,8 +276,8 @@ class BaseChat(ABC):
index = self.chat_retention_rounds - 1
for round_conv in self.history_message[-index:]:
for round_message in round_conv['messages']:
if not isinstance(round_message, ViewMessage):
text += (
if not round_message['type'] in [SystemMessage.type, ViewMessage.type]:
history_text += (
round_message['type']
+ ":"
+ round_message['data']['content']
@ -251,41 +285,26 @@ class BaseChat(ABC):
)
else:
### 直接历史记录拼接
### user all history
for conversation in self.history_message:
for message in conversation['messages']:
if not isinstance(message, ViewMessage):
text += (
### histroy message not have promot and view info
if not message['type'] in [SystemMessage.type, ViewMessage.type]:
history_text += (
message['type']
+ ":"
+ message['data']['content']
+ self.prompt_template.sep
)
### current conversation
for now_message in self.current_message.messages:
text += (
now_message.type + ":" + now_message.content + self.prompt_template.sep
)
return history_text
return text
# 暂时为了兼容前端
def current_ai_response(self) -> str:
for message in self.current_message.messages:
if message.type == "view":
return message.content
return None
def _load_history(self, session_id: str) -> List[OnceConversation]:
"""
load chat history by session_id
Args:
session_id:
Returns:
"""
return self.memory.messages()
def generate(self, p) -> str:
"""
generate context for LLM input

View File

@ -6,4 +6,4 @@ EXAMPLES = [
[{"system": "123"},{"system":"xxx"},{"human":"xxx"},{"assistant":"xxx"}]
]
example = ExampleSelector(examples=EXAMPLES, use_example=True)
plugin_example = ExampleSelector(examples_record=EXAMPLES, use_example=True)

View File

@ -5,7 +5,7 @@ from pilot.scene.base import ChatScene
from pilot.common.schema import SeparatorStyle, ExampleType
from pilot.scene.chat_execution.out_parser import PluginChatOutputParser
from pilot.scene.chat_execution.example import example
from pilot.scene.chat_execution.example import plugin_example
CFG = Config()
@ -49,7 +49,7 @@ prompt = PromptTemplate(
output_parser=PluginChatOutputParser(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_STREAM_OUT
),
example=example,
example_selector=plugin_example,
)
CFG.prompt_templates.update({prompt.template_scene: prompt})

View File

@ -85,12 +85,18 @@ class OnceConversation:
self.messages.clear()
self.session_id = None
def get_user_message(self):
for once in self.messages:
if isinstance(once, HumanMessage):
return once.content
return ""
def get_user_conv(self):
for message in self.messages:
if isinstance(message, HumanMessage):
return message
return None
def get_system_conv(self):
system_convs = []
for message in self.messages:
if isinstance(message, SystemMessage):
system_convs.append(message)
return system_convs
def _conversation_to_dic(once: OnceConversation) -> dict:
start_str: str = ""

View File

@ -18,6 +18,7 @@ from pilot.utils import build_logger
from pilot.server.webserver_base import server_init
from fastapi.staticfiles import StaticFiles
from fastapi import FastAPI, applications
from fastapi.openapi.docs import get_swagger_ui_html
from fastapi.exceptions import RequestValidationError
@ -25,6 +26,7 @@ from fastapi.middleware.cors import CORSMiddleware
from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler
static_file_path = os.path.join(os.getcwd(), "server/static")
CFG = Config()
logger = build_logger("webserver", LOGDIR + "webserver.log")
@ -52,7 +54,9 @@ app.add_middleware(
allow_headers=["*"],
)
# app.mount("static", StaticFiles(directory="static"), name="static")
app.mount("/static", StaticFiles(directory=static_file_path), name="static")
app.add_route("/test", "static/test.html")
app.include_router(api_v1)
app.add_exception_handler(RequestValidationError, validation_exception_handler)

View File

@ -0,0 +1,19 @@
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>Streaming Demo</title>
<script src="https://code.jquery.com/jquery-3.6.0.min.js"></script>
</head>
<body>
<div id="output"></div>
<script>
$(document).ready(function() {
var source = new EventSource("/v1/chat/completions");
source.onmessage = function(event) {
$("#output").append(event.data);
}
});
</script>
</body>
</html>