mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-20 01:07:15 +00:00
feat:rag graph
This commit is contained in:
parent
71c31c3e2e
commit
b63fa2dfe1
@ -9,7 +9,7 @@ chat_factory = ChatFactory()
|
|||||||
|
|
||||||
|
|
||||||
async def llm_chat_response_nostream(chat_scene: str, **chat_param):
|
async def llm_chat_response_nostream(chat_scene: str, **chat_param):
|
||||||
""" llm_chat_response_nostream """
|
"""llm_chat_response_nostream"""
|
||||||
chat: BaseChat = chat_factory.get_implementation(chat_scene, **chat_param)
|
chat: BaseChat = chat_factory.get_implementation(chat_scene, **chat_param)
|
||||||
res = await chat.get_llm_response()
|
res = await chat.get_llm_response()
|
||||||
return res
|
return res
|
||||||
|
@ -45,8 +45,7 @@ class RAGGraphEngine:
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize params."""
|
"""Initialize params."""
|
||||||
# from llama_index.graph_stores import SimpleGraphStore
|
from llama_index.graph_stores import SimpleGraphStore
|
||||||
# from llama_index.graph_stores.types import GraphStore
|
|
||||||
|
|
||||||
# need to set parameters before building index in base class.
|
# need to set parameters before building index in base class.
|
||||||
self.knowledge_source = knowledge_source
|
self.knowledge_source = knowledge_source
|
||||||
@ -55,8 +54,8 @@ class RAGGraphEngine:
|
|||||||
self.text_splitter = text_splitter
|
self.text_splitter = text_splitter
|
||||||
self.index_struct = index_struct
|
self.index_struct = index_struct
|
||||||
self.include_embeddings = include_embeddings
|
self.include_embeddings = include_embeddings
|
||||||
# self.graph_store = graph_store or SimpleGraphStore()
|
self.graph_store = graph_store or SimpleGraphStore()
|
||||||
self.graph_store = graph_store
|
# self.graph_store = graph_store
|
||||||
self.max_triplets_per_chunk = max_triplets_per_chunk
|
self.max_triplets_per_chunk = max_triplets_per_chunk
|
||||||
self._max_object_length = max_object_length
|
self._max_object_length = max_object_length
|
||||||
self._extract_triplet_fn = extract_triplet_fn
|
self._extract_triplet_fn = extract_triplet_fn
|
||||||
@ -103,14 +102,6 @@ class RAGGraphEngine:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
return triplets
|
return triplets
|
||||||
# response = self._service_context.llm_predictor.predict(
|
|
||||||
# self.kg_triple_extract_template,
|
|
||||||
# text=text,
|
|
||||||
# )
|
|
||||||
# print(response, flush=True)
|
|
||||||
# return self._parse_triplet_response(
|
|
||||||
# response, max_length=self._max_object_length
|
|
||||||
# )
|
|
||||||
|
|
||||||
def _build_index_from_docs(self, documents: List[Document]) -> KG:
|
def _build_index_from_docs(self, documents: List[Document]) -> KG:
|
||||||
"""Build the index from nodes."""
|
"""Build the index from nodes."""
|
||||||
@ -126,7 +117,6 @@ class RAGGraphEngine:
|
|||||||
self.graph_store.upsert_triplet(*triplet)
|
self.graph_store.upsert_triplet(*triplet)
|
||||||
index_struct.add_node([subj, obj], text_node)
|
index_struct.add_node([subj, obj], text_node)
|
||||||
|
|
||||||
|
|
||||||
return index_struct
|
return index_struct
|
||||||
|
|
||||||
def search(self, query):
|
def search(self, query):
|
||||||
@ -134,4 +124,3 @@ class RAGGraphEngine:
|
|||||||
|
|
||||||
graph_search = RAGGraphSearch(graph_engine=self)
|
graph_search = RAGGraphSearch(graph_engine=self)
|
||||||
return graph_search.search(query)
|
return graph_search.search(query)
|
||||||
|
|
||||||
|
@ -4,6 +4,8 @@ from collections import defaultdict
|
|||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import List, Optional, Dict, Any, Set, Callable
|
from typing import List, Optional, Dict, Any, Set, Callable
|
||||||
|
|
||||||
|
from langchain.schema import Document
|
||||||
|
|
||||||
from pilot.graph_engine.node import BaseNode, TextNode, NodeWithScore
|
from pilot.graph_engine.node import BaseNode, TextNode, NodeWithScore
|
||||||
from pilot.graph_engine.search import BaseSearch, SearchMode
|
from pilot.graph_engine.search import BaseSearch, SearchMode
|
||||||
from pilot.utils import utils
|
from pilot.utils import utils
|
||||||
@ -67,14 +69,14 @@ class RAGGraphSearch(BaseSearch):
|
|||||||
logger.warn(f"can not to find graph schema: {e}")
|
logger.warn(f"can not to find graph schema: {e}")
|
||||||
self._graph_schema = ""
|
self._graph_schema = ""
|
||||||
|
|
||||||
def _extract_subject_entities(self, query_str: str) -> Set[str]:
|
async def _extract_subject_entities(self, query_str: str) -> Set[str]:
|
||||||
"""extract subject entities."""
|
"""extract subject entities."""
|
||||||
if self.extract_subject_entities_fn is not None:
|
if self.extract_subject_entities_fn is not None:
|
||||||
return self.extract_subject_entities_fn(query_str)
|
return await self.extract_subject_entities_fn(query_str)
|
||||||
else:
|
else:
|
||||||
return self._extract_entities_by_llm(query_str)
|
return await self._extract_entities_by_llm(query_str)
|
||||||
|
|
||||||
def _extract_entities_by_llm(self, text: str) -> Set[str]:
|
async def _extract_entities_by_llm(self, text: str) -> Set[str]:
|
||||||
"""extract subject entities from text by llm"""
|
"""extract subject entities from text by llm"""
|
||||||
from pilot.scene.base import ChatScene
|
from pilot.scene.base import ChatScene
|
||||||
from pilot.common.chat_util import llm_chat_response_nostream
|
from pilot.common.chat_util import llm_chat_response_nostream
|
||||||
@ -86,21 +88,23 @@ class RAGGraphSearch(BaseSearch):
|
|||||||
"select_param": "entity",
|
"select_param": "entity",
|
||||||
"model_name": self.model_name,
|
"model_name": self.model_name,
|
||||||
}
|
}
|
||||||
loop = utils.get_or_create_event_loop()
|
# loop = utils.get_or_create_event_loop()
|
||||||
entities = loop.run_until_complete(
|
# entities = loop.run_until_complete(
|
||||||
llm_chat_response_nostream(
|
# llm_chat_response_nostream(
|
||||||
|
# ChatScene.ExtractEntity.value(), **{"chat_param": chat_param}
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
return await llm_chat_response_nostream(
|
||||||
ChatScene.ExtractEntity.value(), **{"chat_param": chat_param}
|
ChatScene.ExtractEntity.value(), **{"chat_param": chat_param}
|
||||||
)
|
)
|
||||||
)
|
|
||||||
return entities
|
|
||||||
|
|
||||||
def _search(
|
async def _search(
|
||||||
self,
|
self,
|
||||||
query_str: str,
|
query_str: str,
|
||||||
) -> List[NodeWithScore]:
|
) -> List[Document]:
|
||||||
"""Get nodes for response."""
|
"""Get nodes for response."""
|
||||||
node_visited = set()
|
node_visited = set()
|
||||||
keywords = self._extract_subject_entities(query_str)
|
keywords = await self._extract_subject_entities(query_str)
|
||||||
print(f"extract entities: {keywords}\n")
|
print(f"extract entities: {keywords}\n")
|
||||||
rel_texts = []
|
rel_texts = []
|
||||||
cur_rel_map = {}
|
cur_rel_map = {}
|
||||||
@ -114,8 +118,8 @@ class RAGGraphSearch(BaseSearch):
|
|||||||
if node_id in node_visited:
|
if node_id in node_visited:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if self._include_text:
|
# if self._include_text:
|
||||||
chunk_indices_count[node_id] += 1
|
# chunk_indices_count[node_id] += 1
|
||||||
|
|
||||||
node_visited.add(node_id)
|
node_visited.add(node_id)
|
||||||
|
|
||||||
@ -179,8 +183,11 @@ class RAGGraphSearch(BaseSearch):
|
|||||||
sorted_nodes_with_scores.append(
|
sorted_nodes_with_scores.append(
|
||||||
NodeWithScore(node=rel_text_node, score=DEFAULT_NODE_SCORE)
|
NodeWithScore(node=rel_text_node, score=DEFAULT_NODE_SCORE)
|
||||||
)
|
)
|
||||||
|
docs = [
|
||||||
return sorted_nodes_with_scores
|
Document(page_content=node.text, metadata=node.metadata)
|
||||||
|
for node in sorted_nodes_with_scores
|
||||||
|
]
|
||||||
|
return docs
|
||||||
|
|
||||||
def _get_metadata_for_response(
|
def _get_metadata_for_response(
|
||||||
self, nodes: List[BaseNode]
|
self, nodes: List[BaseNode]
|
||||||
|
@ -21,6 +21,7 @@ WRAP_WIDTH = 70
|
|||||||
|
|
||||||
class BaseComponent(BaseModel):
|
class BaseComponent(BaseModel):
|
||||||
"""Base component object to caputure class names."""
|
"""Base component object to caputure class names."""
|
||||||
|
|
||||||
"""reference llama-index"""
|
"""reference llama-index"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -23,7 +23,7 @@ class SearchMode(str, Enum):
|
|||||||
class BaseSearch(ABC):
|
class BaseSearch(ABC):
|
||||||
"""Base Search."""
|
"""Base Search."""
|
||||||
|
|
||||||
def search(self, query: str):
|
async def search(self, query: str):
|
||||||
"""Retrieve nodes given query.
|
"""Retrieve nodes given query.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -32,10 +32,10 @@ class BaseSearch(ABC):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
# if isinstance(query, str):
|
# if isinstance(query, str):
|
||||||
return self._search(query)
|
return await self._search(query)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _search(self, query: str):
|
async def _search(self, query: str):
|
||||||
"""search nodes given query.
|
"""search nodes given query.
|
||||||
|
|
||||||
Implemented by the user.
|
Implemented by the user.
|
||||||
|
@ -105,8 +105,14 @@ class BaseChat(ABC):
|
|||||||
speak_to_user = prompt_define_response
|
speak_to_user = prompt_define_response
|
||||||
return speak_to_user
|
return speak_to_user
|
||||||
|
|
||||||
def __call_base(self):
|
async def __call_base(self):
|
||||||
input_values = self.generate_input_values()
|
import inspect
|
||||||
|
|
||||||
|
input_values = (
|
||||||
|
await self.generate_input_values()
|
||||||
|
if inspect.isawaitable(self.generate_input_values())
|
||||||
|
else self.generate_input_values()
|
||||||
|
)
|
||||||
### Chat sequence advance
|
### Chat sequence advance
|
||||||
self.current_message.chat_order = len(self.history_message) + 1
|
self.current_message.chat_order = len(self.history_message) + 1
|
||||||
self.current_message.add_user_message(self.current_user_input)
|
self.current_message.add_user_message(self.current_user_input)
|
||||||
@ -146,7 +152,7 @@ class BaseChat(ABC):
|
|||||||
|
|
||||||
async def stream_call(self):
|
async def stream_call(self):
|
||||||
# TODO Retry when server connection error
|
# TODO Retry when server connection error
|
||||||
payload = self.__call_base()
|
payload = await self.__call_base()
|
||||||
|
|
||||||
self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11
|
self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11
|
||||||
logger.info(f"Request: \n{payload}")
|
logger.info(f"Request: \n{payload}")
|
||||||
@ -234,7 +240,7 @@ class BaseChat(ABC):
|
|||||||
return self.current_ai_response()
|
return self.current_ai_response()
|
||||||
|
|
||||||
async def get_llm_response(self):
|
async def get_llm_response(self):
|
||||||
payload = self.__call_base()
|
payload = await self.__call_base()
|
||||||
logger.info(f"Request: \n{payload}")
|
logger.info(f"Request: \n{payload}")
|
||||||
ai_response_text = ""
|
ai_response_text = ""
|
||||||
try:
|
try:
|
||||||
|
@ -24,7 +24,7 @@ class ExtractEntity(BaseChat):
|
|||||||
self.user_input = chat_param["current_user_input"]
|
self.user_input = chat_param["current_user_input"]
|
||||||
self.extract_mode = chat_param["select_param"]
|
self.extract_mode = chat_param["select_param"]
|
||||||
|
|
||||||
def generate_input_values(self):
|
async def generate_input_values(self):
|
||||||
input_values = {
|
input_values = {
|
||||||
"text": self.user_input,
|
"text": self.user_input,
|
||||||
}
|
}
|
||||||
|
@ -24,7 +24,7 @@ class ExtractTriplet(BaseChat):
|
|||||||
self.user_input = chat_param["current_user_input"]
|
self.user_input = chat_param["current_user_input"]
|
||||||
self.extract_mode = chat_param["select_param"]
|
self.extract_mode = chat_param["select_param"]
|
||||||
|
|
||||||
def generate_input_values(self):
|
async def generate_input_values(self):
|
||||||
input_values = {
|
input_values = {
|
||||||
"text": self.user_input,
|
"text": self.user_input,
|
||||||
}
|
}
|
||||||
|
@ -64,7 +64,7 @@ class ChatKnowledge(BaseChat):
|
|||||||
self.prompt_template.template_is_strict = False
|
self.prompt_template.template_is_strict = False
|
||||||
|
|
||||||
async def stream_call(self):
|
async def stream_call(self):
|
||||||
input_values = self.generate_input_values()
|
input_values = await self.generate_input_values()
|
||||||
# Source of knowledge file
|
# Source of knowledge file
|
||||||
relations = input_values.get("relations")
|
relations = input_values.get("relations")
|
||||||
last_output = None
|
last_output = None
|
||||||
@ -84,14 +84,14 @@ class ChatKnowledge(BaseChat):
|
|||||||
)
|
)
|
||||||
yield last_output
|
yield last_output
|
||||||
|
|
||||||
def generate_input_values(self):
|
async def generate_input_values(self):
|
||||||
if self.space_context:
|
if self.space_context:
|
||||||
self.prompt_template.template_define = self.space_context["prompt"]["scene"]
|
self.prompt_template.template_define = self.space_context["prompt"]["scene"]
|
||||||
self.prompt_template.template = self.space_context["prompt"]["template"]
|
self.prompt_template.template = self.space_context["prompt"]["template"]
|
||||||
docs = self.rag_engine.search(query=self.current_user_input)
|
docs = await self.rag_engine.search(query=self.current_user_input)
|
||||||
docs = self.knowledge_embedding_client.similar_search(
|
# docs = self.knowledge_embedding_client.similar_search(
|
||||||
self.current_user_input, self.top_k
|
# self.current_user_input, self.top_k
|
||||||
)
|
# )
|
||||||
if not docs:
|
if not docs:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"you have no knowledge space, please add your knowledge space"
|
"you have no knowledge space, please add your knowledge space"
|
||||||
|
@ -261,9 +261,6 @@ class KnowledgeService:
|
|||||||
# docs = engine.search(
|
# docs = engine.search(
|
||||||
# "Comparing Curry and James in terms of their positions, playing styles, and achievements in the NBA"
|
# "Comparing Curry and James in terms of their positions, playing styles, and achievements in the NBA"
|
||||||
# )
|
# )
|
||||||
embedding_factory = CFG.SYSTEM_APP.get_component(
|
|
||||||
"embedding_factory", EmbeddingFactory
|
|
||||||
)
|
|
||||||
# update document status
|
# update document status
|
||||||
doc.status = SyncStatus.RUNNING.name
|
doc.status = SyncStatus.RUNNING.name
|
||||||
doc.chunk_size = len(chunk_docs)
|
doc.chunk_size = len(chunk_docs)
|
||||||
|
Loading…
Reference in New Issue
Block a user