From 1b49267976c96bb3f68c03bc12e345bb1a6c9cc7 Mon Sep 17 00:00:00 2001 From: aries_ckt <916701291@qq.com> Date: Wed, 13 Sep 2023 14:20:04 +0800 Subject: [PATCH] style:fmt --- pilot/openapi/api_v1/api_v1.py | 30 ++++++++++++++----- pilot/scene/base_chat.py | 12 ++++---- pilot/scene/chat_dashboard/chat.py | 9 ++---- .../chat_excel/excel_analyze/chat.py | 4 +-- .../chat_excel/excel_learning/chat.py | 4 +-- pilot/scene/chat_db/professional_qa/chat.py | 4 +-- pilot/scene/chat_execution/chat.py | 4 +-- 7 files changed, 35 insertions(+), 32 deletions(-) diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py index c25475973..b7cc46d3a 100644 --- a/pilot/openapi/api_v1/api_v1.py +++ b/pilot/openapi/api_v1/api_v1.py @@ -217,7 +217,9 @@ async def params_list(chat_mode: str = ChatScene.ChatNormal.value()): @router.post("/v1/chat/mode/params/file/load") -async def params_load(conv_uid: str, chat_mode: str, model_name: str, doc_file: UploadFile = File(...)): +async def params_load( + conv_uid: str, chat_mode: str, model_name: str, doc_file: UploadFile = File(...) +): print(f"params_load: {conv_uid},{chat_mode},{model_name}") try: if doc_file: @@ -237,7 +239,10 @@ async def params_load(conv_uid: str, chat_mode: str, model_name: str, doc_file: ) ## chat prepare dialogue = ConversationVo( - conv_uid=conv_uid, chat_mode=chat_mode, select_param=doc_file.filename, model_name=model_name + conv_uid=conv_uid, + chat_mode=chat_mode, + select_param=doc_file.filename, + model_name=model_name, ) chat: BaseChat = get_chat_instance(dialogue) resp = await chat.prepare() @@ -264,7 +269,8 @@ def get_hist_messages(conv_uid: str): print(f"once:{once}") model_name = once.get("model_name", CFG.LLM_MODEL) once_message_vos = [ - message2Vo(element, once["chat_order"], model_name) for element in once["messages"] + message2Vo(element, once["chat_order"], model_name) + for element in once["messages"] ] message_vos.extend(once_message_vos) return message_vos @@ -293,9 +299,11 @@ def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat: "chat_session_id": dialogue.conv_uid, "current_user_input": dialogue.user_input, "select_param": dialogue.select_param, - "model_name": dialogue.model_name + "model_name": dialogue.model_name, } - chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **{"chat_param": chat_param}) + chat: BaseChat = CHAT_FACTORY.get_implementation( + dialogue.chat_mode, **{"chat_param": chat_param} + ) return chat @@ -314,7 +322,9 @@ async def chat_prepare(dialogue: ConversationVo = Body()): @router.post("/v1/chat/completions") async def chat_completions(dialogue: ConversationVo = Body()): # dialogue.model_name = CFG.LLM_MODEL - print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param},{dialogue.model_name}") + print( + f"chat_completions:{dialogue.chat_mode},{dialogue.select_param},{dialogue.model_name}" + ) chat: BaseChat = get_chat_instance(dialogue) # background_tasks = BackgroundTasks() # background_tasks.add_task(release_model_semaphore) @@ -344,9 +354,10 @@ async def model_types(): print(f"/controller/model/types") try: import httpx + async with httpx.AsyncClient() as client: response = await client.get( - f"{CFG.MODEL_SERVER}/api/controller/models?healthy_only=true", + f"{CFG.MODEL_SERVER}/api/controller/models?healthy_only=true", ) types = set() if response.status_code == 200: @@ -387,5 +398,8 @@ async def stream_generator(chat): def message2Vo(message: dict, order, model_name) -> MessageVo: return MessageVo( - role=message["type"], context=message["data"]["content"], order=order, model_name=model_name + role=message["type"], + context=message["data"]["content"], + order=order, + model_name=model_name, ) diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index a2c284114..9e7a22373 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -32,13 +32,13 @@ class BaseChat(ABC): arbitrary_types_allowed = True - def __init__( - self, chat_param: Dict - ): + def __init__(self, chat_param: Dict): self.chat_session_id = chat_param["chat_session_id"] self.chat_mode = chat_param["chat_mode"] self.current_user_input: str = chat_param["current_user_input"] - self.llm_model = chat_param["model_name"] if chat_param["model_name"] else CFG.LLM_MODEL + self.llm_model = ( + chat_param["model_name"] if chat_param["model_name"] else CFG.LLM_MODEL + ) self.llm_echo = False ### load prompt template @@ -58,7 +58,9 @@ class BaseChat(ABC): self.memory = DuckdbHistoryMemory(chat_param["chat_session_id"]) self.history_message: List[OnceConversation] = self.memory.messages() - self.current_message: OnceConversation = OnceConversation(self.chat_mode.value()) + self.current_message: OnceConversation = OnceConversation( + self.chat_mode.value() + ) self.current_message.model_name = self.llm_model if chat_param["select_param"]: if len(self.chat_mode.param_types()) > 0: diff --git a/pilot/scene/chat_dashboard/chat.py b/pilot/scene/chat_dashboard/chat.py index 066d60cd3..f6213c292 100644 --- a/pilot/scene/chat_dashboard/chat.py +++ b/pilot/scene/chat_dashboard/chat.py @@ -21,16 +21,11 @@ class ChatDashboard(BaseChat): report_name: str """Number of results to return from the query""" - def __init__( - self, - chat_param: Dict - ): + def __init__(self, chat_param: Dict): """ """ self.db_name = chat_param["select_param"] chat_param["chat_mode"] = ChatScene.ChatDashboard - super().__init__( - chat_param=chat_param - ) + super().__init__(chat_param=chat_param) if not self.db_name: raise ValueError(f"{ChatScene.ChatDashboard.value} mode should choose db!") self.db_name = self.db_name diff --git a/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py b/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py index c5248d9ed..27bdeaec8 100644 --- a/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py +++ b/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py @@ -37,9 +37,7 @@ class ChatExcel(BaseChat): ) ) - super().__init__( - chat_param=chat_param - ) + super().__init__(chat_param=chat_param) def _generate_command_string(self, command: Dict[str, Any]) -> str: """ diff --git a/pilot/scene/chat_data/chat_excel/excel_learning/chat.py b/pilot/scene/chat_data/chat_excel/excel_learning/chat.py index dd8970a23..96338589c 100644 --- a/pilot/scene/chat_data/chat_excel/excel_learning/chat.py +++ b/pilot/scene/chat_data/chat_excel/excel_learning/chat.py @@ -43,9 +43,7 @@ class ExcelLearning(BaseChat): "select_param": select_param, "model_name": model_name, } - super().__init__( - chat_param=chat_param - ) + super().__init__(chat_param=chat_param) if parent_mode: self.current_message.chat_mode = parent_mode.value() diff --git a/pilot/scene/chat_db/professional_qa/chat.py b/pilot/scene/chat_db/professional_qa/chat.py index f2e3c8304..39f4052a6 100644 --- a/pilot/scene/chat_db/professional_qa/chat.py +++ b/pilot/scene/chat_db/professional_qa/chat.py @@ -18,9 +18,7 @@ class ChatWithDbQA(BaseChat): """ """ self.db_name = chat_param["select_param"] chat_param["chat_mode"] = ChatScene.ChatWithDbQA - super().__init__( - chat_param=chat_param - ) + super().__init__(chat_param=chat_param) if self.db_name: self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name) diff --git a/pilot/scene/chat_execution/chat.py b/pilot/scene/chat_execution/chat.py index 508e69391..fd18a5564 100644 --- a/pilot/scene/chat_execution/chat.py +++ b/pilot/scene/chat_execution/chat.py @@ -18,9 +18,7 @@ class ChatWithPlugin(BaseChat): def __init__(self, chat_param: Dict): self.plugin_selector = chat_param.select_param chat_param["chat_mode"] = ChatScene.ChatExecution - super().__init__( - chat_param=chat_param - ) + super().__init__(chat_param=chat_param) self.plugins_prompt_generator = PluginPromptGenerator() self.plugins_prompt_generator.command_registry = CFG.command_registry # 加载插件中可用命令