Feat rdb summary wide table (#2035)

Co-authored-by: dongzhancai1 <dongzhancai1@jd.com>
Co-authored-by: dong <dongzhancai@iie2.com>
This commit is contained in:
Cooper 2024-12-18 20:34:21 +08:00 committed by GitHub
parent 7f4b5e79cf
commit 9b0161e521
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 948 additions and 243 deletions

View File

@ -66,6 +66,7 @@ QUANTIZE_8bit=True
#** EMBEDDING SETTINGS **# #** EMBEDDING SETTINGS **#
#*******************************************************************# #*******************************************************************#
EMBEDDING_MODEL=text2vec EMBEDDING_MODEL=text2vec
EMBEDDING_MODEL_MAX_SEQ_LEN=512
#EMBEDDING_MODEL=m3e-large #EMBEDDING_MODEL=m3e-large
#EMBEDDING_MODEL=bge-large-en #EMBEDDING_MODEL=bge-large-en
#EMBEDDING_MODEL=bge-large-zh #EMBEDDING_MODEL=bge-large-zh

View File

@ -264,6 +264,9 @@ class Config(metaclass=Singleton):
# EMBEDDING Configuration # EMBEDDING Configuration
self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec") self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec")
self.EMBEDDING_MODEL_MAX_SEQ_LEN = int(
os.getenv("MEMBEDDING_MODEL_MAX_SEQ_LEN", 512)
)
# Rerank model configuration # Rerank model configuration
self.RERANK_MODEL = os.getenv("RERANK_MODEL") self.RERANK_MODEL = os.getenv("RERANK_MODEL")
self.RERANK_MODEL_PATH = os.getenv("RERANK_MODEL_PATH") self.RERANK_MODEL_PATH = os.getenv("RERANK_MODEL_PATH")

View File

@ -55,9 +55,6 @@ class ChatWithDbQA(BaseChat):
if self.db_name: if self.db_name:
client = DBSummaryClient(system_app=CFG.SYSTEM_APP) client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
try: try:
# table_infos = client.get_db_summary(
# dbname=self.db_name, query=self.current_user_input, topk=self.top_k
# )
table_infos = await blocking_func_to_async( table_infos = await blocking_func_to_async(
self._executor, self._executor,
client.get_db_summary, client.get_db_summary,

View File

@ -1,12 +1,15 @@
"""DBSchemaAssembler.""" """DBSchemaAssembler."""
import os
from typing import Any, List, Optional from typing import Any, List, Optional
from dbgpt.core import Chunk from dbgpt.core import Chunk, Embeddings
from dbgpt.datasource.base import BaseConnector from dbgpt.datasource.base import BaseConnector
from ...serve.rag.connector import VectorStoreConnector
from ...storage.vector_store.base import VectorStoreConfig
from ..assembler.base import BaseAssembler from ..assembler.base import BaseAssembler
from ..chunk_manager import ChunkParameters from ..chunk_manager import ChunkParameters
from ..index.base import IndexStoreBase from ..embedding.embedding_factory import DefaultEmbeddingFactory
from ..knowledge.datasource import DatasourceKnowledge from ..knowledge.datasource import DatasourceKnowledge
from ..retriever.db_schema import DBSchemaRetriever from ..retriever.db_schema import DBSchemaRetriever
@ -35,23 +38,64 @@ class DBSchemaAssembler(BaseAssembler):
def __init__( def __init__(
self, self,
connector: BaseConnector, connector: BaseConnector,
index_store: IndexStoreBase, table_vector_store_connector: VectorStoreConnector,
field_vector_store_connector: VectorStoreConnector = None,
chunk_parameters: Optional[ChunkParameters] = None, chunk_parameters: Optional[ChunkParameters] = None,
embedding_model: Optional[str] = None,
embeddings: Optional[Embeddings] = None,
max_seq_length: int = 512,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Initialize with Embedding Assembler arguments. """Initialize with Embedding Assembler arguments.
Args: Args:
connector: (BaseConnector) BaseConnector connection. connector: (BaseConnector) BaseConnector connection.
index_store: (IndexStoreBase) IndexStoreBase to use. table_vector_store_connector: VectorStoreConnector to load
and retrieve table info.
field_vector_store_connector: VectorStoreConnector to load
and retrieve field info.
chunk_manager: (Optional[ChunkManager]) ChunkManager to use for chunking. chunk_manager: (Optional[ChunkManager]) ChunkManager to use for chunking.
embedding_model: (Optional[str]) Embedding model to use. embedding_model: (Optional[str]) Embedding model to use.
embeddings: (Optional[Embeddings]) Embeddings to use. embeddings: (Optional[Embeddings]) Embeddings to use.
""" """
knowledge = DatasourceKnowledge(connector)
self._connector = connector self._connector = connector
self._index_store = index_store self._table_vector_store_connector = table_vector_store_connector
field_vector_store_config = VectorStoreConfig(
name=table_vector_store_connector.vector_store_config.name + "_field"
)
self._field_vector_store_connector = (
field_vector_store_connector
or VectorStoreConnector.from_default(
os.getenv("VECTOR_STORE_TYPE", "Chroma"),
self._table_vector_store_connector.current_embeddings,
vector_store_config=field_vector_store_config,
)
)
self._embedding_model = embedding_model
if self._embedding_model and not embeddings:
embeddings = DefaultEmbeddingFactory(
default_model_name=self._embedding_model
).create(self._embedding_model)
if (
embeddings
and self._table_vector_store_connector.vector_store_config.embedding_fn
is None
):
self._table_vector_store_connector.vector_store_config.embedding_fn = (
embeddings
)
if (
embeddings
and self._field_vector_store_connector.vector_store_config.embedding_fn
is None
):
self._field_vector_store_connector.vector_store_config.embedding_fn = (
embeddings
)
knowledge = DatasourceKnowledge(connector, model_dimension=max_seq_length)
super().__init__( super().__init__(
knowledge=knowledge, knowledge=knowledge,
chunk_parameters=chunk_parameters, chunk_parameters=chunk_parameters,
@ -62,23 +106,36 @@ class DBSchemaAssembler(BaseAssembler):
def load_from_connection( def load_from_connection(
cls, cls,
connector: BaseConnector, connector: BaseConnector,
index_store: IndexStoreBase, table_vector_store_connector: VectorStoreConnector,
field_vector_store_connector: VectorStoreConnector = None,
chunk_parameters: Optional[ChunkParameters] = None, chunk_parameters: Optional[ChunkParameters] = None,
embedding_model: Optional[str] = None,
embeddings: Optional[Embeddings] = None,
max_seq_length: int = 512,
) -> "DBSchemaAssembler": ) -> "DBSchemaAssembler":
"""Load document embedding into vector store from path. """Load document embedding into vector store from path.
Args: Args:
connector: (BaseConnector) BaseConnector connection. connector: (BaseConnector) BaseConnector connection.
index_store: (IndexStoreBase) IndexStoreBase to use. table_vector_store_connector: used to load table chunks.
field_vector_store_connector: used to load field chunks
if field in table is too much.
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for
chunking. chunking.
embedding_model: (Optional[str]) Embedding model to use.
embeddings: (Optional[Embeddings]) Embeddings to use.
max_seq_length: Embedding model max sequence length
Returns: Returns:
DBSchemaAssembler DBSchemaAssembler
""" """
return cls( return cls(
connector=connector, connector=connector,
index_store=index_store, table_vector_store_connector=table_vector_store_connector,
field_vector_store_connector=field_vector_store_connector,
embedding_model=embedding_model,
chunk_parameters=chunk_parameters, chunk_parameters=chunk_parameters,
embeddings=embeddings,
max_seq_length=max_seq_length,
) )
def get_chunks(self) -> List[Chunk]: def get_chunks(self) -> List[Chunk]:
@ -91,7 +148,19 @@ class DBSchemaAssembler(BaseAssembler):
Returns: Returns:
List[str]: List of chunk ids. List[str]: List of chunk ids.
""" """
return self._index_store.load_document(self._chunks) table_chunks, field_chunks = [], []
for chunk in self._chunks:
metadata = chunk.metadata
if metadata.get("separated"):
if metadata.get("part") == "table":
table_chunks.append(chunk)
else:
field_chunks.append(chunk)
else:
table_chunks.append(chunk)
self._field_vector_store_connector.load_document(field_chunks)
return self._table_vector_store_connector.load_document(table_chunks)
def _extract_info(self, chunks) -> List[Chunk]: def _extract_info(self, chunks) -> List[Chunk]:
"""Extract info from chunks.""" """Extract info from chunks."""
@ -110,5 +179,6 @@ class DBSchemaAssembler(BaseAssembler):
top_k=top_k, top_k=top_k,
connector=self._connector, connector=self._connector,
is_embeddings=True, is_embeddings=True,
index_store=self._index_store, table_vector_store_connector=self._table_vector_store_connector,
field_vector_store_connector=self._field_vector_store_connector,
) )

View File

@ -1,76 +1,117 @@
from unittest.mock import MagicMock from typing import List
from unittest.mock import MagicMock, patch
import pytest import pytest
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector import dbgpt
from dbgpt.rag.assembler.embedding import EmbeddingAssembler from dbgpt.core import Chunk
from dbgpt.rag.chunk_manager import ChunkParameters, SplitterType from dbgpt.rag.retriever.db_schema import DBSchemaRetriever
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
from dbgpt.rag.knowledge.base import Knowledge
from dbgpt.rag.text_splitter.text_splitter import CharacterTextSplitter
from dbgpt.storage.vector_store.chroma_store import ChromaStore
@pytest.fixture @pytest.fixture
def mock_db_connection(): def mock_db_connection():
"""Create a temporary database connection for testing.""" return MagicMock()
connect = SQLiteTempConnector.create_temporary_db()
connect.create_temp_tables(
{ @pytest.fixture
"user": { def mock_table_vector_store_connector():
"columns": { mock_connector = MagicMock()
"id": "INTEGER PRIMARY KEY", mock_connector.vector_store_config.name = "table_name"
"name": "TEXT", chunk = Chunk(
"age": "INTEGER", content="table_name: user\ncomment: user about dbgpt",
}, metadata={
"data": [ "field_num": 6,
(1, "Tom", 10), "part": "table",
(2, "Jerry", 16), "separated": 1,
(3, "Jack", 18), "table_name": "user",
(4, "Alice", 20), },
(5, "Bob", 22),
],
}
}
) )
return connect mock_connector.similar_search_with_scores = MagicMock(return_value=[chunk])
return mock_connector
@pytest.fixture @pytest.fixture
def mock_chunk_parameters(): def mock_field_vector_store_connector():
return MagicMock(spec=ChunkParameters) mock_connector = MagicMock()
chunk1 = Chunk(
content="name,age",
metadata={
"field_num": 6,
"part": "field",
"part_index": 0,
"separated": 1,
"table_name": "user",
},
)
chunk2 = Chunk(
content="address,gender",
metadata={
"field_num": 6,
"part": "field",
"part_index": 1,
"separated": 1,
"table_name": "user",
},
)
chunk3 = Chunk(
content="mail,phone",
metadata={
"field_num": 6,
"part": "field",
"part_index": 2,
"separated": 1,
"table_name": "user",
},
)
mock_connector.similar_search_with_scores = MagicMock(
return_value=[chunk1, chunk2, chunk3]
)
return mock_connector
@pytest.fixture @pytest.fixture
def mock_embedding_factory(): def dbstruct_retriever(
return MagicMock(spec=EmbeddingFactory)
@pytest.fixture
def mock_vector_store_connector():
return MagicMock(spec=ChromaStore)
@pytest.fixture
def mock_knowledge():
return MagicMock(spec=Knowledge)
def test_load_knowledge(
mock_db_connection, mock_db_connection,
mock_knowledge, mock_table_vector_store_connector,
mock_chunk_parameters, mock_field_vector_store_connector,
mock_embedding_factory,
mock_vector_store_connector,
): ):
mock_chunk_parameters.chunk_strategy = "CHUNK_BY_SIZE" return DBSchemaRetriever(
mock_chunk_parameters.text_splitter = CharacterTextSplitter() connector=mock_db_connection,
mock_chunk_parameters.splitter_type = SplitterType.USER_DEFINE table_vector_store_connector=mock_table_vector_store_connector,
assembler = EmbeddingAssembler( field_vector_store_connector=mock_field_vector_store_connector,
knowledge=mock_knowledge, separator="--table-field-separator--",
chunk_parameters=mock_chunk_parameters, )
embeddings=mock_embedding_factory.create(),
index_store=mock_vector_store_connector,
def mock_parse_db_summary() -> str:
"""Patch _parse_db_summary method."""
return (
"table_name: user\ncomment: user about dbgpt\n"
"--table-field-separator--\n"
"name,age\naddress,gender\nmail,phone"
)
# Mocking the _parse_db_summary method in your test function
@patch.object(
dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary
)
def test_retrieve_with_mocked_summary(dbstruct_retriever):
query = "Table summary"
chunks: List[Chunk] = dbstruct_retriever._retrieve(query)
assert isinstance(chunks[0], Chunk)
assert chunks[0].content == (
"table_name: user\ncomment: user about dbgpt\n"
"--table-field-separator--\n"
"name,age\naddress,gender\nmail,phone"
)
def async_mock_parse_db_summary() -> str:
"""Asynchronous patch for _parse_db_summary method."""
return (
"table_name: user\ncomment: user about dbgpt\n"
"--table-field-separator--\n"
"name,age\naddress,gender\nmail,phone"
) )
assembler.load_knowledge(knowledge=mock_knowledge)
assert len(assembler._chunks) == 0

View File

@ -5,9 +5,9 @@ import pytest
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
from dbgpt.rag.assembler.db_schema import DBSchemaAssembler from dbgpt.rag.assembler.db_schema import DBSchemaAssembler
from dbgpt.rag.chunk_manager import ChunkParameters, SplitterType from dbgpt.rag.chunk_manager import ChunkParameters, SplitterType
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddings, EmbeddingFactory
from dbgpt.rag.text_splitter.text_splitter import CharacterTextSplitter from dbgpt.rag.text_splitter.text_splitter import RDBTextSplitter
from dbgpt.storage.vector_store.chroma_store import ChromaStore from dbgpt.serve.rag.connector import VectorStoreConnector
@pytest.fixture @pytest.fixture
@ -21,14 +21,22 @@ def mock_db_connection():
"id": "INTEGER PRIMARY KEY", "id": "INTEGER PRIMARY KEY",
"name": "TEXT", "name": "TEXT",
"age": "INTEGER", "age": "INTEGER",
}, "address": "TEXT",
"data": [ "phone": "TEXT",
(1, "Tom", 10), "email": "TEXT",
(2, "Jerry", 16), "gender": "TEXT",
(3, "Jack", 18), "birthdate": "TEXT",
(4, "Alice", 20), "occupation": "TEXT",
(5, "Bob", 22), "education": "TEXT",
], "marital_status": "TEXT",
"nationality": "TEXT",
"height": "REAL",
"weight": "REAL",
"blood_type": "TEXT",
"emergency_contact": "TEXT",
"created_at": "TEXT",
"updated_at": "TEXT",
}
} }
} }
) )
@ -46,23 +54,29 @@ def mock_embedding_factory():
@pytest.fixture @pytest.fixture
def mock_vector_store_connector(): def mock_table_vector_store_connector():
return MagicMock(spec=ChromaStore) mock_connector = MagicMock(spec=VectorStoreConnector)
mock_connector.vector_store_config.name = "table_vector_store_name"
mock_connector.current_embeddings = DefaultEmbeddings()
return mock_connector
def test_load_knowledge( def test_load_knowledge(
mock_db_connection, mock_db_connection,
mock_chunk_parameters, mock_chunk_parameters,
mock_embedding_factory, mock_embedding_factory,
mock_vector_store_connector, mock_table_vector_store_connector,
): ):
mock_chunk_parameters.chunk_strategy = "CHUNK_BY_SIZE" mock_chunk_parameters.chunk_strategy = "CHUNK_BY_SIZE"
mock_chunk_parameters.text_splitter = CharacterTextSplitter() mock_chunk_parameters.text_splitter = RDBTextSplitter(
separator="--table-field-separator--"
)
mock_chunk_parameters.splitter_type = SplitterType.USER_DEFINE mock_chunk_parameters.splitter_type = SplitterType.USER_DEFINE
assembler = DBSchemaAssembler( assembler = DBSchemaAssembler(
connector=mock_db_connection, connector=mock_db_connection,
chunk_parameters=mock_chunk_parameters, chunk_parameters=mock_chunk_parameters,
embeddings=mock_embedding_factory.create(), embeddings=mock_embedding_factory.create(),
index_store=mock_vector_store_connector, table_vector_store_connector=mock_table_vector_store_connector,
max_seq_length=10,
) )
assert len(assembler._chunks) == 1 assert len(assembler._chunks) > 1

View File

@ -5,7 +5,7 @@ from dbgpt.core import Document
from dbgpt.datasource import BaseConnector from dbgpt.datasource import BaseConnector
from ..summary.gdbms_db_summary import _parse_db_summary as _parse_gdb_summary from ..summary.gdbms_db_summary import _parse_db_summary as _parse_gdb_summary
from ..summary.rdbms_db_summary import _parse_db_summary from ..summary.rdbms_db_summary import _parse_db_summary_with_metadata
from .base import ChunkStrategy, DocumentType, Knowledge, KnowledgeType from .base import ChunkStrategy, DocumentType, Knowledge, KnowledgeType
@ -15,9 +15,11 @@ class DatasourceKnowledge(Knowledge):
def __init__( def __init__(
self, self,
connector: BaseConnector, connector: BaseConnector,
summary_template: str = "{table_name}({columns})", summary_template: str = "table_name: {table_name}",
separator: str = "--table-field-separator--",
knowledge_type: Optional[KnowledgeType] = KnowledgeType.DOCUMENT, knowledge_type: Optional[KnowledgeType] = KnowledgeType.DOCUMENT,
metadata: Optional[Dict[str, Union[str, List[str]]]] = None, metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
model_dimension: int = 512,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Create Datasource Knowledge with Knowledge arguments. """Create Datasource Knowledge with Knowledge arguments.
@ -25,11 +27,17 @@ class DatasourceKnowledge(Knowledge):
Args: Args:
connector(BaseConnector): connector connector(BaseConnector): connector
summary_template(str, optional): summary template summary_template(str, optional): summary template
separator(str, optional): separator used to separate
table's basic info and fields.
defaults `-- table-field-separator--`
knowledge_type(KnowledgeType, optional): knowledge type knowledge_type(KnowledgeType, optional): knowledge type
metadata(Dict[str, Union[str, List[str]], optional): metadata metadata(Dict[str, Union[str, List[str]], optional): metadata
model_dimension(int, optional): The threshold for splitting field string
""" """
self._separator = separator
self._connector = connector self._connector = connector
self._summary_template = summary_template self._summary_template = summary_template
self._model_dimension = model_dimension
super().__init__(knowledge_type=knowledge_type, metadata=metadata, **kwargs) super().__init__(knowledge_type=knowledge_type, metadata=metadata, **kwargs)
def _load(self) -> List[Document]: def _load(self) -> List[Document]:
@ -37,13 +45,23 @@ class DatasourceKnowledge(Knowledge):
docs = [] docs = []
if self._connector.is_graph_type(): if self._connector.is_graph_type():
db_summary = _parse_gdb_summary(self._connector, self._summary_template) db_summary = _parse_gdb_summary(self._connector, self._summary_template)
for table_summary in db_summary:
metadata = {"source": "database"}
docs.append(Document(content=table_summary, metadata=metadata))
else: else:
db_summary = _parse_db_summary(self._connector, self._summary_template) db_summary_with_metadata = _parse_db_summary_with_metadata(
for table_summary in db_summary: self._connector,
metadata = {"source": "database"} self._summary_template,
if self._metadata: self._separator,
metadata.update(self._metadata) # type: ignore self._model_dimension,
docs.append(Document(content=table_summary, metadata=metadata)) )
for summary, table_metadata in db_summary_with_metadata:
metadata = {"source": "database"}
if self._metadata:
metadata.update(self._metadata) # type: ignore
table_metadata.update(metadata)
docs.append(Document(content=summary, metadata=table_metadata))
return docs return docs
@classmethod @classmethod

View File

@ -1,14 +1,15 @@
"""The DBSchema Retriever Operator.""" """The DBSchema Retriever Operator."""
import os
from typing import List, Optional from typing import List, Optional
from dbgpt.core import Chunk from dbgpt.core import Chunk
from dbgpt.core.interface.operators.retriever import RetrieverOperator from dbgpt.core.interface.operators.retriever import RetrieverOperator
from dbgpt.datasource.base import BaseConnector from dbgpt.datasource.base import BaseConnector
from dbgpt.serve.rag.connector import VectorStoreConnector
from ...storage.vector_store.base import VectorStoreConfig
from ..assembler.db_schema import DBSchemaAssembler from ..assembler.db_schema import DBSchemaAssembler
from ..chunk_manager import ChunkParameters from ..chunk_manager import ChunkParameters
from ..index.base import IndexStoreBase
from ..retriever.db_schema import DBSchemaRetriever from ..retriever.db_schema import DBSchemaRetriever
from .assembler import AssemblerOperator from .assembler import AssemblerOperator
@ -19,13 +20,14 @@ class DBSchemaRetrieverOperator(RetrieverOperator[str, List[Chunk]]):
Args: Args:
connector (BaseConnector): The connection. connector (BaseConnector): The connection.
top_k (int, optional): The top k. Defaults to 4. top_k (int, optional): The top k. Defaults to 4.
index_store (IndexStoreBase, optional): The vector store vector_store_connector (VectorStoreConnector, optional): The vector store
connector. Defaults to None. connector. Defaults to None.
""" """
def __init__( def __init__(
self, self,
index_store: IndexStoreBase, table_vector_store_connector: VectorStoreConnector,
field_vector_store_connector: VectorStoreConnector,
top_k: int = 4, top_k: int = 4,
connector: Optional[BaseConnector] = None, connector: Optional[BaseConnector] = None,
**kwargs **kwargs
@ -35,7 +37,8 @@ class DBSchemaRetrieverOperator(RetrieverOperator[str, List[Chunk]]):
self._retriever = DBSchemaRetriever( self._retriever = DBSchemaRetriever(
top_k=top_k, top_k=top_k,
connector=connector, connector=connector,
index_store=index_store, table_vector_store_connector=table_vector_store_connector,
field_vector_store_connector=field_vector_store_connector,
) )
def retrieve(self, query: str) -> List[Chunk]: def retrieve(self, query: str) -> List[Chunk]:
@ -53,7 +56,8 @@ class DBSchemaAssemblerOperator(AssemblerOperator[BaseConnector, List[Chunk]]):
def __init__( def __init__(
self, self,
connector: BaseConnector, connector: BaseConnector,
index_store: IndexStoreBase, table_vector_store_connector: VectorStoreConnector,
field_vector_store_connector: VectorStoreConnector = None,
chunk_parameters: Optional[ChunkParameters] = None, chunk_parameters: Optional[ChunkParameters] = None,
**kwargs **kwargs
): ):
@ -61,14 +65,26 @@ class DBSchemaAssemblerOperator(AssemblerOperator[BaseConnector, List[Chunk]]):
Args: Args:
connector (BaseConnector): The connection. connector (BaseConnector): The connection.
index_store (IndexStoreBase): The Storage IndexStoreBase. vector_store_connector (VectorStoreConnector): The vector store connector.
chunk_parameters (Optional[ChunkParameters], optional): The chunk chunk_parameters (Optional[ChunkParameters], optional): The chunk
parameters. parameters.
""" """
if not chunk_parameters: if not chunk_parameters:
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE") chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")
self._chunk_parameters = chunk_parameters self._chunk_parameters = chunk_parameters
self._index_store = index_store self._table_vector_store_connector = table_vector_store_connector
field_vector_store_config = VectorStoreConfig(
name=table_vector_store_connector.vector_store_config.name + "_field"
)
self._field_vector_store_connector = (
field_vector_store_connector
or VectorStoreConnector.from_default(
os.getenv("VECTOR_STORE_TYPE", "Chroma"),
self._table_vector_store_connector.current_embeddings,
vector_store_config=field_vector_store_config,
)
)
self._connector = connector self._connector = connector
super().__init__(**kwargs) super().__init__(**kwargs)
@ -84,7 +100,8 @@ class DBSchemaAssemblerOperator(AssemblerOperator[BaseConnector, List[Chunk]]):
assembler = DBSchemaAssembler.load_from_connection( assembler = DBSchemaAssembler.load_from_connection(
connector=self._connector, connector=self._connector,
chunk_parameters=self._chunk_parameters, chunk_parameters=self._chunk_parameters,
index_store=self._index_store, table_vector_store_connector=self._table_vector_store_connector,
field_vector_store_connector=self._field_vector_store_connector,
) )
assembler.persist() assembler.persist()
return assembler.get_chunks() return assembler.get_chunks()

View File

@ -1,18 +1,23 @@
"""DBSchema retriever.""" """DBSchema retriever."""
import logging
import os
from typing import List, Optional
from functools import reduce from dbgpt._private.config import Config
from typing import List, Optional, cast
from dbgpt.core import Chunk from dbgpt.core import Chunk
from dbgpt.datasource.base import BaseConnector from dbgpt.datasource.base import BaseConnector
from dbgpt.rag.index.base import IndexStoreBase
from dbgpt.rag.retriever.base import BaseRetriever from dbgpt.rag.retriever.base import BaseRetriever
from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary from dbgpt.rag.summary.gdbms_db_summary import _parse_db_summary
from dbgpt.storage.vector_store.filters import MetadataFilters from dbgpt.serve.rag.connector import VectorStoreConnector
from dbgpt.util.chat_util import run_async_tasks from dbgpt.storage.vector_store.base import VectorStoreConfig
from dbgpt.storage.vector_store.filters import MetadataFilter, MetadataFilters
from dbgpt.util.chat_util import run_tasks
from dbgpt.util.executor_utils import blocking_func_to_async_no_executor from dbgpt.util.executor_utils import blocking_func_to_async_no_executor
from dbgpt.util.tracer import root_tracer
logger = logging.getLogger(__name__)
CFG = Config()
class DBSchemaRetriever(BaseRetriever): class DBSchemaRetriever(BaseRetriever):
@ -20,7 +25,9 @@ class DBSchemaRetriever(BaseRetriever):
def __init__( def __init__(
self, self,
index_store: IndexStoreBase, table_vector_store_connector: VectorStoreConnector,
field_vector_store_connector: VectorStoreConnector = None,
separator: str = "--table-field-separator--",
top_k: int = 4, top_k: int = 4,
connector: Optional[BaseConnector] = None, connector: Optional[BaseConnector] = None,
query_rewrite: bool = False, query_rewrite: bool = False,
@ -30,7 +37,11 @@ class DBSchemaRetriever(BaseRetriever):
"""Create DBSchemaRetriever. """Create DBSchemaRetriever.
Args: Args:
index_store(IndexStore): index connector table_vector_store_connector: VectorStoreConnector
to load and retrieve table info.
field_vector_store_connector: VectorStoreConnector
to load and retrieve field info.
separator: field/table separator
top_k (int): top k top_k (int): top k
connector (Optional[BaseConnector]): RDBMSConnector. connector (Optional[BaseConnector]): RDBMSConnector.
query_rewrite (bool): query rewrite query_rewrite (bool): query rewrite
@ -70,34 +81,42 @@ class DBSchemaRetriever(BaseRetriever):
connector = _create_temporary_connection() connector = _create_temporary_connection()
vector_store_config = ChromaVectorConfig(name="vector_store_name")
embedding_model_path = "{your_embedding_model_path}"
embedding_fn = embedding_factory.create(model_name=embedding_model_path) embedding_fn = embedding_factory.create(model_name=embedding_model_path)
config = ChromaVectorConfig( vector_connector = VectorStoreConnector.from_default(
persist_path=PILOT_PATH, "Chroma",
name="dbschema_rag_test", vector_store_config=vector_store_config,
embedding_fn=DefaultEmbeddingFactory( embedding_fn=embedding_fn,
default_model_name=os.path.join(
MODEL_PATH, "text2vec-large-chinese"
),
).create(),
) )
vector_store = ChromaStore(config)
# get db struct retriever # get db struct retriever
retriever = DBSchemaRetriever( retriever = DBSchemaRetriever(
top_k=3, top_k=3,
index_store=vector_store, vector_store_connector=vector_connector,
connector=connector, connector=connector,
) )
chunks = retriever.retrieve("show columns from table") chunks = retriever.retrieve("show columns from table")
result = [chunk.content for chunk in chunks] result = [chunk.content for chunk in chunks]
print(f"db struct rag example results:{result}") print(f"db struct rag example results:{result}")
""" """
self._separator = separator
self._top_k = top_k self._top_k = top_k
self._connector = connector self._connector = connector
self._query_rewrite = query_rewrite self._query_rewrite = query_rewrite
self._index_store = index_store self._table_vector_store_connector = table_vector_store_connector
field_vector_store_config = VectorStoreConfig(
name=table_vector_store_connector.vector_store_config.name + "_field"
)
self._field_vector_store_connector = (
field_vector_store_connector
or VectorStoreConnector.from_default(
os.getenv("VECTOR_STORE_TYPE", "Chroma"),
self._table_vector_store_connector.current_embeddings,
vector_store_config=field_vector_store_config,
)
)
self._need_embeddings = False self._need_embeddings = False
if self._index_store: if self._table_vector_store_connector:
self._need_embeddings = True self._need_embeddings = True
self._rerank = rerank or DefaultRanker(self._top_k) self._rerank = rerank or DefaultRanker(self._top_k)
@ -114,15 +133,8 @@ class DBSchemaRetriever(BaseRetriever):
List[Chunk]: list of chunks List[Chunk]: list of chunks
""" """
if self._need_embeddings: if self._need_embeddings:
queries = [query] return self._similarity_search(query, filters)
candidates = [
self._index_store.similar_search(query, self._top_k, filters)
for query in queries
]
return cast(List[Chunk], reduce(lambda x, y: x + y, candidates))
else: else:
if not self._connector:
raise RuntimeError("RDBMSConnector connection is required.")
table_summaries = _parse_db_summary(self._connector) table_summaries = _parse_db_summary(self._connector)
return [Chunk(content=table_summary) for table_summary in table_summaries] return [Chunk(content=table_summary) for table_summary in table_summaries]
@ -156,30 +168,11 @@ class DBSchemaRetriever(BaseRetriever):
Returns: Returns:
List[Chunk]: list of chunks List[Chunk]: list of chunks
""" """
if self._need_embeddings: return await blocking_func_to_async_no_executor(
queries = [query] func=self._retrieve,
candidates = [ query=query,
self._similarity_search( filters=filters,
query, filters, root_tracer.get_current_span_id() )
)
for query in queries
]
result_candidates = await run_async_tasks(
tasks=candidates, concurrency_limit=1
)
return cast(List[Chunk], reduce(lambda x, y: x + y, result_candidates))
else:
from dbgpt.rag.summary.rdbms_db_summary import ( # noqa: F401
_parse_db_summary,
)
table_summaries = await run_async_tasks(
tasks=[self._aparse_db_summary(root_tracer.get_current_span_id())],
concurrency_limit=1,
)
return [
Chunk(content=table_summary) for table_summary in table_summaries[0]
]
async def _aretrieve_with_score( async def _aretrieve_with_score(
self, self,
@ -196,34 +189,40 @@ class DBSchemaRetriever(BaseRetriever):
""" """
return await self._aretrieve(query, filters) return await self._aretrieve(query, filters)
async def _similarity_search( def _retrieve_field(self, table_chunk: Chunk, query) -> Chunk:
self, metadata = table_chunk.metadata
query, metadata["part"] = "field"
filters: Optional[MetadataFilters] = None, filters = [MetadataFilter(key=k, value=v) for k, v in metadata.items()]
parent_span_id: Optional[str] = None, field_chunks = self._field_vector_store_connector.similar_search_with_scores(
query, self._top_k, 0, MetadataFilters(filters=filters)
)
field_contents = [chunk.content for chunk in field_chunks]
table_chunk.content += "\n" + self._separator + "\n" + "\n".join(field_contents)
return table_chunk
def _similarity_search(
self, query, filters: Optional[MetadataFilters] = None
) -> List[Chunk]: ) -> List[Chunk]:
"""Similar search.""" """Similar search."""
with root_tracer.start_span( table_chunks = self._table_vector_store_connector.similar_search_with_scores(
"dbgpt.rag.retriever.db_schema._similarity_search", query, self._top_k, 0, filters
parent_span_id, )
metadata={"query": query},
):
return await blocking_func_to_async_no_executor(
self._index_store.similar_search, query, self._top_k, filters
)
async def _aparse_db_summary( not_sep_chunks = [
self, parent_span_id: Optional[str] = None chunk for chunk in table_chunks if not chunk.metadata.get("separated")
) -> List[str]: ]
"""Similar search.""" separated_chunks = [
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary chunk for chunk in table_chunks if chunk.metadata.get("separated")
]
if not separated_chunks:
return not_sep_chunks
if not self._connector: # Create tasks list
raise RuntimeError("RDBMSConnector connection is required.") tasks = [
with root_tracer.start_span( lambda c=chunk: self._retrieve_field(c, query) for chunk in separated_chunks
"dbgpt.rag.retriever.db_schema._aparse_db_summary", ]
parent_span_id, # Run tasks concurrently
): separated_result = run_tasks(tasks, concurrency_limit=3)
return await blocking_func_to_async_no_executor(
_parse_db_summary, self._connector # Combine and return results
) return not_sep_chunks + separated_result

View File

@ -15,42 +15,53 @@ def mock_db_connection():
@pytest.fixture @pytest.fixture
def mock_vector_store_connector(): def mock_table_vector_store_connector():
mock_connector = MagicMock() mock_connector = MagicMock()
mock_connector.similar_search.return_value = [Chunk(content="Table summary")] * 4 mock_connector.vector_store_config.name = "table_name"
mock_connector.similar_search_with_scores.return_value = [
Chunk(content="Table summary")
] * 4
return mock_connector return mock_connector
@pytest.fixture @pytest.fixture
def db_struct_retriever(mock_db_connection, mock_vector_store_connector): def mock_field_vector_store_connector():
mock_connector = MagicMock()
mock_connector.similar_search_with_scores.return_value = [
Chunk(content="Field summary")
] * 4
return mock_connector
@pytest.fixture
def dbstruct_retriever(
mock_db_connection,
mock_table_vector_store_connector,
mock_field_vector_store_connector,
):
return DBSchemaRetriever( return DBSchemaRetriever(
connector=mock_db_connection, connector=mock_db_connection,
index_store=mock_vector_store_connector, table_vector_store_connector=mock_table_vector_store_connector,
field_vector_store_connector=mock_field_vector_store_connector,
) )
def mock_parse_db_summary(conn) -> List[str]: def mock_parse_db_summary() -> str:
"""Patch _parse_db_summary method.""" """Patch _parse_db_summary method."""
return ["Table summary"] return "Table summary"
# Mocking the _parse_db_summary method in your test function # Mocking the _parse_db_summary method in your test function
@patch.object( @patch.object(
dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary
) )
def test_retrieve_with_mocked_summary(db_struct_retriever): def test_retrieve_with_mocked_summary(dbstruct_retriever):
query = "Table summary" query = "Table summary"
chunks: List[Chunk] = db_struct_retriever._retrieve(query) chunks: List[Chunk] = dbstruct_retriever._retrieve(query)
assert isinstance(chunks[0], Chunk) assert isinstance(chunks[0], Chunk)
assert chunks[0].content == "Table summary" assert chunks[0].content == "Table summary"
@pytest.mark.asyncio async def async_mock_parse_db_summary() -> str:
@patch.object( """Asynchronous patch for _parse_db_summary method."""
dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary return "Table summary"
)
async def test_aretrieve_with_mocked_summary(db_struct_retriever):
query = "Table summary"
chunks: List[Chunk] = await db_struct_retriever._aretrieve(query)
assert isinstance(chunks[0], Chunk)
assert chunks[0].content == "Table summary"

View File

@ -2,13 +2,15 @@
import logging import logging
import traceback import traceback
from typing import List
from dbgpt._private.config import Config from dbgpt._private.config import Config
from dbgpt.component import SystemApp from dbgpt.component import SystemApp
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
from dbgpt.rag import ChunkParameters
from dbgpt.rag.summary.gdbms_db_summary import GdbmsSummary from dbgpt.rag.summary.gdbms_db_summary import GdbmsSummary
from dbgpt.rag.summary.rdbms_db_summary import RdbmsSummary from dbgpt.rag.summary.rdbms_db_summary import RdbmsSummary
from dbgpt.rag.text_splitter.text_splitter import RDBTextSplitter
from dbgpt.serve.rag.connector import VectorStoreConnector
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -47,22 +49,26 @@ class DBSummaryClient:
logger.info("db summary embedding success") logger.info("db summary embedding success")
def get_db_summary(self, dbname, query, topk) -> List[str]: def get_db_summary(self, dbname, query, topk):
"""Get user query related tables info.""" """Get user query related tables info."""
from dbgpt.serve.rag.connector import VectorStoreConnector
from dbgpt.storage.vector_store.base import VectorStoreConfig from dbgpt.storage.vector_store.base import VectorStoreConfig
vector_store_config = VectorStoreConfig(name=dbname + "_profile") vector_store_name = dbname + "_profile"
vector_connector = VectorStoreConnector.from_default( table_vector_store_config = VectorStoreConfig(name=vector_store_name)
table_vector_connector = VectorStoreConnector.from_default(
CFG.VECTOR_STORE_TYPE, CFG.VECTOR_STORE_TYPE,
embedding_fn=self.embeddings, self.embeddings,
vector_store_config=vector_store_config, vector_store_config=table_vector_store_config,
) )
from dbgpt.rag.retriever.db_schema import DBSchemaRetriever from dbgpt.rag.retriever.db_schema import DBSchemaRetriever
retriever = DBSchemaRetriever( retriever = DBSchemaRetriever(
top_k=topk, index_store=vector_connector.index_client top_k=topk,
table_vector_store_connector=table_vector_connector,
separator="--table-field-separator--",
) )
table_docs = retriever.retrieve(query) table_docs = retriever.retrieve(query)
ans = [d.content for d in table_docs] ans = [d.content for d in table_docs]
return ans return ans
@ -92,18 +98,23 @@ class DBSummaryClient:
from dbgpt.serve.rag.connector import VectorStoreConnector from dbgpt.serve.rag.connector import VectorStoreConnector
from dbgpt.storage.vector_store.base import VectorStoreConfig from dbgpt.storage.vector_store.base import VectorStoreConfig
vector_store_config = VectorStoreConfig(name=vector_store_name) table_vector_store_config = VectorStoreConfig(name=vector_store_name)
vector_connector = VectorStoreConnector.from_default( table_vector_connector = VectorStoreConnector.from_default(
CFG.VECTOR_STORE_TYPE, CFG.VECTOR_STORE_TYPE,
self.embeddings, self.embeddings,
vector_store_config=vector_store_config, vector_store_config=table_vector_store_config,
) )
if not vector_connector.vector_name_exists(): if not table_vector_connector.vector_name_exists():
from dbgpt.rag.assembler.db_schema import DBSchemaAssembler from dbgpt.rag.assembler.db_schema import DBSchemaAssembler
chunk_parameters = ChunkParameters(
text_splitter=RDBTextSplitter(separator="--table-field-separator--")
)
db_assembler = DBSchemaAssembler.load_from_connection( db_assembler = DBSchemaAssembler.load_from_connection(
connector=db_summary_client.db, connector=db_summary_client.db,
index_store=vector_connector.index_client, table_vector_store_connector=table_vector_connector,
chunk_parameters=chunk_parameters,
max_seq_length=CFG.EMBEDDING_MODEL_MAX_SEQ_LEN,
) )
if len(db_assembler.get_chunks()) > 0: if len(db_assembler.get_chunks()) > 0:
@ -115,16 +126,26 @@ class DBSummaryClient:
def delete_db_profile(self, dbname): def delete_db_profile(self, dbname):
"""Delete db profile.""" """Delete db profile."""
vector_store_name = dbname + "_profile" vector_store_name = dbname + "_profile"
table_vector_store_name = dbname + "_profile"
field_vector_store_name = dbname + "_profile_field"
from dbgpt.serve.rag.connector import VectorStoreConnector from dbgpt.serve.rag.connector import VectorStoreConnector
from dbgpt.storage.vector_store.base import VectorStoreConfig from dbgpt.storage.vector_store.base import VectorStoreConfig
vector_store_config = VectorStoreConfig(name=vector_store_name) table_vector_store_config = VectorStoreConfig(name=vector_store_name)
vector_connector = VectorStoreConnector.from_default( field_vector_store_config = VectorStoreConfig(name=field_vector_store_name)
table_vector_connector = VectorStoreConnector.from_default(
CFG.VECTOR_STORE_TYPE, CFG.VECTOR_STORE_TYPE,
self.embeddings, self.embeddings,
vector_store_config=vector_store_config, vector_store_config=table_vector_store_config,
) )
vector_connector.delete_vector_name(vector_store_name) field_vector_connector = VectorStoreConnector.from_default(
CFG.VECTOR_STORE_TYPE,
self.embeddings,
vector_store_config=field_vector_store_config,
)
table_vector_connector.delete_vector_name(table_vector_store_name)
field_vector_connector.delete_vector_name(field_vector_store_name)
logger.info(f"delete db profile {dbname} success") logger.info(f"delete db profile {dbname} success")
@staticmethod @staticmethod

View File

@ -1,6 +1,6 @@
"""Summary for rdbms database.""" """Summary for rdbms database."""
import re import re
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from dbgpt._private.config import Config from dbgpt._private.config import Config
from dbgpt.datasource import BaseConnector from dbgpt.datasource import BaseConnector
@ -80,6 +80,134 @@ def _parse_db_summary(
return table_info_summaries return table_info_summaries
def _parse_db_summary_with_metadata(
conn: BaseConnector,
summary_template: str = "table_name: {table_name}",
separator: str = "--table-field-separator--",
model_dimension: int = 512,
) -> List[Tuple[str, Dict[str, Any]]]:
"""Get db summary for database.
Args:
conn (BaseConnector): database connection
summary_template (str): summary template
separator(str, optional): separator used to separate table's
basic info and fields. defaults to `-- table-field-separator--`
model_dimension(int, optional): The threshold for splitting field string
"""
tables = conn.get_table_names()
table_info_summaries = [
_parse_table_summary_with_metadata(
conn, summary_template, separator, table_name, model_dimension
)
for table_name in tables
]
return table_info_summaries
def _split_columns_str(columns: List[str], model_dimension: int):
"""Split columns str.
Args:
columns (List[str]): fields string
model_dimension (int, optional): The threshold for splitting field string.
"""
result = []
current_string = ""
current_length = 0
for element_str in columns:
element_length = len(element_str)
# If adding the current element's length would exceed the threshold,
# add the current string to results and reset
if current_length + element_length > model_dimension:
result.append(current_string.strip()) # Remove trailing spaces
current_string = element_str
current_length = element_length
else:
# If current string is empty, add element directly
if current_string:
current_string += "," + element_str
else:
current_string = element_str
current_length += element_length + 1 # Add length of space
# Handle the last string segment
if current_string:
result.append(current_string.strip())
return result
def _parse_table_summary_with_metadata(
conn: BaseConnector,
summary_template: str,
separator,
table_name: str,
model_dimension=512,
) -> Tuple[str, Dict[str, Any]]:
"""Get table summary for table.
Args:
conn (BaseConnector): database connection
summary_template (str): summary template
separator(str, optional): separator used to separate table's
basic info and fields. defaults to `-- table-field-separator--`
model_dimension(int, optional): The threshold for splitting field string
Examples:
metadata: {'table_name': 'asd', 'separated': 0/1}
table_name: table1
table_comment: comment
index_keys: keys
--table-field-separator--
(column1,comment), (column2, comment), (column3, comment)
(column4,comment), (column5, comment), (column6, comment)
"""
columns = []
metadata = {"table_name": table_name, "separated": 0}
for column in conn.get_columns(table_name):
if column.get("comment"):
columns.append(f"{column['name']} ({column.get('comment')})")
else:
columns.append(f"{column['name']}")
metadata.update({"field_num": len(columns)})
separated_columns = _split_columns_str(columns, model_dimension=model_dimension)
if len(separated_columns) > 1:
metadata["separated"] = 1
column_str = "\n".join(separated_columns)
# Obtain index information
index_keys = []
raw_indexes = conn.get_indexes(table_name)
for index in raw_indexes:
if isinstance(index, tuple): # Process tuple type index information
index_name, index_creation_command = index
# Extract column names using re
matched_columns = re.findall(r"\(([^)]+)\)", index_creation_command)
if matched_columns:
key_str = ", ".join(matched_columns)
index_keys.append(f"{index_name}(`{key_str}`) ")
else:
key_str = ", ".join(index["column_names"])
index_keys.append(f"{index['name']}(`{key_str}`) ")
table_str = summary_template.format(table_name=table_name)
try:
comment = conn.get_table_comment(table_name)
except Exception:
comment = dict(text=None)
if comment.get("text"):
table_str += f"\ntable_comment: {comment.get('text')}"
if len(index_keys) > 0:
index_key_str = ", ".join(index_keys)
table_str += f"\nindex_keys: {index_key_str}"
table_str += f"\n{separator}\n{column_str}"
return table_str, metadata
def _parse_table_summary( def _parse_table_summary(
conn: BaseConnector, summary_template: str, table_name: str conn: BaseConnector, summary_template: str, table_name: str
) -> str: ) -> str:

View File

@ -912,3 +912,42 @@ class PageTextSplitter(TextSplitter):
new_doc = Chunk(content=text, metadata=copy.deepcopy(_metadatas[i])) new_doc = Chunk(content=text, metadata=copy.deepcopy(_metadatas[i]))
chunks.append(new_doc) chunks.append(new_doc)
return chunks return chunks
class RDBTextSplitter(TextSplitter):
"""Split relational database tables and fields."""
def __init__(self, **kwargs):
"""Create a new TextSplitter."""
super().__init__(**kwargs)
def split_text(self, text: str, **kwargs):
"""Split text into a couple of parts."""
pass
def split_documents(self, documents: Iterable[Document], **kwargs) -> List[Chunk]:
"""Split document into chunks."""
chunks = []
for doc in documents:
metadata = doc.metadata
content = doc.content
if metadata.get("separated"):
# separate table and field
parts = content.split(self._separator)
table_part, field_part = parts[0], parts[1]
table_metadata, field_metadata = copy.deepcopy(metadata), copy.deepcopy(
metadata
)
table_metadata["part"] = "table" # identify of table_chunk
field_metadata["part"] = "field" # identify of field_chunk
table_chunk = Chunk(content=table_part, metadata=table_metadata)
chunks.append(table_chunk)
field_parts = field_part.split("\n")
for i, sub_part in enumerate(field_parts):
sub_metadata = copy.deepcopy(field_metadata)
sub_metadata["part_index"] = i
field_chunk = Chunk(content=sub_part, metadata=sub_metadata)
chunks.append(field_chunk)
else:
chunks.append(Chunk(content=content, metadata=metadata))
return chunks

View File

@ -1,5 +1,6 @@
import asyncio import asyncio
from typing import Any, Coroutine, List from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Callable, Coroutine, List
async def llm_chat_response_nostream(chat_scene: str, **chat_param): async def llm_chat_response_nostream(chat_scene: str, **chat_param):
@ -47,13 +48,34 @@ async def run_async_tasks(
def run_tasks( def run_tasks(
tasks: List[Coroutine], tasks: List[Callable],
concurrency_limit: int = None,
) -> List[Any]: ) -> List[Any]:
"""Run a list of async tasks.""" """
tasks_to_execute: List[Any] = tasks Run a list of tasks concurrently using a thread pool.
async def _gather() -> List[Any]: Args:
return await asyncio.gather(*tasks_to_execute) tasks: List of callable functions to execute
concurrency_limit: Maximum number of concurrent threads (optional)
outputs: List[Any] = asyncio.run(_gather()) Returns:
return outputs List of results from all tasks in the order they were submitted
"""
max_workers = concurrency_limit if concurrency_limit else None
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all tasks and get futures
futures = [executor.submit(task) for task in tasks]
# Collect results in order, raising any exceptions
results = []
for future in futures:
try:
results.append(future.result())
except Exception as e:
# Cancel any pending futures
for f in futures:
f.cancel()
raise e
return results

View File

@ -0,0 +1,317 @@
CREATE TABLE order_wide_table (
-- order_base
order_id TEXT, -- 订单ID
order_no TEXT, -- 订单编号
parent_order_no TEXT, -- 父订单编号
order_type INTEGER, -- 订单类型1实物2虚拟3混合
order_status INTEGER, -- 订单状态
order_source TEXT, -- 订单来源
order_source_detail TEXT, -- 订单来源详情
create_time DATETIME, -- 创建时间
pay_time DATETIME, -- 支付时间
finish_time DATETIME, -- 完成时间
close_time DATETIME, -- 关闭时间
cancel_time DATETIME, -- 取消时间
cancel_reason TEXT, -- 取消原因
order_remark TEXT, -- 订单备注
seller_remark TEXT, -- 卖家备注
buyer_remark TEXT, -- 买家备注
is_deleted INTEGER, -- 是否删除
delete_time DATETIME, -- 删除时间
order_ip TEXT, -- 下单IP
order_platform TEXT, -- 下单平台
order_device TEXT, -- 下单设备
order_app_version TEXT, -- APP版本号
-- order_amount
currency TEXT, -- 货币类型
exchange_rate REAL, -- 汇率
original_amount REAL, -- 原始金额
discount_amount REAL, -- 优惠金额
coupon_amount REAL, -- 优惠券金额
points_amount REAL, -- 积分抵扣金额
shipping_amount REAL, -- 运费
insurance_amount REAL, -- 保价费
tax_amount REAL, -- 税费
tariff_amount REAL, -- 关税
payment_amount REAL, -- 实付金额
commission_amount REAL, -- 佣金金额
platform_fee REAL, -- 平台费用
seller_income REAL, -- 卖家实收
payment_currency TEXT, -- 支付货币
payment_exchange_rate REAL, -- 支付汇率
-- user_info
user_id TEXT, -- 用户ID
user_name TEXT, -- 用户名
user_nickname TEXT, -- 用户昵称
user_level INTEGER, -- 用户等级
user_type INTEGER, -- 用户类型
register_time DATETIME, -- 注册时间
register_source TEXT, -- 注册来源
mobile TEXT, -- 手机号
mobile_area TEXT, -- 手机号区号
email TEXT, -- 邮箱
is_vip INTEGER, -- 是否VIP
vip_level INTEGER, -- VIP等级
vip_expire_time DATETIME, -- VIP过期时间
user_age INTEGER, -- 用户年龄
user_gender INTEGER, -- 用户性别
user_birthday DATE, -- 用户生日
user_avatar TEXT, -- 用户头像
user_province TEXT, -- 用户所在省
user_city TEXT, -- 用户所在市
user_district TEXT, -- 用户所在区
last_login_time DATETIME, -- 最后登录时间
last_login_ip TEXT, -- 最后登录IP
user_credit_score INTEGER, -- 用户信用分
total_order_count INTEGER, -- 历史订单数
total_order_amount REAL, -- 历史订单金额
-- product_info
product_id TEXT, -- 商品ID
product_code TEXT, -- 商品编码
product_name TEXT, -- 商品名称
product_short_name TEXT, -- 商品短名称
product_type INTEGER, -- 商品类型
product_status INTEGER, -- 商品状态
category_id TEXT, -- 类目ID
category_name TEXT, -- 类目名称
category_path TEXT, -- 类目路径
brand_id TEXT, -- 品牌ID
brand_name TEXT, -- 品牌名称
brand_english_name TEXT, -- 品牌英文名
seller_id TEXT, -- 卖家ID
seller_name TEXT, -- 卖家名称
seller_type INTEGER, -- 卖家类型
shop_id TEXT, -- 店铺ID
shop_name TEXT, -- 店铺名称
product_price REAL, -- 商品价格
market_price REAL, -- 市场价
cost_price REAL, -- 成本价
wholesale_price REAL, -- 批发价
product_quantity INTEGER, -- 商品数量
product_unit TEXT, -- 商品单位
product_weight REAL, -- 商品重量(克)
product_volume REAL, -- 商品体积(cm³)
product_spec TEXT, -- 商品规格
product_color TEXT, -- 商品颜色
product_size TEXT, -- 商品尺寸
product_material TEXT, -- 商品材质
product_origin TEXT, -- 商品产地
product_shelf_life INTEGER, -- 保质期(天)
manufacture_date DATE, -- 生产日期
expiry_date DATE, -- 过期日期
batch_number TEXT, -- 批次号
product_barcode TEXT, -- 商品条码
warehouse_id TEXT, -- 发货仓库ID
warehouse_name TEXT, -- 发货仓库名称
-- address_info
receiver_name TEXT, -- 收货人姓名
receiver_mobile TEXT, -- 收货人手机
receiver_tel TEXT, -- 收货人电话
receiver_email TEXT, -- 收货人邮箱
receiver_country TEXT, -- 国家
receiver_province TEXT, -- 省份
receiver_city TEXT, -- 城市
receiver_district TEXT, -- 区县
receiver_street TEXT, -- 街道
receiver_address TEXT, -- 详细地址
receiver_zip TEXT, -- 邮编
address_type INTEGER, -- 地址类型
is_default INTEGER, -- 是否默认地址
longitude REAL, -- 经度
latitude REAL, -- 纬度
address_label TEXT, -- 地址标签
-- shipping_info
shipping_type INTEGER, -- 配送方式
shipping_method TEXT, -- 配送方式名称
shipping_company TEXT, -- 快递公司
shipping_company_code TEXT, -- 快递公司编码
shipping_no TEXT, -- 快递单号
shipping_time DATETIME, -- 发货时间
shipping_remark TEXT, -- 发货备注
expect_receive_time DATETIME, -- 预计送达时间
receive_time DATETIME, -- 收货时间
sign_type INTEGER, -- 签收类型
shipping_status INTEGER, -- 物流状态
tracking_url TEXT, -- 物流跟踪URL
is_free_shipping INTEGER, -- 是否包邮
shipping_insurance REAL, -- 运费险金额
shipping_distance REAL, -- 配送距离
delivered_time DATETIME, -- 送达时间
delivery_staff_id TEXT, -- 配送员ID
delivery_staff_name TEXT, -- 配送员姓名
delivery_staff_mobile TEXT, -- 配送员电话
-- payment_info
payment_id TEXT, -- 支付ID
payment_no TEXT, -- 支付单号
payment_type INTEGER, -- 支付方式
payment_method TEXT, -- 支付方式名称
payment_status INTEGER, -- 支付状态
payment_platform TEXT, -- 支付平台
transaction_id TEXT, -- 交易流水号
payment_time DATETIME, -- 支付时间
payment_account TEXT, -- 支付账号
payment_bank TEXT, -- 支付银行
payment_card_type TEXT, -- 支付卡类型
payment_card_no TEXT, -- 支付卡号
payment_scene TEXT, -- 支付场景
payment_client_ip TEXT, -- 支付IP
payment_device TEXT, -- 支付设备
payment_remark TEXT, -- 支付备注
payment_voucher TEXT, -- 支付凭证
-- promotion_info
promotion_id TEXT, -- 活动ID
promotion_name TEXT, -- 活动名称
promotion_type INTEGER, -- 活动类型
promotion_desc TEXT, -- 活动描述
promotion_start_time DATETIME, -- 活动开始时间
promotion_end_time DATETIME, -- 活动结束时间
coupon_id TEXT, -- 优惠券ID
coupon_code TEXT, -- 优惠券码
coupon_type INTEGER, -- 优惠券类型
coupon_name TEXT, -- 优惠券名称
coupon_desc TEXT, -- 优惠券描述
points_used INTEGER, -- 使用积分
points_gained INTEGER, -- 获得积分
points_multiple REAL, -- 积分倍率
is_first_order INTEGER, -- 是否首单
is_new_customer INTEGER, -- 是否新客
marketing_channel TEXT, -- 营销渠道
marketing_source TEXT, -- 营销来源
referral_code TEXT, -- 推荐码
referral_user_id TEXT, -- 推荐人ID
-- after_sale_info
refund_id TEXT, -- 退款ID
refund_no TEXT, -- 退款单号
refund_type INTEGER, -- 退款类型
refund_status INTEGER, -- 退款状态
refund_reason TEXT, -- 退款原因
refund_desc TEXT, -- 退款描述
refund_time DATETIME, -- 退款时间
refund_amount REAL, -- 退款金额
return_shipping_no TEXT, -- 退货快递单号
return_shipping_company TEXT, -- 退货快递公司
return_shipping_time DATETIME, -- 退货时间
refund_evidence TEXT, -- 退款凭证
complaint_id TEXT, -- 投诉ID
complaint_type INTEGER, -- 投诉类型
complaint_status INTEGER, -- 投诉状态
complaint_content TEXT, -- 投诉内容
complaint_time DATETIME, -- 投诉时间
complaint_handle_time DATETIME, -- 投诉处理时间
complaint_handle_result TEXT, -- 投诉处理结果
evaluation_score INTEGER, -- 评价分数
evaluation_content TEXT, -- 评价内容
evaluation_time DATETIME, -- 评价时间
evaluation_reply TEXT, -- 评价回复
evaluation_reply_time DATETIME, -- 评价回复时间
evaluation_images TEXT, -- 评价图片
evaluation_videos TEXT, -- 评价视频
is_anonymous INTEGER, -- 是否匿名评价
-- invoice_info
invoice_type INTEGER, -- 发票类型
invoice_title TEXT, -- 发票抬头
invoice_content TEXT, -- 发票内容
tax_no TEXT, -- 税号
invoice_amount REAL, -- 发票金额
invoice_status INTEGER, -- 发票状态
invoice_time DATETIME, -- 开票时间
invoice_number TEXT, -- 发票号码
invoice_code TEXT, -- 发票代码
company_name TEXT, -- 单位名称
company_address TEXT, -- 单位地址
company_tel TEXT, -- 单位电话
company_bank TEXT, -- 开户银行
company_account TEXT, -- 银行账号
-- delivery_time_info
expect_delivery_time DATETIME, -- 期望配送时间
delivery_period_type INTEGER, -- 配送时段类型
delivery_period_start TEXT, -- 配送时段开始
delivery_period_end TEXT, -- 配送时段结束
delivery_priority INTEGER, -- 配送优先级
-- tag_info
order_tags TEXT, -- 订单标签
user_tags TEXT, -- 用户标签
product_tags TEXT, -- 商品标签
risk_level INTEGER, -- 风险等级
risk_tags TEXT, -- 风险标签
business_tags TEXT, -- 业务标签
-- commercial_info
gross_profit REAL, -- 毛利
gross_profit_rate REAL, -- 毛利率
settlement_amount REAL, -- 结算金额
settlement_time DATETIME, -- 结算时间
settlement_cycle INTEGER, -- 结算周期
settlement_status INTEGER, -- 结算状态
commission_rate REAL, -- 佣金比例
platform_service_fee REAL, -- 平台服务费
ad_cost REAL, -- 广告费用
promotion_cost REAL -- 推广费用
);
-- 插入示例数据
INSERT INTO order_wide_table (
-- 基础订单信息
order_id, order_no, order_type, order_status, create_time, order_source,
-- 订单金额
original_amount, payment_amount, shipping_amount,
-- 用户信息
user_id, user_name, user_level, mobile,
-- 商品信息
product_id, product_name, product_quantity, product_price,
-- 收货信息
receiver_name, receiver_mobile, receiver_address,
-- 物流信息
shipping_no, shipping_status,
-- 支付信息
payment_type, payment_status,
-- 营销信息
promotion_id, coupon_amount,
-- 发票信息
invoice_type, invoice_title
) VALUES
(
'ORD20240101001', 'NO20240101001', 1, 2, '2024-01-01 10:00:00', 'APP',
199.99, 188.88, 10.00,
'USER001', '张三', 2, '13800138000',
'PRD001', 'iPhone 15 手机壳', 2, 89.99,
'李四', '13900139000', '北京市朝阳区XX路XX号',
'SF123456789', 1,
1, 1,
'PROM001', 20.00,
1, '个人'
),
(
'ORD20240101002', 'NO20240101002', 1, 1, '2024-01-01 11:00:00', 'H5',
299.99, 279.99, 0.00,
'USER002', '王五', 3, '13700137000',
'PRD002', 'AirPods Pro 保护套', 1, 299.99,
'赵六', '13600136000', '上海市浦东新区XX路XX号',
'YT987654321', 2,
2, 2,
'PROM002', 10.00,
2, '上海科技有限公司'
),
(
'ORD20240101003', 'NO20240101003', 2, 3, '2024-01-01 12:00:00', 'WEB',
1999.99, 1899.99, 0.00,
'USER003', '陈七', 4, '13500135000',
'PRD003', 'MacBook Pro 电脑包', 1, 1999.99,
'孙八', '13400134000', '广州市天河区XX路XX号',
'JD123123123', 3,
3, 1,
'PROM003', 100.00,
1, '个人'
);

View File

@ -4,7 +4,8 @@ from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
from dbgpt.rag.assembler import DBSchemaAssembler from dbgpt.rag.assembler import DBSchemaAssembler
from dbgpt.rag.embedding import DefaultEmbeddingFactory from dbgpt.rag.embedding import DefaultEmbeddingFactory
from dbgpt.storage.vector_store.chroma_store import ChromaStore, ChromaVectorConfig from dbgpt.serve.rag.connector import VectorStoreConnector
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
"""DB struct rag example. """DB struct rag example.
pre-requirements: pre-requirements:
@ -12,7 +13,7 @@ from dbgpt.storage.vector_store.chroma_store import ChromaStore, ChromaVectorCon
``` ```
embedding_model_path = "{your_embedding_model_path}" embedding_model_path = "{your_embedding_model_path}"
``` ```
Examples: Examples:
..code-block:: shell ..code-block:: shell
python examples/rag/db_schema_rag_example.py python examples/rag/db_schema_rag_example.py
@ -45,27 +46,26 @@ def _create_temporary_connection():
def _create_vector_connector(): def _create_vector_connector():
"""Create vector connector.""" """Create vector connector."""
config = ChromaVectorConfig( return VectorStoreConnector.from_default(
persist_path=PILOT_PATH, "Chroma",
name="dbschema_rag_test", vector_store_config=ChromaVectorConfig(
name="db_schema_vector_store_name",
persist_path=os.path.join(PILOT_PATH, "data"),
),
embedding_fn=DefaultEmbeddingFactory( embedding_fn=DefaultEmbeddingFactory(
default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"), default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"),
).create(), ).create(),
) )
return ChromaStore(config)
if __name__ == "__main__": if __name__ == "__main__":
connection = _create_temporary_connection() connection = _create_temporary_connection()
index_store = _create_vector_connector() vector_connector = _create_vector_connector()
assembler = DBSchemaAssembler.load_from_connection( assembler = DBSchemaAssembler.load_from_connection(
connector=connection, connector=connection, table_vector_store_connector=vector_connector
index_store=index_store,
) )
assembler.persist() assembler.persist()
# get db schema retriever # get db schema retriever
retriever = assembler.as_retriever(top_k=1) retriever = assembler.as_retriever(top_k=1)
chunks = retriever.retrieve("show columns from user") chunks = retriever.retrieve("show columns from user")
print(f"db schema rag example results:{[chunk.content for chunk in chunks]}") print(f"db schema rag example results:{[chunk.content for chunk in chunks]}")
index_store.delete_vector_name("dbschema_rag_test")

View File

@ -15,6 +15,7 @@ fi
DEFAULT_DB_FILE="DB-GPT/pilot/data/default_sqlite.db" DEFAULT_DB_FILE="DB-GPT/pilot/data/default_sqlite.db"
DEFAULT_SQL_FILE="DB-GPT/docker/examples/sqls/*_sqlite.sql" DEFAULT_SQL_FILE="DB-GPT/docker/examples/sqls/*_sqlite.sql"
DB_FILE="$WORK_DIR/pilot/data/default_sqlite.db" DB_FILE="$WORK_DIR/pilot/data/default_sqlite.db"
WIDE_DB_FILE="$WORK_DIR/pilot/data/wide_sqlite.db"
SQL_FILE="" SQL_FILE=""
usage () { usage () {
@ -61,6 +62,12 @@ if [ -n $SQL_FILE ];then
sqlite3 $DB_FILE < "$file" sqlite3 $DB_FILE < "$file"
done done
for file in $WORK_DIR/docker/examples/sqls/*_sqlite_wide.sql
do
echo "execute sql file: $file"
sqlite3 $WIDE_DB_FILE < "$file"
done
else else
echo "Execute SQL file ${SQL_FILE}" echo "Execute SQL file ${SQL_FILE}"
sqlite3 $DB_FILE < $SQL_FILE sqlite3 $DB_FILE < $SQL_FILE