From 5777761320b557a694baed4e294e04e7e2ccd8f0 Mon Sep 17 00:00:00 2001 From: "tuyang.yhj" Date: Tue, 27 Jun 2023 17:45:07 +0800 Subject: [PATCH 1/4] WEB API independent --- pilot/memory/chat_history/duckdb_history.py | 18 ++++++---- pilot/openapi/api_v1/api_v1.py | 40 +++++++++------------ pilot/scene/base.py | 4 +-- pilot/scene/base_chat.py | 4 +-- pilot/server/webserver.py | 2 +- 5 files changed, 32 insertions(+), 36 deletions(-) diff --git a/pilot/memory/chat_history/duckdb_history.py b/pilot/memory/chat_history/duckdb_history.py index b24546d19..42177be75 100644 --- a/pilot/memory/chat_history/duckdb_history.py +++ b/pilot/memory/chat_history/duckdb_history.py @@ -8,6 +8,7 @@ from pilot.memory.chat_history.base import BaseChatHistoryMemory from pilot.scene.message import ( OnceConversation, conversation_from_dict, + _conversation_to_dic, conversations_to_dict, ) from pilot.common.formatting import MyEncoder @@ -41,12 +42,15 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory): def __get_messages_by_conv_uid(self, conv_uid: str): cursor = self.connect.cursor() cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [conv_uid]) - return cursor.fetchone() - + content = cursor.fetchone() + if content: + return cursor.fetchone()[0] + else: + return None def messages(self) -> List[OnceConversation]: context = self.__get_messages_by_conv_uid(self.chat_seesion_id) if context: - conversations: List[OnceConversation] = json.loads(context[0]) + conversations: List[OnceConversation] = json.loads(context) return conversations return [] @@ -54,15 +58,15 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory): context = self.__get_messages_by_conv_uid(self.chat_seesion_id) conversations: List[OnceConversation] = [] if context: - conversations = json.load(context) - conversations.append(once_message) + conversations = json.loads(context) + conversations.append(_conversation_to_dic(once_message)) cursor = self.connect.cursor() if context: cursor.execute("UPDATE chat_history set messages=? where conv_uid=?", - [json.dumps(conversations_to_dict(conversations), ensure_ascii=False, indent=4), self.chat_seesion_id]) + [json.dumps(conversations, ensure_ascii=False), self.chat_seesion_id]) else: cursor.execute("INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)", - [self.chat_seesion_id, "", json.dumps(conversations_to_dict(conversations), ensure_ascii=False, indent=4)]) + [self.chat_seesion_id, "", json.dumps(conversations_to_dict(conversations), ensure_ascii=False)]) cursor.commit() self.connect.commit() diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py index b75179eaf..f04c20145 100644 --- a/pilot/openapi/api_v1/api_v1.py +++ b/pilot/openapi/api_v1/api_v1.py @@ -12,7 +12,7 @@ from fastapi.responses import JSONResponse from sse_starlette.sse import EventSourceResponse from typing import List -from pilot.server.api_v1.api_view_model import Result, ConversationVo, MessageVo, ChatSceneVo +from pilot.openapi.api_v1.api_view_model import Result, ConversationVo, MessageVo, ChatSceneVo from pilot.configs.config import Config from pilot.scene.base_chat import BaseChat from pilot.scene.base import ChatScene @@ -70,7 +70,8 @@ async def dialogue_list(response: Response, user_id: str = None): @router.post('/v1/chat/dialogue/scenes', response_model=Result[List[ChatSceneVo]]) async def dialogue_scenes(): scene_vos: List[ChatSceneVo] = [] - new_modes:List[ChatScene] = [ChatScene.ChatDb, ChatScene.ChatData, ChatScene.ChatDashboard, ChatScene.ChatKnowledge, ChatScene.ChatExecution] + new_modes: List[ChatScene] = [ChatScene.ChatWithDbQA, ChatScene.ChatWithDbExecute, ChatScene.ChatDashboard, + ChatScene.ChatKnowledge, ChatScene.ChatExecution] for scene in new_modes: if not scene.value in [ChatScene.ChatNormal.value, ChatScene.InnerChatDBSummary.value]: scene_vo = ChatSceneVo(chat_scene=scene.value, scene_name=scene.name, param_title="Selection Param") @@ -87,7 +88,7 @@ async def dialogue_new(chat_mode: str = ChatScene.ChatNormal.value, user_id: str def get_db_list(): db = CFG.local_db dbs = db.get_database_list() - params:dict = {} + params: dict = {} for name in dbs: params.update({name: name}) return params @@ -108,9 +109,9 @@ def knowledge_list(): @router.post('/v1/chat/mode/params/list', response_model=Result[dict]) async def params_list(chat_mode: str = ChatScene.ChatNormal.value): - if ChatScene.ChatDb.value == chat_mode: + if ChatScene.ChatWithDbQA.value == chat_mode: return Result.succ(get_db_list()) - elif ChatScene.ChatData.value == chat_mode: + elif ChatScene.ChatWithDbExecute.value == chat_mode: return Result.succ(get_db_list()) elif ChatScene.ChatDashboard.value == chat_mode: return Result.succ(get_db_list()) @@ -155,24 +156,21 @@ async def chat_completions(dialogue: ConversationVo = Body()): "user_input": dialogue.user_input, } - if ChatScene.ChatDb == dialogue.chat_mode: - chat_param.update("db_name", dialogue.select_param) - elif ChatScene.ChatData == dialogue.chat_mode: - chat_param.update("db_name", dialogue.select_param) - elif ChatScene.ChatDashboard == dialogue.chat_mode: - chat_param.update("db_name", dialogue.select_param) - elif ChatScene.ChatExecution == dialogue.chat_mode: - chat_param.update("plugin_selector", dialogue.select_param) - elif ChatScene.ChatKnowledge == dialogue.chat_mode: - chat_param.update("knowledge_name", dialogue.select_param) + if ChatScene.ChatWithDbQA.value == dialogue.chat_mode: + chat_param.update({"db_name": dialogue.select_param}) + elif ChatScene.ChatWithDbExecute.value == dialogue.chat_mode: + chat_param.update({"db_name": dialogue.select_param}) + elif ChatScene.ChatDashboard.value == dialogue.chat_mode: + chat_param.update({"db_name": dialogue.select_param}) + elif ChatScene.ChatExecution.value == dialogue.chat_mode: + chat_param.update({"plugin_selector": dialogue.select_param}) + elif ChatScene.ChatKnowledge.value == dialogue.chat_mode: + chat_param.update({"knowledge_name": dialogue.select_param}) chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param) if not chat.prompt_template.stream_out: return non_stream_response(chat) else: - # generator = stream_generator(chat) - # result = Result.succ(data=StreamingResponse(stream_test(), media_type='text/plain')) - # return result return StreamingResponse(stream_generator(chat), media_type="text/plain") @@ -196,12 +194,6 @@ def stream_generator(chat): chat.memory.append(chat.current_message) -# def stream_response(chat): -# logger.info("stream out start!") -# api_response = StreamingResponse(stream_generator(chat), media_type="application/json") -# return api_response - - def message2Vo(message: dict, order) -> MessageVo: # message.additional_kwargs['time_stamp'] if message.additional_kwargs["time_stamp"] else 0 return MessageVo(role=message['type'], context=message['data']['content'], order=order) diff --git a/pilot/scene/base.py b/pilot/scene/base.py index cec443beb..b5e15b4e8 100644 --- a/pilot/scene/base.py +++ b/pilot/scene/base.py @@ -18,8 +18,8 @@ class ChatScene(Enum): ChatNormal = "chat_normal" ChatDashboard = "chat_dashboard" ChatKnowledge = "chat_knowledge" - ChatDb = "chat_db" - ChatData= "chat_data" + # ChatDb = "chat_db" + # ChatData= "chat_data" @staticmethod def is_valid_mode(mode): diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index d6628a19a..042d6af59 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -75,7 +75,7 @@ class BaseChat(ABC): self.prompt_template: PromptTemplate = CFG.prompt_templates[ self.chat_mode.value ] - self.history_message: List[OnceConversation] = [] + self.history_message: List[OnceConversation] = self.memory.messages() self.current_message: OnceConversation = OnceConversation(chat_mode.value) self.current_tokens_used: int = 0 ### load chat_session_id's chat historys @@ -95,7 +95,7 @@ class BaseChat(ABC): def generate_input_values(self): pass - @abstractmethod + def do_action(self, prompt_response): return prompt_response diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index c2cd9d434..86dc0d426 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -58,7 +58,7 @@ from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles -from pilot.server.api_v1.api_v1 import router as api_v1, validation_exception_handler +from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler # 加载插件 CFG = Config() From b2d2828b4eb7e5eb343034fa54086285e75a5e53 Mon Sep 17 00:00:00 2001 From: "tuyang.yhj" Date: Wed, 28 Jun 2023 11:34:40 +0800 Subject: [PATCH 2/4] WEB API independent --- pilot/configs/config.py | 2 + pilot/memory/chat_history/duckdb_history.py | 4 +- pilot/model/llm_out/proxy_llm.py | 3 +- pilot/openapi/api_v1/api_v1.py | 89 +++++--------- pilot/out_parser/base.py | 23 ++-- pilot/prompts/example_base.py | 4 + pilot/scene/base_chat.py | 127 +++++++++----------- pilot/scene/chat_db/auto_execute/prompt.py | 2 +- pilot/scene/chat_execution/example.py | 4 +- pilot/scene/chat_execution/prompt.py | 2 - pilot/scene/chat_normal/prompt.py | 3 +- pilot/scene/message.py | 2 +- pilot/server/dbgpt_server.py | 74 ++++++++++++ pilot/server/llmserver.py | 31 ++--- pilot/server/webserver_base.py | 2 + 15 files changed, 201 insertions(+), 171 deletions(-) create mode 100644 pilot/server/dbgpt_server.py diff --git a/pilot/configs/config.py b/pilot/configs/config.py index 94ff19e21..7bf6d97e6 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -17,6 +17,8 @@ class Config(metaclass=Singleton): def __init__(self) -> None: """Initialize the Config class""" + self.NEW_SERVER_MODE = False + # Gradio language version: en, zh self.LANGUAGE = os.getenv("LANGUAGE", "en") self.WEB_SERVER_PORT = int(os.getenv("WEB_SERVER_PORT", 7860)) diff --git a/pilot/memory/chat_history/duckdb_history.py b/pilot/memory/chat_history/duckdb_history.py index 42177be75..de80a5bc2 100644 --- a/pilot/memory/chat_history/duckdb_history.py +++ b/pilot/memory/chat_history/duckdb_history.py @@ -44,7 +44,7 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory): cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [conv_uid]) content = cursor.fetchone() if content: - return cursor.fetchone()[0] + return content[0] else: return None def messages(self) -> List[OnceConversation]: @@ -66,7 +66,7 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory): [json.dumps(conversations, ensure_ascii=False), self.chat_seesion_id]) else: cursor.execute("INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)", - [self.chat_seesion_id, "", json.dumps(conversations_to_dict(conversations), ensure_ascii=False)]) + [self.chat_seesion_id, "", json.dumps(conversations, ensure_ascii=False)]) cursor.commit() self.connect.commit() diff --git a/pilot/model/llm_out/proxy_llm.py b/pilot/model/llm_out/proxy_llm.py index 4336d43e3..c4423a1a6 100644 --- a/pilot/model/llm_out/proxy_llm.py +++ b/pilot/model/llm_out/proxy_llm.py @@ -39,7 +39,7 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048) elif "ai:" in message: history.append( { - "role": "ai", + "role": "assistant", "content": message.split("ai:")[1], } ) @@ -57,6 +57,7 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048) for m in temp_his: if m["role"] == "user": last_user_input = m + break if last_user_input: history.remove(last_user_input) history.append(last_user_input) diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py index 832c6785c..d11ba0976 100644 --- a/pilot/openapi/api_v1/api_v1.py +++ b/pilot/openapi/api_v1/api_v1.py @@ -2,7 +2,7 @@ import uuid import json import asyncio import time -from fastapi import APIRouter, Request, Body, status, HTTPException, Response +from fastapi import APIRouter, Request, Body, status, HTTPException, Response, BackgroundTasks from fastapi.responses import JSONResponse from fastapi.responses import StreamingResponse @@ -31,6 +31,8 @@ CHAT_FACTORY = ChatFactory() logger = build_logger("api_v1", LOGDIR + "api_v1.log") knowledge_service = KnowledgeService() +model_semaphore = None +global_counter = 0 async def validation_exception_handler(request: Request, exc: RequestValidationError): message = "" @@ -148,6 +150,11 @@ async def dialogue_history_messages(con_uid: str): @router.post('/v1/chat/completions') async def chat_completions(dialogue: ConversationVo = Body()): print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}") + global model_semaphore, global_counter + global_counter += 1 + if model_semaphore is None: + model_semaphore = asyncio.Semaphore(CFG.LIMIT_MODEL_CONCURRENCY) + await model_semaphore.acquire() if not ChatScene.is_valid_mode(dialogue.chat_mode): raise StopAsyncIteration(Result.faild("Unsupported Chat Mode," + dialogue.chat_mode + "!")) @@ -170,73 +177,31 @@ async def chat_completions(dialogue: ConversationVo = Body()): chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param) if not chat.prompt_template.stream_out: - return non_stream_response(chat) + return chat.nostream_call() else: - return StreamingResponse(stream_generator(chat), media_type="text/plain") - - -def stream_test(): - for message in ["Hello", "world", "how", "are", "you"]: - yield message - # yield json.dumps(Result.succ(message).__dict__).encode("utf-8") + background_tasks = BackgroundTasks() + background_tasks.add_task(release_model_semaphore) + return StreamingResponse(stream_generator(chat), background=background_tasks) +def release_model_semaphore(): + model_semaphore.release() def stream_generator(chat): model_response = chat.stream_call() - for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"): - if chunk: - msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len) - chat.current_message.add_ai_message(msg) - yield msg - # chat.current_message.add_ai_message(msg) - # vo = MessageVo(role="view", context=msg, order=chat.current_message.chat_order) - # json_text = json.dumps(vo.__dict__) - # yield json_text.encode('utf-8') + if not CFG.NEW_SERVER_MODE: + for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len) + chat.current_message.add_ai_message(msg) + yield msg + else: + for chunk in model_response: + if chunk: + msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len) + chat.current_message.add_ai_message(msg) + yield msg chat.memory.append(chat.current_message) def message2Vo(message: dict, order) -> MessageVo: - # message.additional_kwargs['time_stamp'] if message.additional_kwargs["time_stamp"] else 0 - return MessageVo(role=message['type'], context=message['data']['content'], order=order) - - -def non_stream_response(chat): - logger.info("not stream out, wait model response!") - return chat.nostream_call() - - -@router.get('/v1/db/types', response_model=Result[str]) -async def db_types(): - return Result.succ(["mysql", "duckdb"]) - - -@router.get('/v1/db/list', response_model=Result[str]) -async def db_list(): - db = CFG.local_db - dbs = db.get_database_list() - return Result.succ(dbs) - - -@router.get('/v1/knowledge/list') -async def knowledge_list(): - return ["test1", "test2"] - - -@router.post('/v1/knowledge/add') -async def knowledge_add(): - return ["test1", "test2"] - - -@router.post('/v1/knowledge/delete') -async def knowledge_delete(): - return ["test1", "test2"] - - -@router.get('/v1/knowledge/types') -async def knowledge_types(): - return ["test1", "test2"] - - -@router.get('/v1/knowledge/detail') -async def knowledge_detail(): - return ["test1", "test2"] + return MessageVo(role=message['type'], context=message['data']['content'], order=order) \ No newline at end of file diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index 4476a68fc..ca308e92f 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -47,6 +47,8 @@ class BaseOutputParser(ABC): return code def parse_model_stream_resp_ex(self, chunk, skip_echo_len): + if b"\0" in chunk: + chunk = chunk.replace(b"\0", b"") data = json.loads(chunk.decode()) """ TODO Multi mode output handler, rewrite this for multi model, use adapter mode. @@ -95,11 +97,8 @@ class BaseOutputParser(ABC): def parse_model_nostream_resp(self, response, sep: str): text = response.text.strip() text = text.rstrip() - respObj = json.loads(text) - - xx = respObj["response"] - xx = xx.strip(b"\x00".decode()) - respObj_ex = json.loads(xx) + text = text.strip(b"\x00".decode()) + respObj_ex = json.loads(text) if respObj_ex["error_code"] == 0: all_text = respObj_ex["text"] ### 解析返回文本,获取AI回复部分 @@ -123,7 +122,7 @@ class BaseOutputParser(ABC): def __extract_json(slef, s): i = s.index("{") count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数 - for j, c in enumerate(s[i + 1 :], start=i + 1): + for j, c in enumerate(s[i + 1:], start=i + 1): if c == "}": count -= 1 elif c == "{": @@ -131,7 +130,7 @@ class BaseOutputParser(ABC): if count == 0: break assert count == 0 # 检查是否找到最后一个'}' - return s[i : j + 1] + return s[i: j + 1] def parse_prompt_response(self, model_out_text) -> T: """ @@ -148,9 +147,9 @@ class BaseOutputParser(ABC): # if "```" in cleaned_output: # cleaned_output, _ = cleaned_output.split("```") if cleaned_output.startswith("```json"): - cleaned_output = cleaned_output[len("```json") :] + cleaned_output = cleaned_output[len("```json"):] if cleaned_output.startswith("```"): - cleaned_output = cleaned_output[len("```") :] + cleaned_output = cleaned_output[len("```"):] if cleaned_output.endswith("```"): cleaned_output = cleaned_output[: -len("```")] cleaned_output = cleaned_output.strip() @@ -159,9 +158,9 @@ class BaseOutputParser(ABC): cleaned_output = self.__extract_json(cleaned_output) cleaned_output = ( cleaned_output.strip() - .replace("\n", " ") - .replace("\\n", " ") - .replace("\\", " ") + .replace("\n", " ") + .replace("\\n", " ") + .replace("\\", " ") ) return cleaned_output diff --git a/pilot/prompts/example_base.py b/pilot/prompts/example_base.py index 4d876aa51..3372c2b1d 100644 --- a/pilot/prompts/example_base.py +++ b/pilot/prompts/example_base.py @@ -15,6 +15,10 @@ class ExampleSelector(BaseModel, ABC): else: return self.__few_shot_context(count) + def __examples_text(self, used_examples): + + + def __few_shot_context(self, count: int = 2) -> List[List]: """ Use 2 or more examples, default 2 diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 042d6af59..560ce7039 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -39,6 +39,7 @@ from pilot.scene.base_message import ( ViewMessage, ) from pilot.configs.config import Config +from pilot.server.llmserver import worker logger = build_logger("BaseChat", LOGDIR + "BaseChat.log") headers = {"User-Agent": "dbgpt Client"} @@ -59,10 +60,10 @@ class BaseChat(ABC): arbitrary_types_allowed = True def __init__( - self, - chat_mode, - chat_session_id, - current_user_input, + self, + chat_mode, + chat_session_id, + current_user_input, ): self.chat_session_id = chat_session_id self.chat_mode = chat_mode @@ -95,7 +96,6 @@ class BaseChat(ABC): def generate_input_values(self): pass - def do_action(self, prompt_response): return prompt_response @@ -138,24 +138,17 @@ class BaseChat(ABC): logger.info(f"Requert: \n{payload}") ai_response_text = "" try: - show_info = "" - response = requests.post( - urljoin(CFG.MODEL_SERVER, "generate_stream"), - headers=headers, - json=payload, - stream=True, - timeout=120, - ) - return response - - # yield self.prompt_template.output_parser.parse_model_stream_resp(response, skip_echo_len) - - # for resp_text_trunck in ai_response_text: - # show_info = resp_text_trunck - # yield resp_text_trunck + "▌" - - self.current_message.add_ai_message(show_info) - + if not CFG.NEW_SERVER_MODE: + response = requests.post( + urljoin(CFG.MODEL_SERVER, "generate_stream"), + headers=headers, + json=payload, + stream=True, + timeout=120, + ) + return response + else: + return worker.generate_stream_gate(payload) except Exception as e: print(traceback.format_exc()) logger.error("model response parase faild!" + str(e)) @@ -170,39 +163,28 @@ class BaseChat(ABC): logger.info(f"Requert: \n{payload}") ai_response_text = "" try: - ### 走非流式的模型服务接口 - response = requests.post( - urljoin(CFG.MODEL_SERVER, "generate"), - headers=headers, - json=payload, - timeout=120, - ) + rsp_str = "" + if not CFG.NEW_SERVER_MODE: + ### 走非流式的模型服务接口 + rsp_str = requests.post( + urljoin(CFG.MODEL_SERVER, "generate"), + headers=headers, + json=payload, + timeout=120, + ) + else: + ###TODO no stream mode need independent + output = worker.generate_stream_gate(payload) + for rsp in output: + rsp_str = str(rsp, "utf-8") + print("[TEST: output]:", rsp_str) ### output parse - ai_response_text = ( - self.prompt_template.output_parser.parse_model_nostream_resp( - response, self.prompt_template.sep - ) - ) - - # ### MOCK - # ai_response_text = """{ - # "thoughts": "可以从users表和tran_order表联合查询,按城市和订单数量进行分组统计,并使用柱状图展示。", - # "reasoning": "为了分析用户在不同城市的分布情况,需要查询users表和tran_order表,使用LEFT JOIN将两个表联合起来。按照城市进行分组,统计每个城市的订单数量。使用柱状图展示可以直观地看出每个城市的订单数量,方便比较。", - # "speak": "根据您的分析目标,我查询了用户表和订单表,统计了每个城市的订单数量,并生成了柱状图展示。", - # "command": { - # "name": "histogram-executor", - # "args": { - # "title": "订单城市分布柱状图", - # "sql": "SELECT users.city, COUNT(tran_order.order_id) AS order_count FROM users LEFT JOIN tran_order ON users.user_name = tran_order.user_name GROUP BY users.city" - # } - # } - # }""" - + ai_response_text = self.prompt_template.output_parser.parse_model_nostream_resp(rsp_str, + self.prompt_template.sep) + ### model result deal self.current_message.add_ai_message(ai_response_text) prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text) - - result = self.do_action(prompt_define_response) if hasattr(prompt_define_response, "thoughts"): @@ -248,41 +230,42 @@ class BaseChat(ABC): ### 处理历史信息 if len(self.history_message) > self.chat_retention_rounds: ### 使用历史信息的第一轮和最后n轮数据合并成历史对话记录, 做上下文提示时,用户展示消息需要过滤掉 - for first_message in self.history_message[0].messages: - if not isinstance(first_message, ViewMessage): + for first_message in self.history_message[0]['messages']: + if not first_message['type'] in [ViewMessage.type, SystemMessage.type]: text += ( - first_message.type - + ":" - + first_message.content - + self.prompt_template.sep + first_message['type'] + + ":" + + first_message['data']['content'] + + self.prompt_template.sep ) index = self.chat_retention_rounds - 1 - for last_message in self.history_message[-index:].messages: - if not isinstance(last_message, ViewMessage): - text += ( - last_message.type - + ":" - + last_message.content - + self.prompt_template.sep - ) + for round_conv in self.history_message[-index:]: + for round_message in round_conv['messages']: + if not isinstance(round_message, ViewMessage): + text += ( + round_message['type'] + + ":" + + round_message['data']['content'] + + self.prompt_template.sep + ) else: ### 直接历史记录拼接 for conversation in self.history_message: - for message in conversation.messages: + for message in conversation['messages']: if not isinstance(message, ViewMessage): text += ( - message.type - + ":" - + message.content - + self.prompt_template.sep + message['type'] + + ":" + + message['data']['content'] + + self.prompt_template.sep ) ### current conversation for now_message in self.current_message.messages: text += ( - now_message.type + ":" + now_message.content + self.prompt_template.sep + now_message.type + ":" + now_message.content + self.prompt_template.sep ) return text diff --git a/pilot/scene/chat_db/auto_execute/prompt.py b/pilot/scene/chat_db/auto_execute/prompt.py index 938860aaf..cc4878c70 100644 --- a/pilot/scene/chat_db/auto_execute/prompt.py +++ b/pilot/scene/chat_db/auto_execute/prompt.py @@ -8,7 +8,7 @@ from pilot.common.schema import SeparatorStyle CFG = Config() -PROMPT_SCENE_DEFINE = """You are an AI designed to answer human questions, please follow the prompts and conventions of the system's input for your answers""" +PROMPT_SCENE_DEFINE = None _DEFAULT_TEMPLATE = """ diff --git a/pilot/scene/chat_execution/example.py b/pilot/scene/chat_execution/example.py index 6cd71b39c..0008ee9ec 100644 --- a/pilot/scene/chat_execution/example.py +++ b/pilot/scene/chat_execution/example.py @@ -2,8 +2,8 @@ from pilot.prompts.example_base import ExampleSelector ## Two examples are defined by default EXAMPLES = [ - [{"System": "123"},{"System":"xxx"},{"User":"xxx"},{"Assistant":"xxx"}], - [{"System": "123"},{"System":"xxx"},{"User":"xxx"},{"Assistant":"xxx"}] + [{"system": "123"},{"system":"xxx"},{"human":"xxx"},{"assistant":"xxx"}], + [{"system": "123"},{"system":"xxx"},{"human":"xxx"},{"assistant":"xxx"}] ] example = ExampleSelector(examples=EXAMPLES, use_example=True) diff --git a/pilot/scene/chat_execution/prompt.py b/pilot/scene/chat_execution/prompt.py index b5bc38bb2..fef2a13b3 100644 --- a/pilot/scene/chat_execution/prompt.py +++ b/pilot/scene/chat_execution/prompt.py @@ -9,10 +9,8 @@ from pilot.scene.chat_execution.example import example CFG = Config() -# PROMPT_SCENE_DEFINE = """You are an AI designed to solve the user's goals with given commands, please follow the prompts and constraints of the system's input for your answers.""" PROMPT_SCENE_DEFINE = "You are an AI designed to solve the user's goals with given commands, please follow the constraints of the system's input for your answers." - _DEFAULT_TEMPLATE = """ Goals: {input} diff --git a/pilot/scene/chat_normal/prompt.py b/pilot/scene/chat_normal/prompt.py index 55514dffa..1dc455255 100644 --- a/pilot/scene/chat_normal/prompt.py +++ b/pilot/scene/chat_normal/prompt.py @@ -8,8 +8,7 @@ from pilot.common.schema import SeparatorStyle from pilot.scene.chat_normal.out_parser import NormalChatOutputParser -PROMPT_SCENE_DEFINE = """A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge. - The assistant gives helpful, detailed, professional and polite answers to the user's questions. """ +PROMPT_SCENE_DEFINE = None CFG = Config() diff --git a/pilot/scene/message.py b/pilot/scene/message.py index a2a894fe8..f2bd1fa56 100644 --- a/pilot/scene/message.py +++ b/pilot/scene/message.py @@ -95,7 +95,7 @@ class OnceConversation: def _conversation_to_dic(once: OnceConversation) -> dict: start_str: str = "" - if once.start_date: + if hasattr(once, 'start_date') and once.start_date: if isinstance(once.start_date, datetime): start_str = once.start_date.strftime("%Y-%m-%d %H:%M:%S") else: diff --git a/pilot/server/dbgpt_server.py b/pilot/server/dbgpt_server.py new file mode 100644 index 000000000..7032982d9 --- /dev/null +++ b/pilot/server/dbgpt_server.py @@ -0,0 +1,74 @@ +import traceback +import os +import shutil +import argparse +import sys + +ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.append(ROOT_PATH) + +from pilot.configs.config import Config +from pilot.configs.model_config import ( + DATASETS_DIR, + KNOWLEDGE_UPLOAD_ROOT_PATH, + LLM_MODEL_CONFIG, + LOGDIR, +) +from pilot.utils import build_logger + +from pilot.server.webserver_base import server_init + +from fastapi import FastAPI, applications +from fastapi.openapi.docs import get_swagger_ui_html +from fastapi.exceptions import RequestValidationError +from fastapi.middleware.cors import CORSMiddleware + +from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler + + +CFG = Config() +logger = build_logger("webserver", LOGDIR + "webserver.log") + + +def swagger_monkey_patch(*args, **kwargs): + return get_swagger_ui_html( + *args, **kwargs, + swagger_js_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js', + swagger_css_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css' + ) + + +applications.get_swagger_ui_html = swagger_monkey_patch + +app = FastAPI() +origins = ["*"] + +# 添加跨域中间件 +app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"], + allow_headers=["*"], +) + +# app.mount("static", StaticFiles(directory="static"), name="static") +app.include_router(api_v1) +app.add_exception_handler(RequestValidationError, validation_exception_handler) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model_list_mode", type=str, default="once", choices=["once", "reload"]) + + # old version server config + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument("--port", type=int, default=CFG.WEB_SERVER_PORT) + parser.add_argument("--concurrency-count", type=int, default=10) + parser.add_argument("--share", default=False, action="store_true") + + # init server config + args = parser.parse_args() + server_init(args) + CFG.NEW_SERVER_MODE = True + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=5000) diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index 3030d1fdc..0fb4c5ae6 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -27,14 +27,13 @@ from pilot.server.chat_adapter import get_llm_chat_adapter CFG = Config() - class ModelWorker: def __init__(self, model_path, model_name, device, num_gpus=1): if model_path.endswith("/"): model_path = model_path[:-1] self.model_name = model_name or model_path.split("/")[-1] self.device = device - + print(f"Loading {model_name} LLM ModelServer in {device}! Please Wait......") self.ml = ModelLoader(model_path=model_path) self.model, self.tokenizer = self.ml.loader( num_gpus, load_8bit=ISLOAD_8BIT, debug=ISDEBUG @@ -42,11 +41,11 @@ class ModelWorker: if not isinstance(self.model, str): if hasattr(self.model, "config") and hasattr( - self.model.config, "max_sequence_length" + self.model.config, "max_sequence_length" ): self.context_len = self.model.config.max_sequence_length elif hasattr(self.model, "config") and hasattr( - self.model.config, "max_position_embeddings" + self.model.config, "max_position_embeddings" ): self.context_len = self.model.config.max_position_embeddings @@ -56,29 +55,32 @@ class ModelWorker: self.llm_chat_adapter = get_llm_chat_adapter(model_path) self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func() + def start_check(self): + print("LLM Model Loading Success!") + def get_queue_length(self): if ( - model_semaphore is None - or model_semaphore._value is None - or model_semaphore._waiters is None + model_semaphore is None + or model_semaphore._value is None + or model_semaphore._waiters is None ): return 0 else: ( - CFG.LIMIT_MODEL_CONCURRENCY - - model_semaphore._value - + len(model_semaphore._waiters) + CFG.LIMIT_MODEL_CONCURRENCY + - model_semaphore._value + + len(model_semaphore._waiters) ) def generate_stream_gate(self, params): try: for output in self.generate_stream_func( - self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS + self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS ): # Please do not open the output in production! # The gpt4all thread shares stdout with the parent process, # and opening it may affect the frontend output. - # print("output: ", output) + print("output: ", output) ret = { "text": output, "error_code": 0, @@ -106,6 +108,7 @@ worker = ModelWorker( app = FastAPI() from pilot.openapi.knowledge.knowledge_controller import router + app.include_router(router) origins = [ @@ -122,6 +125,7 @@ app.add_middleware( allow_headers=["*"] ) + class PromptRequest(BaseModel): prompt: str temperature: float @@ -177,10 +181,9 @@ def generate(prompt_request: PromptRequest): for rsp in output: # rsp = rsp.decode("utf-8") rsp_str = str(rsp, "utf-8") - print("[TEST: output]:", rsp_str) response.append(rsp_str) - return {"response": rsp_str} + return rsp_str @app.post("/embedding") diff --git a/pilot/server/webserver_base.py b/pilot/server/webserver_base.py index 0aa2ac3f9..c76984c37 100644 --- a/pilot/server/webserver_base.py +++ b/pilot/server/webserver_base.py @@ -39,6 +39,8 @@ def server_init(args): # init config cfg = Config() + from pilot.server.llmserver import worker + worker.start_check() load_native_plugins(cfg) signal.signal(signal.SIGINT, signal_handler) async_db_summery() From 307ffd6d905956553b98b479e8f1b52aa7a6ea64 Mon Sep 17 00:00:00 2001 From: "tuyang.yhj" Date: Thu, 29 Jun 2023 09:55:43 +0800 Subject: [PATCH 3/4] WEB API independent --- .gitignore | 1 + pilot/openapi/api_v1/api_v1.py | 47 ++++++++-- pilot/prompts/example_base.py | 10 +- pilot/prompts/prompt_new.py | 5 +- pilot/scene/base_chat.py | 129 +++++++++++++++----------- pilot/scene/chat_execution/example.py | 2 +- pilot/scene/chat_execution/prompt.py | 4 +- pilot/scene/message.py | 16 +++- pilot/server/dbgpt_server.py | 6 +- pilot/server/static/test.html | 19 ++++ 10 files changed, 157 insertions(+), 82 deletions(-) create mode 100644 pilot/server/static/test.html diff --git a/.gitignore b/.gitignore index d040022b1..f6e18e09f 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ __pycache__/ *.so message/ +static/ .env .idea diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py index 0369ecc2d..8d71ebdea 100644 --- a/pilot/openapi/api_v1/api_v1.py +++ b/pilot/openapi/api_v1/api_v1.py @@ -2,7 +2,7 @@ import uuid import json import asyncio import time -from fastapi import APIRouter, Request, Body, status, HTTPException, Response +from fastapi import APIRouter, Request, Body, status, HTTPException, Response, BackgroundTasks from fastapi.responses import JSONResponse from fastapi.responses import StreamingResponse @@ -35,6 +35,7 @@ knowledge_service = KnowledgeService() model_semaphore = None global_counter = 0 + async def validation_exception_handler(request: Request, exc: RequestValidationError): message = "" for error in exc.errors(): @@ -102,7 +103,7 @@ async def dialogue_scenes(): @router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo]) async def dialogue_new( - chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None + chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None ): unique_id = uuid.uuid1() return Result.succ(ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode)) @@ -176,6 +177,11 @@ async def dialogue_history_messages(con_uid: str): @router.post("/v1/chat/completions") async def chat_completions(dialogue: ConversationVo = Body()): print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}") + if not dialogue.chat_mode: + dialogue.chat_mode = ChatScene.ChatNormal.value + if not dialogue.conv_uid: + dialogue.conv_uid = str(uuid.uuid1()) + global model_semaphore, global_counter global_counter += 1 if model_semaphore is None: @@ -204,32 +210,53 @@ async def chat_completions(dialogue: ConversationVo = Body()): chat_param.update({"knowledge_space": dialogue.select_param}) chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param) + background_tasks = BackgroundTasks() + background_tasks.add_task(release_model_semaphore) + headers = { + # "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + # "Transfer-Encoding": "chunked", + } + if not chat.prompt_template.stream_out: - return chat.nostream_call() + return StreamingResponse(no_stream_generator(chat), headers=headers, media_type='text/event-stream', + background=background_tasks) else: - background_tasks = BackgroundTasks() - background_tasks.add_task(release_model_semaphore) - return StreamingResponse(stream_generator(chat), background=background_tasks) + return StreamingResponse(stream_generator(chat), headers=headers, media_type='text/plain', + background=background_tasks) + def release_model_semaphore(): model_semaphore.release() -def stream_generator(chat): + +async def no_stream_generator(chat): + msg = chat.nostream_call() + msg = msg.replace("\n", "\\n") + yield f"data: {msg}\n\n" + +async def stream_generator(chat): model_response = chat.stream_call() if not CFG.NEW_SERVER_MODE: for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"): if chunk: msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len) chat.current_message.add_ai_message(msg) - yield msg + msg = msg.replace("\n", "\\n") + yield f"data:{msg}\n\n" + await asyncio.sleep(0.1) else: for chunk in model_response: if chunk: msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len) chat.current_message.add_ai_message(msg) - yield msg + + msg = msg.replace("\n", "\\n") + yield f"data:{msg}\n\n" + await asyncio.sleep(0.1) chat.memory.append(chat.current_message) def message2Vo(message: dict, order) -> MessageVo: - return MessageVo(role=message['type'], context=message['data']['content'], order=order) \ No newline at end of file + return MessageVo(role=message['type'], context=message['data']['content'], order=order) diff --git a/pilot/prompts/example_base.py b/pilot/prompts/example_base.py index ab7c7379a..2553be150 100644 --- a/pilot/prompts/example_base.py +++ b/pilot/prompts/example_base.py @@ -6,7 +6,7 @@ from pilot.common.schema import ExampleType class ExampleSelector(BaseModel, ABC): - examples: List[List] + examples_record: List[List] use_example: bool = False type: str = ExampleType.ONE_SHOT.value @@ -16,17 +16,13 @@ class ExampleSelector(BaseModel, ABC): else: return self.__few_shot_context(count) - def __examples_text(self, used_examples): - - - def __few_shot_context(self, count: int = 2) -> List[List]: """ Use 2 or more examples, default 2 Returns: example text """ if self.use_example: - need_use = self.examples[:count] + need_use = self.examples_record[:count] return need_use return None @@ -37,7 +33,7 @@ class ExampleSelector(BaseModel, ABC): """ if self.use_example: - need_use = self.examples[:1] + need_use = self.examples_record[:1] return need_use return None diff --git a/pilot/prompts/prompt_new.py b/pilot/prompts/prompt_new.py index 475f82ea4..80f05c730 100644 --- a/pilot/prompts/prompt_new.py +++ b/pilot/prompts/prompt_new.py @@ -46,7 +46,10 @@ class PromptTemplate(BaseModel, ABC): output_parser: BaseOutputParser = None """""" sep: str = SeparatorStyle.SINGLE.value - example: ExampleSelector = None + + example_selector: ExampleSelector = None + + need_historical_messages: bool = False class Config: """Configuration for this pydantic object.""" diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 560ce7039..8e7b3dfe7 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -52,7 +52,7 @@ class BaseChat(ABC): temperature: float = 0.6 max_new_tokens: int = 1024 # By default, keep the last two rounds of conversation records as the context - chat_retention_rounds: int = 2 + chat_retention_rounds: int = 1 class Config: """Configuration for this pydantic object.""" @@ -79,8 +79,6 @@ class BaseChat(ABC): self.history_message: List[OnceConversation] = self.memory.messages() self.current_message: OnceConversation = OnceConversation(chat_mode.value) self.current_tokens_used: int = 0 - ### load chat_session_id's chat historys - self._load_history(self.chat_session_id) class Config: """Configuration for this pydantic object.""" @@ -107,18 +105,9 @@ class BaseChat(ABC): self.current_message.start_date = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") # TODO self.current_message.tokens = 0 - current_prompt = None if self.prompt_template.template: current_prompt = self.prompt_template.format(**input_values) - - ### 构建当前对话, 是否安第一次对话prompt构造? 是否考虑切换库 - if self.history_message: - ## TODO 带历史对话记录的场景需要确定切换库后怎么处理 - logger.info( - f"There are already {len(self.history_message)} rounds of conversations!" - ) - if current_prompt: self.current_message.add_system_message(current_prompt) payload = { @@ -155,7 +144,7 @@ class BaseChat(ABC): self.current_message.add_view_message( f"""ERROR!{str(e)}\n {ai_response_text} """ ) - ### 对话记录存储 + ### store current conversation self.memory.append(self.current_message) def nostream_call(self): @@ -165,7 +154,6 @@ class BaseChat(ABC): try: rsp_str = "" if not CFG.NEW_SERVER_MODE: - ### 走非流式的模型服务接口 rsp_str = requests.post( urljoin(CFG.MODEL_SERVER, "generate"), headers=headers, @@ -212,7 +200,7 @@ class BaseChat(ABC): self.current_message.add_view_message( f"""ERROR!{str(e)}\n {ai_response_text} """ ) - ### 对话记录存储 + ### store dialogue self.memory.append(self.current_message) return self.current_ai_response() @@ -224,68 +212,99 @@ class BaseChat(ABC): def generate_llm_text(self) -> str: text = "" + ### Load scene setting or character definition if self.prompt_template.template_define: - text = self.prompt_template.template_define + self.prompt_template.sep + text += self.prompt_template.template_define + self.prompt_template.sep + ### Load prompt + text += self.__load_system_message() - ### 处理历史信息 - if len(self.history_message) > self.chat_retention_rounds: - ### 使用历史信息的第一轮和最后n轮数据合并成历史对话记录, 做上下文提示时,用户展示消息需要过滤掉 - for first_message in self.history_message[0]['messages']: - if not first_message['type'] in [ViewMessage.type, SystemMessage.type]: - text += ( - first_message['type'] - + ":" - + first_message['data']['content'] - + self.prompt_template.sep - ) + ### Load examples + text += self.__load_example_messages() - index = self.chat_retention_rounds - 1 - for round_conv in self.history_message[-index:]: + ### Load History + text += self.__load_histroy_messages() + + ### Load User Input + text += self.__load_user_message() + return text + + def __load_system_message(self): + system_convs = self.current_message.get_system_conv() + system_text = "" + for system_conv in system_convs: + system_text += system_conv.type + ":" + system_conv.content + self.prompt_template.sep + return system_text + + def __load_user_message(self): + user_conv = self.current_message.get_user_conv() + if user_conv: + return user_conv.type + ":" + user_conv.content + self.prompt_template.sep + else: + raise ValueError("Hi! What do you want to talk about?") + + def __load_example_messages(self): + example_text = "" + if self.prompt_template.example_selector: + for round_conv in self.prompt_template.example_selector.examples(): for round_message in round_conv['messages']: - if not isinstance(round_message, ViewMessage): - text += ( + if not round_message['type'] in [SystemMessage.type, ViewMessage.type]: + example_text += ( round_message['type'] + ":" + round_message['data']['content'] + self.prompt_template.sep ) + return example_text - else: - ### 直接历史记录拼接 - for conversation in self.history_message: - for message in conversation['messages']: - if not isinstance(message, ViewMessage): - text += ( - message['type'] + def __load_histroy_messages(self): + history_text = "" + if self.prompt_template.need_historical_messages: + if self.history_message: + logger.info( + f"There are already {len(self.history_message)} rounds of conversations! Will use {self.chat_retention_rounds} rounds of content as history!" + ) + if len(self.history_message) > self.chat_retention_rounds: + for first_message in self.history_message[0]['messages']: + if not first_message['type'] in [ViewMessage.type, SystemMessage.type]: + history_text += ( + first_message['type'] + ":" - + message['data']['content'] + + first_message['data']['content'] + self.prompt_template.sep ) - ### current conversation - for now_message in self.current_message.messages: - text += ( - now_message.type + ":" + now_message.content + self.prompt_template.sep - ) + index = self.chat_retention_rounds - 1 + for round_conv in self.history_message[-index:]: + for round_message in round_conv['messages']: + if not round_message['type'] in [SystemMessage.type, ViewMessage.type]: + history_text += ( + round_message['type'] + + ":" + + round_message['data']['content'] + + self.prompt_template.sep + ) - return text + else: + ### user all history + for conversation in self.history_message: + for message in conversation['messages']: + ### histroy message not have promot and view info + if not message['type'] in [SystemMessage.type, ViewMessage.type]: + history_text += ( + message['type'] + + ":" + + message['data']['content'] + + self.prompt_template.sep + ) + + return history_text - # 暂时为了兼容前端 def current_ai_response(self) -> str: for message in self.current_message.messages: if message.type == "view": return message.content return None - def _load_history(self, session_id: str) -> List[OnceConversation]: - """ - load chat history by session_id - Args: - session_id: - Returns: - """ - return self.memory.messages() - def generate(self, p) -> str: """ generate context for LLM input diff --git a/pilot/scene/chat_execution/example.py b/pilot/scene/chat_execution/example.py index 0008ee9ec..f50a7f546 100644 --- a/pilot/scene/chat_execution/example.py +++ b/pilot/scene/chat_execution/example.py @@ -6,4 +6,4 @@ EXAMPLES = [ [{"system": "123"},{"system":"xxx"},{"human":"xxx"},{"assistant":"xxx"}] ] -example = ExampleSelector(examples=EXAMPLES, use_example=True) +plugin_example = ExampleSelector(examples_record=EXAMPLES, use_example=True) diff --git a/pilot/scene/chat_execution/prompt.py b/pilot/scene/chat_execution/prompt.py index 768527c19..af5087609 100644 --- a/pilot/scene/chat_execution/prompt.py +++ b/pilot/scene/chat_execution/prompt.py @@ -5,7 +5,7 @@ from pilot.scene.base import ChatScene from pilot.common.schema import SeparatorStyle, ExampleType from pilot.scene.chat_execution.out_parser import PluginChatOutputParser -from pilot.scene.chat_execution.example import example +from pilot.scene.chat_execution.example import plugin_example CFG = Config() @@ -49,7 +49,7 @@ prompt = PromptTemplate( output_parser=PluginChatOutputParser( sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_STREAM_OUT ), - example=example, + example_selector=plugin_example, ) CFG.prompt_templates.update({prompt.template_scene: prompt}) diff --git a/pilot/scene/message.py b/pilot/scene/message.py index ba884c571..972331bbb 100644 --- a/pilot/scene/message.py +++ b/pilot/scene/message.py @@ -85,12 +85,18 @@ class OnceConversation: self.messages.clear() self.session_id = None - def get_user_message(self): - for once in self.messages: - if isinstance(once, HumanMessage): - return once.content - return "" + def get_user_conv(self): + for message in self.messages: + if isinstance(message, HumanMessage): + return message + return None + def get_system_conv(self): + system_convs = [] + for message in self.messages: + if isinstance(message, SystemMessage): + system_convs.append(message) + return system_convs def _conversation_to_dic(once: OnceConversation) -> dict: start_str: str = "" diff --git a/pilot/server/dbgpt_server.py b/pilot/server/dbgpt_server.py index 7032982d9..6c22105cd 100644 --- a/pilot/server/dbgpt_server.py +++ b/pilot/server/dbgpt_server.py @@ -18,6 +18,7 @@ from pilot.utils import build_logger from pilot.server.webserver_base import server_init +from fastapi.staticfiles import StaticFiles from fastapi import FastAPI, applications from fastapi.openapi.docs import get_swagger_ui_html from fastapi.exceptions import RequestValidationError @@ -25,6 +26,7 @@ from fastapi.middleware.cors import CORSMiddleware from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler +static_file_path = os.path.join(os.getcwd(), "server/static") CFG = Config() logger = build_logger("webserver", LOGDIR + "webserver.log") @@ -52,7 +54,9 @@ app.add_middleware( allow_headers=["*"], ) -# app.mount("static", StaticFiles(directory="static"), name="static") +app.mount("/static", StaticFiles(directory=static_file_path), name="static") +app.add_route("/test", "static/test.html") + app.include_router(api_v1) app.add_exception_handler(RequestValidationError, validation_exception_handler) diff --git a/pilot/server/static/test.html b/pilot/server/static/test.html new file mode 100644 index 000000000..709180f11 --- /dev/null +++ b/pilot/server/static/test.html @@ -0,0 +1,19 @@ + + + + + Streaming Demo + + + +
+ + + \ No newline at end of file From 6a3bf33a249507ac3be5503744e144e9d1ca5701 Mon Sep 17 00:00:00 2001 From: "tuyang.yhj" Date: Thu, 29 Jun 2023 11:58:13 +0800 Subject: [PATCH 4/4] WEB API independent --- pilot/connections/rdbms/py_study/test_duckdb.py | 1 + pilot/openapi/api_v1/api_v1.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/pilot/connections/rdbms/py_study/test_duckdb.py b/pilot/connections/rdbms/py_study/test_duckdb.py index dbcf2ecb7..dc3f926ac 100644 --- a/pilot/connections/rdbms/py_study/test_duckdb.py +++ b/pilot/connections/rdbms/py_study/test_duckdb.py @@ -8,6 +8,7 @@ duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db") if __name__ == "__main__": if os.path.isfile(duckdb_path): cursor = duckdb.connect(duckdb_path).cursor() + # cursor.execute("SELECT * FROM chat_history limit 20") cursor.execute("SELECT * FROM chat_history limit 20") data = cursor.fetchall() print(data) diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py index 8d71ebdea..4f2e23946 100644 --- a/pilot/openapi/api_v1/api_v1.py +++ b/pilot/openapi/api_v1/api_v1.py @@ -74,15 +74,15 @@ async def dialogue_list(response: Response, user_id: str = None): ) dialogues.append(conv_vo) - return Result[ConversationVo].succ(dialogues) + return Result[ConversationVo].succ(dialogues[-10:][::-1]) @router.post("/v1/chat/dialogue/scenes", response_model=Result[List[ChatSceneVo]]) async def dialogue_scenes(): scene_vos: List[ChatSceneVo] = [] new_modes: List[ChatScene] = [ - ChatScene.ChatDb, - ChatScene.ChatData, + ChatScene.ChatWithDbExecute, + ChatScene.ChatWithDbQA, ChatScene.ChatDashboard, ChatScene.ChatKnowledge, ChatScene.ChatExecution,