mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-19 16:57:21 +00:00
feat:rag graph
This commit is contained in:
parent
71c31c3e2e
commit
b63fa2dfe1
@ -45,8 +45,7 @@ class RAGGraphEngine:
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize params."""
|
||||
# from llama_index.graph_stores import SimpleGraphStore
|
||||
# from llama_index.graph_stores.types import GraphStore
|
||||
from llama_index.graph_stores import SimpleGraphStore
|
||||
|
||||
# need to set parameters before building index in base class.
|
||||
self.knowledge_source = knowledge_source
|
||||
@ -55,8 +54,8 @@ class RAGGraphEngine:
|
||||
self.text_splitter = text_splitter
|
||||
self.index_struct = index_struct
|
||||
self.include_embeddings = include_embeddings
|
||||
# self.graph_store = graph_store or SimpleGraphStore()
|
||||
self.graph_store = graph_store
|
||||
self.graph_store = graph_store or SimpleGraphStore()
|
||||
# self.graph_store = graph_store
|
||||
self.max_triplets_per_chunk = max_triplets_per_chunk
|
||||
self._max_object_length = max_object_length
|
||||
self._extract_triplet_fn = extract_triplet_fn
|
||||
@ -103,14 +102,6 @@ class RAGGraphEngine:
|
||||
)
|
||||
)
|
||||
return triplets
|
||||
# response = self._service_context.llm_predictor.predict(
|
||||
# self.kg_triple_extract_template,
|
||||
# text=text,
|
||||
# )
|
||||
# print(response, flush=True)
|
||||
# return self._parse_triplet_response(
|
||||
# response, max_length=self._max_object_length
|
||||
# )
|
||||
|
||||
def _build_index_from_docs(self, documents: List[Document]) -> KG:
|
||||
"""Build the index from nodes."""
|
||||
@ -126,7 +117,6 @@ class RAGGraphEngine:
|
||||
self.graph_store.upsert_triplet(*triplet)
|
||||
index_struct.add_node([subj, obj], text_node)
|
||||
|
||||
|
||||
return index_struct
|
||||
|
||||
def search(self, query):
|
||||
@ -134,4 +124,3 @@ class RAGGraphEngine:
|
||||
|
||||
graph_search = RAGGraphSearch(graph_engine=self)
|
||||
return graph_search.search(query)
|
||||
|
||||
|
@ -4,6 +4,8 @@ from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import List, Optional, Dict, Any, Set, Callable
|
||||
|
||||
from langchain.schema import Document
|
||||
|
||||
from pilot.graph_engine.node import BaseNode, TextNode, NodeWithScore
|
||||
from pilot.graph_engine.search import BaseSearch, SearchMode
|
||||
from pilot.utils import utils
|
||||
@ -67,14 +69,14 @@ class RAGGraphSearch(BaseSearch):
|
||||
logger.warn(f"can not to find graph schema: {e}")
|
||||
self._graph_schema = ""
|
||||
|
||||
def _extract_subject_entities(self, query_str: str) -> Set[str]:
|
||||
async def _extract_subject_entities(self, query_str: str) -> Set[str]:
|
||||
"""extract subject entities."""
|
||||
if self.extract_subject_entities_fn is not None:
|
||||
return self.extract_subject_entities_fn(query_str)
|
||||
return await self.extract_subject_entities_fn(query_str)
|
||||
else:
|
||||
return self._extract_entities_by_llm(query_str)
|
||||
return await self._extract_entities_by_llm(query_str)
|
||||
|
||||
def _extract_entities_by_llm(self, text: str) -> Set[str]:
|
||||
async def _extract_entities_by_llm(self, text: str) -> Set[str]:
|
||||
"""extract subject entities from text by llm"""
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.chat_util import llm_chat_response_nostream
|
||||
@ -86,21 +88,23 @@ class RAGGraphSearch(BaseSearch):
|
||||
"select_param": "entity",
|
||||
"model_name": self.model_name,
|
||||
}
|
||||
loop = utils.get_or_create_event_loop()
|
||||
entities = loop.run_until_complete(
|
||||
llm_chat_response_nostream(
|
||||
# loop = utils.get_or_create_event_loop()
|
||||
# entities = loop.run_until_complete(
|
||||
# llm_chat_response_nostream(
|
||||
# ChatScene.ExtractEntity.value(), **{"chat_param": chat_param}
|
||||
# )
|
||||
# )
|
||||
return await llm_chat_response_nostream(
|
||||
ChatScene.ExtractEntity.value(), **{"chat_param": chat_param}
|
||||
)
|
||||
)
|
||||
return entities
|
||||
|
||||
def _search(
|
||||
async def _search(
|
||||
self,
|
||||
query_str: str,
|
||||
) -> List[NodeWithScore]:
|
||||
) -> List[Document]:
|
||||
"""Get nodes for response."""
|
||||
node_visited = set()
|
||||
keywords = self._extract_subject_entities(query_str)
|
||||
keywords = await self._extract_subject_entities(query_str)
|
||||
print(f"extract entities: {keywords}\n")
|
||||
rel_texts = []
|
||||
cur_rel_map = {}
|
||||
@ -114,8 +118,8 @@ class RAGGraphSearch(BaseSearch):
|
||||
if node_id in node_visited:
|
||||
continue
|
||||
|
||||
if self._include_text:
|
||||
chunk_indices_count[node_id] += 1
|
||||
# if self._include_text:
|
||||
# chunk_indices_count[node_id] += 1
|
||||
|
||||
node_visited.add(node_id)
|
||||
|
||||
@ -179,8 +183,11 @@ class RAGGraphSearch(BaseSearch):
|
||||
sorted_nodes_with_scores.append(
|
||||
NodeWithScore(node=rel_text_node, score=DEFAULT_NODE_SCORE)
|
||||
)
|
||||
|
||||
return sorted_nodes_with_scores
|
||||
docs = [
|
||||
Document(page_content=node.text, metadata=node.metadata)
|
||||
for node in sorted_nodes_with_scores
|
||||
]
|
||||
return docs
|
||||
|
||||
def _get_metadata_for_response(
|
||||
self, nodes: List[BaseNode]
|
||||
|
@ -21,6 +21,7 @@ WRAP_WIDTH = 70
|
||||
|
||||
class BaseComponent(BaseModel):
|
||||
"""Base component object to caputure class names."""
|
||||
|
||||
"""reference llama-index"""
|
||||
|
||||
@classmethod
|
||||
|
@ -23,7 +23,7 @@ class SearchMode(str, Enum):
|
||||
class BaseSearch(ABC):
|
||||
"""Base Search."""
|
||||
|
||||
def search(self, query: str):
|
||||
async def search(self, query: str):
|
||||
"""Retrieve nodes given query.
|
||||
|
||||
Args:
|
||||
@ -32,10 +32,10 @@ class BaseSearch(ABC):
|
||||
|
||||
"""
|
||||
# if isinstance(query, str):
|
||||
return self._search(query)
|
||||
return await self._search(query)
|
||||
|
||||
@abstractmethod
|
||||
def _search(self, query: str):
|
||||
async def _search(self, query: str):
|
||||
"""search nodes given query.
|
||||
|
||||
Implemented by the user.
|
||||
|
@ -105,8 +105,14 @@ class BaseChat(ABC):
|
||||
speak_to_user = prompt_define_response
|
||||
return speak_to_user
|
||||
|
||||
def __call_base(self):
|
||||
input_values = self.generate_input_values()
|
||||
async def __call_base(self):
|
||||
import inspect
|
||||
|
||||
input_values = (
|
||||
await self.generate_input_values()
|
||||
if inspect.isawaitable(self.generate_input_values())
|
||||
else self.generate_input_values()
|
||||
)
|
||||
### Chat sequence advance
|
||||
self.current_message.chat_order = len(self.history_message) + 1
|
||||
self.current_message.add_user_message(self.current_user_input)
|
||||
@ -146,7 +152,7 @@ class BaseChat(ABC):
|
||||
|
||||
async def stream_call(self):
|
||||
# TODO Retry when server connection error
|
||||
payload = self.__call_base()
|
||||
payload = await self.__call_base()
|
||||
|
||||
self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11
|
||||
logger.info(f"Request: \n{payload}")
|
||||
@ -234,7 +240,7 @@ class BaseChat(ABC):
|
||||
return self.current_ai_response()
|
||||
|
||||
async def get_llm_response(self):
|
||||
payload = self.__call_base()
|
||||
payload = await self.__call_base()
|
||||
logger.info(f"Request: \n{payload}")
|
||||
ai_response_text = ""
|
||||
try:
|
||||
|
@ -24,7 +24,7 @@ class ExtractEntity(BaseChat):
|
||||
self.user_input = chat_param["current_user_input"]
|
||||
self.extract_mode = chat_param["select_param"]
|
||||
|
||||
def generate_input_values(self):
|
||||
async def generate_input_values(self):
|
||||
input_values = {
|
||||
"text": self.user_input,
|
||||
}
|
||||
|
@ -24,7 +24,7 @@ class ExtractTriplet(BaseChat):
|
||||
self.user_input = chat_param["current_user_input"]
|
||||
self.extract_mode = chat_param["select_param"]
|
||||
|
||||
def generate_input_values(self):
|
||||
async def generate_input_values(self):
|
||||
input_values = {
|
||||
"text": self.user_input,
|
||||
}
|
||||
|
@ -64,7 +64,7 @@ class ChatKnowledge(BaseChat):
|
||||
self.prompt_template.template_is_strict = False
|
||||
|
||||
async def stream_call(self):
|
||||
input_values = self.generate_input_values()
|
||||
input_values = await self.generate_input_values()
|
||||
# Source of knowledge file
|
||||
relations = input_values.get("relations")
|
||||
last_output = None
|
||||
@ -84,14 +84,14 @@ class ChatKnowledge(BaseChat):
|
||||
)
|
||||
yield last_output
|
||||
|
||||
def generate_input_values(self):
|
||||
async def generate_input_values(self):
|
||||
if self.space_context:
|
||||
self.prompt_template.template_define = self.space_context["prompt"]["scene"]
|
||||
self.prompt_template.template = self.space_context["prompt"]["template"]
|
||||
docs = self.rag_engine.search(query=self.current_user_input)
|
||||
docs = self.knowledge_embedding_client.similar_search(
|
||||
self.current_user_input, self.top_k
|
||||
)
|
||||
docs = await self.rag_engine.search(query=self.current_user_input)
|
||||
# docs = self.knowledge_embedding_client.similar_search(
|
||||
# self.current_user_input, self.top_k
|
||||
# )
|
||||
if not docs:
|
||||
raise ValueError(
|
||||
"you have no knowledge space, please add your knowledge space"
|
||||
|
@ -261,9 +261,6 @@ class KnowledgeService:
|
||||
# docs = engine.search(
|
||||
# "Comparing Curry and James in terms of their positions, playing styles, and achievements in the NBA"
|
||||
# )
|
||||
embedding_factory = CFG.SYSTEM_APP.get_component(
|
||||
"embedding_factory", EmbeddingFactory
|
||||
)
|
||||
# update document status
|
||||
doc.status = SyncStatus.RUNNING.name
|
||||
doc.chunk_size = len(chunk_docs)
|
||||
|
Loading…
Reference in New Issue
Block a user