Feat rdb summary wide table (#2035)

Co-authored-by: dongzhancai1 <dongzhancai1@jd.com>
Co-authored-by: dong <dongzhancai@iie2.com>
This commit is contained in:
Cooper
2024-12-18 20:34:21 +08:00
committed by GitHub
parent 7f4b5e79cf
commit 9b0161e521
17 changed files with 948 additions and 243 deletions

View File

@@ -264,6 +264,9 @@ class Config(metaclass=Singleton):
# EMBEDDING Configuration
self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec")
self.EMBEDDING_MODEL_MAX_SEQ_LEN = int(
os.getenv("MEMBEDDING_MODEL_MAX_SEQ_LEN", 512)
)
# Rerank model configuration
self.RERANK_MODEL = os.getenv("RERANK_MODEL")
self.RERANK_MODEL_PATH = os.getenv("RERANK_MODEL_PATH")

View File

@@ -55,9 +55,6 @@ class ChatWithDbQA(BaseChat):
if self.db_name:
client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
try:
# table_infos = client.get_db_summary(
# dbname=self.db_name, query=self.current_user_input, topk=self.top_k
# )
table_infos = await blocking_func_to_async(
self._executor,
client.get_db_summary,

View File

@@ -1,12 +1,15 @@
"""DBSchemaAssembler."""
import os
from typing import Any, List, Optional
from dbgpt.core import Chunk
from dbgpt.core import Chunk, Embeddings
from dbgpt.datasource.base import BaseConnector
from ...serve.rag.connector import VectorStoreConnector
from ...storage.vector_store.base import VectorStoreConfig
from ..assembler.base import BaseAssembler
from ..chunk_manager import ChunkParameters
from ..index.base import IndexStoreBase
from ..embedding.embedding_factory import DefaultEmbeddingFactory
from ..knowledge.datasource import DatasourceKnowledge
from ..retriever.db_schema import DBSchemaRetriever
@@ -35,23 +38,64 @@ class DBSchemaAssembler(BaseAssembler):
def __init__(
self,
connector: BaseConnector,
index_store: IndexStoreBase,
table_vector_store_connector: VectorStoreConnector,
field_vector_store_connector: VectorStoreConnector = None,
chunk_parameters: Optional[ChunkParameters] = None,
embedding_model: Optional[str] = None,
embeddings: Optional[Embeddings] = None,
max_seq_length: int = 512,
**kwargs: Any,
) -> None:
"""Initialize with Embedding Assembler arguments.
Args:
connector: (BaseConnector) BaseConnector connection.
index_store: (IndexStoreBase) IndexStoreBase to use.
table_vector_store_connector: VectorStoreConnector to load
and retrieve table info.
field_vector_store_connector: VectorStoreConnector to load
and retrieve field info.
chunk_manager: (Optional[ChunkManager]) ChunkManager to use for chunking.
embedding_model: (Optional[str]) Embedding model to use.
embeddings: (Optional[Embeddings]) Embeddings to use.
"""
knowledge = DatasourceKnowledge(connector)
self._connector = connector
self._index_store = index_store
self._table_vector_store_connector = table_vector_store_connector
field_vector_store_config = VectorStoreConfig(
name=table_vector_store_connector.vector_store_config.name + "_field"
)
self._field_vector_store_connector = (
field_vector_store_connector
or VectorStoreConnector.from_default(
os.getenv("VECTOR_STORE_TYPE", "Chroma"),
self._table_vector_store_connector.current_embeddings,
vector_store_config=field_vector_store_config,
)
)
self._embedding_model = embedding_model
if self._embedding_model and not embeddings:
embeddings = DefaultEmbeddingFactory(
default_model_name=self._embedding_model
).create(self._embedding_model)
if (
embeddings
and self._table_vector_store_connector.vector_store_config.embedding_fn
is None
):
self._table_vector_store_connector.vector_store_config.embedding_fn = (
embeddings
)
if (
embeddings
and self._field_vector_store_connector.vector_store_config.embedding_fn
is None
):
self._field_vector_store_connector.vector_store_config.embedding_fn = (
embeddings
)
knowledge = DatasourceKnowledge(connector, model_dimension=max_seq_length)
super().__init__(
knowledge=knowledge,
chunk_parameters=chunk_parameters,
@@ -62,23 +106,36 @@ class DBSchemaAssembler(BaseAssembler):
def load_from_connection(
cls,
connector: BaseConnector,
index_store: IndexStoreBase,
table_vector_store_connector: VectorStoreConnector,
field_vector_store_connector: VectorStoreConnector = None,
chunk_parameters: Optional[ChunkParameters] = None,
embedding_model: Optional[str] = None,
embeddings: Optional[Embeddings] = None,
max_seq_length: int = 512,
) -> "DBSchemaAssembler":
"""Load document embedding into vector store from path.
Args:
connector: (BaseConnector) BaseConnector connection.
index_store: (IndexStoreBase) IndexStoreBase to use.
table_vector_store_connector: used to load table chunks.
field_vector_store_connector: used to load field chunks
if field in table is too much.
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for
chunking.
embedding_model: (Optional[str]) Embedding model to use.
embeddings: (Optional[Embeddings]) Embeddings to use.
max_seq_length: Embedding model max sequence length
Returns:
DBSchemaAssembler
"""
return cls(
connector=connector,
index_store=index_store,
table_vector_store_connector=table_vector_store_connector,
field_vector_store_connector=field_vector_store_connector,
embedding_model=embedding_model,
chunk_parameters=chunk_parameters,
embeddings=embeddings,
max_seq_length=max_seq_length,
)
def get_chunks(self) -> List[Chunk]:
@@ -91,7 +148,19 @@ class DBSchemaAssembler(BaseAssembler):
Returns:
List[str]: List of chunk ids.
"""
return self._index_store.load_document(self._chunks)
table_chunks, field_chunks = [], []
for chunk in self._chunks:
metadata = chunk.metadata
if metadata.get("separated"):
if metadata.get("part") == "table":
table_chunks.append(chunk)
else:
field_chunks.append(chunk)
else:
table_chunks.append(chunk)
self._field_vector_store_connector.load_document(field_chunks)
return self._table_vector_store_connector.load_document(table_chunks)
def _extract_info(self, chunks) -> List[Chunk]:
"""Extract info from chunks."""
@@ -110,5 +179,6 @@ class DBSchemaAssembler(BaseAssembler):
top_k=top_k,
connector=self._connector,
is_embeddings=True,
index_store=self._index_store,
table_vector_store_connector=self._table_vector_store_connector,
field_vector_store_connector=self._field_vector_store_connector,
)

View File

@@ -1,76 +1,117 @@
from unittest.mock import MagicMock
from typing import List
from unittest.mock import MagicMock, patch
import pytest
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
from dbgpt.rag.assembler.embedding import EmbeddingAssembler
from dbgpt.rag.chunk_manager import ChunkParameters, SplitterType
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
from dbgpt.rag.knowledge.base import Knowledge
from dbgpt.rag.text_splitter.text_splitter import CharacterTextSplitter
from dbgpt.storage.vector_store.chroma_store import ChromaStore
import dbgpt
from dbgpt.core import Chunk
from dbgpt.rag.retriever.db_schema import DBSchemaRetriever
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
@pytest.fixture
def mock_db_connection():
"""Create a temporary database connection for testing."""
connect = SQLiteTempConnector.create_temporary_db()
connect.create_temp_tables(
{
"user": {
"columns": {
"id": "INTEGER PRIMARY KEY",
"name": "TEXT",
"age": "INTEGER",
},
"data": [
(1, "Tom", 10),
(2, "Jerry", 16),
(3, "Jack", 18),
(4, "Alice", 20),
(5, "Bob", 22),
],
}
}
return MagicMock()
@pytest.fixture
def mock_table_vector_store_connector():
mock_connector = MagicMock()
mock_connector.vector_store_config.name = "table_name"
chunk = Chunk(
content="table_name: user\ncomment: user about dbgpt",
metadata={
"field_num": 6,
"part": "table",
"separated": 1,
"table_name": "user",
},
)
return connect
mock_connector.similar_search_with_scores = MagicMock(return_value=[chunk])
return mock_connector
@pytest.fixture
def mock_chunk_parameters():
return MagicMock(spec=ChunkParameters)
def mock_field_vector_store_connector():
mock_connector = MagicMock()
chunk1 = Chunk(
content="name,age",
metadata={
"field_num": 6,
"part": "field",
"part_index": 0,
"separated": 1,
"table_name": "user",
},
)
chunk2 = Chunk(
content="address,gender",
metadata={
"field_num": 6,
"part": "field",
"part_index": 1,
"separated": 1,
"table_name": "user",
},
)
chunk3 = Chunk(
content="mail,phone",
metadata={
"field_num": 6,
"part": "field",
"part_index": 2,
"separated": 1,
"table_name": "user",
},
)
mock_connector.similar_search_with_scores = MagicMock(
return_value=[chunk1, chunk2, chunk3]
)
return mock_connector
@pytest.fixture
def mock_embedding_factory():
return MagicMock(spec=EmbeddingFactory)
@pytest.fixture
def mock_vector_store_connector():
return MagicMock(spec=ChromaStore)
@pytest.fixture
def mock_knowledge():
return MagicMock(spec=Knowledge)
def test_load_knowledge(
def dbstruct_retriever(
mock_db_connection,
mock_knowledge,
mock_chunk_parameters,
mock_embedding_factory,
mock_vector_store_connector,
mock_table_vector_store_connector,
mock_field_vector_store_connector,
):
mock_chunk_parameters.chunk_strategy = "CHUNK_BY_SIZE"
mock_chunk_parameters.text_splitter = CharacterTextSplitter()
mock_chunk_parameters.splitter_type = SplitterType.USER_DEFINE
assembler = EmbeddingAssembler(
knowledge=mock_knowledge,
chunk_parameters=mock_chunk_parameters,
embeddings=mock_embedding_factory.create(),
index_store=mock_vector_store_connector,
return DBSchemaRetriever(
connector=mock_db_connection,
table_vector_store_connector=mock_table_vector_store_connector,
field_vector_store_connector=mock_field_vector_store_connector,
separator="--table-field-separator--",
)
def mock_parse_db_summary() -> str:
"""Patch _parse_db_summary method."""
return (
"table_name: user\ncomment: user about dbgpt\n"
"--table-field-separator--\n"
"name,age\naddress,gender\nmail,phone"
)
# Mocking the _parse_db_summary method in your test function
@patch.object(
dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary
)
def test_retrieve_with_mocked_summary(dbstruct_retriever):
query = "Table summary"
chunks: List[Chunk] = dbstruct_retriever._retrieve(query)
assert isinstance(chunks[0], Chunk)
assert chunks[0].content == (
"table_name: user\ncomment: user about dbgpt\n"
"--table-field-separator--\n"
"name,age\naddress,gender\nmail,phone"
)
def async_mock_parse_db_summary() -> str:
"""Asynchronous patch for _parse_db_summary method."""
return (
"table_name: user\ncomment: user about dbgpt\n"
"--table-field-separator--\n"
"name,age\naddress,gender\nmail,phone"
)
assembler.load_knowledge(knowledge=mock_knowledge)
assert len(assembler._chunks) == 0

View File

@@ -5,9 +5,9 @@ import pytest
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
from dbgpt.rag.assembler.db_schema import DBSchemaAssembler
from dbgpt.rag.chunk_manager import ChunkParameters, SplitterType
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
from dbgpt.rag.text_splitter.text_splitter import CharacterTextSplitter
from dbgpt.storage.vector_store.chroma_store import ChromaStore
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddings, EmbeddingFactory
from dbgpt.rag.text_splitter.text_splitter import RDBTextSplitter
from dbgpt.serve.rag.connector import VectorStoreConnector
@pytest.fixture
@@ -21,14 +21,22 @@ def mock_db_connection():
"id": "INTEGER PRIMARY KEY",
"name": "TEXT",
"age": "INTEGER",
},
"data": [
(1, "Tom", 10),
(2, "Jerry", 16),
(3, "Jack", 18),
(4, "Alice", 20),
(5, "Bob", 22),
],
"address": "TEXT",
"phone": "TEXT",
"email": "TEXT",
"gender": "TEXT",
"birthdate": "TEXT",
"occupation": "TEXT",
"education": "TEXT",
"marital_status": "TEXT",
"nationality": "TEXT",
"height": "REAL",
"weight": "REAL",
"blood_type": "TEXT",
"emergency_contact": "TEXT",
"created_at": "TEXT",
"updated_at": "TEXT",
}
}
}
)
@@ -46,23 +54,29 @@ def mock_embedding_factory():
@pytest.fixture
def mock_vector_store_connector():
return MagicMock(spec=ChromaStore)
def mock_table_vector_store_connector():
mock_connector = MagicMock(spec=VectorStoreConnector)
mock_connector.vector_store_config.name = "table_vector_store_name"
mock_connector.current_embeddings = DefaultEmbeddings()
return mock_connector
def test_load_knowledge(
mock_db_connection,
mock_chunk_parameters,
mock_embedding_factory,
mock_vector_store_connector,
mock_table_vector_store_connector,
):
mock_chunk_parameters.chunk_strategy = "CHUNK_BY_SIZE"
mock_chunk_parameters.text_splitter = CharacterTextSplitter()
mock_chunk_parameters.text_splitter = RDBTextSplitter(
separator="--table-field-separator--"
)
mock_chunk_parameters.splitter_type = SplitterType.USER_DEFINE
assembler = DBSchemaAssembler(
connector=mock_db_connection,
chunk_parameters=mock_chunk_parameters,
embeddings=mock_embedding_factory.create(),
index_store=mock_vector_store_connector,
table_vector_store_connector=mock_table_vector_store_connector,
max_seq_length=10,
)
assert len(assembler._chunks) == 1
assert len(assembler._chunks) > 1

View File

@@ -5,7 +5,7 @@ from dbgpt.core import Document
from dbgpt.datasource import BaseConnector
from ..summary.gdbms_db_summary import _parse_db_summary as _parse_gdb_summary
from ..summary.rdbms_db_summary import _parse_db_summary
from ..summary.rdbms_db_summary import _parse_db_summary_with_metadata
from .base import ChunkStrategy, DocumentType, Knowledge, KnowledgeType
@@ -15,9 +15,11 @@ class DatasourceKnowledge(Knowledge):
def __init__(
self,
connector: BaseConnector,
summary_template: str = "{table_name}({columns})",
summary_template: str = "table_name: {table_name}",
separator: str = "--table-field-separator--",
knowledge_type: Optional[KnowledgeType] = KnowledgeType.DOCUMENT,
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
model_dimension: int = 512,
**kwargs: Any,
) -> None:
"""Create Datasource Knowledge with Knowledge arguments.
@@ -25,11 +27,17 @@ class DatasourceKnowledge(Knowledge):
Args:
connector(BaseConnector): connector
summary_template(str, optional): summary template
separator(str, optional): separator used to separate
table's basic info and fields.
defaults `-- table-field-separator--`
knowledge_type(KnowledgeType, optional): knowledge type
metadata(Dict[str, Union[str, List[str]], optional): metadata
model_dimension(int, optional): The threshold for splitting field string
"""
self._separator = separator
self._connector = connector
self._summary_template = summary_template
self._model_dimension = model_dimension
super().__init__(knowledge_type=knowledge_type, metadata=metadata, **kwargs)
def _load(self) -> List[Document]:
@@ -37,13 +45,23 @@ class DatasourceKnowledge(Knowledge):
docs = []
if self._connector.is_graph_type():
db_summary = _parse_gdb_summary(self._connector, self._summary_template)
for table_summary in db_summary:
metadata = {"source": "database"}
docs.append(Document(content=table_summary, metadata=metadata))
else:
db_summary = _parse_db_summary(self._connector, self._summary_template)
for table_summary in db_summary:
metadata = {"source": "database"}
if self._metadata:
metadata.update(self._metadata) # type: ignore
docs.append(Document(content=table_summary, metadata=metadata))
db_summary_with_metadata = _parse_db_summary_with_metadata(
self._connector,
self._summary_template,
self._separator,
self._model_dimension,
)
for summary, table_metadata in db_summary_with_metadata:
metadata = {"source": "database"}
if self._metadata:
metadata.update(self._metadata) # type: ignore
table_metadata.update(metadata)
docs.append(Document(content=summary, metadata=table_metadata))
return docs
@classmethod

View File

@@ -1,14 +1,15 @@
"""The DBSchema Retriever Operator."""
import os
from typing import List, Optional
from dbgpt.core import Chunk
from dbgpt.core.interface.operators.retriever import RetrieverOperator
from dbgpt.datasource.base import BaseConnector
from dbgpt.serve.rag.connector import VectorStoreConnector
from ...storage.vector_store.base import VectorStoreConfig
from ..assembler.db_schema import DBSchemaAssembler
from ..chunk_manager import ChunkParameters
from ..index.base import IndexStoreBase
from ..retriever.db_schema import DBSchemaRetriever
from .assembler import AssemblerOperator
@@ -19,13 +20,14 @@ class DBSchemaRetrieverOperator(RetrieverOperator[str, List[Chunk]]):
Args:
connector (BaseConnector): The connection.
top_k (int, optional): The top k. Defaults to 4.
index_store (IndexStoreBase, optional): The vector store
vector_store_connector (VectorStoreConnector, optional): The vector store
connector. Defaults to None.
"""
def __init__(
self,
index_store: IndexStoreBase,
table_vector_store_connector: VectorStoreConnector,
field_vector_store_connector: VectorStoreConnector,
top_k: int = 4,
connector: Optional[BaseConnector] = None,
**kwargs
@@ -35,7 +37,8 @@ class DBSchemaRetrieverOperator(RetrieverOperator[str, List[Chunk]]):
self._retriever = DBSchemaRetriever(
top_k=top_k,
connector=connector,
index_store=index_store,
table_vector_store_connector=table_vector_store_connector,
field_vector_store_connector=field_vector_store_connector,
)
def retrieve(self, query: str) -> List[Chunk]:
@@ -53,7 +56,8 @@ class DBSchemaAssemblerOperator(AssemblerOperator[BaseConnector, List[Chunk]]):
def __init__(
self,
connector: BaseConnector,
index_store: IndexStoreBase,
table_vector_store_connector: VectorStoreConnector,
field_vector_store_connector: VectorStoreConnector = None,
chunk_parameters: Optional[ChunkParameters] = None,
**kwargs
):
@@ -61,14 +65,26 @@ class DBSchemaAssemblerOperator(AssemblerOperator[BaseConnector, List[Chunk]]):
Args:
connector (BaseConnector): The connection.
index_store (IndexStoreBase): The Storage IndexStoreBase.
vector_store_connector (VectorStoreConnector): The vector store connector.
chunk_parameters (Optional[ChunkParameters], optional): The chunk
parameters.
"""
if not chunk_parameters:
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")
self._chunk_parameters = chunk_parameters
self._index_store = index_store
self._table_vector_store_connector = table_vector_store_connector
field_vector_store_config = VectorStoreConfig(
name=table_vector_store_connector.vector_store_config.name + "_field"
)
self._field_vector_store_connector = (
field_vector_store_connector
or VectorStoreConnector.from_default(
os.getenv("VECTOR_STORE_TYPE", "Chroma"),
self._table_vector_store_connector.current_embeddings,
vector_store_config=field_vector_store_config,
)
)
self._connector = connector
super().__init__(**kwargs)
@@ -84,7 +100,8 @@ class DBSchemaAssemblerOperator(AssemblerOperator[BaseConnector, List[Chunk]]):
assembler = DBSchemaAssembler.load_from_connection(
connector=self._connector,
chunk_parameters=self._chunk_parameters,
index_store=self._index_store,
table_vector_store_connector=self._table_vector_store_connector,
field_vector_store_connector=self._field_vector_store_connector,
)
assembler.persist()
return assembler.get_chunks()

View File

@@ -1,18 +1,23 @@
"""DBSchema retriever."""
import logging
import os
from typing import List, Optional
from functools import reduce
from typing import List, Optional, cast
from dbgpt._private.config import Config
from dbgpt.core import Chunk
from dbgpt.datasource.base import BaseConnector
from dbgpt.rag.index.base import IndexStoreBase
from dbgpt.rag.retriever.base import BaseRetriever
from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
from dbgpt.storage.vector_store.filters import MetadataFilters
from dbgpt.util.chat_util import run_async_tasks
from dbgpt.rag.summary.gdbms_db_summary import _parse_db_summary
from dbgpt.serve.rag.connector import VectorStoreConnector
from dbgpt.storage.vector_store.base import VectorStoreConfig
from dbgpt.storage.vector_store.filters import MetadataFilter, MetadataFilters
from dbgpt.util.chat_util import run_tasks
from dbgpt.util.executor_utils import blocking_func_to_async_no_executor
from dbgpt.util.tracer import root_tracer
logger = logging.getLogger(__name__)
CFG = Config()
class DBSchemaRetriever(BaseRetriever):
@@ -20,7 +25,9 @@ class DBSchemaRetriever(BaseRetriever):
def __init__(
self,
index_store: IndexStoreBase,
table_vector_store_connector: VectorStoreConnector,
field_vector_store_connector: VectorStoreConnector = None,
separator: str = "--table-field-separator--",
top_k: int = 4,
connector: Optional[BaseConnector] = None,
query_rewrite: bool = False,
@@ -30,7 +37,11 @@ class DBSchemaRetriever(BaseRetriever):
"""Create DBSchemaRetriever.
Args:
index_store(IndexStore): index connector
table_vector_store_connector: VectorStoreConnector
to load and retrieve table info.
field_vector_store_connector: VectorStoreConnector
to load and retrieve field info.
separator: field/table separator
top_k (int): top k
connector (Optional[BaseConnector]): RDBMSConnector.
query_rewrite (bool): query rewrite
@@ -70,34 +81,42 @@ class DBSchemaRetriever(BaseRetriever):
connector = _create_temporary_connection()
vector_store_config = ChromaVectorConfig(name="vector_store_name")
embedding_model_path = "{your_embedding_model_path}"
embedding_fn = embedding_factory.create(model_name=embedding_model_path)
config = ChromaVectorConfig(
persist_path=PILOT_PATH,
name="dbschema_rag_test",
embedding_fn=DefaultEmbeddingFactory(
default_model_name=os.path.join(
MODEL_PATH, "text2vec-large-chinese"
),
).create(),
vector_connector = VectorStoreConnector.from_default(
"Chroma",
vector_store_config=vector_store_config,
embedding_fn=embedding_fn,
)
vector_store = ChromaStore(config)
# get db struct retriever
retriever = DBSchemaRetriever(
top_k=3,
index_store=vector_store,
vector_store_connector=vector_connector,
connector=connector,
)
chunks = retriever.retrieve("show columns from table")
result = [chunk.content for chunk in chunks]
print(f"db struct rag example results:{result}")
"""
self._separator = separator
self._top_k = top_k
self._connector = connector
self._query_rewrite = query_rewrite
self._index_store = index_store
self._table_vector_store_connector = table_vector_store_connector
field_vector_store_config = VectorStoreConfig(
name=table_vector_store_connector.vector_store_config.name + "_field"
)
self._field_vector_store_connector = (
field_vector_store_connector
or VectorStoreConnector.from_default(
os.getenv("VECTOR_STORE_TYPE", "Chroma"),
self._table_vector_store_connector.current_embeddings,
vector_store_config=field_vector_store_config,
)
)
self._need_embeddings = False
if self._index_store:
if self._table_vector_store_connector:
self._need_embeddings = True
self._rerank = rerank or DefaultRanker(self._top_k)
@@ -114,15 +133,8 @@ class DBSchemaRetriever(BaseRetriever):
List[Chunk]: list of chunks
"""
if self._need_embeddings:
queries = [query]
candidates = [
self._index_store.similar_search(query, self._top_k, filters)
for query in queries
]
return cast(List[Chunk], reduce(lambda x, y: x + y, candidates))
return self._similarity_search(query, filters)
else:
if not self._connector:
raise RuntimeError("RDBMSConnector connection is required.")
table_summaries = _parse_db_summary(self._connector)
return [Chunk(content=table_summary) for table_summary in table_summaries]
@@ -156,30 +168,11 @@ class DBSchemaRetriever(BaseRetriever):
Returns:
List[Chunk]: list of chunks
"""
if self._need_embeddings:
queries = [query]
candidates = [
self._similarity_search(
query, filters, root_tracer.get_current_span_id()
)
for query in queries
]
result_candidates = await run_async_tasks(
tasks=candidates, concurrency_limit=1
)
return cast(List[Chunk], reduce(lambda x, y: x + y, result_candidates))
else:
from dbgpt.rag.summary.rdbms_db_summary import ( # noqa: F401
_parse_db_summary,
)
table_summaries = await run_async_tasks(
tasks=[self._aparse_db_summary(root_tracer.get_current_span_id())],
concurrency_limit=1,
)
return [
Chunk(content=table_summary) for table_summary in table_summaries[0]
]
return await blocking_func_to_async_no_executor(
func=self._retrieve,
query=query,
filters=filters,
)
async def _aretrieve_with_score(
self,
@@ -196,34 +189,40 @@ class DBSchemaRetriever(BaseRetriever):
"""
return await self._aretrieve(query, filters)
async def _similarity_search(
self,
query,
filters: Optional[MetadataFilters] = None,
parent_span_id: Optional[str] = None,
def _retrieve_field(self, table_chunk: Chunk, query) -> Chunk:
metadata = table_chunk.metadata
metadata["part"] = "field"
filters = [MetadataFilter(key=k, value=v) for k, v in metadata.items()]
field_chunks = self._field_vector_store_connector.similar_search_with_scores(
query, self._top_k, 0, MetadataFilters(filters=filters)
)
field_contents = [chunk.content for chunk in field_chunks]
table_chunk.content += "\n" + self._separator + "\n" + "\n".join(field_contents)
return table_chunk
def _similarity_search(
self, query, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Similar search."""
with root_tracer.start_span(
"dbgpt.rag.retriever.db_schema._similarity_search",
parent_span_id,
metadata={"query": query},
):
return await blocking_func_to_async_no_executor(
self._index_store.similar_search, query, self._top_k, filters
)
table_chunks = self._table_vector_store_connector.similar_search_with_scores(
query, self._top_k, 0, filters
)
async def _aparse_db_summary(
self, parent_span_id: Optional[str] = None
) -> List[str]:
"""Similar search."""
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
not_sep_chunks = [
chunk for chunk in table_chunks if not chunk.metadata.get("separated")
]
separated_chunks = [
chunk for chunk in table_chunks if chunk.metadata.get("separated")
]
if not separated_chunks:
return not_sep_chunks
if not self._connector:
raise RuntimeError("RDBMSConnector connection is required.")
with root_tracer.start_span(
"dbgpt.rag.retriever.db_schema._aparse_db_summary",
parent_span_id,
):
return await blocking_func_to_async_no_executor(
_parse_db_summary, self._connector
)
# Create tasks list
tasks = [
lambda c=chunk: self._retrieve_field(c, query) for chunk in separated_chunks
]
# Run tasks concurrently
separated_result = run_tasks(tasks, concurrency_limit=3)
# Combine and return results
return not_sep_chunks + separated_result

View File

@@ -15,42 +15,53 @@ def mock_db_connection():
@pytest.fixture
def mock_vector_store_connector():
def mock_table_vector_store_connector():
mock_connector = MagicMock()
mock_connector.similar_search.return_value = [Chunk(content="Table summary")] * 4
mock_connector.vector_store_config.name = "table_name"
mock_connector.similar_search_with_scores.return_value = [
Chunk(content="Table summary")
] * 4
return mock_connector
@pytest.fixture
def db_struct_retriever(mock_db_connection, mock_vector_store_connector):
def mock_field_vector_store_connector():
mock_connector = MagicMock()
mock_connector.similar_search_with_scores.return_value = [
Chunk(content="Field summary")
] * 4
return mock_connector
@pytest.fixture
def dbstruct_retriever(
mock_db_connection,
mock_table_vector_store_connector,
mock_field_vector_store_connector,
):
return DBSchemaRetriever(
connector=mock_db_connection,
index_store=mock_vector_store_connector,
table_vector_store_connector=mock_table_vector_store_connector,
field_vector_store_connector=mock_field_vector_store_connector,
)
def mock_parse_db_summary(conn) -> List[str]:
def mock_parse_db_summary() -> str:
"""Patch _parse_db_summary method."""
return ["Table summary"]
return "Table summary"
# Mocking the _parse_db_summary method in your test function
@patch.object(
dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary
)
def test_retrieve_with_mocked_summary(db_struct_retriever):
def test_retrieve_with_mocked_summary(dbstruct_retriever):
query = "Table summary"
chunks: List[Chunk] = db_struct_retriever._retrieve(query)
chunks: List[Chunk] = dbstruct_retriever._retrieve(query)
assert isinstance(chunks[0], Chunk)
assert chunks[0].content == "Table summary"
@pytest.mark.asyncio
@patch.object(
dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary
)
async def test_aretrieve_with_mocked_summary(db_struct_retriever):
query = "Table summary"
chunks: List[Chunk] = await db_struct_retriever._aretrieve(query)
assert isinstance(chunks[0], Chunk)
assert chunks[0].content == "Table summary"
async def async_mock_parse_db_summary() -> str:
"""Asynchronous patch for _parse_db_summary method."""
return "Table summary"

View File

@@ -2,13 +2,15 @@
import logging
import traceback
from typing import List
from dbgpt._private.config import Config
from dbgpt.component import SystemApp
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
from dbgpt.rag import ChunkParameters
from dbgpt.rag.summary.gdbms_db_summary import GdbmsSummary
from dbgpt.rag.summary.rdbms_db_summary import RdbmsSummary
from dbgpt.rag.text_splitter.text_splitter import RDBTextSplitter
from dbgpt.serve.rag.connector import VectorStoreConnector
logger = logging.getLogger(__name__)
@@ -47,22 +49,26 @@ class DBSummaryClient:
logger.info("db summary embedding success")
def get_db_summary(self, dbname, query, topk) -> List[str]:
def get_db_summary(self, dbname, query, topk):
"""Get user query related tables info."""
from dbgpt.serve.rag.connector import VectorStoreConnector
from dbgpt.storage.vector_store.base import VectorStoreConfig
vector_store_config = VectorStoreConfig(name=dbname + "_profile")
vector_connector = VectorStoreConnector.from_default(
vector_store_name = dbname + "_profile"
table_vector_store_config = VectorStoreConfig(name=vector_store_name)
table_vector_connector = VectorStoreConnector.from_default(
CFG.VECTOR_STORE_TYPE,
embedding_fn=self.embeddings,
vector_store_config=vector_store_config,
self.embeddings,
vector_store_config=table_vector_store_config,
)
from dbgpt.rag.retriever.db_schema import DBSchemaRetriever
retriever = DBSchemaRetriever(
top_k=topk, index_store=vector_connector.index_client
top_k=topk,
table_vector_store_connector=table_vector_connector,
separator="--table-field-separator--",
)
table_docs = retriever.retrieve(query)
ans = [d.content for d in table_docs]
return ans
@@ -92,18 +98,23 @@ class DBSummaryClient:
from dbgpt.serve.rag.connector import VectorStoreConnector
from dbgpt.storage.vector_store.base import VectorStoreConfig
vector_store_config = VectorStoreConfig(name=vector_store_name)
vector_connector = VectorStoreConnector.from_default(
table_vector_store_config = VectorStoreConfig(name=vector_store_name)
table_vector_connector = VectorStoreConnector.from_default(
CFG.VECTOR_STORE_TYPE,
self.embeddings,
vector_store_config=vector_store_config,
vector_store_config=table_vector_store_config,
)
if not vector_connector.vector_name_exists():
if not table_vector_connector.vector_name_exists():
from dbgpt.rag.assembler.db_schema import DBSchemaAssembler
chunk_parameters = ChunkParameters(
text_splitter=RDBTextSplitter(separator="--table-field-separator--")
)
db_assembler = DBSchemaAssembler.load_from_connection(
connector=db_summary_client.db,
index_store=vector_connector.index_client,
table_vector_store_connector=table_vector_connector,
chunk_parameters=chunk_parameters,
max_seq_length=CFG.EMBEDDING_MODEL_MAX_SEQ_LEN,
)
if len(db_assembler.get_chunks()) > 0:
@@ -115,16 +126,26 @@ class DBSummaryClient:
def delete_db_profile(self, dbname):
"""Delete db profile."""
vector_store_name = dbname + "_profile"
table_vector_store_name = dbname + "_profile"
field_vector_store_name = dbname + "_profile_field"
from dbgpt.serve.rag.connector import VectorStoreConnector
from dbgpt.storage.vector_store.base import VectorStoreConfig
vector_store_config = VectorStoreConfig(name=vector_store_name)
vector_connector = VectorStoreConnector.from_default(
table_vector_store_config = VectorStoreConfig(name=vector_store_name)
field_vector_store_config = VectorStoreConfig(name=field_vector_store_name)
table_vector_connector = VectorStoreConnector.from_default(
CFG.VECTOR_STORE_TYPE,
self.embeddings,
vector_store_config=vector_store_config,
vector_store_config=table_vector_store_config,
)
vector_connector.delete_vector_name(vector_store_name)
field_vector_connector = VectorStoreConnector.from_default(
CFG.VECTOR_STORE_TYPE,
self.embeddings,
vector_store_config=field_vector_store_config,
)
table_vector_connector.delete_vector_name(table_vector_store_name)
field_vector_connector.delete_vector_name(field_vector_store_name)
logger.info(f"delete db profile {dbname} success")
@staticmethod

View File

@@ -1,6 +1,6 @@
"""Summary for rdbms database."""
import re
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from dbgpt._private.config import Config
from dbgpt.datasource import BaseConnector
@@ -80,6 +80,134 @@ def _parse_db_summary(
return table_info_summaries
def _parse_db_summary_with_metadata(
conn: BaseConnector,
summary_template: str = "table_name: {table_name}",
separator: str = "--table-field-separator--",
model_dimension: int = 512,
) -> List[Tuple[str, Dict[str, Any]]]:
"""Get db summary for database.
Args:
conn (BaseConnector): database connection
summary_template (str): summary template
separator(str, optional): separator used to separate table's
basic info and fields. defaults to `-- table-field-separator--`
model_dimension(int, optional): The threshold for splitting field string
"""
tables = conn.get_table_names()
table_info_summaries = [
_parse_table_summary_with_metadata(
conn, summary_template, separator, table_name, model_dimension
)
for table_name in tables
]
return table_info_summaries
def _split_columns_str(columns: List[str], model_dimension: int):
"""Split columns str.
Args:
columns (List[str]): fields string
model_dimension (int, optional): The threshold for splitting field string.
"""
result = []
current_string = ""
current_length = 0
for element_str in columns:
element_length = len(element_str)
# If adding the current element's length would exceed the threshold,
# add the current string to results and reset
if current_length + element_length > model_dimension:
result.append(current_string.strip()) # Remove trailing spaces
current_string = element_str
current_length = element_length
else:
# If current string is empty, add element directly
if current_string:
current_string += "," + element_str
else:
current_string = element_str
current_length += element_length + 1 # Add length of space
# Handle the last string segment
if current_string:
result.append(current_string.strip())
return result
def _parse_table_summary_with_metadata(
conn: BaseConnector,
summary_template: str,
separator,
table_name: str,
model_dimension=512,
) -> Tuple[str, Dict[str, Any]]:
"""Get table summary for table.
Args:
conn (BaseConnector): database connection
summary_template (str): summary template
separator(str, optional): separator used to separate table's
basic info and fields. defaults to `-- table-field-separator--`
model_dimension(int, optional): The threshold for splitting field string
Examples:
metadata: {'table_name': 'asd', 'separated': 0/1}
table_name: table1
table_comment: comment
index_keys: keys
--table-field-separator--
(column1,comment), (column2, comment), (column3, comment)
(column4,comment), (column5, comment), (column6, comment)
"""
columns = []
metadata = {"table_name": table_name, "separated": 0}
for column in conn.get_columns(table_name):
if column.get("comment"):
columns.append(f"{column['name']} ({column.get('comment')})")
else:
columns.append(f"{column['name']}")
metadata.update({"field_num": len(columns)})
separated_columns = _split_columns_str(columns, model_dimension=model_dimension)
if len(separated_columns) > 1:
metadata["separated"] = 1
column_str = "\n".join(separated_columns)
# Obtain index information
index_keys = []
raw_indexes = conn.get_indexes(table_name)
for index in raw_indexes:
if isinstance(index, tuple): # Process tuple type index information
index_name, index_creation_command = index
# Extract column names using re
matched_columns = re.findall(r"\(([^)]+)\)", index_creation_command)
if matched_columns:
key_str = ", ".join(matched_columns)
index_keys.append(f"{index_name}(`{key_str}`) ")
else:
key_str = ", ".join(index["column_names"])
index_keys.append(f"{index['name']}(`{key_str}`) ")
table_str = summary_template.format(table_name=table_name)
try:
comment = conn.get_table_comment(table_name)
except Exception:
comment = dict(text=None)
if comment.get("text"):
table_str += f"\ntable_comment: {comment.get('text')}"
if len(index_keys) > 0:
index_key_str = ", ".join(index_keys)
table_str += f"\nindex_keys: {index_key_str}"
table_str += f"\n{separator}\n{column_str}"
return table_str, metadata
def _parse_table_summary(
conn: BaseConnector, summary_template: str, table_name: str
) -> str:

View File

@@ -912,3 +912,42 @@ class PageTextSplitter(TextSplitter):
new_doc = Chunk(content=text, metadata=copy.deepcopy(_metadatas[i]))
chunks.append(new_doc)
return chunks
class RDBTextSplitter(TextSplitter):
"""Split relational database tables and fields."""
def __init__(self, **kwargs):
"""Create a new TextSplitter."""
super().__init__(**kwargs)
def split_text(self, text: str, **kwargs):
"""Split text into a couple of parts."""
pass
def split_documents(self, documents: Iterable[Document], **kwargs) -> List[Chunk]:
"""Split document into chunks."""
chunks = []
for doc in documents:
metadata = doc.metadata
content = doc.content
if metadata.get("separated"):
# separate table and field
parts = content.split(self._separator)
table_part, field_part = parts[0], parts[1]
table_metadata, field_metadata = copy.deepcopy(metadata), copy.deepcopy(
metadata
)
table_metadata["part"] = "table" # identify of table_chunk
field_metadata["part"] = "field" # identify of field_chunk
table_chunk = Chunk(content=table_part, metadata=table_metadata)
chunks.append(table_chunk)
field_parts = field_part.split("\n")
for i, sub_part in enumerate(field_parts):
sub_metadata = copy.deepcopy(field_metadata)
sub_metadata["part_index"] = i
field_chunk = Chunk(content=sub_part, metadata=sub_metadata)
chunks.append(field_chunk)
else:
chunks.append(Chunk(content=content, metadata=metadata))
return chunks

View File

@@ -1,5 +1,6 @@
import asyncio
from typing import Any, Coroutine, List
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Callable, Coroutine, List
async def llm_chat_response_nostream(chat_scene: str, **chat_param):
@@ -47,13 +48,34 @@ async def run_async_tasks(
def run_tasks(
tasks: List[Coroutine],
tasks: List[Callable],
concurrency_limit: int = None,
) -> List[Any]:
"""Run a list of async tasks."""
tasks_to_execute: List[Any] = tasks
"""
Run a list of tasks concurrently using a thread pool.
async def _gather() -> List[Any]:
return await asyncio.gather(*tasks_to_execute)
Args:
tasks: List of callable functions to execute
concurrency_limit: Maximum number of concurrent threads (optional)
outputs: List[Any] = asyncio.run(_gather())
return outputs
Returns:
List of results from all tasks in the order they were submitted
"""
max_workers = concurrency_limit if concurrency_limit else None
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all tasks and get futures
futures = [executor.submit(task) for task in tasks]
# Collect results in order, raising any exceptions
results = []
for future in futures:
try:
results.append(future.result())
except Exception as e:
# Cancel any pending futures
for f in futures:
f.cancel()
raise e
return results