diff --git a/.gitignore b/.gitignore index d040022b1..f6e18e09f 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ __pycache__/ *.so message/ +static/ .env .idea diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py index 0369ecc2d..8d71ebdea 100644 --- a/pilot/openapi/api_v1/api_v1.py +++ b/pilot/openapi/api_v1/api_v1.py @@ -2,7 +2,7 @@ import uuid import json import asyncio import time -from fastapi import APIRouter, Request, Body, status, HTTPException, Response +from fastapi import APIRouter, Request, Body, status, HTTPException, Response, BackgroundTasks from fastapi.responses import JSONResponse from fastapi.responses import StreamingResponse @@ -35,6 +35,7 @@ knowledge_service = KnowledgeService() model_semaphore = None global_counter = 0 + async def validation_exception_handler(request: Request, exc: RequestValidationError): message = "" for error in exc.errors(): @@ -102,7 +103,7 @@ async def dialogue_scenes(): @router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo]) async def dialogue_new( - chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None + chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None ): unique_id = uuid.uuid1() return Result.succ(ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode)) @@ -176,6 +177,11 @@ async def dialogue_history_messages(con_uid: str): @router.post("/v1/chat/completions") async def chat_completions(dialogue: ConversationVo = Body()): print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}") + if not dialogue.chat_mode: + dialogue.chat_mode = ChatScene.ChatNormal.value + if not dialogue.conv_uid: + dialogue.conv_uid = str(uuid.uuid1()) + global model_semaphore, global_counter global_counter += 1 if model_semaphore is None: @@ -204,32 +210,53 @@ async def chat_completions(dialogue: ConversationVo = Body()): chat_param.update({"knowledge_space": dialogue.select_param}) chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param) + background_tasks = BackgroundTasks() + background_tasks.add_task(release_model_semaphore) + headers = { + # "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + # "Transfer-Encoding": "chunked", + } + if not chat.prompt_template.stream_out: - return chat.nostream_call() + return StreamingResponse(no_stream_generator(chat), headers=headers, media_type='text/event-stream', + background=background_tasks) else: - background_tasks = BackgroundTasks() - background_tasks.add_task(release_model_semaphore) - return StreamingResponse(stream_generator(chat), background=background_tasks) + return StreamingResponse(stream_generator(chat), headers=headers, media_type='text/plain', + background=background_tasks) + def release_model_semaphore(): model_semaphore.release() -def stream_generator(chat): + +async def no_stream_generator(chat): + msg = chat.nostream_call() + msg = msg.replace("\n", "\\n") + yield f"data: {msg}\n\n" + +async def stream_generator(chat): model_response = chat.stream_call() if not CFG.NEW_SERVER_MODE: for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"): if chunk: msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len) chat.current_message.add_ai_message(msg) - yield msg + msg = msg.replace("\n", "\\n") + yield f"data:{msg}\n\n" + await asyncio.sleep(0.1) else: for chunk in model_response: if chunk: msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len) chat.current_message.add_ai_message(msg) - yield msg + + msg = msg.replace("\n", "\\n") + yield f"data:{msg}\n\n" + await asyncio.sleep(0.1) chat.memory.append(chat.current_message) def message2Vo(message: dict, order) -> MessageVo: - return MessageVo(role=message['type'], context=message['data']['content'], order=order) \ No newline at end of file + return MessageVo(role=message['type'], context=message['data']['content'], order=order) diff --git a/pilot/prompts/example_base.py b/pilot/prompts/example_base.py index ab7c7379a..2553be150 100644 --- a/pilot/prompts/example_base.py +++ b/pilot/prompts/example_base.py @@ -6,7 +6,7 @@ from pilot.common.schema import ExampleType class ExampleSelector(BaseModel, ABC): - examples: List[List] + examples_record: List[List] use_example: bool = False type: str = ExampleType.ONE_SHOT.value @@ -16,17 +16,13 @@ class ExampleSelector(BaseModel, ABC): else: return self.__few_shot_context(count) - def __examples_text(self, used_examples): - - - def __few_shot_context(self, count: int = 2) -> List[List]: """ Use 2 or more examples, default 2 Returns: example text """ if self.use_example: - need_use = self.examples[:count] + need_use = self.examples_record[:count] return need_use return None @@ -37,7 +33,7 @@ class ExampleSelector(BaseModel, ABC): """ if self.use_example: - need_use = self.examples[:1] + need_use = self.examples_record[:1] return need_use return None diff --git a/pilot/prompts/prompt_new.py b/pilot/prompts/prompt_new.py index 475f82ea4..80f05c730 100644 --- a/pilot/prompts/prompt_new.py +++ b/pilot/prompts/prompt_new.py @@ -46,7 +46,10 @@ class PromptTemplate(BaseModel, ABC): output_parser: BaseOutputParser = None """""" sep: str = SeparatorStyle.SINGLE.value - example: ExampleSelector = None + + example_selector: ExampleSelector = None + + need_historical_messages: bool = False class Config: """Configuration for this pydantic object.""" diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 560ce7039..8e7b3dfe7 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -52,7 +52,7 @@ class BaseChat(ABC): temperature: float = 0.6 max_new_tokens: int = 1024 # By default, keep the last two rounds of conversation records as the context - chat_retention_rounds: int = 2 + chat_retention_rounds: int = 1 class Config: """Configuration for this pydantic object.""" @@ -79,8 +79,6 @@ class BaseChat(ABC): self.history_message: List[OnceConversation] = self.memory.messages() self.current_message: OnceConversation = OnceConversation(chat_mode.value) self.current_tokens_used: int = 0 - ### load chat_session_id's chat historys - self._load_history(self.chat_session_id) class Config: """Configuration for this pydantic object.""" @@ -107,18 +105,9 @@ class BaseChat(ABC): self.current_message.start_date = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") # TODO self.current_message.tokens = 0 - current_prompt = None if self.prompt_template.template: current_prompt = self.prompt_template.format(**input_values) - - ### 构建当前对话, 是否安第一次对话prompt构造? 是否考虑切换库 - if self.history_message: - ## TODO 带历史对话记录的场景需要确定切换库后怎么处理 - logger.info( - f"There are already {len(self.history_message)} rounds of conversations!" - ) - if current_prompt: self.current_message.add_system_message(current_prompt) payload = { @@ -155,7 +144,7 @@ class BaseChat(ABC): self.current_message.add_view_message( f"""ERROR!{str(e)}\n {ai_response_text} """ ) - ### 对话记录存储 + ### store current conversation self.memory.append(self.current_message) def nostream_call(self): @@ -165,7 +154,6 @@ class BaseChat(ABC): try: rsp_str = "" if not CFG.NEW_SERVER_MODE: - ### 走非流式的模型服务接口 rsp_str = requests.post( urljoin(CFG.MODEL_SERVER, "generate"), headers=headers, @@ -212,7 +200,7 @@ class BaseChat(ABC): self.current_message.add_view_message( f"""ERROR!{str(e)}\n {ai_response_text} """ ) - ### 对话记录存储 + ### store dialogue self.memory.append(self.current_message) return self.current_ai_response() @@ -224,68 +212,99 @@ class BaseChat(ABC): def generate_llm_text(self) -> str: text = "" + ### Load scene setting or character definition if self.prompt_template.template_define: - text = self.prompt_template.template_define + self.prompt_template.sep + text += self.prompt_template.template_define + self.prompt_template.sep + ### Load prompt + text += self.__load_system_message() - ### 处理历史信息 - if len(self.history_message) > self.chat_retention_rounds: - ### 使用历史信息的第一轮和最后n轮数据合并成历史对话记录, 做上下文提示时,用户展示消息需要过滤掉 - for first_message in self.history_message[0]['messages']: - if not first_message['type'] in [ViewMessage.type, SystemMessage.type]: - text += ( - first_message['type'] - + ":" - + first_message['data']['content'] - + self.prompt_template.sep - ) + ### Load examples + text += self.__load_example_messages() - index = self.chat_retention_rounds - 1 - for round_conv in self.history_message[-index:]: + ### Load History + text += self.__load_histroy_messages() + + ### Load User Input + text += self.__load_user_message() + return text + + def __load_system_message(self): + system_convs = self.current_message.get_system_conv() + system_text = "" + for system_conv in system_convs: + system_text += system_conv.type + ":" + system_conv.content + self.prompt_template.sep + return system_text + + def __load_user_message(self): + user_conv = self.current_message.get_user_conv() + if user_conv: + return user_conv.type + ":" + user_conv.content + self.prompt_template.sep + else: + raise ValueError("Hi! What do you want to talk about?") + + def __load_example_messages(self): + example_text = "" + if self.prompt_template.example_selector: + for round_conv in self.prompt_template.example_selector.examples(): for round_message in round_conv['messages']: - if not isinstance(round_message, ViewMessage): - text += ( + if not round_message['type'] in [SystemMessage.type, ViewMessage.type]: + example_text += ( round_message['type'] + ":" + round_message['data']['content'] + self.prompt_template.sep ) + return example_text - else: - ### 直接历史记录拼接 - for conversation in self.history_message: - for message in conversation['messages']: - if not isinstance(message, ViewMessage): - text += ( - message['type'] + def __load_histroy_messages(self): + history_text = "" + if self.prompt_template.need_historical_messages: + if self.history_message: + logger.info( + f"There are already {len(self.history_message)} rounds of conversations! Will use {self.chat_retention_rounds} rounds of content as history!" + ) + if len(self.history_message) > self.chat_retention_rounds: + for first_message in self.history_message[0]['messages']: + if not first_message['type'] in [ViewMessage.type, SystemMessage.type]: + history_text += ( + first_message['type'] + ":" - + message['data']['content'] + + first_message['data']['content'] + self.prompt_template.sep ) - ### current conversation - for now_message in self.current_message.messages: - text += ( - now_message.type + ":" + now_message.content + self.prompt_template.sep - ) + index = self.chat_retention_rounds - 1 + for round_conv in self.history_message[-index:]: + for round_message in round_conv['messages']: + if not round_message['type'] in [SystemMessage.type, ViewMessage.type]: + history_text += ( + round_message['type'] + + ":" + + round_message['data']['content'] + + self.prompt_template.sep + ) - return text + else: + ### user all history + for conversation in self.history_message: + for message in conversation['messages']: + ### histroy message not have promot and view info + if not message['type'] in [SystemMessage.type, ViewMessage.type]: + history_text += ( + message['type'] + + ":" + + message['data']['content'] + + self.prompt_template.sep + ) + + return history_text - # 暂时为了兼容前端 def current_ai_response(self) -> str: for message in self.current_message.messages: if message.type == "view": return message.content return None - def _load_history(self, session_id: str) -> List[OnceConversation]: - """ - load chat history by session_id - Args: - session_id: - Returns: - """ - return self.memory.messages() - def generate(self, p) -> str: """ generate context for LLM input diff --git a/pilot/scene/chat_execution/example.py b/pilot/scene/chat_execution/example.py index 0008ee9ec..f50a7f546 100644 --- a/pilot/scene/chat_execution/example.py +++ b/pilot/scene/chat_execution/example.py @@ -6,4 +6,4 @@ EXAMPLES = [ [{"system": "123"},{"system":"xxx"},{"human":"xxx"},{"assistant":"xxx"}] ] -example = ExampleSelector(examples=EXAMPLES, use_example=True) +plugin_example = ExampleSelector(examples_record=EXAMPLES, use_example=True) diff --git a/pilot/scene/chat_execution/prompt.py b/pilot/scene/chat_execution/prompt.py index 768527c19..af5087609 100644 --- a/pilot/scene/chat_execution/prompt.py +++ b/pilot/scene/chat_execution/prompt.py @@ -5,7 +5,7 @@ from pilot.scene.base import ChatScene from pilot.common.schema import SeparatorStyle, ExampleType from pilot.scene.chat_execution.out_parser import PluginChatOutputParser -from pilot.scene.chat_execution.example import example +from pilot.scene.chat_execution.example import plugin_example CFG = Config() @@ -49,7 +49,7 @@ prompt = PromptTemplate( output_parser=PluginChatOutputParser( sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_STREAM_OUT ), - example=example, + example_selector=plugin_example, ) CFG.prompt_templates.update({prompt.template_scene: prompt}) diff --git a/pilot/scene/message.py b/pilot/scene/message.py index ba884c571..972331bbb 100644 --- a/pilot/scene/message.py +++ b/pilot/scene/message.py @@ -85,12 +85,18 @@ class OnceConversation: self.messages.clear() self.session_id = None - def get_user_message(self): - for once in self.messages: - if isinstance(once, HumanMessage): - return once.content - return "" + def get_user_conv(self): + for message in self.messages: + if isinstance(message, HumanMessage): + return message + return None + def get_system_conv(self): + system_convs = [] + for message in self.messages: + if isinstance(message, SystemMessage): + system_convs.append(message) + return system_convs def _conversation_to_dic(once: OnceConversation) -> dict: start_str: str = "" diff --git a/pilot/server/dbgpt_server.py b/pilot/server/dbgpt_server.py index 7032982d9..6c22105cd 100644 --- a/pilot/server/dbgpt_server.py +++ b/pilot/server/dbgpt_server.py @@ -18,6 +18,7 @@ from pilot.utils import build_logger from pilot.server.webserver_base import server_init +from fastapi.staticfiles import StaticFiles from fastapi import FastAPI, applications from fastapi.openapi.docs import get_swagger_ui_html from fastapi.exceptions import RequestValidationError @@ -25,6 +26,7 @@ from fastapi.middleware.cors import CORSMiddleware from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler +static_file_path = os.path.join(os.getcwd(), "server/static") CFG = Config() logger = build_logger("webserver", LOGDIR + "webserver.log") @@ -52,7 +54,9 @@ app.add_middleware( allow_headers=["*"], ) -# app.mount("static", StaticFiles(directory="static"), name="static") +app.mount("/static", StaticFiles(directory=static_file_path), name="static") +app.add_route("/test", "static/test.html") + app.include_router(api_v1) app.add_exception_handler(RequestValidationError, validation_exception_handler) diff --git a/pilot/server/static/test.html b/pilot/server/static/test.html new file mode 100644 index 000000000..709180f11 --- /dev/null +++ b/pilot/server/static/test.html @@ -0,0 +1,19 @@ + + + + + Streaming Demo + + + +
+ + + \ No newline at end of file