mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-29 05:18:47 +00:00
198 lines
7.7 KiB
Python
198 lines
7.7 KiB
Python
import logging
|
|
import os
|
|
from collections import defaultdict
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from typing import List, Optional, Dict, Any, Set, Callable
|
|
|
|
from langchain.schema import Document
|
|
|
|
from pilot.graph_engine.node import BaseNode, TextNode, NodeWithScore
|
|
from pilot.graph_engine.search import BaseSearch, SearchMode
|
|
|
|
logger = logging.getLogger(__name__)
|
|
DEFAULT_NODE_SCORE = 1000.0
|
|
GLOBAL_EXPLORE_NODE_LIMIT = 3
|
|
REL_TEXT_LIMIT = 30
|
|
|
|
|
|
class RAGGraphSearch(BaseSearch):
|
|
"""RAG Graph Search.
|
|
|
|
args:
|
|
graph_engine RAGGraphEngine.
|
|
model_name (str): model name
|
|
(see :ref:`Prompt-Templates`).
|
|
text_qa_template (Optional[BasePromptTemplate]): A Question Answering Prompt
|
|
(see :ref:`Prompt-Templates`).
|
|
max_keywords_per_query (int): Maximum number of keywords to extract from query.
|
|
num_chunks_per_query (int): Maximum number of text chunks to query.
|
|
search_mode (Optional[SearchMode]): Specifies whether to use keyowrds, default SearchMode.KEYWORD
|
|
embeddings, or both to find relevant triplets. Should be one of "keyword",
|
|
"embedding", or "hybrid".
|
|
graph_store_query_depth (int): The depth of the graph store query.
|
|
extract_subject_entities_fn (Optional[Callback]): extract_subject_entities callback.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
graph_engine,
|
|
model_name: str = None,
|
|
max_keywords_per_query: int = 10,
|
|
num_chunks_per_query: int = 10,
|
|
search_mode: Optional[SearchMode] = SearchMode.KEYWORD,
|
|
graph_store_query_depth: int = 2,
|
|
extract_subject_entities_fn: Optional[Callable] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Initialize params."""
|
|
from pilot.graph_engine.graph_engine import RAGGraphEngine
|
|
|
|
self.graph_engine: RAGGraphEngine = graph_engine
|
|
self.model_name = model_name or self.graph_engine.model_name
|
|
self._index_struct = self.graph_engine.index_struct
|
|
self.max_keywords_per_query = max_keywords_per_query
|
|
self.num_chunks_per_query = num_chunks_per_query
|
|
self._search_mode = search_mode
|
|
|
|
self._graph_store = self.graph_engine.graph_store
|
|
self.graph_store_query_depth = graph_store_query_depth
|
|
self._verbose = kwargs.get("verbose", False)
|
|
refresh_schema = kwargs.get("refresh_schema", False)
|
|
self.extract_subject_entities_fn = extract_subject_entities_fn
|
|
self.executor = ThreadPoolExecutor(max_workers=os.cpu_count() * 5)
|
|
try:
|
|
self._graph_schema = self._graph_store.get_schema(refresh=refresh_schema)
|
|
except NotImplementedError:
|
|
self._graph_schema = ""
|
|
except Exception as e:
|
|
logger.warn(f"can not to find graph schema: {e}")
|
|
self._graph_schema = ""
|
|
|
|
async def _extract_subject_entities(self, query_str: str) -> Set[str]:
|
|
"""extract subject entities."""
|
|
if self.extract_subject_entities_fn is not None:
|
|
return await self.extract_subject_entities_fn(query_str)
|
|
else:
|
|
return await self._extract_entities_by_llm(query_str)
|
|
|
|
async def _extract_entities_by_llm(self, text: str) -> Set[str]:
|
|
"""extract subject entities from text by llm"""
|
|
from pilot.scene.base import ChatScene
|
|
from pilot.common.chat_util import llm_chat_response_nostream
|
|
import uuid
|
|
|
|
chat_param = {
|
|
"chat_session_id": uuid.uuid1(),
|
|
"current_user_input": text,
|
|
"select_param": "entity",
|
|
"model_name": self.model_name,
|
|
}
|
|
# loop = utils.get_or_create_event_loop()
|
|
# entities = loop.run_until_complete(
|
|
# llm_chat_response_nostream(
|
|
# ChatScene.ExtractEntity.value(), **{"chat_param": chat_param}
|
|
# )
|
|
# )
|
|
return await llm_chat_response_nostream(
|
|
ChatScene.ExtractEntity.value(), **{"chat_param": chat_param}
|
|
)
|
|
|
|
async def _search(
|
|
self,
|
|
query_str: str,
|
|
) -> List[Document]:
|
|
"""Get nodes for response."""
|
|
node_visited = set()
|
|
keywords = await self._extract_subject_entities(query_str)
|
|
print(f"extract entities: {keywords}\n")
|
|
rel_texts = []
|
|
cur_rel_map = {}
|
|
chunk_indices_count: Dict[str, int] = defaultdict(int)
|
|
if self._search_mode != SearchMode.EMBEDDING:
|
|
for keyword in keywords:
|
|
keyword = keyword.lower()
|
|
subjs = set((keyword,))
|
|
# node_ids = self._index_struct.search_node_by_keyword(keyword)
|
|
# for node_id in node_ids[:GLOBAL_EXPLORE_NODE_LIMIT]:
|
|
# if node_id in node_visited:
|
|
# continue
|
|
#
|
|
# # if self._include_text:
|
|
# # chunk_indices_count[node_id] += 1
|
|
#
|
|
# node_visited.add(node_id)
|
|
|
|
rel_map = self._graph_store.get_rel_map(
|
|
list(subjs), self.graph_store_query_depth
|
|
)
|
|
logger.debug(f"rel_map: {rel_map}")
|
|
|
|
if not rel_map:
|
|
continue
|
|
rel_texts.extend(
|
|
[
|
|
str(rel_obj)
|
|
for rel_objs in rel_map.values()
|
|
for rel_obj in rel_objs
|
|
]
|
|
)
|
|
cur_rel_map.update(rel_map)
|
|
|
|
sorted_nodes_with_scores = []
|
|
if not rel_texts:
|
|
logger.info("> No relationships found, returning nodes found by keywords.")
|
|
if len(sorted_nodes_with_scores) == 0:
|
|
logger.info("> No nodes found by keywords, returning empty response.")
|
|
return [Document(page_content="No relationships found.")]
|
|
|
|
# add relationships as Node
|
|
# TODO: make initial text customizable
|
|
rel_initial_text = (
|
|
f"The following are knowledge sequence in max depth"
|
|
f" {self.graph_store_query_depth} "
|
|
f"in the form of directed graph like:\n"
|
|
f"`subject -[predicate]->, object, <-[predicate_next_hop]-,"
|
|
f" object_next_hop ...`"
|
|
)
|
|
rel_info = [rel_initial_text] + rel_texts
|
|
rel_node_info = {
|
|
"kg_rel_texts": rel_texts,
|
|
"kg_rel_map": cur_rel_map,
|
|
}
|
|
if self._graph_schema != "":
|
|
rel_node_info["kg_schema"] = {"schema": self._graph_schema}
|
|
rel_info_text = "\n".join(
|
|
[
|
|
str(item)
|
|
for sublist in rel_info
|
|
for item in (sublist if isinstance(sublist, list) else [sublist])
|
|
]
|
|
)
|
|
if self._verbose:
|
|
print(f"KG context:\n{rel_info_text}\n", color="blue")
|
|
rel_text_node = TextNode(
|
|
text=rel_info_text,
|
|
metadata=rel_node_info,
|
|
excluded_embed_metadata_keys=["kg_rel_map", "kg_rel_texts"],
|
|
excluded_llm_metadata_keys=["kg_rel_map", "kg_rel_texts"],
|
|
)
|
|
# this node is constructed from rel_texts, give high confidence to avoid cutoff
|
|
sorted_nodes_with_scores.append(
|
|
NodeWithScore(node=rel_text_node, score=DEFAULT_NODE_SCORE)
|
|
)
|
|
docs = [
|
|
Document(page_content=node.text, metadata=node.metadata)
|
|
for node in sorted_nodes_with_scores
|
|
]
|
|
return docs
|
|
|
|
def _get_metadata_for_response(
|
|
self, nodes: List[BaseNode]
|
|
) -> Optional[Dict[str, Any]]:
|
|
"""Get metadata for response."""
|
|
for node in nodes:
|
|
if node.metadata is None or "kg_rel_map" not in node.metadata:
|
|
continue
|
|
return node.metadata
|
|
raise ValueError("kg_rel_map must be found in at least one Node.")
|