mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-01 17:16:51 +00:00
GraphRAG: add new feature and bugfix (#2373)
Co-authored-by: 秉翟 <lyusonglin.lsl@antgroup.com>
This commit is contained in:
@@ -50,26 +50,6 @@ class SystemParameters:
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StorageVectorConfig(BaseParameters):
|
||||
type: str = field(
|
||||
default="Chroma",
|
||||
metadata={
|
||||
"help": _("default vector type"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StorageGraphConfig(BaseParameters):
|
||||
type: str = field(
|
||||
default="TuGraph",
|
||||
metadata={
|
||||
"help": _("default graph type"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StorageConfig(BaseParameters):
|
||||
vector: VectorStoreConfig = field(
|
||||
|
@@ -5,7 +5,10 @@ import os
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
from dbgpt.rag.transformer.keyword_extractor import KeywordExtractor
|
||||
from dbgpt.rag.transformer.simple_intent_translator import SimpleIntentTranslator
|
||||
from dbgpt.storage.graph_store.graph import Graph, MemoryGraph
|
||||
from dbgpt_ext.rag.transformer.local_text2gql import LocalText2GQL
|
||||
from dbgpt_ext.rag.transformer.text2gql import Text2GQL
|
||||
|
||||
from ...transformer.text_embedder import TextEmbedder
|
||||
from .base import GraphRetrieverBase
|
||||
@@ -81,10 +84,25 @@ class GraphRetriever(GraphRetrieverBase):
|
||||
if "TEXT_SEARCH_ENABLED" in os.environ
|
||||
else config.enable_text_search
|
||||
)
|
||||
text2gql_model_enabled = (
|
||||
os.environ["TEXT2GQL_MODEL_ENABLED"].lower() == "true"
|
||||
if "TEXT2GQL_MODEL_ENABLED" in os.environ
|
||||
else config.text2gql_model_enabled
|
||||
)
|
||||
text2gql_model_name = os.getenv(
|
||||
"TEXT2GQL_MODEL_NAME",
|
||||
config.text2gql_model_name,
|
||||
)
|
||||
|
||||
self._keyword_extractor = KeywordExtractor(llm_client, model_name)
|
||||
self._text_embedder = TextEmbedder(config.embedding_fn)
|
||||
|
||||
intent_interpreter = SimpleIntentTranslator(llm_client, model_name)
|
||||
if text2gql_model_enabled:
|
||||
text2gql = LocalText2GQL(text2gql_model_name)
|
||||
else:
|
||||
text2gql = Text2GQL(llm_client, model_name)
|
||||
|
||||
self._keyword_based_graph_retriever = KeywordBasedGraphRetriever(
|
||||
graph_store_adapter, triplet_topk
|
||||
)
|
||||
@@ -95,7 +113,10 @@ class GraphRetriever(GraphRetrieverBase):
|
||||
similarity_search_score_threshold,
|
||||
)
|
||||
self._text_based_graph_retriever = TextBasedGraphRetriever(
|
||||
graph_store_adapter, triplet_topk, llm_client, model_name
|
||||
graph_store_adapter,
|
||||
triplet_topk,
|
||||
intent_interpreter,
|
||||
text2gql,
|
||||
)
|
||||
self._document_graph_retriever = DocumentGraphRetriever(
|
||||
graph_store_adapter,
|
||||
|
@@ -4,10 +4,8 @@ import json
|
||||
import logging
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
from dbgpt.rag.transformer.simple_intent_translator import SimpleIntentTranslator
|
||||
from dbgpt.storage.graph_store.graph import Graph, MemoryGraph
|
||||
from dbgpt_ext.rag.retriever.graph_retriever.base import GraphRetrieverBase
|
||||
from dbgpt_ext.rag.transformer.text2gql import Text2GQL
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -15,12 +13,18 @@ logger = logging.getLogger(__name__)
|
||||
class TextBasedGraphRetriever(GraphRetrieverBase):
|
||||
"""Text Based Graph Retriever class."""
|
||||
|
||||
def __init__(self, graph_store_adapter, triplet_topk, llm_client, model_name):
|
||||
def __init__(
|
||||
self,
|
||||
graph_store_adapter,
|
||||
triplet_topk,
|
||||
intent_interpreter,
|
||||
text2gql,
|
||||
):
|
||||
"""Initialize Text Based Graph Retriever."""
|
||||
self._graph_store_adapter = graph_store_adapter
|
||||
self._triplet_topk = triplet_topk
|
||||
self._intent_interpreter = SimpleIntentTranslator(llm_client, model_name)
|
||||
self._text2gql = Text2GQL(llm_client, model_name)
|
||||
self._intent_interpreter = intent_interpreter
|
||||
self._text2gql = text2gql
|
||||
|
||||
async def retrieve(self, text: str) -> Tuple[Graph, str]:
|
||||
"""Retrieve from triplets graph with text2gql."""
|
||||
|
@@ -0,0 +1,70 @@
|
||||
"""LocalText2GQL class."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Dict, List, Union
|
||||
|
||||
from dbgpt.core import BaseMessage, HumanPromptTemplate
|
||||
from dbgpt.model.proxy.llms.ollama import OllamaLLMClient
|
||||
from dbgpt.rag.transformer.llm_translator import LLMTranslator
|
||||
|
||||
LOCAL_TEXT_TO_GQL_PT = """
|
||||
A question written in graph query language style is provided below. Given the question, translate the question into a cypher query that can be executed on the given knowledge graph. Make sure the syntax of the translated cypher query is correct.
|
||||
To help query generation, the schema of the knowledge graph is:
|
||||
{schema}
|
||||
---------------------
|
||||
Question: {question}
|
||||
Query:
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LocalText2GQL(LLMTranslator):
|
||||
"""LocalText2GQL class."""
|
||||
|
||||
def __init__(self, model_name: str):
|
||||
"""Initialize the LocalText2GQL."""
|
||||
super().__init__(OllamaLLMClient(), model_name, LOCAL_TEXT_TO_GQL_PT)
|
||||
|
||||
def _format_messages(self, text: str, history: str = None) -> List[BaseMessage]:
|
||||
# translate intention to gql with single prompt only.
|
||||
intention: Dict[str, Union[str, List[str]]] = json.loads(text)
|
||||
question = intention.get("rewritten_question", "")
|
||||
schema = intention.get("schema", "")
|
||||
|
||||
template = HumanPromptTemplate.from_template(self._prompt_template)
|
||||
|
||||
messages = (
|
||||
template.format_messages(
|
||||
schema=schema,
|
||||
question=question,
|
||||
history=history,
|
||||
)
|
||||
if history is not None
|
||||
else template.format_messages(
|
||||
schema=schema,
|
||||
question=question,
|
||||
)
|
||||
)
|
||||
|
||||
return messages
|
||||
|
||||
def _parse_response(self, text: str) -> Dict:
|
||||
"""Parse llm response."""
|
||||
translation: Dict[str, str] = {}
|
||||
query = ""
|
||||
|
||||
code_block_pattern = re.compile(r"```cypher(.*?)```", re.S)
|
||||
|
||||
result = re.findall(code_block_pattern, text)
|
||||
if result:
|
||||
query = result[0]
|
||||
else:
|
||||
query = text
|
||||
|
||||
translation["query"] = query.strip()
|
||||
|
||||
return translation
|
@@ -382,9 +382,9 @@ class TuGraphStoreAdapter(GraphStoreAdapter):
|
||||
GraphElemType.ENTITY.value,
|
||||
)
|
||||
|
||||
def delete_document(self, chunk_ids: str) -> None:
|
||||
def delete_document(self, chunk_id: str) -> None:
|
||||
"""Delete document in the graph."""
|
||||
chunkids_list = [uuid.strip() for uuid in chunk_ids.split(",")]
|
||||
chunkids_list = [uuid.strip() for uuid in chunk_id.split(",")]
|
||||
del_chunk_gql = (
|
||||
f"MATCH(m:{GraphElemType.DOCUMENT.value})-[r]->"
|
||||
f"(n:{GraphElemType.CHUNK.value}) WHERE n.id IN {chunkids_list} DELETE n"
|
||||
|
@@ -210,6 +210,14 @@ class CommunitySummaryKnowledgeGraphConfig(BuiltinKnowledgeGraphConfig):
|
||||
default=False,
|
||||
description="Enable text2gql search or not.",
|
||||
)
|
||||
text2gql_model_enabled: bool = Field(
|
||||
default=False,
|
||||
description="Enable fine-tuned text2gql model for text2gql translation.",
|
||||
)
|
||||
text2gql_model_name: str = Field(
|
||||
default=None,
|
||||
description="LLM Model for text2gql translation.",
|
||||
)
|
||||
|
||||
|
||||
@register_resource(
|
||||
|
@@ -3,8 +3,9 @@
|
||||
import logging
|
||||
from typing import Tuple, Type
|
||||
|
||||
from dbgpt.storage import vector_store
|
||||
from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig
|
||||
from dbgpt_ext.storage import __vector_store__ as vector_store_list
|
||||
from dbgpt_ext.storage import _select_rag_storage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -35,9 +36,9 @@ class VectorStoreFactory:
|
||||
|
||||
@staticmethod
|
||||
def __find_type(vector_store_type: str) -> Tuple[Type, Type]:
|
||||
for t in vector_store.__vector_store__:
|
||||
for t in vector_store_list:
|
||||
if t.lower() == vector_store_type.lower():
|
||||
store_cls, cfg_cls = getattr(vector_store, t)
|
||||
store_cls, cfg_cls = _select_rag_storage(t)
|
||||
if issubclass(store_cls, VectorStoreBase) and issubclass(
|
||||
cfg_cls, VectorStoreConfig
|
||||
):
|
||||
|
Reference in New Issue
Block a user