diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py index b8d1dc3b7..1a72a4e04 100644 --- a/pilot/openapi/api_v1/api_v1.py +++ b/pilot/openapi/api_v1/api_v1.py @@ -216,7 +216,7 @@ async def params_load(conv_uid: str, chat_mode: str, doc_file: UploadFile = File conv_uid=conv_uid, chat_mode=chat_mode, select_param=doc_file.filename ) chat: BaseChat = get_chat_instance(dialogue) - resp = chat.prepare() + resp = await chat.prepare() ### refresh messages return Result.succ(get_hist_messages(conv_uid)) @@ -279,7 +279,7 @@ async def chat_prepare(dialogue: ConversationVo = Body()): chat: BaseChat = get_chat_instance(dialogue) if len(chat.history_message) > 0: return Result.succ(None) - resp = chat.prepare() + resp = await chat.prepare() return Result.succ(resp) diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 88f457935..e2c9a03b0 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -244,7 +244,7 @@ class BaseChat(ABC): else: return self._blocking_nostream_call() - def prepare(self): + async def prepare(self): pass def generate_llm_text(self) -> str: diff --git a/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py b/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py index 806fada18..0bfb9f915 100644 --- a/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py +++ b/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py @@ -81,7 +81,7 @@ class ChatExcel(BaseChat): } return input_values - def prepare(self): + async def prepare(self): logger.info(f"{self.chat_mode} prepare start!") if len(self.history_message) > 0: return None @@ -93,7 +93,7 @@ class ChatExcel(BaseChat): "excel_reader": self.excel_reader, } learn_chat = ExcelLearning(**chat_param) - result = learn_chat.nostream_call() + result = await learn_chat.nostream_call() return result def do_action(self, prompt_response): diff --git a/pilot/utils/utils.py b/pilot/utils/utils.py index ca7cf9d3c..5b4ed064a 100644 --- a/pilot/utils/utils.py +++ b/pilot/utils/utils.py @@ -135,12 +135,13 @@ def pretty_print_semaphore(semaphore): def get_or_create_event_loop() -> asyncio.BaseEventLoop: + loop = None try: loop = asyncio.get_event_loop() - except Exception as e: + assert loop is not None + return loop + except RuntimeError as e: if not "no running event loop" in str(e): raise e logging.warning("Cant not get running event loop, create new event loop now") - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - return loop + return asyncio.get_event_loop_policy().get_event_loop()