mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-12 20:53:48 +00:00
feat: add GraphRAG framework and integrate TuGraph (#1506)
Co-authored-by: KingSkyLi <15566300566@163.com> Co-authored-by: aries_ckt <916701291@qq.com> Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
"""Base Assembler."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Optional
|
||||
|
||||
@@ -37,13 +38,15 @@ class BaseAssembler(ABC):
|
||||
)
|
||||
self._chunks: List[Chunk] = []
|
||||
metadata = {
|
||||
"knowledge_cls": self._knowledge.__class__.__name__
|
||||
if self._knowledge
|
||||
else None,
|
||||
"knowledge_cls": (
|
||||
self._knowledge.__class__.__name__ if self._knowledge else None
|
||||
),
|
||||
"knowledge_type": self._knowledge.type().value if self._knowledge else None,
|
||||
"path": self._knowledge._path
|
||||
if self._knowledge and hasattr(self._knowledge, "_path")
|
||||
else None,
|
||||
"path": (
|
||||
self._knowledge._path
|
||||
if self._knowledge and hasattr(self._knowledge, "_path")
|
||||
else None
|
||||
),
|
||||
"chunk_parameters": self._chunk_parameters.dict(),
|
||||
}
|
||||
with root_tracer.start_span("BaseAssembler.load_knowledge", metadata=metadata):
|
||||
@@ -70,6 +73,14 @@ class BaseAssembler(ABC):
|
||||
List[str]: List of persisted chunk ids.
|
||||
"""
|
||||
|
||||
async def apersist(self) -> List[str]:
|
||||
"""Persist chunks.
|
||||
|
||||
Returns:
|
||||
List[str]: List of persisted chunk ids.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_chunks(self) -> List[Chunk]:
|
||||
"""Return chunks."""
|
||||
return self._chunks
|
||||
|
@@ -106,6 +106,14 @@ class EmbeddingAssembler(BaseAssembler):
|
||||
"""
|
||||
return self._vector_store_connector.load_document(self._chunks)
|
||||
|
||||
async def apersist(self) -> List[str]:
|
||||
"""Persist chunks into store.
|
||||
|
||||
Returns:
|
||||
List[str]: List of chunk ids.
|
||||
"""
|
||||
return await self._vector_store_connector.aload_document(self._chunks)
|
||||
|
||||
def _extract_info(self, chunks) -> List[Chunk]:
|
||||
"""Extract info from chunks."""
|
||||
return []
|
||||
|
1
dbgpt/rag/index/__init__.py
Normal file
1
dbgpt/rag/index/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Module for index."""
|
168
dbgpt/rag/index/base.py
Normal file
168
dbgpt/rag/index/base.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""Index store base class."""
|
||||
import logging
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_dict
|
||||
from dbgpt.core import Chunk, Embeddings
|
||||
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IndexStoreConfig(BaseModel):
|
||||
"""Index store config."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
|
||||
|
||||
name: str = Field(
|
||||
default="dbgpt_collection",
|
||||
description="The name of index store, if not set, will use the default name.",
|
||||
)
|
||||
embedding_fn: Optional[Embeddings] = Field(
|
||||
default=None,
|
||||
description="The embedding function of vector store, if not set, will use the "
|
||||
"default embedding function.",
|
||||
)
|
||||
max_chunks_once_load: int = Field(
|
||||
default=10,
|
||||
description="The max number of chunks to load at once. If your document is "
|
||||
"large, you can set this value to a larger number to speed up the loading "
|
||||
"process. Default is 10.",
|
||||
)
|
||||
max_threads: int = Field(
|
||||
default=1,
|
||||
description="The max number of threads to use. Default is 1. If you set this "
|
||||
"bigger than 1, please make sure your vector store is thread-safe.",
|
||||
)
|
||||
|
||||
def to_dict(self, **kwargs) -> Dict[str, Any]:
|
||||
"""Convert to dict."""
|
||||
return model_to_dict(self, **kwargs)
|
||||
|
||||
|
||||
class IndexStoreBase(ABC):
|
||||
"""Index store base class."""
|
||||
|
||||
@abstractmethod
|
||||
def load_document(self, chunks: List[Chunk]) -> List[str]:
|
||||
"""Load document in index database.
|
||||
|
||||
Args:
|
||||
chunks(List[Chunk]): document chunks.
|
||||
|
||||
Return:
|
||||
List[str]: chunk ids.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def aload_document(self, chunks: List[Chunk]) -> List[str]:
|
||||
"""Load document in index database.
|
||||
|
||||
Args:
|
||||
chunks(List[Chunk]): document chunks.
|
||||
|
||||
Return:
|
||||
List[str]: chunk ids.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def similar_search_with_scores(
|
||||
self,
|
||||
text,
|
||||
topk,
|
||||
score_threshold: float,
|
||||
filters: Optional[MetadataFilters] = None,
|
||||
) -> List[Chunk]:
|
||||
"""Similar search with scores in index database.
|
||||
|
||||
Args:
|
||||
text(str): The query text.
|
||||
topk(int): The number of similar documents to return.
|
||||
score_threshold(int): score_threshold: Optional, a floating point value
|
||||
between 0 to 1
|
||||
filters(Optional[MetadataFilters]): metadata filters.
|
||||
Return:
|
||||
List[Chunk]: The similar documents.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_ids(self, ids: str):
|
||||
"""Delete docs.
|
||||
|
||||
Args:
|
||||
ids(str): The vector ids to delete, separated by comma.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def delete_vector_name(self, index_name: str):
|
||||
"""Delete index by name.
|
||||
|
||||
Args:
|
||||
index_name(str): The name of index to delete.
|
||||
"""
|
||||
|
||||
def load_document_with_limit(
|
||||
self, chunks: List[Chunk], max_chunks_once_load: int = 10, max_threads: int = 1
|
||||
) -> List[str]:
|
||||
"""Load document in index database with specified limit.
|
||||
|
||||
Args:
|
||||
chunks(List[Chunk]): Document chunks.
|
||||
max_chunks_once_load(int): Max number of chunks to load at once.
|
||||
max_threads(int): Max number of threads to use.
|
||||
|
||||
Return:
|
||||
List[str]: Chunk ids.
|
||||
"""
|
||||
# Group the chunks into chunks of size max_chunks
|
||||
chunk_groups = [
|
||||
chunks[i : i + max_chunks_once_load]
|
||||
for i in range(0, len(chunks), max_chunks_once_load)
|
||||
]
|
||||
logger.info(
|
||||
f"Loading {len(chunks)} chunks in {len(chunk_groups)} groups with "
|
||||
f"{max_threads} threads."
|
||||
)
|
||||
ids = []
|
||||
loaded_cnt = 0
|
||||
start_time = time.time()
|
||||
with ThreadPoolExecutor(max_workers=max_threads) as executor:
|
||||
tasks = []
|
||||
for chunk_group in chunk_groups:
|
||||
tasks.append(executor.submit(self.load_document, chunk_group))
|
||||
for future in tasks:
|
||||
success_ids = future.result()
|
||||
ids.extend(success_ids)
|
||||
loaded_cnt += len(success_ids)
|
||||
logger.info(f"Loaded {loaded_cnt} chunks, total {len(chunks)} chunks.")
|
||||
logger.info(
|
||||
f"Loaded {len(chunks)} chunks in {time.time() - start_time} seconds"
|
||||
)
|
||||
return ids
|
||||
|
||||
def similar_search(
|
||||
self, text: str, topk: int, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
"""Similar search in index database.
|
||||
|
||||
Args:
|
||||
text(str): The query text.
|
||||
topk(int): The number of similar documents to return.
|
||||
filters(Optional[MetadataFilters]): metadata filters.
|
||||
Return:
|
||||
List[Chunk]: The similar documents.
|
||||
"""
|
||||
return self.similar_search_with_scores(text, topk, 1.0, filters)
|
||||
|
||||
async def asimilar_search_with_scores(
|
||||
self,
|
||||
doc: str,
|
||||
topk: int,
|
||||
score_threshold: float,
|
||||
filters: Optional[MetadataFilters] = None,
|
||||
) -> List[Chunk]:
|
||||
"""Aynsc similar_search_with_score in vector database."""
|
||||
return self.similar_search_with_scores(doc, topk, score_threshold, filters)
|
@@ -229,6 +229,6 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
self, query, score_threshold, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
"""Similar search with score."""
|
||||
return self._vector_store_connector.similar_search_with_scores(
|
||||
return await self._vector_store_connector.asimilar_search_with_scores(
|
||||
query, self._top_k, score_threshold, filters
|
||||
)
|
||||
|
1
dbgpt/rag/transformer/__init__.py
Normal file
1
dbgpt/rag/transformer/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Module for transformer."""
|
26
dbgpt/rag/transformer/base.py
Normal file
26
dbgpt/rag/transformer/base.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""Transformer base class."""
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TransformerBase:
|
||||
"""Transformer base class."""
|
||||
|
||||
|
||||
class EmbedderBase(TransformerBase, ABC):
|
||||
"""Embedder base class."""
|
||||
|
||||
|
||||
class ExtractorBase(TransformerBase, ABC):
|
||||
"""Extractor base class."""
|
||||
|
||||
@abstractmethod
|
||||
async def extract(self, text: str, limit: Optional[int] = None) -> List:
|
||||
"""Extract results from text."""
|
||||
|
||||
|
||||
class TranslatorBase(TransformerBase, ABC):
|
||||
"""Translator base class."""
|
50
dbgpt/rag/transformer/keyword_extractor.py
Normal file
50
dbgpt/rag/transformer/keyword_extractor.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""KeywordExtractor class."""
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt.core import LLMClient
|
||||
from dbgpt.rag.transformer.llm_extractor import LLMExtractor
|
||||
|
||||
KEYWORD_EXTRACT_PT = (
|
||||
"A question is provided below. Given the question, extract up to "
|
||||
"keywords from the text. Focus on extracting the keywords that we can use "
|
||||
"to best lookup answers to the question.\n"
|
||||
"Generate as more as possible synonyms or alias of the keywords "
|
||||
"considering possible cases of capitalization, pluralization, "
|
||||
"common expressions, etc.\n"
|
||||
"Avoid stopwords.\n"
|
||||
"Provide the keywords and synonyms in comma-separated format."
|
||||
"Formatted keywords and synonyms text should be separated by a semicolon.\n"
|
||||
"---------------------\n"
|
||||
"Example:\n"
|
||||
"Text: Alice is Bob's mother.\n"
|
||||
"Keywords:\nAlice,mother,Bob;mummy\n"
|
||||
"Text: Philz is a coffee shop founded in Berkeley in 1982.\n"
|
||||
"Keywords:\nPhilz,coffee shop,Berkeley,1982;coffee bar,coffee house\n"
|
||||
"---------------------\n"
|
||||
"Text: {text}\n"
|
||||
"Keywords:\n"
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class KeywordExtractor(LLMExtractor):
|
||||
"""KeywordExtractor class."""
|
||||
|
||||
def __init__(self, llm_client: LLMClient, model_name: str):
|
||||
"""Initialize the KeywordExtractor."""
|
||||
super().__init__(llm_client, model_name, KEYWORD_EXTRACT_PT)
|
||||
|
||||
def _parse_response(self, text: str, limit: Optional[int] = None) -> List[str]:
|
||||
keywords = set()
|
||||
|
||||
for part in text.split(";"):
|
||||
for s in part.strip().split(","):
|
||||
keyword = s.strip()
|
||||
if keyword:
|
||||
keywords.add(keyword)
|
||||
if limit and len(keywords) >= limit:
|
||||
return list(keywords)
|
||||
|
||||
return list(keywords)
|
50
dbgpt/rag/transformer/llm_extractor.py
Normal file
50
dbgpt/rag/transformer/llm_extractor.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""TripletExtractor class."""
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt.core import HumanPromptTemplate, LLMClient, ModelMessage, ModelRequest
|
||||
from dbgpt.rag.transformer.base import ExtractorBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMExtractor(ExtractorBase, ABC):
|
||||
"""LLMExtractor class."""
|
||||
|
||||
def __init__(self, llm_client: LLMClient, model_name: str, prompt_template: str):
|
||||
"""Initialize the LLMExtractor."""
|
||||
self._llm_client = llm_client
|
||||
self._model_name = model_name
|
||||
self._prompt_template = prompt_template
|
||||
|
||||
async def extract(self, text: str, limit: Optional[int] = None) -> List:
|
||||
"""Extract by LLm."""
|
||||
template = HumanPromptTemplate.from_template(self._prompt_template)
|
||||
messages = template.format_messages(text=text)
|
||||
|
||||
# use default model if needed
|
||||
if not self._model_name:
|
||||
models = await self._llm_client.models()
|
||||
if not models:
|
||||
raise Exception("No models available")
|
||||
self._model_name = models[0].model
|
||||
logger.info(f"Using model {self._model_name} to extract")
|
||||
|
||||
model_messages = ModelMessage.from_base_messages(messages)
|
||||
request = ModelRequest(model=self._model_name, messages=model_messages)
|
||||
response = await self._llm_client.generate(request=request)
|
||||
|
||||
if not response.success:
|
||||
code = str(response.error_code)
|
||||
reason = response.text
|
||||
logger.error(f"request llm failed ({code}) {reason}")
|
||||
return []
|
||||
|
||||
if limit and limit < 1:
|
||||
ValueError("optional argument limit >= 1")
|
||||
return self._parse_response(response.text, limit)
|
||||
|
||||
@abstractmethod
|
||||
def _parse_response(self, text: str, limit: Optional[int] = None) -> List:
|
||||
"""Parse llm response."""
|
10
dbgpt/rag/transformer/text2cypher.py
Normal file
10
dbgpt/rag/transformer/text2cypher.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""Text2Cypher class."""
|
||||
import logging
|
||||
|
||||
from dbgpt.rag.transformer.base import TranslatorBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Text2Cypher(TranslatorBase):
|
||||
"""Text2Cypher class."""
|
10
dbgpt/rag/transformer/text2gql.py
Normal file
10
dbgpt/rag/transformer/text2gql.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""Text2GQL class."""
|
||||
import logging
|
||||
|
||||
from dbgpt.rag.transformer.base import TranslatorBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Text2GQL(TranslatorBase):
|
||||
"""Text2GQL class."""
|
10
dbgpt/rag/transformer/text2vector.py
Normal file
10
dbgpt/rag/transformer/text2vector.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""Text2Vector class."""
|
||||
import logging
|
||||
|
||||
from dbgpt.rag.transformer.base import EmbedderBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Text2Vector(EmbedderBase):
|
||||
"""Text2Vector class."""
|
71
dbgpt/rag/transformer/triplet_extractor.py
Normal file
71
dbgpt/rag/transformer/triplet_extractor.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""TripletExtractor class."""
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
from dbgpt.core import LLMClient
|
||||
from dbgpt.rag.transformer.llm_extractor import LLMExtractor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TRIPLET_EXTRACT_PT = (
|
||||
"Some text is provided below. Given the text, "
|
||||
"extract up to knowledge triplets as more as possible "
|
||||
"in the form of (subject, predicate, object).\n"
|
||||
"Avoid stopwords.\n"
|
||||
"---------------------\n"
|
||||
"Example:\n"
|
||||
"Text: Alice is Bob's mother.\n"
|
||||
"Triplets:\n(Alice, is mother of, Bob)\n"
|
||||
"Text: Alice has 2 apples.\n"
|
||||
"Triplets:\n(Alice, has 2, apple)\n"
|
||||
"Text: Alice was given 1 apple by Bob.\n"
|
||||
"Triplets:(Bob, gives 1 apple, Bob)\n"
|
||||
"Text: Alice was pushed by Bob.\n"
|
||||
"Triplets:(Bob, pushes, Alice)\n"
|
||||
"Text: Bob's mother Alice has 2 apples.\n"
|
||||
"Triplets:\n(Alice, is mother of, Bob)\n(Alice, has 2, apple)\n"
|
||||
"Text: A Big monkey climbed up the tall fruit tree and picked 3 peaches.\n"
|
||||
"Triplets:\n(monkey, climbed up, fruit tree)\n(monkey, picked 3, peach)\n"
|
||||
"Text: Alice has 2 apples, she gives 1 to Bob.\n"
|
||||
"Triplets:\n"
|
||||
"(Alice, has 2, apple)\n(Alice, gives 1 apple, Bob)\n"
|
||||
"Text: Philz is a coffee shop founded in Berkeley in 1982.\n"
|
||||
"Triplets:\n"
|
||||
"(Philz, is, coffee shop)\n(Philz, founded in, Berkeley)\n"
|
||||
"(Philz, founded in, 1982)\n"
|
||||
"---------------------\n"
|
||||
"Text: {text}\n"
|
||||
"Triplets:\n"
|
||||
)
|
||||
|
||||
|
||||
class TripletExtractor(LLMExtractor):
|
||||
"""TripletExtractor class."""
|
||||
|
||||
def __init__(self, llm_client: LLMClient, model_name: str):
|
||||
"""Initialize the TripletExtractor."""
|
||||
super().__init__(llm_client, model_name, TRIPLET_EXTRACT_PT)
|
||||
|
||||
def _parse_response(
|
||||
self, text: str, limit: Optional[int] = None
|
||||
) -> List[Tuple[Any, ...]]:
|
||||
triplets = []
|
||||
|
||||
for line in text.split("\n"):
|
||||
for match in re.findall(r"\((.*?)\)", line):
|
||||
splits = match.split(",")
|
||||
parts = [split.strip() for split in splits if split.strip()]
|
||||
if len(parts) == 3:
|
||||
parts = [
|
||||
p.strip(
|
||||
"`~!@#$%^&*()-=+[]\\{}|;':\",./<>?"
|
||||
"·!¥&*()—【】、「」;‘’:“”,。、《》?"
|
||||
)
|
||||
for p in parts
|
||||
]
|
||||
triplets.append(tuple(parts))
|
||||
if limit and len(triplets) >= limit:
|
||||
return triplets
|
||||
|
||||
return triplets
|
Reference in New Issue
Block a user