diff --git a/.gitignore b/.gitignore index d040022b1..f6e18e09f 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ __pycache__/ *.so message/ +static/ .env .idea diff --git a/pilot/configs/config.py b/pilot/configs/config.py index 7e259f9fc..0be6f18fc 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -17,6 +17,8 @@ class Config(metaclass=Singleton): def __init__(self) -> None: """Initialize the Config class""" + self.NEW_SERVER_MODE = False + # Gradio language version: en, zh self.LANGUAGE = os.getenv("LANGUAGE", "en") self.WEB_SERVER_PORT = int(os.getenv("WEB_SERVER_PORT", 7860)) diff --git a/pilot/connections/rdbms/py_study/test_duckdb.py b/pilot/connections/rdbms/py_study/test_duckdb.py index dbcf2ecb7..dc3f926ac 100644 --- a/pilot/connections/rdbms/py_study/test_duckdb.py +++ b/pilot/connections/rdbms/py_study/test_duckdb.py @@ -8,6 +8,7 @@ duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db") if __name__ == "__main__": if os.path.isfile(duckdb_path): cursor = duckdb.connect(duckdb_path).cursor() + # cursor.execute("SELECT * FROM chat_history limit 20") cursor.execute("SELECT * FROM chat_history limit 20") data = cursor.fetchall() print(data) diff --git a/pilot/memory/chat_history/duckdb_history.py b/pilot/memory/chat_history/duckdb_history.py index d97232217..de80a5bc2 100644 --- a/pilot/memory/chat_history/duckdb_history.py +++ b/pilot/memory/chat_history/duckdb_history.py @@ -8,18 +8,20 @@ from pilot.memory.chat_history.base import BaseChatHistoryMemory from pilot.scene.message import ( OnceConversation, conversation_from_dict, + _conversation_to_dic, conversations_to_dict, ) from pilot.common.formatting import MyEncoder default_db_path = os.path.join(os.getcwd(), "message") duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db") -table_name = "chat_history" +table_name = 'chat_history' CFG = Config() class DuckdbHistoryMemory(BaseChatHistoryMemory): + def __init__(self, chat_session_id: str): self.chat_seesion_id = chat_session_id os.makedirs(default_db_path, exist_ok=True) @@ -27,26 +29,28 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory): self.__init_chat_history_tables() def __init_chat_history_tables(self): + # 检查表是否存在 - result = self.connect.execute( - "SELECT name FROM sqlite_master WHERE type='table' AND name=?", [table_name] - ).fetchall() + result = self.connect.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?", + [table_name]).fetchall() if not result: # 如果表不存在,则创建新表 self.connect.execute( - "CREATE TABLE chat_history (conv_uid VARCHAR(100) PRIMARY KEY, user_name VARCHAR(100), messages TEXT)" - ) + "CREATE TABLE chat_history (conv_uid VARCHAR(100) PRIMARY KEY, user_name VARCHAR(100), messages TEXT)") def __get_messages_by_conv_uid(self, conv_uid: str): cursor = self.connect.cursor() cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [conv_uid]) - return cursor.fetchone() - + content = cursor.fetchone() + if content: + return content[0] + else: + return None def messages(self) -> List[OnceConversation]: context = self.__get_messages_by_conv_uid(self.chat_seesion_id) if context: - conversations: List[OnceConversation] = json.loads(context[0]) + conversations: List[OnceConversation] = json.loads(context) return conversations return [] @@ -54,50 +58,27 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory): context = self.__get_messages_by_conv_uid(self.chat_seesion_id) conversations: List[OnceConversation] = [] if context: - conversations = json.load(context) - conversations.append(once_message) + conversations = json.loads(context) + conversations.append(_conversation_to_dic(once_message)) cursor = self.connect.cursor() if context: - cursor.execute( - "UPDATE chat_history set messages=? where conv_uid=?", - [ - json.dumps( - conversations_to_dict(conversations), - ensure_ascii=False, - indent=4, - ), - self.chat_seesion_id, - ], - ) + cursor.execute("UPDATE chat_history set messages=? where conv_uid=?", + [json.dumps(conversations, ensure_ascii=False), self.chat_seesion_id]) else: - cursor.execute( - "INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)", - [ - self.chat_seesion_id, - "", - json.dumps( - conversations_to_dict(conversations), - ensure_ascii=False, - indent=4, - ), - ], - ) + cursor.execute("INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)", + [self.chat_seesion_id, "", json.dumps(conversations, ensure_ascii=False)]) cursor.commit() self.connect.commit() def clear(self) -> None: cursor = self.connect.cursor() - cursor.execute( - "DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id] - ) + cursor.execute("DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id]) cursor.commit() self.connect.commit() def delete(self) -> bool: cursor = self.connect.cursor() - cursor.execute( - "DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id] - ) + cursor.execute("DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id]) cursor.commit() return True @@ -106,9 +87,7 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory): if os.path.isfile(duckdb_path): cursor = duckdb.connect(duckdb_path).cursor() if user_name: - cursor.execute( - "SELECT * FROM chat_history where user_name=? limit 20", [user_name] - ) + cursor.execute("SELECT * FROM chat_history where user_name=? limit 20", [user_name]) else: cursor.execute("SELECT * FROM chat_history limit 20") # 获取查询结果字段名 @@ -124,11 +103,10 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory): return [] - def get_messages(self) -> List[OnceConversation]: + + def get_messages(self)-> List[OnceConversation]: cursor = self.connect.cursor() - cursor.execute( - "SELECT messages FROM chat_history where conv_uid=?", [self.chat_seesion_id] - ) + cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [self.chat_seesion_id]) context = cursor.fetchone() if context: return json.loads(context[0]) diff --git a/pilot/model/llm_out/proxy_llm.py b/pilot/model/llm_out/proxy_llm.py index 4336d43e3..c4423a1a6 100644 --- a/pilot/model/llm_out/proxy_llm.py +++ b/pilot/model/llm_out/proxy_llm.py @@ -39,7 +39,7 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048) elif "ai:" in message: history.append( { - "role": "ai", + "role": "assistant", "content": message.split("ai:")[1], } ) @@ -57,6 +57,7 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048) for m in temp_his: if m["role"] == "user": last_user_input = m + break if last_user_input: history.remove(last_user_input) history.append(last_user_input) diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py index ba5ef2f04..4f2e23946 100644 --- a/pilot/openapi/api_v1/api_v1.py +++ b/pilot/openapi/api_v1/api_v1.py @@ -2,7 +2,7 @@ import uuid import json import asyncio import time -from fastapi import APIRouter, Request, Body, status, HTTPException, Response +from fastapi import APIRouter, Request, Body, status, HTTPException, Response, BackgroundTasks from fastapi.responses import JSONResponse from fastapi.responses import StreamingResponse @@ -12,12 +12,7 @@ from fastapi.responses import JSONResponse from sse_starlette.sse import EventSourceResponse from typing import List -from pilot.server.api_v1.api_view_model import ( - Result, - ConversationVo, - MessageVo, - ChatSceneVo, -) +from pilot.openapi.api_v1.api_view_model import Result, ConversationVo, MessageVo, ChatSceneVo from pilot.configs.config import Config from pilot.openapi.knowledge.knowledge_service import KnowledgeService from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest @@ -37,6 +32,9 @@ CHAT_FACTORY = ChatFactory() logger = build_logger("api_v1", LOGDIR + "api_v1.log") knowledge_service = KnowledgeService() +model_semaphore = None +global_counter = 0 + async def validation_exception_handler(request: Request, exc: RequestValidationError): message = "" @@ -76,15 +74,15 @@ async def dialogue_list(response: Response, user_id: str = None): ) dialogues.append(conv_vo) - return Result[ConversationVo].succ(dialogues) + return Result[ConversationVo].succ(dialogues[-10:][::-1]) @router.post("/v1/chat/dialogue/scenes", response_model=Result[List[ChatSceneVo]]) async def dialogue_scenes(): scene_vos: List[ChatSceneVo] = [] new_modes: List[ChatScene] = [ - ChatScene.ChatDb, - ChatScene.ChatData, + ChatScene.ChatWithDbExecute, + ChatScene.ChatWithDbQA, ChatScene.ChatDashboard, ChatScene.ChatKnowledge, ChatScene.ChatExecution, @@ -105,7 +103,7 @@ async def dialogue_scenes(): @router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo]) async def dialogue_new( - chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None + chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None ): unique_id = uuid.uuid1() return Result.succ(ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode)) @@ -139,9 +137,9 @@ def knowledge_list(): @router.post("/v1/chat/mode/params/list", response_model=Result[dict]) async def params_list(chat_mode: str = ChatScene.ChatNormal.value): - if ChatScene.ChatDb.value == chat_mode: + if ChatScene.ChatWithDbQA.value == chat_mode: return Result.succ(get_db_list()) - elif ChatScene.ChatData.value == chat_mode: + elif ChatScene.ChatWithDbExecute.value == chat_mode: return Result.succ(get_db_list()) elif ChatScene.ChatDashboard.value == chat_mode: return Result.succ(get_db_list()) @@ -179,6 +177,16 @@ async def dialogue_history_messages(con_uid: str): @router.post("/v1/chat/completions") async def chat_completions(dialogue: ConversationVo = Body()): print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}") + if not dialogue.chat_mode: + dialogue.chat_mode = ChatScene.ChatNormal.value + if not dialogue.conv_uid: + dialogue.conv_uid = str(uuid.uuid1()) + + global model_semaphore, global_counter + global_counter += 1 + if model_semaphore is None: + model_semaphore = asyncio.Semaphore(CFG.LIMIT_MODEL_CONCURRENCY) + await model_semaphore.acquire() if not ChatScene.is_valid_mode(dialogue.chat_mode): raise StopAsyncIteration( @@ -190,99 +198,65 @@ async def chat_completions(dialogue: ConversationVo = Body()): "user_input": dialogue.user_input, } - if ChatScene.ChatDb == dialogue.chat_mode: - chat_param.update("db_name", dialogue.select_param) - elif ChatScene.ChatData == dialogue.chat_mode: - chat_param.update("db_name", dialogue.select_param) - elif ChatScene.ChatDashboard == dialogue.chat_mode: - chat_param.update("db_name", dialogue.select_param) - elif ChatScene.ChatExecution == dialogue.chat_mode: - chat_param.update("plugin_selector", dialogue.select_param) - elif ChatScene.ChatKnowledge == dialogue.chat_mode: - chat_param.update("knowledge_space", dialogue.select_param) + if ChatScene.ChatWithDbQA.value == dialogue.chat_mode: + chat_param.update({"db_name": dialogue.select_param}) + elif ChatScene.ChatWithDbExecute.value == dialogue.chat_mode: + chat_param.update({"db_name": dialogue.select_param}) + elif ChatScene.ChatDashboard.value == dialogue.chat_mode: + chat_param.update({"db_name": dialogue.select_param}) + elif ChatScene.ChatExecution.value == dialogue.chat_mode: + chat_param.update({"plugin_selector": dialogue.select_param}) + elif ChatScene.ChatKnowledge.value == dialogue.chat_mode: + chat_param.update({"knowledge_space": dialogue.select_param}) chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param) + background_tasks = BackgroundTasks() + background_tasks.add_task(release_model_semaphore) + headers = { + # "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + # "Transfer-Encoding": "chunked", + } + if not chat.prompt_template.stream_out: - return non_stream_response(chat) + return StreamingResponse(no_stream_generator(chat), headers=headers, media_type='text/event-stream', + background=background_tasks) else: - # generator = stream_generator(chat) - # result = Result.succ(data=StreamingResponse(stream_test(), media_type='text/plain')) - # return result - return StreamingResponse(stream_generator(chat), media_type="text/plain") + return StreamingResponse(stream_generator(chat), headers=headers, media_type='text/plain', + background=background_tasks) -def stream_test(): - for message in ["Hello", "world", "how", "are", "you"]: - yield message - # yield json.dumps(Result.succ(message).__dict__).encode("utf-8") +def release_model_semaphore(): + model_semaphore.release() -def stream_generator(chat): +async def no_stream_generator(chat): + msg = chat.nostream_call() + msg = msg.replace("\n", "\\n") + yield f"data: {msg}\n\n" + +async def stream_generator(chat): model_response = chat.stream_call() - for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"): - if chunk: - msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex( - chunk, chat.skip_echo_len - ) - chat.current_message.add_ai_message(msg) - yield msg - # chat.current_message.add_ai_message(msg) - # vo = MessageVo(role="view", context=msg, order=chat.current_message.chat_order) - # json_text = json.dumps(vo.__dict__) - # yield json_text.encode('utf-8') + if not CFG.NEW_SERVER_MODE: + for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"): + if chunk: + msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len) + chat.current_message.add_ai_message(msg) + msg = msg.replace("\n", "\\n") + yield f"data:{msg}\n\n" + await asyncio.sleep(0.1) + else: + for chunk in model_response: + if chunk: + msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len) + chat.current_message.add_ai_message(msg) + + msg = msg.replace("\n", "\\n") + yield f"data:{msg}\n\n" + await asyncio.sleep(0.1) chat.memory.append(chat.current_message) -# def stream_response(chat): -# logger.info("stream out start!") -# api_response = StreamingResponse(stream_generator(chat), media_type="application/json") -# return api_response - - def message2Vo(message: dict, order) -> MessageVo: - # message.additional_kwargs['time_stamp'] if message.additional_kwargs["time_stamp"] else 0 - return MessageVo( - role=message["type"], context=message["data"]["content"], order=order - ) - - -def non_stream_response(chat): - logger.info("not stream out, wait model response!") - return chat.nostream_call() - - -@router.get("/v1/db/types", response_model=Result[str]) -async def db_types(): - return Result.succ(["mysql", "duckdb"]) - - -@router.get("/v1/db/list", response_model=Result[str]) -async def db_list(): - db = CFG.local_db - dbs = db.get_database_list() - return Result.succ(dbs) - - -@router.get("/v1/knowledge/list") -async def knowledge_list(): - return ["test1", "test2"] - - -@router.post("/v1/knowledge/add") -async def knowledge_add(): - return ["test1", "test2"] - - -@router.post("/v1/knowledge/delete") -async def knowledge_delete(): - return ["test1", "test2"] - - -@router.get("/v1/knowledge/types") -async def knowledge_types(): - return ["test1", "test2"] - - -@router.get("/v1/knowledge/detail") -async def knowledge_detail(): - return ["test1", "test2"] + return MessageVo(role=message['type'], context=message['data']['content'], order=order) diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index 4476a68fc..ca308e92f 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -47,6 +47,8 @@ class BaseOutputParser(ABC): return code def parse_model_stream_resp_ex(self, chunk, skip_echo_len): + if b"\0" in chunk: + chunk = chunk.replace(b"\0", b"") data = json.loads(chunk.decode()) """ TODO Multi mode output handler, rewrite this for multi model, use adapter mode. @@ -95,11 +97,8 @@ class BaseOutputParser(ABC): def parse_model_nostream_resp(self, response, sep: str): text = response.text.strip() text = text.rstrip() - respObj = json.loads(text) - - xx = respObj["response"] - xx = xx.strip(b"\x00".decode()) - respObj_ex = json.loads(xx) + text = text.strip(b"\x00".decode()) + respObj_ex = json.loads(text) if respObj_ex["error_code"] == 0: all_text = respObj_ex["text"] ### 解析返回文本,获取AI回复部分 @@ -123,7 +122,7 @@ class BaseOutputParser(ABC): def __extract_json(slef, s): i = s.index("{") count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数 - for j, c in enumerate(s[i + 1 :], start=i + 1): + for j, c in enumerate(s[i + 1:], start=i + 1): if c == "}": count -= 1 elif c == "{": @@ -131,7 +130,7 @@ class BaseOutputParser(ABC): if count == 0: break assert count == 0 # 检查是否找到最后一个'}' - return s[i : j + 1] + return s[i: j + 1] def parse_prompt_response(self, model_out_text) -> T: """ @@ -148,9 +147,9 @@ class BaseOutputParser(ABC): # if "```" in cleaned_output: # cleaned_output, _ = cleaned_output.split("```") if cleaned_output.startswith("```json"): - cleaned_output = cleaned_output[len("```json") :] + cleaned_output = cleaned_output[len("```json"):] if cleaned_output.startswith("```"): - cleaned_output = cleaned_output[len("```") :] + cleaned_output = cleaned_output[len("```"):] if cleaned_output.endswith("```"): cleaned_output = cleaned_output[: -len("```")] cleaned_output = cleaned_output.strip() @@ -159,9 +158,9 @@ class BaseOutputParser(ABC): cleaned_output = self.__extract_json(cleaned_output) cleaned_output = ( cleaned_output.strip() - .replace("\n", " ") - .replace("\\n", " ") - .replace("\\", " ") + .replace("\n", " ") + .replace("\\n", " ") + .replace("\\", " ") ) return cleaned_output diff --git a/pilot/prompts/example_base.py b/pilot/prompts/example_base.py index e930927c9..2553be150 100644 --- a/pilot/prompts/example_base.py +++ b/pilot/prompts/example_base.py @@ -6,7 +6,7 @@ from pilot.common.schema import ExampleType class ExampleSelector(BaseModel, ABC): - examples: List[List] + examples_record: List[List] use_example: bool = False type: str = ExampleType.ONE_SHOT.value @@ -22,7 +22,7 @@ class ExampleSelector(BaseModel, ABC): Returns: example text """ if self.use_example: - need_use = self.examples[:count] + need_use = self.examples_record[:count] return need_use return None @@ -33,7 +33,7 @@ class ExampleSelector(BaseModel, ABC): """ if self.use_example: - need_use = self.examples[:1] + need_use = self.examples_record[:1] return need_use return None diff --git a/pilot/prompts/prompt_new.py b/pilot/prompts/prompt_new.py index 475f82ea4..80f05c730 100644 --- a/pilot/prompts/prompt_new.py +++ b/pilot/prompts/prompt_new.py @@ -46,7 +46,10 @@ class PromptTemplate(BaseModel, ABC): output_parser: BaseOutputParser = None """""" sep: str = SeparatorStyle.SINGLE.value - example: ExampleSelector = None + + example_selector: ExampleSelector = None + + need_historical_messages: bool = False class Config: """Configuration for this pydantic object.""" diff --git a/pilot/scene/base.py b/pilot/scene/base.py index d0bb99255..7c7d483f9 100644 --- a/pilot/scene/base.py +++ b/pilot/scene/base.py @@ -20,8 +20,8 @@ class ChatScene(Enum): ChatNormal = "chat_normal" ChatDashboard = "chat_dashboard" ChatKnowledge = "chat_knowledge" - ChatDb = "chat_db" - ChatData = "chat_data" + # ChatDb = "chat_db" + # ChatData= "chat_data" @staticmethod def is_valid_mode(mode): diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index c9ada6b29..8e7b3dfe7 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -39,6 +39,7 @@ from pilot.scene.base_message import ( ViewMessage, ) from pilot.configs.config import Config +from pilot.server.llmserver import worker logger = build_logger("BaseChat", LOGDIR + "BaseChat.log") headers = {"User-Agent": "dbgpt Client"} @@ -51,7 +52,7 @@ class BaseChat(ABC): temperature: float = 0.6 max_new_tokens: int = 1024 # By default, keep the last two rounds of conversation records as the context - chat_retention_rounds: int = 2 + chat_retention_rounds: int = 1 class Config: """Configuration for this pydantic object.""" @@ -59,10 +60,10 @@ class BaseChat(ABC): arbitrary_types_allowed = True def __init__( - self, - chat_mode, - chat_session_id, - current_user_input, + self, + chat_mode, + chat_session_id, + current_user_input, ): self.chat_session_id = chat_session_id self.chat_mode = chat_mode @@ -75,11 +76,9 @@ class BaseChat(ABC): self.prompt_template: PromptTemplate = CFG.prompt_templates[ self.chat_mode.value ] - self.history_message: List[OnceConversation] = [] + self.history_message: List[OnceConversation] = self.memory.messages() self.current_message: OnceConversation = OnceConversation(chat_mode.value) self.current_tokens_used: int = 0 - ### load chat_session_id's chat historys - self._load_history(self.chat_session_id) class Config: """Configuration for this pydantic object.""" @@ -95,7 +94,6 @@ class BaseChat(ABC): def generate_input_values(self): pass - @abstractmethod def do_action(self, prompt_response): return prompt_response @@ -104,23 +102,12 @@ class BaseChat(ABC): ### Chat sequence advance self.current_message.chat_order = len(self.history_message) + 1 self.current_message.add_user_message(self.current_user_input) - self.current_message.start_date = datetime.datetime.now().strftime( - "%Y-%m-%d %H:%M:%S" - ) + self.current_message.start_date = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") # TODO self.current_message.tokens = 0 - current_prompt = None if self.prompt_template.template: current_prompt = self.prompt_template.format(**input_values) - - ### 构建当前对话, 是否安第一次对话prompt构造? 是否考虑切换库 - if self.history_message: - ## TODO 带历史对话记录的场景需要确定切换库后怎么处理 - logger.info( - f"There are already {len(self.history_message)} rounds of conversations!" - ) - if current_prompt: self.current_message.add_system_message(current_prompt) payload = { @@ -140,31 +127,24 @@ class BaseChat(ABC): logger.info(f"Requert: \n{payload}") ai_response_text = "" try: - show_info = "" - response = requests.post( - urljoin(CFG.MODEL_SERVER, "generate_stream"), - headers=headers, - json=payload, - stream=True, - timeout=120, - ) - return response - - # yield self.prompt_template.output_parser.parse_model_stream_resp(response, skip_echo_len) - - # for resp_text_trunck in ai_response_text: - # show_info = resp_text_trunck - # yield resp_text_trunck + "▌" - - self.current_message.add_ai_message(show_info) - + if not CFG.NEW_SERVER_MODE: + response = requests.post( + urljoin(CFG.MODEL_SERVER, "generate_stream"), + headers=headers, + json=payload, + stream=True, + timeout=120, + ) + return response + else: + return worker.generate_stream_gate(payload) except Exception as e: print(traceback.format_exc()) logger.error("model response parase faild!" + str(e)) self.current_message.add_view_message( f"""ERROR!{str(e)}\n {ai_response_text} """ ) - ### 对话记录存储 + ### store current conversation self.memory.append(self.current_message) def nostream_call(self): @@ -172,42 +152,27 @@ class BaseChat(ABC): logger.info(f"Requert: \n{payload}") ai_response_text = "" try: - ### 走非流式的模型服务接口 - response = requests.post( - urljoin(CFG.MODEL_SERVER, "generate"), - headers=headers, - json=payload, - timeout=120, - ) + rsp_str = "" + if not CFG.NEW_SERVER_MODE: + rsp_str = requests.post( + urljoin(CFG.MODEL_SERVER, "generate"), + headers=headers, + json=payload, + timeout=120, + ) + else: + ###TODO no stream mode need independent + output = worker.generate_stream_gate(payload) + for rsp in output: + rsp_str = str(rsp, "utf-8") + print("[TEST: output]:", rsp_str) ### output parse - ai_response_text = ( - self.prompt_template.output_parser.parse_model_nostream_resp( - response, self.prompt_template.sep - ) - ) - - # ### MOCK - # ai_response_text = """{ - # "thoughts": "可以从users表和tran_order表联合查询,按城市和订单数量进行分组统计,并使用柱状图展示。", - # "reasoning": "为了分析用户在不同城市的分布情况,需要查询users表和tran_order表,使用LEFT JOIN将两个表联合起来。按照城市进行分组,统计每个城市的订单数量。使用柱状图展示可以直观地看出每个城市的订单数量,方便比较。", - # "speak": "根据您的分析目标,我查询了用户表和订单表,统计了每个城市的订单数量,并生成了柱状图展示。", - # "command": { - # "name": "histogram-executor", - # "args": { - # "title": "订单城市分布柱状图", - # "sql": "SELECT users.city, COUNT(tran_order.order_id) AS order_count FROM users LEFT JOIN tran_order ON users.user_name = tran_order.user_name GROUP BY users.city" - # } - # } - # }""" - + ai_response_text = self.prompt_template.output_parser.parse_model_nostream_resp(rsp_str, + self.prompt_template.sep) + ### model result deal self.current_message.add_ai_message(ai_response_text) - prompt_define_response = ( - self.prompt_template.output_parser.parse_prompt_response( - ai_response_text - ) - ) - + prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text) result = self.do_action(prompt_define_response) if hasattr(prompt_define_response, "thoughts"): @@ -235,7 +200,7 @@ class BaseChat(ABC): self.current_message.add_view_message( f"""ERROR!{str(e)}\n {ai_response_text} """ ) - ### 对话记录存储 + ### store dialogue self.memory.append(self.current_message) return self.current_ai_response() @@ -247,67 +212,99 @@ class BaseChat(ABC): def generate_llm_text(self) -> str: text = "" + ### Load scene setting or character definition if self.prompt_template.template_define: - text = self.prompt_template.template_define + self.prompt_template.sep + text += self.prompt_template.template_define + self.prompt_template.sep + ### Load prompt + text += self.__load_system_message() - ### 处理历史信息 - if len(self.history_message) > self.chat_retention_rounds: - ### 使用历史信息的第一轮和最后n轮数据合并成历史对话记录, 做上下文提示时,用户展示消息需要过滤掉 - for first_message in self.history_message[0].messages: - if not isinstance(first_message, ViewMessage): - text += ( - first_message.type - + ":" - + first_message.content - + self.prompt_template.sep - ) + ### Load examples + text += self.__load_example_messages() - index = self.chat_retention_rounds - 1 - for last_message in self.history_message[-index:].messages: - if not isinstance(last_message, ViewMessage): - text += ( - last_message.type - + ":" - + last_message.content - + self.prompt_template.sep - ) - - else: - ### 直接历史记录拼接 - for conversation in self.history_message: - for message in conversation.messages: - if not isinstance(message, ViewMessage): - text += ( - message.type - + ":" - + message.content - + self.prompt_template.sep - ) - ### current conversation - - for now_message in self.current_message.messages: - text += ( - now_message.type + ":" + now_message.content + self.prompt_template.sep - ) + ### Load History + text += self.__load_histroy_messages() + ### Load User Input + text += self.__load_user_message() return text - # 暂时为了兼容前端 + def __load_system_message(self): + system_convs = self.current_message.get_system_conv() + system_text = "" + for system_conv in system_convs: + system_text += system_conv.type + ":" + system_conv.content + self.prompt_template.sep + return system_text + + def __load_user_message(self): + user_conv = self.current_message.get_user_conv() + if user_conv: + return user_conv.type + ":" + user_conv.content + self.prompt_template.sep + else: + raise ValueError("Hi! What do you want to talk about?") + + def __load_example_messages(self): + example_text = "" + if self.prompt_template.example_selector: + for round_conv in self.prompt_template.example_selector.examples(): + for round_message in round_conv['messages']: + if not round_message['type'] in [SystemMessage.type, ViewMessage.type]: + example_text += ( + round_message['type'] + + ":" + + round_message['data']['content'] + + self.prompt_template.sep + ) + return example_text + + def __load_histroy_messages(self): + history_text = "" + if self.prompt_template.need_historical_messages: + if self.history_message: + logger.info( + f"There are already {len(self.history_message)} rounds of conversations! Will use {self.chat_retention_rounds} rounds of content as history!" + ) + if len(self.history_message) > self.chat_retention_rounds: + for first_message in self.history_message[0]['messages']: + if not first_message['type'] in [ViewMessage.type, SystemMessage.type]: + history_text += ( + first_message['type'] + + ":" + + first_message['data']['content'] + + self.prompt_template.sep + ) + + index = self.chat_retention_rounds - 1 + for round_conv in self.history_message[-index:]: + for round_message in round_conv['messages']: + if not round_message['type'] in [SystemMessage.type, ViewMessage.type]: + history_text += ( + round_message['type'] + + ":" + + round_message['data']['content'] + + self.prompt_template.sep + ) + + else: + ### user all history + for conversation in self.history_message: + for message in conversation['messages']: + ### histroy message not have promot and view info + if not message['type'] in [SystemMessage.type, ViewMessage.type]: + history_text += ( + message['type'] + + ":" + + message['data']['content'] + + self.prompt_template.sep + ) + + return history_text + def current_ai_response(self) -> str: for message in self.current_message.messages: if message.type == "view": return message.content return None - def _load_history(self, session_id: str) -> List[OnceConversation]: - """ - load chat history by session_id - Args: - session_id: - Returns: - """ - return self.memory.messages() - def generate(self, p) -> str: """ generate context for LLM input diff --git a/pilot/scene/chat_db/auto_execute/prompt.py b/pilot/scene/chat_db/auto_execute/prompt.py index 938860aaf..cc4878c70 100644 --- a/pilot/scene/chat_db/auto_execute/prompt.py +++ b/pilot/scene/chat_db/auto_execute/prompt.py @@ -8,7 +8,7 @@ from pilot.common.schema import SeparatorStyle CFG = Config() -PROMPT_SCENE_DEFINE = """You are an AI designed to answer human questions, please follow the prompts and conventions of the system's input for your answers""" +PROMPT_SCENE_DEFINE = None _DEFAULT_TEMPLATE = """ diff --git a/pilot/scene/chat_execution/example.py b/pilot/scene/chat_execution/example.py index e41de3abd..f50a7f546 100644 --- a/pilot/scene/chat_execution/example.py +++ b/pilot/scene/chat_execution/example.py @@ -2,8 +2,8 @@ from pilot.prompts.example_base import ExampleSelector ## Two examples are defined by default EXAMPLES = [ - [{"System": "123"}, {"System": "xxx"}, {"User": "xxx"}, {"Assistant": "xxx"}], - [{"System": "123"}, {"System": "xxx"}, {"User": "xxx"}, {"Assistant": "xxx"}], + [{"system": "123"},{"system":"xxx"},{"human":"xxx"},{"assistant":"xxx"}], + [{"system": "123"},{"system":"xxx"},{"human":"xxx"},{"assistant":"xxx"}] ] -example = ExampleSelector(examples=EXAMPLES, use_example=True) +plugin_example = ExampleSelector(examples_record=EXAMPLES, use_example=True) diff --git a/pilot/scene/chat_execution/prompt.py b/pilot/scene/chat_execution/prompt.py index ceb5db09d..af5087609 100644 --- a/pilot/scene/chat_execution/prompt.py +++ b/pilot/scene/chat_execution/prompt.py @@ -5,14 +5,12 @@ from pilot.scene.base import ChatScene from pilot.common.schema import SeparatorStyle, ExampleType from pilot.scene.chat_execution.out_parser import PluginChatOutputParser -from pilot.scene.chat_execution.example import example +from pilot.scene.chat_execution.example import plugin_example CFG = Config() -# PROMPT_SCENE_DEFINE = """You are an AI designed to solve the user's goals with given commands, please follow the prompts and constraints of the system's input for your answers.""" PROMPT_SCENE_DEFINE = "You are an AI designed to solve the user's goals with given commands, please follow the constraints of the system's input for your answers." - _DEFAULT_TEMPLATE = """ Goals: {input} @@ -51,7 +49,7 @@ prompt = PromptTemplate( output_parser=PluginChatOutputParser( sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_STREAM_OUT ), - example=example, + example_selector=plugin_example, ) CFG.prompt_templates.update({prompt.template_scene: prompt}) diff --git a/pilot/scene/chat_normal/prompt.py b/pilot/scene/chat_normal/prompt.py index 55514dffa..1dc455255 100644 --- a/pilot/scene/chat_normal/prompt.py +++ b/pilot/scene/chat_normal/prompt.py @@ -8,8 +8,7 @@ from pilot.common.schema import SeparatorStyle from pilot.scene.chat_normal.out_parser import NormalChatOutputParser -PROMPT_SCENE_DEFINE = """A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge. - The assistant gives helpful, detailed, professional and polite answers to the user's questions. """ +PROMPT_SCENE_DEFINE = None CFG = Config() diff --git a/pilot/scene/message.py b/pilot/scene/message.py index ae32bbfe7..972331bbb 100644 --- a/pilot/scene/message.py +++ b/pilot/scene/message.py @@ -85,16 +85,22 @@ class OnceConversation: self.messages.clear() self.session_id = None - def get_user_message(self): - for once in self.messages: - if isinstance(once, HumanMessage): - return once.content - return "" + def get_user_conv(self): + for message in self.messages: + if isinstance(message, HumanMessage): + return message + return None + def get_system_conv(self): + system_convs = [] + for message in self.messages: + if isinstance(message, SystemMessage): + system_convs.append(message) + return system_convs def _conversation_to_dic(once: OnceConversation) -> dict: start_str: str = "" - if once.start_date: + if hasattr(once, 'start_date') and once.start_date: if isinstance(once.start_date, datetime): start_str = once.start_date.strftime("%Y-%m-%d %H:%M:%S") else: diff --git a/pilot/server/dbgpt_server.py b/pilot/server/dbgpt_server.py new file mode 100644 index 000000000..6c22105cd --- /dev/null +++ b/pilot/server/dbgpt_server.py @@ -0,0 +1,78 @@ +import traceback +import os +import shutil +import argparse +import sys + +ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.append(ROOT_PATH) + +from pilot.configs.config import Config +from pilot.configs.model_config import ( + DATASETS_DIR, + KNOWLEDGE_UPLOAD_ROOT_PATH, + LLM_MODEL_CONFIG, + LOGDIR, +) +from pilot.utils import build_logger + +from pilot.server.webserver_base import server_init + +from fastapi.staticfiles import StaticFiles +from fastapi import FastAPI, applications +from fastapi.openapi.docs import get_swagger_ui_html +from fastapi.exceptions import RequestValidationError +from fastapi.middleware.cors import CORSMiddleware + +from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler + +static_file_path = os.path.join(os.getcwd(), "server/static") + +CFG = Config() +logger = build_logger("webserver", LOGDIR + "webserver.log") + + +def swagger_monkey_patch(*args, **kwargs): + return get_swagger_ui_html( + *args, **kwargs, + swagger_js_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js', + swagger_css_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css' + ) + + +applications.get_swagger_ui_html = swagger_monkey_patch + +app = FastAPI() +origins = ["*"] + +# 添加跨域中间件 +app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"], + allow_headers=["*"], +) + +app.mount("/static", StaticFiles(directory=static_file_path), name="static") +app.add_route("/test", "static/test.html") + +app.include_router(api_v1) +app.add_exception_handler(RequestValidationError, validation_exception_handler) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model_list_mode", type=str, default="once", choices=["once", "reload"]) + + # old version server config + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument("--port", type=int, default=CFG.WEB_SERVER_PORT) + parser.add_argument("--concurrency-count", type=int, default=10) + parser.add_argument("--share", default=False, action="store_true") + + # init server config + args = parser.parse_args() + server_init(args) + CFG.NEW_SERVER_MODE = True + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=5000) diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index 367671d68..67e6183b2 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -33,7 +33,7 @@ class ModelWorker: model_path = model_path[:-1] self.model_name = model_name or model_path.split("/")[-1] self.device = device - + print(f"Loading {model_name} LLM ModelServer in {device}! Please Wait......") self.ml = ModelLoader(model_path=model_path) self.model, self.tokenizer = self.ml.loader( num_gpus, load_8bit=ISLOAD_8BIT, debug=ISDEBUG @@ -41,11 +41,11 @@ class ModelWorker: if not isinstance(self.model, str): if hasattr(self.model, "config") and hasattr( - self.model.config, "max_sequence_length" + self.model.config, "max_sequence_length" ): self.context_len = self.model.config.max_sequence_length elif hasattr(self.model, "config") and hasattr( - self.model.config, "max_position_embeddings" + self.model.config, "max_position_embeddings" ): self.context_len = self.model.config.max_position_embeddings @@ -55,29 +55,32 @@ class ModelWorker: self.llm_chat_adapter = get_llm_chat_adapter(model_path) self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func() + def start_check(self): + print("LLM Model Loading Success!") + def get_queue_length(self): if ( - model_semaphore is None - or model_semaphore._value is None - or model_semaphore._waiters is None + model_semaphore is None + or model_semaphore._value is None + or model_semaphore._waiters is None ): return 0 else: ( - CFG.LIMIT_MODEL_CONCURRENCY - - model_semaphore._value - + len(model_semaphore._waiters) + CFG.LIMIT_MODEL_CONCURRENCY + - model_semaphore._value + + len(model_semaphore._waiters) ) def generate_stream_gate(self, params): try: for output in self.generate_stream_func( - self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS + self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS ): # Please do not open the output in production! # The gpt4all thread shares stdout with the parent process, # and opening it may affect the frontend output. - # print("output: ", output) + print("output: ", output) ret = { "text": output, "error_code": 0, @@ -178,10 +181,9 @@ def generate(prompt_request: PromptRequest): for rsp in output: # rsp = rsp.decode("utf-8") rsp_str = str(rsp, "utf-8") - print("[TEST: output]:", rsp_str) response.append(rsp_str) - return {"response": rsp_str} + return rsp_str @app.post("/embedding") diff --git a/pilot/server/static/test.html b/pilot/server/static/test.html new file mode 100644 index 000000000..709180f11 --- /dev/null +++ b/pilot/server/static/test.html @@ -0,0 +1,19 @@ + + +
+ +