diff --git a/pilot/embedding_engine/knowledge_type.py b/pilot/embedding_engine/knowledge_type.py index 77fb98666..c4c278be4 100644 --- a/pilot/embedding_engine/knowledge_type.py +++ b/pilot/embedding_engine/knowledge_type.py @@ -41,7 +41,11 @@ class KnowledgeType(Enum): def get_knowledge_embedding( - knowledge_type, knowledge_source, vector_store_config, source_reader, text_splitter + knowledge_type, + knowledge_source, + vector_store_config=None, + source_reader=None, + text_splitter=None, ): match knowledge_type: case KnowledgeType.DOCUMENT.value: diff --git a/pilot/embedding_engine/source_embedding.py b/pilot/embedding_engine/source_embedding.py index 24bae97b2..5b1e57ae2 100644 --- a/pilot/embedding_engine/source_embedding.py +++ b/pilot/embedding_engine/source_embedding.py @@ -31,11 +31,11 @@ class SourceEmbedding(ABC): ): """Initialize with Loader url, model_name, vector_store_config""" self.file_path = file_path - self.vector_store_config = vector_store_config + self.vector_store_config = vector_store_config or {} self.source_reader = source_reader or None self.text_splitter = text_splitter or None self.embedding_args = embedding_args - self.embeddings = vector_store_config["embeddings"] + self.embeddings = self.vector_store_config.get("embeddings", None) @abstractmethod @register diff --git a/pilot/embedding_engine/url_embedding.py b/pilot/embedding_engine/url_embedding.py index e00cf84e2..39e7bf1dc 100644 --- a/pilot/embedding_engine/url_embedding.py +++ b/pilot/embedding_engine/url_embedding.py @@ -27,7 +27,7 @@ class URLEmbedding(SourceEmbedding): file_path, vector_store_config, source_reader=None, text_splitter=None ) self.file_path = file_path - self.vector_store_config = vector_store_config + self.vector_store_config = vector_store_config or None self.source_reader = source_reader or None self.text_splitter = text_splitter or None diff --git a/pilot/scene/base.py b/pilot/scene/base.py index 162759e3c..489e98d0b 100644 --- a/pilot/scene/base.py +++ b/pilot/scene/base.py @@ -75,6 +75,16 @@ class ChatScene(Enum): "Dialogue through natural language and private documents and knowledge bases.", ["Knowledge Space Select"], ) + ExtractTriplet = Scene( + "extract_triplet", + "Extract Triplet", + "Extract Triplet", + ["Extract Select"], + True, + ) + ExtractEntity = Scene( + "extract_entity", "Extract Entity", "Extract Entity", ["Extract Select"], True + ) @staticmethod def of_mode(mode): diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index daab56964..5c16ee286 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -205,6 +205,40 @@ class BaseChat(ABC): self.memory.append(self.current_message) return self.current_ai_response() + async def get_llm_response(self): + payload = self.__call_base() + logger.info(f"Request: \n{payload}") + ai_response_text = "" + try: + from pilot.model.cluster import WorkerManagerFactory + + worker_manager = CFG.SYSTEM_APP.get_component( + ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory + ).create() + + model_output = await worker_manager.generate(payload) + + ### output parse + ai_response_text = ( + self.prompt_template.output_parser.parse_model_nostream_resp( + model_output, self.prompt_template.sep + ) + ) + ### model result deal + self.current_message.add_ai_message(ai_response_text) + prompt_define_response = ( + self.prompt_template.output_parser.parse_prompt_response( + ai_response_text + ) + ) + except Exception as e: + print(traceback.format_exc()) + logger.error("model response parse failed!" + str(e)) + self.current_message.add_view_message( + f"""model response parse failed!{str(e)}\n {ai_response_text} """ + ) + return prompt_define_response + def _blocking_stream_call(self): logger.warn( "_blocking_stream_call is only temporarily used in webserver and will be deleted soon, please use stream_call to replace it for higher performance" diff --git a/pilot/scene/chat_factory.py b/pilot/scene/chat_factory.py index ad2da3b3f..fc47b7468 100644 --- a/pilot/scene/chat_factory.py +++ b/pilot/scene/chat_factory.py @@ -13,6 +13,8 @@ class ChatFactory(metaclass=Singleton): from pilot.scene.chat_dashboard.chat import ChatDashboard from pilot.scene.chat_knowledge.v1.chat import ChatKnowledge from pilot.scene.chat_knowledge.inner_db_summary.chat import InnerChatDBSummary + from pilot.scene.chat_knowledge.extract_triplet.chat import ExtractTriplet + from pilot.scene.chat_knowledge.extract_entity.chat import ExtractEntity from pilot.scene.chat_data.chat_excel.excel_analyze.chat import ChatExcel chat_classes = BaseChat.__subclasses__() diff --git a/pilot/utils/utils.py b/pilot/utils/utils.py index b72745a33..ebb3534e0 100644 --- a/pilot/utils/utils.py +++ b/pilot/utils/utils.py @@ -168,10 +168,12 @@ def get_or_create_event_loop() -> asyncio.BaseEventLoop: assert loop is not None return loop except RuntimeError as e: - if not "no running event loop" in str(e): + if not "no running event loop" in str(e) and not "no current event loop" in str( + e + ): raise e logging.warning("Cant not get running event loop, create new event loop now") - return asyncio.get_event_loop_policy().get_event_loop() + return asyncio.get_event_loop_policy().new_event_loop() def logging_str_to_uvicorn_level(log_level_str):