feat:rag graph

This commit is contained in:
aries_ckt 2023-10-16 14:09:04 +08:00
parent 71c31c3e2e
commit b63fa2dfe1
10 changed files with 51 additions and 51 deletions

View File

@ -45,8 +45,7 @@ class RAGGraphEngine:
**kwargs: Any,
) -> None:
"""Initialize params."""
# from llama_index.graph_stores import SimpleGraphStore
# from llama_index.graph_stores.types import GraphStore
from llama_index.graph_stores import SimpleGraphStore
# need to set parameters before building index in base class.
self.knowledge_source = knowledge_source
@ -55,8 +54,8 @@ class RAGGraphEngine:
self.text_splitter = text_splitter
self.index_struct = index_struct
self.include_embeddings = include_embeddings
# self.graph_store = graph_store or SimpleGraphStore()
self.graph_store = graph_store
self.graph_store = graph_store or SimpleGraphStore()
# self.graph_store = graph_store
self.max_triplets_per_chunk = max_triplets_per_chunk
self._max_object_length = max_object_length
self._extract_triplet_fn = extract_triplet_fn
@ -103,14 +102,6 @@ class RAGGraphEngine:
)
)
return triplets
# response = self._service_context.llm_predictor.predict(
# self.kg_triple_extract_template,
# text=text,
# )
# print(response, flush=True)
# return self._parse_triplet_response(
# response, max_length=self._max_object_length
# )
def _build_index_from_docs(self, documents: List[Document]) -> KG:
"""Build the index from nodes."""
@ -126,7 +117,6 @@ class RAGGraphEngine:
self.graph_store.upsert_triplet(*triplet)
index_struct.add_node([subj, obj], text_node)
return index_struct
def search(self, query):
@ -134,4 +124,3 @@ class RAGGraphEngine:
graph_search = RAGGraphSearch(graph_engine=self)
return graph_search.search(query)

View File

@ -4,6 +4,8 @@ from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from typing import List, Optional, Dict, Any, Set, Callable
from langchain.schema import Document
from pilot.graph_engine.node import BaseNode, TextNode, NodeWithScore
from pilot.graph_engine.search import BaseSearch, SearchMode
from pilot.utils import utils
@ -67,14 +69,14 @@ class RAGGraphSearch(BaseSearch):
logger.warn(f"can not to find graph schema: {e}")
self._graph_schema = ""
def _extract_subject_entities(self, query_str: str) -> Set[str]:
async def _extract_subject_entities(self, query_str: str) -> Set[str]:
"""extract subject entities."""
if self.extract_subject_entities_fn is not None:
return self.extract_subject_entities_fn(query_str)
return await self.extract_subject_entities_fn(query_str)
else:
return self._extract_entities_by_llm(query_str)
return await self._extract_entities_by_llm(query_str)
def _extract_entities_by_llm(self, text: str) -> Set[str]:
async def _extract_entities_by_llm(self, text: str) -> Set[str]:
"""extract subject entities from text by llm"""
from pilot.scene.base import ChatScene
from pilot.common.chat_util import llm_chat_response_nostream
@ -86,21 +88,23 @@ class RAGGraphSearch(BaseSearch):
"select_param": "entity",
"model_name": self.model_name,
}
loop = utils.get_or_create_event_loop()
entities = loop.run_until_complete(
llm_chat_response_nostream(
# loop = utils.get_or_create_event_loop()
# entities = loop.run_until_complete(
# llm_chat_response_nostream(
# ChatScene.ExtractEntity.value(), **{"chat_param": chat_param}
# )
# )
return await llm_chat_response_nostream(
ChatScene.ExtractEntity.value(), **{"chat_param": chat_param}
)
)
return entities
def _search(
async def _search(
self,
query_str: str,
) -> List[NodeWithScore]:
) -> List[Document]:
"""Get nodes for response."""
node_visited = set()
keywords = self._extract_subject_entities(query_str)
keywords = await self._extract_subject_entities(query_str)
print(f"extract entities: {keywords}\n")
rel_texts = []
cur_rel_map = {}
@ -114,8 +118,8 @@ class RAGGraphSearch(BaseSearch):
if node_id in node_visited:
continue
if self._include_text:
chunk_indices_count[node_id] += 1
# if self._include_text:
# chunk_indices_count[node_id] += 1
node_visited.add(node_id)
@ -179,8 +183,11 @@ class RAGGraphSearch(BaseSearch):
sorted_nodes_with_scores.append(
NodeWithScore(node=rel_text_node, score=DEFAULT_NODE_SCORE)
)
return sorted_nodes_with_scores
docs = [
Document(page_content=node.text, metadata=node.metadata)
for node in sorted_nodes_with_scores
]
return docs
def _get_metadata_for_response(
self, nodes: List[BaseNode]

View File

@ -21,6 +21,7 @@ WRAP_WIDTH = 70
class BaseComponent(BaseModel):
"""Base component object to caputure class names."""
"""reference llama-index"""
@classmethod

View File

@ -23,7 +23,7 @@ class SearchMode(str, Enum):
class BaseSearch(ABC):
"""Base Search."""
def search(self, query: str):
async def search(self, query: str):
"""Retrieve nodes given query.
Args:
@ -32,10 +32,10 @@ class BaseSearch(ABC):
"""
# if isinstance(query, str):
return self._search(query)
return await self._search(query)
@abstractmethod
def _search(self, query: str):
async def _search(self, query: str):
"""search nodes given query.
Implemented by the user.

View File

@ -105,8 +105,14 @@ class BaseChat(ABC):
speak_to_user = prompt_define_response
return speak_to_user
def __call_base(self):
input_values = self.generate_input_values()
async def __call_base(self):
import inspect
input_values = (
await self.generate_input_values()
if inspect.isawaitable(self.generate_input_values())
else self.generate_input_values()
)
### Chat sequence advance
self.current_message.chat_order = len(self.history_message) + 1
self.current_message.add_user_message(self.current_user_input)
@ -146,7 +152,7 @@ class BaseChat(ABC):
async def stream_call(self):
# TODO Retry when server connection error
payload = self.__call_base()
payload = await self.__call_base()
self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11
logger.info(f"Request: \n{payload}")
@ -234,7 +240,7 @@ class BaseChat(ABC):
return self.current_ai_response()
async def get_llm_response(self):
payload = self.__call_base()
payload = await self.__call_base()
logger.info(f"Request: \n{payload}")
ai_response_text = ""
try:

View File

@ -24,7 +24,7 @@ class ExtractEntity(BaseChat):
self.user_input = chat_param["current_user_input"]
self.extract_mode = chat_param["select_param"]
def generate_input_values(self):
async def generate_input_values(self):
input_values = {
"text": self.user_input,
}

View File

@ -24,7 +24,7 @@ class ExtractTriplet(BaseChat):
self.user_input = chat_param["current_user_input"]
self.extract_mode = chat_param["select_param"]
def generate_input_values(self):
async def generate_input_values(self):
input_values = {
"text": self.user_input,
}

View File

@ -64,7 +64,7 @@ class ChatKnowledge(BaseChat):
self.prompt_template.template_is_strict = False
async def stream_call(self):
input_values = self.generate_input_values()
input_values = await self.generate_input_values()
# Source of knowledge file
relations = input_values.get("relations")
last_output = None
@ -84,14 +84,14 @@ class ChatKnowledge(BaseChat):
)
yield last_output
def generate_input_values(self):
async def generate_input_values(self):
if self.space_context:
self.prompt_template.template_define = self.space_context["prompt"]["scene"]
self.prompt_template.template = self.space_context["prompt"]["template"]
docs = self.rag_engine.search(query=self.current_user_input)
docs = self.knowledge_embedding_client.similar_search(
self.current_user_input, self.top_k
)
docs = await self.rag_engine.search(query=self.current_user_input)
# docs = self.knowledge_embedding_client.similar_search(
# self.current_user_input, self.top_k
# )
if not docs:
raise ValueError(
"you have no knowledge space, please add your knowledge space"

View File

@ -261,9 +261,6 @@ class KnowledgeService:
# docs = engine.search(
# "Comparing Curry and James in terms of their positions, playing styles, and achievements in the NBA"
# )
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
)
# update document status
doc.status = SyncStatus.RUNNING.name
doc.chunk_size = len(chunk_docs)