mirror of
https://github.com/csunny/DB-GPT.git
synced 2026-01-25 14:54:26 +00:00
feat:rag graph
This commit is contained in:
@@ -45,8 +45,7 @@ class RAGGraphEngine:
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize params."""
|
||||
# from llama_index.graph_stores import SimpleGraphStore
|
||||
# from llama_index.graph_stores.types import GraphStore
|
||||
from llama_index.graph_stores import SimpleGraphStore
|
||||
|
||||
# need to set parameters before building index in base class.
|
||||
self.knowledge_source = knowledge_source
|
||||
@@ -55,8 +54,8 @@ class RAGGraphEngine:
|
||||
self.text_splitter = text_splitter
|
||||
self.index_struct = index_struct
|
||||
self.include_embeddings = include_embeddings
|
||||
# self.graph_store = graph_store or SimpleGraphStore()
|
||||
self.graph_store = graph_store
|
||||
self.graph_store = graph_store or SimpleGraphStore()
|
||||
# self.graph_store = graph_store
|
||||
self.max_triplets_per_chunk = max_triplets_per_chunk
|
||||
self._max_object_length = max_object_length
|
||||
self._extract_triplet_fn = extract_triplet_fn
|
||||
@@ -103,14 +102,6 @@ class RAGGraphEngine:
|
||||
)
|
||||
)
|
||||
return triplets
|
||||
# response = self._service_context.llm_predictor.predict(
|
||||
# self.kg_triple_extract_template,
|
||||
# text=text,
|
||||
# )
|
||||
# print(response, flush=True)
|
||||
# return self._parse_triplet_response(
|
||||
# response, max_length=self._max_object_length
|
||||
# )
|
||||
|
||||
def _build_index_from_docs(self, documents: List[Document]) -> KG:
|
||||
"""Build the index from nodes."""
|
||||
@@ -126,7 +117,6 @@ class RAGGraphEngine:
|
||||
self.graph_store.upsert_triplet(*triplet)
|
||||
index_struct.add_node([subj, obj], text_node)
|
||||
|
||||
|
||||
return index_struct
|
||||
|
||||
def search(self, query):
|
||||
@@ -134,4 +124,3 @@ class RAGGraphEngine:
|
||||
|
||||
graph_search = RAGGraphSearch(graph_engine=self)
|
||||
return graph_search.search(query)
|
||||
|
||||
|
||||
@@ -4,6 +4,8 @@ from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import List, Optional, Dict, Any, Set, Callable
|
||||
|
||||
from langchain.schema import Document
|
||||
|
||||
from pilot.graph_engine.node import BaseNode, TextNode, NodeWithScore
|
||||
from pilot.graph_engine.search import BaseSearch, SearchMode
|
||||
from pilot.utils import utils
|
||||
@@ -67,14 +69,14 @@ class RAGGraphSearch(BaseSearch):
|
||||
logger.warn(f"can not to find graph schema: {e}")
|
||||
self._graph_schema = ""
|
||||
|
||||
def _extract_subject_entities(self, query_str: str) -> Set[str]:
|
||||
async def _extract_subject_entities(self, query_str: str) -> Set[str]:
|
||||
"""extract subject entities."""
|
||||
if self.extract_subject_entities_fn is not None:
|
||||
return self.extract_subject_entities_fn(query_str)
|
||||
return await self.extract_subject_entities_fn(query_str)
|
||||
else:
|
||||
return self._extract_entities_by_llm(query_str)
|
||||
return await self._extract_entities_by_llm(query_str)
|
||||
|
||||
def _extract_entities_by_llm(self, text: str) -> Set[str]:
|
||||
async def _extract_entities_by_llm(self, text: str) -> Set[str]:
|
||||
"""extract subject entities from text by llm"""
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.chat_util import llm_chat_response_nostream
|
||||
@@ -86,21 +88,23 @@ class RAGGraphSearch(BaseSearch):
|
||||
"select_param": "entity",
|
||||
"model_name": self.model_name,
|
||||
}
|
||||
loop = utils.get_or_create_event_loop()
|
||||
entities = loop.run_until_complete(
|
||||
llm_chat_response_nostream(
|
||||
ChatScene.ExtractEntity.value(), **{"chat_param": chat_param}
|
||||
)
|
||||
# loop = utils.get_or_create_event_loop()
|
||||
# entities = loop.run_until_complete(
|
||||
# llm_chat_response_nostream(
|
||||
# ChatScene.ExtractEntity.value(), **{"chat_param": chat_param}
|
||||
# )
|
||||
# )
|
||||
return await llm_chat_response_nostream(
|
||||
ChatScene.ExtractEntity.value(), **{"chat_param": chat_param}
|
||||
)
|
||||
return entities
|
||||
|
||||
def _search(
|
||||
async def _search(
|
||||
self,
|
||||
query_str: str,
|
||||
) -> List[NodeWithScore]:
|
||||
) -> List[Document]:
|
||||
"""Get nodes for response."""
|
||||
node_visited = set()
|
||||
keywords = self._extract_subject_entities(query_str)
|
||||
keywords = await self._extract_subject_entities(query_str)
|
||||
print(f"extract entities: {keywords}\n")
|
||||
rel_texts = []
|
||||
cur_rel_map = {}
|
||||
@@ -114,8 +118,8 @@ class RAGGraphSearch(BaseSearch):
|
||||
if node_id in node_visited:
|
||||
continue
|
||||
|
||||
if self._include_text:
|
||||
chunk_indices_count[node_id] += 1
|
||||
# if self._include_text:
|
||||
# chunk_indices_count[node_id] += 1
|
||||
|
||||
node_visited.add(node_id)
|
||||
|
||||
@@ -179,8 +183,11 @@ class RAGGraphSearch(BaseSearch):
|
||||
sorted_nodes_with_scores.append(
|
||||
NodeWithScore(node=rel_text_node, score=DEFAULT_NODE_SCORE)
|
||||
)
|
||||
|
||||
return sorted_nodes_with_scores
|
||||
docs = [
|
||||
Document(page_content=node.text, metadata=node.metadata)
|
||||
for node in sorted_nodes_with_scores
|
||||
]
|
||||
return docs
|
||||
|
||||
def _get_metadata_for_response(
|
||||
self, nodes: List[BaseNode]
|
||||
@@ -190,4 +197,4 @@ class RAGGraphSearch(BaseSearch):
|
||||
if node.metadata is None or "kg_rel_map" not in node.metadata:
|
||||
continue
|
||||
return node.metadata
|
||||
raise ValueError("kg_rel_map must be found in at least one Node.")
|
||||
raise ValueError("kg_rel_map must be found in at least one Node.")
|
||||
|
||||
@@ -21,6 +21,7 @@ WRAP_WIDTH = 70
|
||||
|
||||
class BaseComponent(BaseModel):
|
||||
"""Base component object to caputure class names."""
|
||||
|
||||
"""reference llama-index"""
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -23,7 +23,7 @@ class SearchMode(str, Enum):
|
||||
class BaseSearch(ABC):
|
||||
"""Base Search."""
|
||||
|
||||
def search(self, query: str):
|
||||
async def search(self, query: str):
|
||||
"""Retrieve nodes given query.
|
||||
|
||||
Args:
|
||||
@@ -32,10 +32,10 @@ class BaseSearch(ABC):
|
||||
|
||||
"""
|
||||
# if isinstance(query, str):
|
||||
return self._search(query)
|
||||
return await self._search(query)
|
||||
|
||||
@abstractmethod
|
||||
def _search(self, query: str):
|
||||
async def _search(self, query: str):
|
||||
"""search nodes given query.
|
||||
|
||||
Implemented by the user.
|
||||
|
||||
Reference in New Issue
Block a user