diff --git a/pilot/memory/chat_history/duckdb_history.py b/pilot/memory/chat_history/duckdb_history.py index de80a5bc2..659690120 100644 --- a/pilot/memory/chat_history/duckdb_history.py +++ b/pilot/memory/chat_history/duckdb_history.py @@ -15,13 +15,12 @@ from pilot.common.formatting import MyEncoder default_db_path = os.path.join(os.getcwd(), "message") duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db") -table_name = 'chat_history' +table_name = "chat_history" CFG = Config() class DuckdbHistoryMemory(BaseChatHistoryMemory): - def __init__(self, chat_session_id: str): self.chat_seesion_id = chat_session_id os.makedirs(default_db_path, exist_ok=True) @@ -29,15 +28,16 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory): self.__init_chat_history_tables() def __init_chat_history_tables(self): - # 检查表是否存在 - result = self.connect.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?", - [table_name]).fetchall() + result = self.connect.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name=?", [table_name] + ).fetchall() if not result: # 如果表不存在,则创建新表 self.connect.execute( - "CREATE TABLE chat_history (conv_uid VARCHAR(100) PRIMARY KEY, user_name VARCHAR(100), messages TEXT)") + "CREATE TABLE chat_history (conv_uid VARCHAR(100) PRIMARY KEY, user_name VARCHAR(100), messages TEXT)" + ) def __get_messages_by_conv_uid(self, conv_uid: str): cursor = self.connect.cursor() @@ -47,6 +47,7 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory): return content[0] else: return None + def messages(self) -> List[OnceConversation]: context = self.__get_messages_by_conv_uid(self.chat_seesion_id) if context: @@ -62,23 +63,35 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory): conversations.append(_conversation_to_dic(once_message)) cursor = self.connect.cursor() if context: - cursor.execute("UPDATE chat_history set messages=? where conv_uid=?", - [json.dumps(conversations, ensure_ascii=False), self.chat_seesion_id]) + cursor.execute( + "UPDATE chat_history set messages=? where conv_uid=?", + [json.dumps(conversations, ensure_ascii=False), self.chat_seesion_id], + ) else: - cursor.execute("INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)", - [self.chat_seesion_id, "", json.dumps(conversations, ensure_ascii=False)]) + cursor.execute( + "INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)", + [ + self.chat_seesion_id, + "", + json.dumps(conversations, ensure_ascii=False), + ], + ) cursor.commit() self.connect.commit() def clear(self) -> None: cursor = self.connect.cursor() - cursor.execute("DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id]) + cursor.execute( + "DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id] + ) cursor.commit() self.connect.commit() def delete(self) -> bool: cursor = self.connect.cursor() - cursor.execute("DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id]) + cursor.execute( + "DELETE FROM chat_history where conv_uid=?", [self.chat_seesion_id] + ) cursor.commit() return True @@ -87,7 +100,9 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory): if os.path.isfile(duckdb_path): cursor = duckdb.connect(duckdb_path).cursor() if user_name: - cursor.execute("SELECT * FROM chat_history where user_name=? limit 20", [user_name]) + cursor.execute( + "SELECT * FROM chat_history where user_name=? limit 20", [user_name] + ) else: cursor.execute("SELECT * FROM chat_history limit 20") # 获取查询结果字段名 @@ -103,10 +118,11 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory): return [] - - def get_messages(self)-> List[OnceConversation]: + def get_messages(self) -> List[OnceConversation]: cursor = self.connect.cursor() - cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [self.chat_seesion_id]) + cursor.execute( + "SELECT messages FROM chat_history where conv_uid=?", [self.chat_seesion_id] + ) context = cursor.fetchone() if context: return json.loads(context[0]) diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py index 4f2e23946..d391c6e41 100644 --- a/pilot/openapi/api_v1/api_v1.py +++ b/pilot/openapi/api_v1/api_v1.py @@ -2,17 +2,29 @@ import uuid import json import asyncio import time -from fastapi import APIRouter, Request, Body, status, HTTPException, Response, BackgroundTasks +from fastapi import ( + APIRouter, + Request, + Body, + status, + HTTPException, + Response, + BackgroundTasks, +) from fastapi.responses import JSONResponse from fastapi.responses import StreamingResponse from fastapi.encoders import jsonable_encoder from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse -from sse_starlette.sse import EventSourceResponse from typing import List -from pilot.openapi.api_v1.api_view_model import Result, ConversationVo, MessageVo, ChatSceneVo +from pilot.openapi.api_v1.api_view_model import ( + Result, + ConversationVo, + MessageVo, + ChatSceneVo, +) from pilot.configs.config import Config from pilot.openapi.knowledge.knowledge_service import KnowledgeService from pilot.openapi.knowledge.request.knowledge_request import KnowledgeSpaceRequest @@ -103,7 +115,7 @@ async def dialogue_scenes(): @router.post("/v1/chat/dialogue/new", response_model=Result[ConversationVo]) async def dialogue_new( - chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None + chat_mode: str = ChatScene.ChatNormal.value, user_id: str = None ): unique_id = uuid.uuid1() return Result.succ(ConversationVo(conv_uid=str(unique_id), chat_mode=chat_mode)) @@ -220,11 +232,19 @@ async def chat_completions(dialogue: ConversationVo = Body()): } if not chat.prompt_template.stream_out: - return StreamingResponse(no_stream_generator(chat), headers=headers, media_type='text/event-stream', - background=background_tasks) + return StreamingResponse( + no_stream_generator(chat), + headers=headers, + media_type="text/event-stream", + background=background_tasks, + ) else: - return StreamingResponse(stream_generator(chat), headers=headers, media_type='text/plain', - background=background_tasks) + return StreamingResponse( + stream_generator(chat), + headers=headers, + media_type="text/plain", + background=background_tasks, + ) def release_model_semaphore(): @@ -236,12 +256,15 @@ async def no_stream_generator(chat): msg = msg.replace("\n", "\\n") yield f"data: {msg}\n\n" + async def stream_generator(chat): model_response = chat.stream_call() if not CFG.NEW_SERVER_MODE: for chunk in model_response.iter_lines(decode_unicode=False, delimiter=b"\0"): if chunk: - msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len) + msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex( + chunk, chat.skip_echo_len + ) chat.current_message.add_ai_message(msg) msg = msg.replace("\n", "\\n") yield f"data:{msg}\n\n" @@ -249,7 +272,9 @@ async def stream_generator(chat): else: for chunk in model_response: if chunk: - msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex(chunk, chat.skip_echo_len) + msg = chat.prompt_template.output_parser.parse_model_stream_resp_ex( + chunk, chat.skip_echo_len + ) chat.current_message.add_ai_message(msg) msg = msg.replace("\n", "\\n") @@ -259,4 +284,6 @@ async def stream_generator(chat): def message2Vo(message: dict, order) -> MessageVo: - return MessageVo(role=message['type'], context=message['data']['content'], order=order) + return MessageVo( + role=message["type"], context=message["data"]["content"], order=order + ) diff --git a/pilot/openapi/knowledge/knowledge_controller.py b/pilot/openapi/knowledge/knowledge_controller.py index 1b452119a..aec612e9c 100644 --- a/pilot/openapi/knowledge/knowledge_controller.py +++ b/pilot/openapi/knowledge/knowledge_controller.py @@ -76,18 +76,34 @@ def document_list(space_name: str, query_request: DocumentQueryRequest): @router.post("/knowledge/{space_name}/document/upload") -async def document_upload(space_name: str, doc_name: str = Form(...), doc_type: str = Form(...), doc_file: UploadFile = File(...)): +async def document_upload( + space_name: str, + doc_name: str = Form(...), + doc_type: str = Form(...), + doc_file: UploadFile = File(...), +): print(f"/document/upload params: {space_name}") try: if doc_file: - with NamedTemporaryFile(dir=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name), delete=False) as tmp: + with NamedTemporaryFile( + dir=os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name), delete=False + ) as tmp: tmp.write(await doc_file.read()) tmp_path = tmp.name - shutil.move(tmp_path, os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name, doc_file.filename)) + shutil.move( + tmp_path, + os.path.join( + KNOWLEDGE_UPLOAD_ROOT_PATH, space_name, doc_file.filename + ), + ) request = KnowledgeDocumentRequest() request.doc_name = doc_name request.doc_type = doc_type - request.content = os.path.join(KNOWLEDGE_UPLOAD_ROOT_PATH, space_name, doc_file.filename), + request.content = ( + os.path.join( + KNOWLEDGE_UPLOAD_ROOT_PATH, space_name, doc_file.filename + ), + ) knowledge_space_service.create_knowledge_document( space=space_name, request=request ) diff --git a/pilot/openapi/knowledge/knowledge_service.py b/pilot/openapi/knowledge/knowledge_service.py index 9ee9f3c40..2f035fcb5 100644 --- a/pilot/openapi/knowledge/knowledge_service.py +++ b/pilot/openapi/knowledge/knowledge_service.py @@ -25,7 +25,10 @@ from pilot.openapi.knowledge.request.knowledge_request import ( ) from enum import Enum -from pilot.openapi.knowledge.request.knowledge_response import ChunkQueryResponse, DocumentQueryResponse +from pilot.openapi.knowledge.request.knowledge_response import ( + ChunkQueryResponse, + DocumentQueryResponse, +) knowledge_space_dao = KnowledgeSpaceDao() knowledge_document_dao = KnowledgeDocumentDao() diff --git a/pilot/openapi/knowledge/request/knowledge_response.py b/pilot/openapi/knowledge/request/knowledge_response.py index 71d426643..7fbf36155 100644 --- a/pilot/openapi/knowledge/request/knowledge_response.py +++ b/pilot/openapi/knowledge/request/knowledge_response.py @@ -5,6 +5,7 @@ from pydantic import BaseModel class ChunkQueryResponse(BaseModel): """data: data""" + data: List = None """total: total size""" total: int = None @@ -14,9 +15,9 @@ class ChunkQueryResponse(BaseModel): class DocumentQueryResponse(BaseModel): """data: data""" + data: List = None """total: total size""" total: int = None """page: current page""" page: int = None - diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index ca308e92f..cd75c950c 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -122,7 +122,7 @@ class BaseOutputParser(ABC): def __extract_json(slef, s): i = s.index("{") count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数 - for j, c in enumerate(s[i + 1:], start=i + 1): + for j, c in enumerate(s[i + 1 :], start=i + 1): if c == "}": count -= 1 elif c == "{": @@ -130,7 +130,7 @@ class BaseOutputParser(ABC): if count == 0: break assert count == 0 # 检查是否找到最后一个'}' - return s[i: j + 1] + return s[i : j + 1] def parse_prompt_response(self, model_out_text) -> T: """ @@ -147,9 +147,9 @@ class BaseOutputParser(ABC): # if "```" in cleaned_output: # cleaned_output, _ = cleaned_output.split("```") if cleaned_output.startswith("```json"): - cleaned_output = cleaned_output[len("```json"):] + cleaned_output = cleaned_output[len("```json") :] if cleaned_output.startswith("```"): - cleaned_output = cleaned_output[len("```"):] + cleaned_output = cleaned_output[len("```") :] if cleaned_output.endswith("```"): cleaned_output = cleaned_output[: -len("```")] cleaned_output = cleaned_output.strip() @@ -158,9 +158,9 @@ class BaseOutputParser(ABC): cleaned_output = self.__extract_json(cleaned_output) cleaned_output = ( cleaned_output.strip() - .replace("\n", " ") - .replace("\\n", " ") - .replace("\\", " ") + .replace("\n", " ") + .replace("\\n", " ") + .replace("\\", " ") ) return cleaned_output diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 8e7b3dfe7..245851062 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -60,10 +60,10 @@ class BaseChat(ABC): arbitrary_types_allowed = True def __init__( - self, - chat_mode, - chat_session_id, - current_user_input, + self, + chat_mode, + chat_session_id, + current_user_input, ): self.chat_session_id = chat_session_id self.chat_mode = chat_mode @@ -102,7 +102,9 @@ class BaseChat(ABC): ### Chat sequence advance self.current_message.chat_order = len(self.history_message) + 1 self.current_message.add_user_message(self.current_user_input) - self.current_message.start_date = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + self.current_message.start_date = datetime.datetime.now().strftime( + "%Y-%m-%d %H:%M:%S" + ) # TODO self.current_message.tokens = 0 @@ -168,11 +170,18 @@ class BaseChat(ABC): print("[TEST: output]:", rsp_str) ### output parse - ai_response_text = self.prompt_template.output_parser.parse_model_nostream_resp(rsp_str, - self.prompt_template.sep) + ai_response_text = ( + self.prompt_template.output_parser.parse_model_nostream_resp( + rsp_str, self.prompt_template.sep + ) + ) ### model result deal self.current_message.add_ai_message(ai_response_text) - prompt_define_response = self.prompt_template.output_parser.parse_prompt_response(ai_response_text) + prompt_define_response = ( + self.prompt_template.output_parser.parse_prompt_response( + ai_response_text + ) + ) result = self.do_action(prompt_define_response) if hasattr(prompt_define_response, "thoughts"): @@ -232,7 +241,9 @@ class BaseChat(ABC): system_convs = self.current_message.get_system_conv() system_text = "" for system_conv in system_convs: - system_text += system_conv.type + ":" + system_conv.content + self.prompt_template.sep + system_text += ( + system_conv.type + ":" + system_conv.content + self.prompt_template.sep + ) return system_text def __load_user_message(self): @@ -246,13 +257,16 @@ class BaseChat(ABC): example_text = "" if self.prompt_template.example_selector: for round_conv in self.prompt_template.example_selector.examples(): - for round_message in round_conv['messages']: - if not round_message['type'] in [SystemMessage.type, ViewMessage.type]: + for round_message in round_conv["messages"]: + if not round_message["type"] in [ + SystemMessage.type, + ViewMessage.type, + ]: example_text += ( - round_message['type'] - + ":" - + round_message['data']['content'] - + self.prompt_template.sep + round_message["type"] + + ":" + + round_message["data"]["content"] + + self.prompt_template.sep ) return example_text @@ -264,37 +278,46 @@ class BaseChat(ABC): f"There are already {len(self.history_message)} rounds of conversations! Will use {self.chat_retention_rounds} rounds of content as history!" ) if len(self.history_message) > self.chat_retention_rounds: - for first_message in self.history_message[0]['messages']: - if not first_message['type'] in [ViewMessage.type, SystemMessage.type]: + for first_message in self.history_message[0]["messages"]: + if not first_message["type"] in [ + ViewMessage.type, + SystemMessage.type, + ]: history_text += ( - first_message['type'] - + ":" - + first_message['data']['content'] - + self.prompt_template.sep + first_message["type"] + + ":" + + first_message["data"]["content"] + + self.prompt_template.sep ) index = self.chat_retention_rounds - 1 for round_conv in self.history_message[-index:]: - for round_message in round_conv['messages']: - if not round_message['type'] in [SystemMessage.type, ViewMessage.type]: + for round_message in round_conv["messages"]: + if not round_message["type"] in [ + SystemMessage.type, + ViewMessage.type, + ]: history_text += ( - round_message['type'] - + ":" - + round_message['data']['content'] - + self.prompt_template.sep + round_message["type"] + + ":" + + round_message["data"]["content"] + + self.prompt_template.sep ) else: ### user all history for conversation in self.history_message: - for message in conversation['messages']: + for message in conversation["messages"]: ### histroy message not have promot and view info - if not message['type'] in [SystemMessage.type, ViewMessage.type]: + if not message["type"] in [ + SystemMessage.type, + ViewMessage.type, + ]: history_text += ( - message['type'] - + ":" - + message['data']['content'] - + self.prompt_template.sep + message["type"] + + ":" + + message["data"]["content"] + + self.prompt_template.sep ) return history_text diff --git a/pilot/scene/chat_execution/example.py b/pilot/scene/chat_execution/example.py index f50a7f546..9e2aee6a2 100644 --- a/pilot/scene/chat_execution/example.py +++ b/pilot/scene/chat_execution/example.py @@ -2,8 +2,8 @@ from pilot.prompts.example_base import ExampleSelector ## Two examples are defined by default EXAMPLES = [ - [{"system": "123"},{"system":"xxx"},{"human":"xxx"},{"assistant":"xxx"}], - [{"system": "123"},{"system":"xxx"},{"human":"xxx"},{"assistant":"xxx"}] + [{"system": "123"}, {"system": "xxx"}, {"human": "xxx"}, {"assistant": "xxx"}], + [{"system": "123"}, {"system": "xxx"}, {"human": "xxx"}, {"assistant": "xxx"}], ] plugin_example = ExampleSelector(examples_record=EXAMPLES, use_example=True) diff --git a/pilot/scene/message.py b/pilot/scene/message.py index 972331bbb..51ec2643e 100644 --- a/pilot/scene/message.py +++ b/pilot/scene/message.py @@ -98,9 +98,10 @@ class OnceConversation: system_convs.append(message) return system_convs + def _conversation_to_dic(once: OnceConversation) -> dict: start_str: str = "" - if hasattr(once, 'start_date') and once.start_date: + if hasattr(once, "start_date") and once.start_date: if isinstance(once.start_date, datetime): start_str = once.start_date.strftime("%Y-%m-%d %H:%M:%S") else: diff --git a/pilot/server/dbgpt_server.py b/pilot/server/dbgpt_server.py index 6c22105cd..c4f9ad87e 100644 --- a/pilot/server/dbgpt_server.py +++ b/pilot/server/dbgpt_server.py @@ -23,9 +23,12 @@ from fastapi import FastAPI, applications from fastapi.openapi.docs import get_swagger_ui_html from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware +from pilot.openapi.knowledge.knowledge_controller import router as knowledge_router + from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler + static_file_path = os.path.join(os.getcwd(), "server/static") CFG = Config() @@ -34,9 +37,10 @@ logger = build_logger("webserver", LOGDIR + "webserver.log") def swagger_monkey_patch(*args, **kwargs): return get_swagger_ui_html( - *args, **kwargs, - swagger_js_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js', - swagger_css_url='https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css' + *args, + **kwargs, + swagger_js_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui-bundle.js", + swagger_css_url="https://cdn.bootcdn.net/ajax/libs/swagger-ui/4.10.3/swagger-ui.css" ) @@ -55,14 +59,16 @@ app.add_middleware( ) app.mount("/static", StaticFiles(directory=static_file_path), name="static") -app.add_route("/test", "static/test.html") - +app.add_route("/test", "static/test.html") +app.include_router(knowledge_router) app.include_router(api_v1) app.add_exception_handler(RequestValidationError, validation_exception_handler) if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--model_list_mode", type=str, default="once", choices=["once", "reload"]) + parser.add_argument( + "--model_list_mode", type=str, default="once", choices=["once", "reload"] + ) # old version server config parser.add_argument("--host", type=str, default="0.0.0.0") @@ -75,4 +81,5 @@ if __name__ == "__main__": server_init(args) CFG.NEW_SERVER_MODE = True import uvicorn + uvicorn.run(app, host="0.0.0.0", port=5000) diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index 67e6183b2..d87540a8e 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -9,7 +9,8 @@ import sys import uvicorn from fastapi import BackgroundTasks, FastAPI, Request from fastapi.responses import StreamingResponse -from fastapi.middleware.cors import CORSMiddleware + +# from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel global_counter = 0 @@ -41,11 +42,11 @@ class ModelWorker: if not isinstance(self.model, str): if hasattr(self.model, "config") and hasattr( - self.model.config, "max_sequence_length" + self.model.config, "max_sequence_length" ): self.context_len = self.model.config.max_sequence_length elif hasattr(self.model, "config") and hasattr( - self.model.config, "max_position_embeddings" + self.model.config, "max_position_embeddings" ): self.context_len = self.model.config.max_position_embeddings @@ -60,22 +61,22 @@ class ModelWorker: def get_queue_length(self): if ( - model_semaphore is None - or model_semaphore._value is None - or model_semaphore._waiters is None + model_semaphore is None + or model_semaphore._value is None + or model_semaphore._waiters is None ): return 0 else: ( - CFG.LIMIT_MODEL_CONCURRENCY - - model_semaphore._value - + len(model_semaphore._waiters) + CFG.LIMIT_MODEL_CONCURRENCY + - model_semaphore._value + + len(model_semaphore._waiters) ) def generate_stream_gate(self, params): try: for output in self.generate_stream_func( - self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS + self.model, self.tokenizer, params, DEVICE, CFG.MAX_POSITION_EMBEDDINGS ): # Please do not open the output in production! # The gpt4all thread shares stdout with the parent process, @@ -107,23 +108,23 @@ worker = ModelWorker( ) app = FastAPI() -from pilot.openapi.knowledge.knowledge_controller import router - -app.include_router(router) - -origins = [ - "http://localhost", - "http://localhost:8000", - "http://localhost:3000", -] - -app.add_middleware( - CORSMiddleware, - allow_origins=origins, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) +# from pilot.openapi.knowledge.knowledge_controller import router +# +# app.include_router(router) +# +# origins = [ +# "http://localhost", +# "http://localhost:8000", +# "http://localhost:3000", +# ] +# +# app.add_middleware( +# CORSMiddleware, +# allow_origins=origins, +# allow_credentials=True, +# allow_methods=["*"], +# allow_headers=["*"], +# ) class PromptRequest(BaseModel): diff --git a/pilot/server/webserver_base.py b/pilot/server/webserver_base.py index c76984c37..33486c439 100644 --- a/pilot/server/webserver_base.py +++ b/pilot/server/webserver_base.py @@ -40,6 +40,7 @@ def server_init(args): cfg = Config() from pilot.server.llmserver import worker + worker.start_check() load_native_plugins(cfg) signal.signal(signal.SIGINT, signal_handler)