feat:rag graph

This commit is contained in:
aries_ckt 2023-10-16 14:09:04 +08:00
parent 71c31c3e2e
commit b63fa2dfe1
10 changed files with 51 additions and 51 deletions

View File

@ -9,7 +9,7 @@ chat_factory = ChatFactory()
async def llm_chat_response_nostream(chat_scene: str, **chat_param): async def llm_chat_response_nostream(chat_scene: str, **chat_param):
""" llm_chat_response_nostream """ """llm_chat_response_nostream"""
chat: BaseChat = chat_factory.get_implementation(chat_scene, **chat_param) chat: BaseChat = chat_factory.get_implementation(chat_scene, **chat_param)
res = await chat.get_llm_response() res = await chat.get_llm_response()
return res return res

View File

@ -45,8 +45,7 @@ class RAGGraphEngine:
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Initialize params.""" """Initialize params."""
# from llama_index.graph_stores import SimpleGraphStore from llama_index.graph_stores import SimpleGraphStore
# from llama_index.graph_stores.types import GraphStore
# need to set parameters before building index in base class. # need to set parameters before building index in base class.
self.knowledge_source = knowledge_source self.knowledge_source = knowledge_source
@ -55,8 +54,8 @@ class RAGGraphEngine:
self.text_splitter = text_splitter self.text_splitter = text_splitter
self.index_struct = index_struct self.index_struct = index_struct
self.include_embeddings = include_embeddings self.include_embeddings = include_embeddings
# self.graph_store = graph_store or SimpleGraphStore() self.graph_store = graph_store or SimpleGraphStore()
self.graph_store = graph_store # self.graph_store = graph_store
self.max_triplets_per_chunk = max_triplets_per_chunk self.max_triplets_per_chunk = max_triplets_per_chunk
self._max_object_length = max_object_length self._max_object_length = max_object_length
self._extract_triplet_fn = extract_triplet_fn self._extract_triplet_fn = extract_triplet_fn
@ -103,14 +102,6 @@ class RAGGraphEngine:
) )
) )
return triplets return triplets
# response = self._service_context.llm_predictor.predict(
# self.kg_triple_extract_template,
# text=text,
# )
# print(response, flush=True)
# return self._parse_triplet_response(
# response, max_length=self._max_object_length
# )
def _build_index_from_docs(self, documents: List[Document]) -> KG: def _build_index_from_docs(self, documents: List[Document]) -> KG:
"""Build the index from nodes.""" """Build the index from nodes."""
@ -126,7 +117,6 @@ class RAGGraphEngine:
self.graph_store.upsert_triplet(*triplet) self.graph_store.upsert_triplet(*triplet)
index_struct.add_node([subj, obj], text_node) index_struct.add_node([subj, obj], text_node)
return index_struct return index_struct
def search(self, query): def search(self, query):
@ -134,4 +124,3 @@ class RAGGraphEngine:
graph_search = RAGGraphSearch(graph_engine=self) graph_search = RAGGraphSearch(graph_engine=self)
return graph_search.search(query) return graph_search.search(query)

View File

@ -4,6 +4,8 @@ from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import List, Optional, Dict, Any, Set, Callable from typing import List, Optional, Dict, Any, Set, Callable
from langchain.schema import Document
from pilot.graph_engine.node import BaseNode, TextNode, NodeWithScore from pilot.graph_engine.node import BaseNode, TextNode, NodeWithScore
from pilot.graph_engine.search import BaseSearch, SearchMode from pilot.graph_engine.search import BaseSearch, SearchMode
from pilot.utils import utils from pilot.utils import utils
@ -67,14 +69,14 @@ class RAGGraphSearch(BaseSearch):
logger.warn(f"can not to find graph schema: {e}") logger.warn(f"can not to find graph schema: {e}")
self._graph_schema = "" self._graph_schema = ""
def _extract_subject_entities(self, query_str: str) -> Set[str]: async def _extract_subject_entities(self, query_str: str) -> Set[str]:
"""extract subject entities.""" """extract subject entities."""
if self.extract_subject_entities_fn is not None: if self.extract_subject_entities_fn is not None:
return self.extract_subject_entities_fn(query_str) return await self.extract_subject_entities_fn(query_str)
else: else:
return self._extract_entities_by_llm(query_str) return await self._extract_entities_by_llm(query_str)
def _extract_entities_by_llm(self, text: str) -> Set[str]: async def _extract_entities_by_llm(self, text: str) -> Set[str]:
"""extract subject entities from text by llm""" """extract subject entities from text by llm"""
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
from pilot.common.chat_util import llm_chat_response_nostream from pilot.common.chat_util import llm_chat_response_nostream
@ -86,21 +88,23 @@ class RAGGraphSearch(BaseSearch):
"select_param": "entity", "select_param": "entity",
"model_name": self.model_name, "model_name": self.model_name,
} }
loop = utils.get_or_create_event_loop() # loop = utils.get_or_create_event_loop()
entities = loop.run_until_complete( # entities = loop.run_until_complete(
llm_chat_response_nostream( # llm_chat_response_nostream(
# ChatScene.ExtractEntity.value(), **{"chat_param": chat_param}
# )
# )
return await llm_chat_response_nostream(
ChatScene.ExtractEntity.value(), **{"chat_param": chat_param} ChatScene.ExtractEntity.value(), **{"chat_param": chat_param}
) )
)
return entities
def _search( async def _search(
self, self,
query_str: str, query_str: str,
) -> List[NodeWithScore]: ) -> List[Document]:
"""Get nodes for response.""" """Get nodes for response."""
node_visited = set() node_visited = set()
keywords = self._extract_subject_entities(query_str) keywords = await self._extract_subject_entities(query_str)
print(f"extract entities: {keywords}\n") print(f"extract entities: {keywords}\n")
rel_texts = [] rel_texts = []
cur_rel_map = {} cur_rel_map = {}
@ -114,8 +118,8 @@ class RAGGraphSearch(BaseSearch):
if node_id in node_visited: if node_id in node_visited:
continue continue
if self._include_text: # if self._include_text:
chunk_indices_count[node_id] += 1 # chunk_indices_count[node_id] += 1
node_visited.add(node_id) node_visited.add(node_id)
@ -179,8 +183,11 @@ class RAGGraphSearch(BaseSearch):
sorted_nodes_with_scores.append( sorted_nodes_with_scores.append(
NodeWithScore(node=rel_text_node, score=DEFAULT_NODE_SCORE) NodeWithScore(node=rel_text_node, score=DEFAULT_NODE_SCORE)
) )
docs = [
return sorted_nodes_with_scores Document(page_content=node.text, metadata=node.metadata)
for node in sorted_nodes_with_scores
]
return docs
def _get_metadata_for_response( def _get_metadata_for_response(
self, nodes: List[BaseNode] self, nodes: List[BaseNode]

View File

@ -21,6 +21,7 @@ WRAP_WIDTH = 70
class BaseComponent(BaseModel): class BaseComponent(BaseModel):
"""Base component object to caputure class names.""" """Base component object to caputure class names."""
"""reference llama-index""" """reference llama-index"""
@classmethod @classmethod

View File

@ -23,7 +23,7 @@ class SearchMode(str, Enum):
class BaseSearch(ABC): class BaseSearch(ABC):
"""Base Search.""" """Base Search."""
def search(self, query: str): async def search(self, query: str):
"""Retrieve nodes given query. """Retrieve nodes given query.
Args: Args:
@ -32,10 +32,10 @@ class BaseSearch(ABC):
""" """
# if isinstance(query, str): # if isinstance(query, str):
return self._search(query) return await self._search(query)
@abstractmethod @abstractmethod
def _search(self, query: str): async def _search(self, query: str):
"""search nodes given query. """search nodes given query.
Implemented by the user. Implemented by the user.

View File

@ -105,8 +105,14 @@ class BaseChat(ABC):
speak_to_user = prompt_define_response speak_to_user = prompt_define_response
return speak_to_user return speak_to_user
def __call_base(self): async def __call_base(self):
input_values = self.generate_input_values() import inspect
input_values = (
await self.generate_input_values()
if inspect.isawaitable(self.generate_input_values())
else self.generate_input_values()
)
### Chat sequence advance ### Chat sequence advance
self.current_message.chat_order = len(self.history_message) + 1 self.current_message.chat_order = len(self.history_message) + 1
self.current_message.add_user_message(self.current_user_input) self.current_message.add_user_message(self.current_user_input)
@ -146,7 +152,7 @@ class BaseChat(ABC):
async def stream_call(self): async def stream_call(self):
# TODO Retry when server connection error # TODO Retry when server connection error
payload = self.__call_base() payload = await self.__call_base()
self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11 self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11
logger.info(f"Request: \n{payload}") logger.info(f"Request: \n{payload}")
@ -234,7 +240,7 @@ class BaseChat(ABC):
return self.current_ai_response() return self.current_ai_response()
async def get_llm_response(self): async def get_llm_response(self):
payload = self.__call_base() payload = await self.__call_base()
logger.info(f"Request: \n{payload}") logger.info(f"Request: \n{payload}")
ai_response_text = "" ai_response_text = ""
try: try:

View File

@ -24,7 +24,7 @@ class ExtractEntity(BaseChat):
self.user_input = chat_param["current_user_input"] self.user_input = chat_param["current_user_input"]
self.extract_mode = chat_param["select_param"] self.extract_mode = chat_param["select_param"]
def generate_input_values(self): async def generate_input_values(self):
input_values = { input_values = {
"text": self.user_input, "text": self.user_input,
} }

View File

@ -24,7 +24,7 @@ class ExtractTriplet(BaseChat):
self.user_input = chat_param["current_user_input"] self.user_input = chat_param["current_user_input"]
self.extract_mode = chat_param["select_param"] self.extract_mode = chat_param["select_param"]
def generate_input_values(self): async def generate_input_values(self):
input_values = { input_values = {
"text": self.user_input, "text": self.user_input,
} }

View File

@ -64,7 +64,7 @@ class ChatKnowledge(BaseChat):
self.prompt_template.template_is_strict = False self.prompt_template.template_is_strict = False
async def stream_call(self): async def stream_call(self):
input_values = self.generate_input_values() input_values = await self.generate_input_values()
# Source of knowledge file # Source of knowledge file
relations = input_values.get("relations") relations = input_values.get("relations")
last_output = None last_output = None
@ -84,14 +84,14 @@ class ChatKnowledge(BaseChat):
) )
yield last_output yield last_output
def generate_input_values(self): async def generate_input_values(self):
if self.space_context: if self.space_context:
self.prompt_template.template_define = self.space_context["prompt"]["scene"] self.prompt_template.template_define = self.space_context["prompt"]["scene"]
self.prompt_template.template = self.space_context["prompt"]["template"] self.prompt_template.template = self.space_context["prompt"]["template"]
docs = self.rag_engine.search(query=self.current_user_input) docs = await self.rag_engine.search(query=self.current_user_input)
docs = self.knowledge_embedding_client.similar_search( # docs = self.knowledge_embedding_client.similar_search(
self.current_user_input, self.top_k # self.current_user_input, self.top_k
) # )
if not docs: if not docs:
raise ValueError( raise ValueError(
"you have no knowledge space, please add your knowledge space" "you have no knowledge space, please add your knowledge space"

View File

@ -261,9 +261,6 @@ class KnowledgeService:
# docs = engine.search( # docs = engine.search(
# "Comparing Curry and James in terms of their positions, playing styles, and achievements in the NBA" # "Comparing Curry and James in terms of their positions, playing styles, and achievements in the NBA"
# ) # )
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
)
# update document status # update document status
doc.status = SyncStatus.RUNNING.name doc.status = SyncStatus.RUNNING.name
doc.chunk_size = len(chunk_docs) doc.chunk_size = len(chunk_docs)