diff --git a/pilot/common/chat_util.py b/pilot/common/chat_util.py index 159db99d0..0de0b9bda 100644 --- a/pilot/common/chat_util.py +++ b/pilot/common/chat_util.py @@ -9,7 +9,7 @@ chat_factory = ChatFactory() async def llm_chat_response_nostream(chat_scene: str, **chat_param): - """ llm_chat_response_nostream """ + """llm_chat_response_nostream""" chat: BaseChat = chat_factory.get_implementation(chat_scene, **chat_param) res = await chat.get_llm_response() return res diff --git a/pilot/graph_engine/graph_engine.py b/pilot/graph_engine/graph_engine.py index c20142123..04c4f54d9 100644 --- a/pilot/graph_engine/graph_engine.py +++ b/pilot/graph_engine/graph_engine.py @@ -45,8 +45,7 @@ class RAGGraphEngine: **kwargs: Any, ) -> None: """Initialize params.""" - # from llama_index.graph_stores import SimpleGraphStore - # from llama_index.graph_stores.types import GraphStore + from llama_index.graph_stores import SimpleGraphStore # need to set parameters before building index in base class. self.knowledge_source = knowledge_source @@ -55,8 +54,8 @@ class RAGGraphEngine: self.text_splitter = text_splitter self.index_struct = index_struct self.include_embeddings = include_embeddings - # self.graph_store = graph_store or SimpleGraphStore() - self.graph_store = graph_store + self.graph_store = graph_store or SimpleGraphStore() + # self.graph_store = graph_store self.max_triplets_per_chunk = max_triplets_per_chunk self._max_object_length = max_object_length self._extract_triplet_fn = extract_triplet_fn @@ -103,14 +102,6 @@ class RAGGraphEngine: ) ) return triplets - # response = self._service_context.llm_predictor.predict( - # self.kg_triple_extract_template, - # text=text, - # ) - # print(response, flush=True) - # return self._parse_triplet_response( - # response, max_length=self._max_object_length - # ) def _build_index_from_docs(self, documents: List[Document]) -> KG: """Build the index from nodes.""" @@ -126,7 +117,6 @@ class RAGGraphEngine: self.graph_store.upsert_triplet(*triplet) index_struct.add_node([subj, obj], text_node) - return index_struct def search(self, query): @@ -134,4 +124,3 @@ class RAGGraphEngine: graph_search = RAGGraphSearch(graph_engine=self) return graph_search.search(query) - diff --git a/pilot/graph_engine/graph_search.py b/pilot/graph_engine/graph_search.py index 9b06fd234..d1f6a4519 100644 --- a/pilot/graph_engine/graph_search.py +++ b/pilot/graph_engine/graph_search.py @@ -4,6 +4,8 @@ from collections import defaultdict from concurrent.futures import ThreadPoolExecutor from typing import List, Optional, Dict, Any, Set, Callable +from langchain.schema import Document + from pilot.graph_engine.node import BaseNode, TextNode, NodeWithScore from pilot.graph_engine.search import BaseSearch, SearchMode from pilot.utils import utils @@ -67,14 +69,14 @@ class RAGGraphSearch(BaseSearch): logger.warn(f"can not to find graph schema: {e}") self._graph_schema = "" - def _extract_subject_entities(self, query_str: str) -> Set[str]: + async def _extract_subject_entities(self, query_str: str) -> Set[str]: """extract subject entities.""" if self.extract_subject_entities_fn is not None: - return self.extract_subject_entities_fn(query_str) + return await self.extract_subject_entities_fn(query_str) else: - return self._extract_entities_by_llm(query_str) + return await self._extract_entities_by_llm(query_str) - def _extract_entities_by_llm(self, text: str) -> Set[str]: + async def _extract_entities_by_llm(self, text: str) -> Set[str]: """extract subject entities from text by llm""" from pilot.scene.base import ChatScene from pilot.common.chat_util import llm_chat_response_nostream @@ -86,21 +88,23 @@ class RAGGraphSearch(BaseSearch): "select_param": "entity", "model_name": self.model_name, } - loop = utils.get_or_create_event_loop() - entities = loop.run_until_complete( - llm_chat_response_nostream( - ChatScene.ExtractEntity.value(), **{"chat_param": chat_param} - ) + # loop = utils.get_or_create_event_loop() + # entities = loop.run_until_complete( + # llm_chat_response_nostream( + # ChatScene.ExtractEntity.value(), **{"chat_param": chat_param} + # ) + # ) + return await llm_chat_response_nostream( + ChatScene.ExtractEntity.value(), **{"chat_param": chat_param} ) - return entities - def _search( + async def _search( self, query_str: str, - ) -> List[NodeWithScore]: + ) -> List[Document]: """Get nodes for response.""" node_visited = set() - keywords = self._extract_subject_entities(query_str) + keywords = await self._extract_subject_entities(query_str) print(f"extract entities: {keywords}\n") rel_texts = [] cur_rel_map = {} @@ -114,8 +118,8 @@ class RAGGraphSearch(BaseSearch): if node_id in node_visited: continue - if self._include_text: - chunk_indices_count[node_id] += 1 + # if self._include_text: + # chunk_indices_count[node_id] += 1 node_visited.add(node_id) @@ -179,8 +183,11 @@ class RAGGraphSearch(BaseSearch): sorted_nodes_with_scores.append( NodeWithScore(node=rel_text_node, score=DEFAULT_NODE_SCORE) ) - - return sorted_nodes_with_scores + docs = [ + Document(page_content=node.text, metadata=node.metadata) + for node in sorted_nodes_with_scores + ] + return docs def _get_metadata_for_response( self, nodes: List[BaseNode] @@ -190,4 +197,4 @@ class RAGGraphSearch(BaseSearch): if node.metadata is None or "kg_rel_map" not in node.metadata: continue return node.metadata - raise ValueError("kg_rel_map must be found in at least one Node.") \ No newline at end of file + raise ValueError("kg_rel_map must be found in at least one Node.") diff --git a/pilot/graph_engine/node.py b/pilot/graph_engine/node.py index 6f6d45ae4..b23681010 100644 --- a/pilot/graph_engine/node.py +++ b/pilot/graph_engine/node.py @@ -21,6 +21,7 @@ WRAP_WIDTH = 70 class BaseComponent(BaseModel): """Base component object to caputure class names.""" + """reference llama-index""" @classmethod diff --git a/pilot/graph_engine/search.py b/pilot/graph_engine/search.py index 8db837278..297620b00 100644 --- a/pilot/graph_engine/search.py +++ b/pilot/graph_engine/search.py @@ -23,7 +23,7 @@ class SearchMode(str, Enum): class BaseSearch(ABC): """Base Search.""" - def search(self, query: str): + async def search(self, query: str): """Retrieve nodes given query. Args: @@ -32,10 +32,10 @@ class BaseSearch(ABC): """ # if isinstance(query, str): - return self._search(query) + return await self._search(query) @abstractmethod - def _search(self, query: str): + async def _search(self, query: str): """search nodes given query. Implemented by the user. diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index c1880f48a..bea00bde3 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -105,8 +105,14 @@ class BaseChat(ABC): speak_to_user = prompt_define_response return speak_to_user - def __call_base(self): - input_values = self.generate_input_values() + async def __call_base(self): + import inspect + + input_values = ( + await self.generate_input_values() + if inspect.isawaitable(self.generate_input_values()) + else self.generate_input_values() + ) ### Chat sequence advance self.current_message.chat_order = len(self.history_message) + 1 self.current_message.add_user_message(self.current_user_input) @@ -146,7 +152,7 @@ class BaseChat(ABC): async def stream_call(self): # TODO Retry when server connection error - payload = self.__call_base() + payload = await self.__call_base() self.skip_echo_len = len(payload.get("prompt").replace("", " ")) + 11 logger.info(f"Request: \n{payload}") @@ -234,7 +240,7 @@ class BaseChat(ABC): return self.current_ai_response() async def get_llm_response(self): - payload = self.__call_base() + payload = await self.__call_base() logger.info(f"Request: \n{payload}") ai_response_text = "" try: diff --git a/pilot/scene/chat_knowledge/extract_entity/chat.py b/pilot/scene/chat_knowledge/extract_entity/chat.py index bb52961b5..373bb4e5d 100644 --- a/pilot/scene/chat_knowledge/extract_entity/chat.py +++ b/pilot/scene/chat_knowledge/extract_entity/chat.py @@ -24,7 +24,7 @@ class ExtractEntity(BaseChat): self.user_input = chat_param["current_user_input"] self.extract_mode = chat_param["select_param"] - def generate_input_values(self): + async def generate_input_values(self): input_values = { "text": self.user_input, } diff --git a/pilot/scene/chat_knowledge/extract_triplet/chat.py b/pilot/scene/chat_knowledge/extract_triplet/chat.py index 11fe871ab..28152b92e 100644 --- a/pilot/scene/chat_knowledge/extract_triplet/chat.py +++ b/pilot/scene/chat_knowledge/extract_triplet/chat.py @@ -24,7 +24,7 @@ class ExtractTriplet(BaseChat): self.user_input = chat_param["current_user_input"] self.extract_mode = chat_param["select_param"] - def generate_input_values(self): + async def generate_input_values(self): input_values = { "text": self.user_input, } diff --git a/pilot/scene/chat_knowledge/v1/chat.py b/pilot/scene/chat_knowledge/v1/chat.py index c381546f8..ea7ca1922 100644 --- a/pilot/scene/chat_knowledge/v1/chat.py +++ b/pilot/scene/chat_knowledge/v1/chat.py @@ -64,7 +64,7 @@ class ChatKnowledge(BaseChat): self.prompt_template.template_is_strict = False async def stream_call(self): - input_values = self.generate_input_values() + input_values = await self.generate_input_values() # Source of knowledge file relations = input_values.get("relations") last_output = None @@ -84,14 +84,14 @@ class ChatKnowledge(BaseChat): ) yield last_output - def generate_input_values(self): + async def generate_input_values(self): if self.space_context: self.prompt_template.template_define = self.space_context["prompt"]["scene"] self.prompt_template.template = self.space_context["prompt"]["template"] - docs = self.rag_engine.search(query=self.current_user_input) - docs = self.knowledge_embedding_client.similar_search( - self.current_user_input, self.top_k - ) + docs = await self.rag_engine.search(query=self.current_user_input) + # docs = self.knowledge_embedding_client.similar_search( + # self.current_user_input, self.top_k + # ) if not docs: raise ValueError( "you have no knowledge space, please add your knowledge space" diff --git a/pilot/server/knowledge/service.py b/pilot/server/knowledge/service.py index f4150fa73..95949d319 100644 --- a/pilot/server/knowledge/service.py +++ b/pilot/server/knowledge/service.py @@ -261,9 +261,6 @@ class KnowledgeService: # docs = engine.search( # "Comparing Curry and James in terms of their positions, playing styles, and achievements in the NBA" # ) - embedding_factory = CFG.SYSTEM_APP.get_component( - "embedding_factory", EmbeddingFactory - ) # update document status doc.status = SyncStatus.RUNNING.name doc.chunk_size = len(chunk_docs)