From 55daa31dd96fb027cb98cd00dd4e1b535c673246 Mon Sep 17 00:00:00 2001 From: SonglinLyu <111941624+SonglinLyu@users.noreply.github.com> Date: Thu, 27 Feb 2025 21:56:19 +0800 Subject: [PATCH] GraphRAG: add new feature and bugfix (#2373) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: 秉翟 --- packages/dbgpt-app/src/dbgpt_app/config.py | 20 ------ .../graph_retriever/graph_retriever.py | 23 +++++- .../text_based_graph_retriever.py | 14 ++-- .../rag/transformer/local_text2gql.py | 70 +++++++++++++++++++ .../dbgpt_ext/rag/transformer/text2cypher.py | 0 .../dbgpt_ext/rag/transformer/text2vector.py | 0 .../community/tugraph_store_adapter.py | 4 +- .../knowledge_graph/community_summary.py | 8 +++ .../dbgpt_ext/storage/vector_store/factory.py | 7 +- 9 files changed, 115 insertions(+), 31 deletions(-) create mode 100644 packages/dbgpt-ext/src/dbgpt_ext/rag/transformer/local_text2gql.py delete mode 100644 packages/dbgpt-ext/src/dbgpt_ext/rag/transformer/text2cypher.py delete mode 100644 packages/dbgpt-ext/src/dbgpt_ext/rag/transformer/text2vector.py diff --git a/packages/dbgpt-app/src/dbgpt_app/config.py b/packages/dbgpt-app/src/dbgpt_app/config.py index 9c63fe6e9..31dc6a6b2 100644 --- a/packages/dbgpt-app/src/dbgpt_app/config.py +++ b/packages/dbgpt-app/src/dbgpt_app/config.py @@ -50,26 +50,6 @@ class SystemParameters: ) -@dataclass -class StorageVectorConfig(BaseParameters): - type: str = field( - default="Chroma", - metadata={ - "help": _("default vector type"), - }, - ) - - -@dataclass -class StorageGraphConfig(BaseParameters): - type: str = field( - default="TuGraph", - metadata={ - "help": _("default graph type"), - }, - ) - - @dataclass class StorageConfig(BaseParameters): vector: VectorStoreConfig = field( diff --git a/packages/dbgpt-ext/src/dbgpt_ext/rag/retriever/graph_retriever/graph_retriever.py b/packages/dbgpt-ext/src/dbgpt_ext/rag/retriever/graph_retriever/graph_retriever.py index c80868f6d..8709c644a 100644 --- a/packages/dbgpt-ext/src/dbgpt_ext/rag/retriever/graph_retriever/graph_retriever.py +++ b/packages/dbgpt-ext/src/dbgpt_ext/rag/retriever/graph_retriever/graph_retriever.py @@ -5,7 +5,10 @@ import os from typing import List, Tuple, Union from dbgpt.rag.transformer.keyword_extractor import KeywordExtractor +from dbgpt.rag.transformer.simple_intent_translator import SimpleIntentTranslator from dbgpt.storage.graph_store.graph import Graph, MemoryGraph +from dbgpt_ext.rag.transformer.local_text2gql import LocalText2GQL +from dbgpt_ext.rag.transformer.text2gql import Text2GQL from ...transformer.text_embedder import TextEmbedder from .base import GraphRetrieverBase @@ -81,10 +84,25 @@ class GraphRetriever(GraphRetrieverBase): if "TEXT_SEARCH_ENABLED" in os.environ else config.enable_text_search ) + text2gql_model_enabled = ( + os.environ["TEXT2GQL_MODEL_ENABLED"].lower() == "true" + if "TEXT2GQL_MODEL_ENABLED" in os.environ + else config.text2gql_model_enabled + ) + text2gql_model_name = os.getenv( + "TEXT2GQL_MODEL_NAME", + config.text2gql_model_name, + ) self._keyword_extractor = KeywordExtractor(llm_client, model_name) self._text_embedder = TextEmbedder(config.embedding_fn) + intent_interpreter = SimpleIntentTranslator(llm_client, model_name) + if text2gql_model_enabled: + text2gql = LocalText2GQL(text2gql_model_name) + else: + text2gql = Text2GQL(llm_client, model_name) + self._keyword_based_graph_retriever = KeywordBasedGraphRetriever( graph_store_adapter, triplet_topk ) @@ -95,7 +113,10 @@ class GraphRetriever(GraphRetrieverBase): similarity_search_score_threshold, ) self._text_based_graph_retriever = TextBasedGraphRetriever( - graph_store_adapter, triplet_topk, llm_client, model_name + graph_store_adapter, + triplet_topk, + intent_interpreter, + text2gql, ) self._document_graph_retriever = DocumentGraphRetriever( graph_store_adapter, diff --git a/packages/dbgpt-ext/src/dbgpt_ext/rag/retriever/graph_retriever/text_based_graph_retriever.py b/packages/dbgpt-ext/src/dbgpt_ext/rag/retriever/graph_retriever/text_based_graph_retriever.py index 1af80283c..c4754b003 100644 --- a/packages/dbgpt-ext/src/dbgpt_ext/rag/retriever/graph_retriever/text_based_graph_retriever.py +++ b/packages/dbgpt-ext/src/dbgpt_ext/rag/retriever/graph_retriever/text_based_graph_retriever.py @@ -4,10 +4,8 @@ import json import logging from typing import Dict, List, Tuple, Union -from dbgpt.rag.transformer.simple_intent_translator import SimpleIntentTranslator from dbgpt.storage.graph_store.graph import Graph, MemoryGraph from dbgpt_ext.rag.retriever.graph_retriever.base import GraphRetrieverBase -from dbgpt_ext.rag.transformer.text2gql import Text2GQL logger = logging.getLogger(__name__) @@ -15,12 +13,18 @@ logger = logging.getLogger(__name__) class TextBasedGraphRetriever(GraphRetrieverBase): """Text Based Graph Retriever class.""" - def __init__(self, graph_store_adapter, triplet_topk, llm_client, model_name): + def __init__( + self, + graph_store_adapter, + triplet_topk, + intent_interpreter, + text2gql, + ): """Initialize Text Based Graph Retriever.""" self._graph_store_adapter = graph_store_adapter self._triplet_topk = triplet_topk - self._intent_interpreter = SimpleIntentTranslator(llm_client, model_name) - self._text2gql = Text2GQL(llm_client, model_name) + self._intent_interpreter = intent_interpreter + self._text2gql = text2gql async def retrieve(self, text: str) -> Tuple[Graph, str]: """Retrieve from triplets graph with text2gql.""" diff --git a/packages/dbgpt-ext/src/dbgpt_ext/rag/transformer/local_text2gql.py b/packages/dbgpt-ext/src/dbgpt_ext/rag/transformer/local_text2gql.py new file mode 100644 index 000000000..db3dc08b4 --- /dev/null +++ b/packages/dbgpt-ext/src/dbgpt_ext/rag/transformer/local_text2gql.py @@ -0,0 +1,70 @@ +"""LocalText2GQL class.""" + +import json +import logging +import re +from typing import Dict, List, Union + +from dbgpt.core import BaseMessage, HumanPromptTemplate +from dbgpt.model.proxy.llms.ollama import OllamaLLMClient +from dbgpt.rag.transformer.llm_translator import LLMTranslator + +LOCAL_TEXT_TO_GQL_PT = """ +A question written in graph query language style is provided below. Given the question, translate the question into a cypher query that can be executed on the given knowledge graph. Make sure the syntax of the translated cypher query is correct. +To help query generation, the schema of the knowledge graph is: +{schema} +--------------------- +Question: {question} +Query: + +""" # noqa: E501 + +logger = logging.getLogger(__name__) + + +class LocalText2GQL(LLMTranslator): + """LocalText2GQL class.""" + + def __init__(self, model_name: str): + """Initialize the LocalText2GQL.""" + super().__init__(OllamaLLMClient(), model_name, LOCAL_TEXT_TO_GQL_PT) + + def _format_messages(self, text: str, history: str = None) -> List[BaseMessage]: + # translate intention to gql with single prompt only. + intention: Dict[str, Union[str, List[str]]] = json.loads(text) + question = intention.get("rewritten_question", "") + schema = intention.get("schema", "") + + template = HumanPromptTemplate.from_template(self._prompt_template) + + messages = ( + template.format_messages( + schema=schema, + question=question, + history=history, + ) + if history is not None + else template.format_messages( + schema=schema, + question=question, + ) + ) + + return messages + + def _parse_response(self, text: str) -> Dict: + """Parse llm response.""" + translation: Dict[str, str] = {} + query = "" + + code_block_pattern = re.compile(r"```cypher(.*?)```", re.S) + + result = re.findall(code_block_pattern, text) + if result: + query = result[0] + else: + query = text + + translation["query"] = query.strip() + + return translation diff --git a/packages/dbgpt-ext/src/dbgpt_ext/rag/transformer/text2cypher.py b/packages/dbgpt-ext/src/dbgpt_ext/rag/transformer/text2cypher.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/packages/dbgpt-ext/src/dbgpt_ext/rag/transformer/text2vector.py b/packages/dbgpt-ext/src/dbgpt_ext/rag/transformer/text2vector.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/packages/dbgpt-ext/src/dbgpt_ext/storage/knowledge_graph/community/tugraph_store_adapter.py b/packages/dbgpt-ext/src/dbgpt_ext/storage/knowledge_graph/community/tugraph_store_adapter.py index 3747cb1d5..56d274cc8 100644 --- a/packages/dbgpt-ext/src/dbgpt_ext/storage/knowledge_graph/community/tugraph_store_adapter.py +++ b/packages/dbgpt-ext/src/dbgpt_ext/storage/knowledge_graph/community/tugraph_store_adapter.py @@ -382,9 +382,9 @@ class TuGraphStoreAdapter(GraphStoreAdapter): GraphElemType.ENTITY.value, ) - def delete_document(self, chunk_ids: str) -> None: + def delete_document(self, chunk_id: str) -> None: """Delete document in the graph.""" - chunkids_list = [uuid.strip() for uuid in chunk_ids.split(",")] + chunkids_list = [uuid.strip() for uuid in chunk_id.split(",")] del_chunk_gql = ( f"MATCH(m:{GraphElemType.DOCUMENT.value})-[r]->" f"(n:{GraphElemType.CHUNK.value}) WHERE n.id IN {chunkids_list} DELETE n" diff --git a/packages/dbgpt-ext/src/dbgpt_ext/storage/knowledge_graph/community_summary.py b/packages/dbgpt-ext/src/dbgpt_ext/storage/knowledge_graph/community_summary.py index 7df2f2f2b..540605645 100644 --- a/packages/dbgpt-ext/src/dbgpt_ext/storage/knowledge_graph/community_summary.py +++ b/packages/dbgpt-ext/src/dbgpt_ext/storage/knowledge_graph/community_summary.py @@ -210,6 +210,14 @@ class CommunitySummaryKnowledgeGraphConfig(BuiltinKnowledgeGraphConfig): default=False, description="Enable text2gql search or not.", ) + text2gql_model_enabled: bool = Field( + default=False, + description="Enable fine-tuned text2gql model for text2gql translation.", + ) + text2gql_model_name: str = Field( + default=None, + description="LLM Model for text2gql translation.", + ) @register_resource( diff --git a/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/factory.py b/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/factory.py index 65e2541a4..fac690884 100644 --- a/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/factory.py +++ b/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/factory.py @@ -3,8 +3,9 @@ import logging from typing import Tuple, Type -from dbgpt.storage import vector_store from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig +from dbgpt_ext.storage import __vector_store__ as vector_store_list +from dbgpt_ext.storage import _select_rag_storage logger = logging.getLogger(__name__) @@ -35,9 +36,9 @@ class VectorStoreFactory: @staticmethod def __find_type(vector_store_type: str) -> Tuple[Type, Type]: - for t in vector_store.__vector_store__: + for t in vector_store_list: if t.lower() == vector_store_type.lower(): - store_cls, cfg_cls = getattr(vector_store, t) + store_cls, cfg_cls = _select_rag_storage(t) if issubclass(store_cls, VectorStoreBase) and issubclass( cfg_cls, VectorStoreConfig ):