Feat rdb summary wide table (#2035)

Co-authored-by: dongzhancai1 <dongzhancai1@jd.com>
Co-authored-by: dong <dongzhancai@iie2.com>
This commit is contained in:
Cooper 2024-12-18 20:34:21 +08:00 committed by GitHub
parent 7f4b5e79cf
commit 9b0161e521
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 948 additions and 243 deletions

View File

@ -66,6 +66,7 @@ QUANTIZE_8bit=True
#** EMBEDDING SETTINGS **#
#*******************************************************************#
EMBEDDING_MODEL=text2vec
EMBEDDING_MODEL_MAX_SEQ_LEN=512
#EMBEDDING_MODEL=m3e-large
#EMBEDDING_MODEL=bge-large-en
#EMBEDDING_MODEL=bge-large-zh

View File

@ -264,6 +264,9 @@ class Config(metaclass=Singleton):
# EMBEDDING Configuration
self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec")
self.EMBEDDING_MODEL_MAX_SEQ_LEN = int(
os.getenv("MEMBEDDING_MODEL_MAX_SEQ_LEN", 512)
)
# Rerank model configuration
self.RERANK_MODEL = os.getenv("RERANK_MODEL")
self.RERANK_MODEL_PATH = os.getenv("RERANK_MODEL_PATH")

View File

@ -55,9 +55,6 @@ class ChatWithDbQA(BaseChat):
if self.db_name:
client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
try:
# table_infos = client.get_db_summary(
# dbname=self.db_name, query=self.current_user_input, topk=self.top_k
# )
table_infos = await blocking_func_to_async(
self._executor,
client.get_db_summary,

View File

@ -1,12 +1,15 @@
"""DBSchemaAssembler."""
import os
from typing import Any, List, Optional
from dbgpt.core import Chunk
from dbgpt.core import Chunk, Embeddings
from dbgpt.datasource.base import BaseConnector
from ...serve.rag.connector import VectorStoreConnector
from ...storage.vector_store.base import VectorStoreConfig
from ..assembler.base import BaseAssembler
from ..chunk_manager import ChunkParameters
from ..index.base import IndexStoreBase
from ..embedding.embedding_factory import DefaultEmbeddingFactory
from ..knowledge.datasource import DatasourceKnowledge
from ..retriever.db_schema import DBSchemaRetriever
@ -35,23 +38,64 @@ class DBSchemaAssembler(BaseAssembler):
def __init__(
self,
connector: BaseConnector,
index_store: IndexStoreBase,
table_vector_store_connector: VectorStoreConnector,
field_vector_store_connector: VectorStoreConnector = None,
chunk_parameters: Optional[ChunkParameters] = None,
embedding_model: Optional[str] = None,
embeddings: Optional[Embeddings] = None,
max_seq_length: int = 512,
**kwargs: Any,
) -> None:
"""Initialize with Embedding Assembler arguments.
Args:
connector: (BaseConnector) BaseConnector connection.
index_store: (IndexStoreBase) IndexStoreBase to use.
table_vector_store_connector: VectorStoreConnector to load
and retrieve table info.
field_vector_store_connector: VectorStoreConnector to load
and retrieve field info.
chunk_manager: (Optional[ChunkManager]) ChunkManager to use for chunking.
embedding_model: (Optional[str]) Embedding model to use.
embeddings: (Optional[Embeddings]) Embeddings to use.
"""
knowledge = DatasourceKnowledge(connector)
self._connector = connector
self._index_store = index_store
self._table_vector_store_connector = table_vector_store_connector
field_vector_store_config = VectorStoreConfig(
name=table_vector_store_connector.vector_store_config.name + "_field"
)
self._field_vector_store_connector = (
field_vector_store_connector
or VectorStoreConnector.from_default(
os.getenv("VECTOR_STORE_TYPE", "Chroma"),
self._table_vector_store_connector.current_embeddings,
vector_store_config=field_vector_store_config,
)
)
self._embedding_model = embedding_model
if self._embedding_model and not embeddings:
embeddings = DefaultEmbeddingFactory(
default_model_name=self._embedding_model
).create(self._embedding_model)
if (
embeddings
and self._table_vector_store_connector.vector_store_config.embedding_fn
is None
):
self._table_vector_store_connector.vector_store_config.embedding_fn = (
embeddings
)
if (
embeddings
and self._field_vector_store_connector.vector_store_config.embedding_fn
is None
):
self._field_vector_store_connector.vector_store_config.embedding_fn = (
embeddings
)
knowledge = DatasourceKnowledge(connector, model_dimension=max_seq_length)
super().__init__(
knowledge=knowledge,
chunk_parameters=chunk_parameters,
@ -62,23 +106,36 @@ class DBSchemaAssembler(BaseAssembler):
def load_from_connection(
cls,
connector: BaseConnector,
index_store: IndexStoreBase,
table_vector_store_connector: VectorStoreConnector,
field_vector_store_connector: VectorStoreConnector = None,
chunk_parameters: Optional[ChunkParameters] = None,
embedding_model: Optional[str] = None,
embeddings: Optional[Embeddings] = None,
max_seq_length: int = 512,
) -> "DBSchemaAssembler":
"""Load document embedding into vector store from path.
Args:
connector: (BaseConnector) BaseConnector connection.
index_store: (IndexStoreBase) IndexStoreBase to use.
table_vector_store_connector: used to load table chunks.
field_vector_store_connector: used to load field chunks
if field in table is too much.
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for
chunking.
embedding_model: (Optional[str]) Embedding model to use.
embeddings: (Optional[Embeddings]) Embeddings to use.
max_seq_length: Embedding model max sequence length
Returns:
DBSchemaAssembler
"""
return cls(
connector=connector,
index_store=index_store,
table_vector_store_connector=table_vector_store_connector,
field_vector_store_connector=field_vector_store_connector,
embedding_model=embedding_model,
chunk_parameters=chunk_parameters,
embeddings=embeddings,
max_seq_length=max_seq_length,
)
def get_chunks(self) -> List[Chunk]:
@ -91,7 +148,19 @@ class DBSchemaAssembler(BaseAssembler):
Returns:
List[str]: List of chunk ids.
"""
return self._index_store.load_document(self._chunks)
table_chunks, field_chunks = [], []
for chunk in self._chunks:
metadata = chunk.metadata
if metadata.get("separated"):
if metadata.get("part") == "table":
table_chunks.append(chunk)
else:
field_chunks.append(chunk)
else:
table_chunks.append(chunk)
self._field_vector_store_connector.load_document(field_chunks)
return self._table_vector_store_connector.load_document(table_chunks)
def _extract_info(self, chunks) -> List[Chunk]:
"""Extract info from chunks."""
@ -110,5 +179,6 @@ class DBSchemaAssembler(BaseAssembler):
top_k=top_k,
connector=self._connector,
is_embeddings=True,
index_store=self._index_store,
table_vector_store_connector=self._table_vector_store_connector,
field_vector_store_connector=self._field_vector_store_connector,
)

View File

@ -1,76 +1,117 @@
from unittest.mock import MagicMock
from typing import List
from unittest.mock import MagicMock, patch
import pytest
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
from dbgpt.rag.assembler.embedding import EmbeddingAssembler
from dbgpt.rag.chunk_manager import ChunkParameters, SplitterType
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
from dbgpt.rag.knowledge.base import Knowledge
from dbgpt.rag.text_splitter.text_splitter import CharacterTextSplitter
from dbgpt.storage.vector_store.chroma_store import ChromaStore
import dbgpt
from dbgpt.core import Chunk
from dbgpt.rag.retriever.db_schema import DBSchemaRetriever
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
@pytest.fixture
def mock_db_connection():
"""Create a temporary database connection for testing."""
connect = SQLiteTempConnector.create_temporary_db()
connect.create_temp_tables(
{
"user": {
"columns": {
"id": "INTEGER PRIMARY KEY",
"name": "TEXT",
"age": "INTEGER",
},
"data": [
(1, "Tom", 10),
(2, "Jerry", 16),
(3, "Jack", 18),
(4, "Alice", 20),
(5, "Bob", 22),
],
}
}
return MagicMock()
@pytest.fixture
def mock_table_vector_store_connector():
mock_connector = MagicMock()
mock_connector.vector_store_config.name = "table_name"
chunk = Chunk(
content="table_name: user\ncomment: user about dbgpt",
metadata={
"field_num": 6,
"part": "table",
"separated": 1,
"table_name": "user",
},
)
return connect
mock_connector.similar_search_with_scores = MagicMock(return_value=[chunk])
return mock_connector
@pytest.fixture
def mock_chunk_parameters():
return MagicMock(spec=ChunkParameters)
def mock_field_vector_store_connector():
mock_connector = MagicMock()
chunk1 = Chunk(
content="name,age",
metadata={
"field_num": 6,
"part": "field",
"part_index": 0,
"separated": 1,
"table_name": "user",
},
)
chunk2 = Chunk(
content="address,gender",
metadata={
"field_num": 6,
"part": "field",
"part_index": 1,
"separated": 1,
"table_name": "user",
},
)
chunk3 = Chunk(
content="mail,phone",
metadata={
"field_num": 6,
"part": "field",
"part_index": 2,
"separated": 1,
"table_name": "user",
},
)
mock_connector.similar_search_with_scores = MagicMock(
return_value=[chunk1, chunk2, chunk3]
)
return mock_connector
@pytest.fixture
def mock_embedding_factory():
return MagicMock(spec=EmbeddingFactory)
@pytest.fixture
def mock_vector_store_connector():
return MagicMock(spec=ChromaStore)
@pytest.fixture
def mock_knowledge():
return MagicMock(spec=Knowledge)
def test_load_knowledge(
def dbstruct_retriever(
mock_db_connection,
mock_knowledge,
mock_chunk_parameters,
mock_embedding_factory,
mock_vector_store_connector,
mock_table_vector_store_connector,
mock_field_vector_store_connector,
):
mock_chunk_parameters.chunk_strategy = "CHUNK_BY_SIZE"
mock_chunk_parameters.text_splitter = CharacterTextSplitter()
mock_chunk_parameters.splitter_type = SplitterType.USER_DEFINE
assembler = EmbeddingAssembler(
knowledge=mock_knowledge,
chunk_parameters=mock_chunk_parameters,
embeddings=mock_embedding_factory.create(),
index_store=mock_vector_store_connector,
return DBSchemaRetriever(
connector=mock_db_connection,
table_vector_store_connector=mock_table_vector_store_connector,
field_vector_store_connector=mock_field_vector_store_connector,
separator="--table-field-separator--",
)
def mock_parse_db_summary() -> str:
"""Patch _parse_db_summary method."""
return (
"table_name: user\ncomment: user about dbgpt\n"
"--table-field-separator--\n"
"name,age\naddress,gender\nmail,phone"
)
# Mocking the _parse_db_summary method in your test function
@patch.object(
dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary
)
def test_retrieve_with_mocked_summary(dbstruct_retriever):
query = "Table summary"
chunks: List[Chunk] = dbstruct_retriever._retrieve(query)
assert isinstance(chunks[0], Chunk)
assert chunks[0].content == (
"table_name: user\ncomment: user about dbgpt\n"
"--table-field-separator--\n"
"name,age\naddress,gender\nmail,phone"
)
def async_mock_parse_db_summary() -> str:
"""Asynchronous patch for _parse_db_summary method."""
return (
"table_name: user\ncomment: user about dbgpt\n"
"--table-field-separator--\n"
"name,age\naddress,gender\nmail,phone"
)
assembler.load_knowledge(knowledge=mock_knowledge)
assert len(assembler._chunks) == 0

View File

@ -5,9 +5,9 @@ import pytest
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
from dbgpt.rag.assembler.db_schema import DBSchemaAssembler
from dbgpt.rag.chunk_manager import ChunkParameters, SplitterType
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
from dbgpt.rag.text_splitter.text_splitter import CharacterTextSplitter
from dbgpt.storage.vector_store.chroma_store import ChromaStore
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddings, EmbeddingFactory
from dbgpt.rag.text_splitter.text_splitter import RDBTextSplitter
from dbgpt.serve.rag.connector import VectorStoreConnector
@pytest.fixture
@ -21,14 +21,22 @@ def mock_db_connection():
"id": "INTEGER PRIMARY KEY",
"name": "TEXT",
"age": "INTEGER",
},
"data": [
(1, "Tom", 10),
(2, "Jerry", 16),
(3, "Jack", 18),
(4, "Alice", 20),
(5, "Bob", 22),
],
"address": "TEXT",
"phone": "TEXT",
"email": "TEXT",
"gender": "TEXT",
"birthdate": "TEXT",
"occupation": "TEXT",
"education": "TEXT",
"marital_status": "TEXT",
"nationality": "TEXT",
"height": "REAL",
"weight": "REAL",
"blood_type": "TEXT",
"emergency_contact": "TEXT",
"created_at": "TEXT",
"updated_at": "TEXT",
}
}
}
)
@ -46,23 +54,29 @@ def mock_embedding_factory():
@pytest.fixture
def mock_vector_store_connector():
return MagicMock(spec=ChromaStore)
def mock_table_vector_store_connector():
mock_connector = MagicMock(spec=VectorStoreConnector)
mock_connector.vector_store_config.name = "table_vector_store_name"
mock_connector.current_embeddings = DefaultEmbeddings()
return mock_connector
def test_load_knowledge(
mock_db_connection,
mock_chunk_parameters,
mock_embedding_factory,
mock_vector_store_connector,
mock_table_vector_store_connector,
):
mock_chunk_parameters.chunk_strategy = "CHUNK_BY_SIZE"
mock_chunk_parameters.text_splitter = CharacterTextSplitter()
mock_chunk_parameters.text_splitter = RDBTextSplitter(
separator="--table-field-separator--"
)
mock_chunk_parameters.splitter_type = SplitterType.USER_DEFINE
assembler = DBSchemaAssembler(
connector=mock_db_connection,
chunk_parameters=mock_chunk_parameters,
embeddings=mock_embedding_factory.create(),
index_store=mock_vector_store_connector,
table_vector_store_connector=mock_table_vector_store_connector,
max_seq_length=10,
)
assert len(assembler._chunks) == 1
assert len(assembler._chunks) > 1

View File

@ -5,7 +5,7 @@ from dbgpt.core import Document
from dbgpt.datasource import BaseConnector
from ..summary.gdbms_db_summary import _parse_db_summary as _parse_gdb_summary
from ..summary.rdbms_db_summary import _parse_db_summary
from ..summary.rdbms_db_summary import _parse_db_summary_with_metadata
from .base import ChunkStrategy, DocumentType, Knowledge, KnowledgeType
@ -15,9 +15,11 @@ class DatasourceKnowledge(Knowledge):
def __init__(
self,
connector: BaseConnector,
summary_template: str = "{table_name}({columns})",
summary_template: str = "table_name: {table_name}",
separator: str = "--table-field-separator--",
knowledge_type: Optional[KnowledgeType] = KnowledgeType.DOCUMENT,
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
model_dimension: int = 512,
**kwargs: Any,
) -> None:
"""Create Datasource Knowledge with Knowledge arguments.
@ -25,11 +27,17 @@ class DatasourceKnowledge(Knowledge):
Args:
connector(BaseConnector): connector
summary_template(str, optional): summary template
separator(str, optional): separator used to separate
table's basic info and fields.
defaults `-- table-field-separator--`
knowledge_type(KnowledgeType, optional): knowledge type
metadata(Dict[str, Union[str, List[str]], optional): metadata
model_dimension(int, optional): The threshold for splitting field string
"""
self._separator = separator
self._connector = connector
self._summary_template = summary_template
self._model_dimension = model_dimension
super().__init__(knowledge_type=knowledge_type, metadata=metadata, **kwargs)
def _load(self) -> List[Document]:
@ -37,13 +45,23 @@ class DatasourceKnowledge(Knowledge):
docs = []
if self._connector.is_graph_type():
db_summary = _parse_gdb_summary(self._connector, self._summary_template)
for table_summary in db_summary:
metadata = {"source": "database"}
docs.append(Document(content=table_summary, metadata=metadata))
else:
db_summary = _parse_db_summary(self._connector, self._summary_template)
for table_summary in db_summary:
metadata = {"source": "database"}
if self._metadata:
metadata.update(self._metadata) # type: ignore
docs.append(Document(content=table_summary, metadata=metadata))
db_summary_with_metadata = _parse_db_summary_with_metadata(
self._connector,
self._summary_template,
self._separator,
self._model_dimension,
)
for summary, table_metadata in db_summary_with_metadata:
metadata = {"source": "database"}
if self._metadata:
metadata.update(self._metadata) # type: ignore
table_metadata.update(metadata)
docs.append(Document(content=summary, metadata=table_metadata))
return docs
@classmethod

View File

@ -1,14 +1,15 @@
"""The DBSchema Retriever Operator."""
import os
from typing import List, Optional
from dbgpt.core import Chunk
from dbgpt.core.interface.operators.retriever import RetrieverOperator
from dbgpt.datasource.base import BaseConnector
from dbgpt.serve.rag.connector import VectorStoreConnector
from ...storage.vector_store.base import VectorStoreConfig
from ..assembler.db_schema import DBSchemaAssembler
from ..chunk_manager import ChunkParameters
from ..index.base import IndexStoreBase
from ..retriever.db_schema import DBSchemaRetriever
from .assembler import AssemblerOperator
@ -19,13 +20,14 @@ class DBSchemaRetrieverOperator(RetrieverOperator[str, List[Chunk]]):
Args:
connector (BaseConnector): The connection.
top_k (int, optional): The top k. Defaults to 4.
index_store (IndexStoreBase, optional): The vector store
vector_store_connector (VectorStoreConnector, optional): The vector store
connector. Defaults to None.
"""
def __init__(
self,
index_store: IndexStoreBase,
table_vector_store_connector: VectorStoreConnector,
field_vector_store_connector: VectorStoreConnector,
top_k: int = 4,
connector: Optional[BaseConnector] = None,
**kwargs
@ -35,7 +37,8 @@ class DBSchemaRetrieverOperator(RetrieverOperator[str, List[Chunk]]):
self._retriever = DBSchemaRetriever(
top_k=top_k,
connector=connector,
index_store=index_store,
table_vector_store_connector=table_vector_store_connector,
field_vector_store_connector=field_vector_store_connector,
)
def retrieve(self, query: str) -> List[Chunk]:
@ -53,7 +56,8 @@ class DBSchemaAssemblerOperator(AssemblerOperator[BaseConnector, List[Chunk]]):
def __init__(
self,
connector: BaseConnector,
index_store: IndexStoreBase,
table_vector_store_connector: VectorStoreConnector,
field_vector_store_connector: VectorStoreConnector = None,
chunk_parameters: Optional[ChunkParameters] = None,
**kwargs
):
@ -61,14 +65,26 @@ class DBSchemaAssemblerOperator(AssemblerOperator[BaseConnector, List[Chunk]]):
Args:
connector (BaseConnector): The connection.
index_store (IndexStoreBase): The Storage IndexStoreBase.
vector_store_connector (VectorStoreConnector): The vector store connector.
chunk_parameters (Optional[ChunkParameters], optional): The chunk
parameters.
"""
if not chunk_parameters:
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")
self._chunk_parameters = chunk_parameters
self._index_store = index_store
self._table_vector_store_connector = table_vector_store_connector
field_vector_store_config = VectorStoreConfig(
name=table_vector_store_connector.vector_store_config.name + "_field"
)
self._field_vector_store_connector = (
field_vector_store_connector
or VectorStoreConnector.from_default(
os.getenv("VECTOR_STORE_TYPE", "Chroma"),
self._table_vector_store_connector.current_embeddings,
vector_store_config=field_vector_store_config,
)
)
self._connector = connector
super().__init__(**kwargs)
@ -84,7 +100,8 @@ class DBSchemaAssemblerOperator(AssemblerOperator[BaseConnector, List[Chunk]]):
assembler = DBSchemaAssembler.load_from_connection(
connector=self._connector,
chunk_parameters=self._chunk_parameters,
index_store=self._index_store,
table_vector_store_connector=self._table_vector_store_connector,
field_vector_store_connector=self._field_vector_store_connector,
)
assembler.persist()
return assembler.get_chunks()

View File

@ -1,18 +1,23 @@
"""DBSchema retriever."""
import logging
import os
from typing import List, Optional
from functools import reduce
from typing import List, Optional, cast
from dbgpt._private.config import Config
from dbgpt.core import Chunk
from dbgpt.datasource.base import BaseConnector
from dbgpt.rag.index.base import IndexStoreBase
from dbgpt.rag.retriever.base import BaseRetriever
from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
from dbgpt.storage.vector_store.filters import MetadataFilters
from dbgpt.util.chat_util import run_async_tasks
from dbgpt.rag.summary.gdbms_db_summary import _parse_db_summary
from dbgpt.serve.rag.connector import VectorStoreConnector
from dbgpt.storage.vector_store.base import VectorStoreConfig
from dbgpt.storage.vector_store.filters import MetadataFilter, MetadataFilters
from dbgpt.util.chat_util import run_tasks
from dbgpt.util.executor_utils import blocking_func_to_async_no_executor
from dbgpt.util.tracer import root_tracer
logger = logging.getLogger(__name__)
CFG = Config()
class DBSchemaRetriever(BaseRetriever):
@ -20,7 +25,9 @@ class DBSchemaRetriever(BaseRetriever):
def __init__(
self,
index_store: IndexStoreBase,
table_vector_store_connector: VectorStoreConnector,
field_vector_store_connector: VectorStoreConnector = None,
separator: str = "--table-field-separator--",
top_k: int = 4,
connector: Optional[BaseConnector] = None,
query_rewrite: bool = False,
@ -30,7 +37,11 @@ class DBSchemaRetriever(BaseRetriever):
"""Create DBSchemaRetriever.
Args:
index_store(IndexStore): index connector
table_vector_store_connector: VectorStoreConnector
to load and retrieve table info.
field_vector_store_connector: VectorStoreConnector
to load and retrieve field info.
separator: field/table separator
top_k (int): top k
connector (Optional[BaseConnector]): RDBMSConnector.
query_rewrite (bool): query rewrite
@ -70,34 +81,42 @@ class DBSchemaRetriever(BaseRetriever):
connector = _create_temporary_connection()
vector_store_config = ChromaVectorConfig(name="vector_store_name")
embedding_model_path = "{your_embedding_model_path}"
embedding_fn = embedding_factory.create(model_name=embedding_model_path)
config = ChromaVectorConfig(
persist_path=PILOT_PATH,
name="dbschema_rag_test",
embedding_fn=DefaultEmbeddingFactory(
default_model_name=os.path.join(
MODEL_PATH, "text2vec-large-chinese"
),
).create(),
vector_connector = VectorStoreConnector.from_default(
"Chroma",
vector_store_config=vector_store_config,
embedding_fn=embedding_fn,
)
vector_store = ChromaStore(config)
# get db struct retriever
retriever = DBSchemaRetriever(
top_k=3,
index_store=vector_store,
vector_store_connector=vector_connector,
connector=connector,
)
chunks = retriever.retrieve("show columns from table")
result = [chunk.content for chunk in chunks]
print(f"db struct rag example results:{result}")
"""
self._separator = separator
self._top_k = top_k
self._connector = connector
self._query_rewrite = query_rewrite
self._index_store = index_store
self._table_vector_store_connector = table_vector_store_connector
field_vector_store_config = VectorStoreConfig(
name=table_vector_store_connector.vector_store_config.name + "_field"
)
self._field_vector_store_connector = (
field_vector_store_connector
or VectorStoreConnector.from_default(
os.getenv("VECTOR_STORE_TYPE", "Chroma"),
self._table_vector_store_connector.current_embeddings,
vector_store_config=field_vector_store_config,
)
)
self._need_embeddings = False
if self._index_store:
if self._table_vector_store_connector:
self._need_embeddings = True
self._rerank = rerank or DefaultRanker(self._top_k)
@ -114,15 +133,8 @@ class DBSchemaRetriever(BaseRetriever):
List[Chunk]: list of chunks
"""
if self._need_embeddings:
queries = [query]
candidates = [
self._index_store.similar_search(query, self._top_k, filters)
for query in queries
]
return cast(List[Chunk], reduce(lambda x, y: x + y, candidates))
return self._similarity_search(query, filters)
else:
if not self._connector:
raise RuntimeError("RDBMSConnector connection is required.")
table_summaries = _parse_db_summary(self._connector)
return [Chunk(content=table_summary) for table_summary in table_summaries]
@ -156,30 +168,11 @@ class DBSchemaRetriever(BaseRetriever):
Returns:
List[Chunk]: list of chunks
"""
if self._need_embeddings:
queries = [query]
candidates = [
self._similarity_search(
query, filters, root_tracer.get_current_span_id()
)
for query in queries
]
result_candidates = await run_async_tasks(
tasks=candidates, concurrency_limit=1
)
return cast(List[Chunk], reduce(lambda x, y: x + y, result_candidates))
else:
from dbgpt.rag.summary.rdbms_db_summary import ( # noqa: F401
_parse_db_summary,
)
table_summaries = await run_async_tasks(
tasks=[self._aparse_db_summary(root_tracer.get_current_span_id())],
concurrency_limit=1,
)
return [
Chunk(content=table_summary) for table_summary in table_summaries[0]
]
return await blocking_func_to_async_no_executor(
func=self._retrieve,
query=query,
filters=filters,
)
async def _aretrieve_with_score(
self,
@ -196,34 +189,40 @@ class DBSchemaRetriever(BaseRetriever):
"""
return await self._aretrieve(query, filters)
async def _similarity_search(
self,
query,
filters: Optional[MetadataFilters] = None,
parent_span_id: Optional[str] = None,
def _retrieve_field(self, table_chunk: Chunk, query) -> Chunk:
metadata = table_chunk.metadata
metadata["part"] = "field"
filters = [MetadataFilter(key=k, value=v) for k, v in metadata.items()]
field_chunks = self._field_vector_store_connector.similar_search_with_scores(
query, self._top_k, 0, MetadataFilters(filters=filters)
)
field_contents = [chunk.content for chunk in field_chunks]
table_chunk.content += "\n" + self._separator + "\n" + "\n".join(field_contents)
return table_chunk
def _similarity_search(
self, query, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Similar search."""
with root_tracer.start_span(
"dbgpt.rag.retriever.db_schema._similarity_search",
parent_span_id,
metadata={"query": query},
):
return await blocking_func_to_async_no_executor(
self._index_store.similar_search, query, self._top_k, filters
)
table_chunks = self._table_vector_store_connector.similar_search_with_scores(
query, self._top_k, 0, filters
)
async def _aparse_db_summary(
self, parent_span_id: Optional[str] = None
) -> List[str]:
"""Similar search."""
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
not_sep_chunks = [
chunk for chunk in table_chunks if not chunk.metadata.get("separated")
]
separated_chunks = [
chunk for chunk in table_chunks if chunk.metadata.get("separated")
]
if not separated_chunks:
return not_sep_chunks
if not self._connector:
raise RuntimeError("RDBMSConnector connection is required.")
with root_tracer.start_span(
"dbgpt.rag.retriever.db_schema._aparse_db_summary",
parent_span_id,
):
return await blocking_func_to_async_no_executor(
_parse_db_summary, self._connector
)
# Create tasks list
tasks = [
lambda c=chunk: self._retrieve_field(c, query) for chunk in separated_chunks
]
# Run tasks concurrently
separated_result = run_tasks(tasks, concurrency_limit=3)
# Combine and return results
return not_sep_chunks + separated_result

View File

@ -15,42 +15,53 @@ def mock_db_connection():
@pytest.fixture
def mock_vector_store_connector():
def mock_table_vector_store_connector():
mock_connector = MagicMock()
mock_connector.similar_search.return_value = [Chunk(content="Table summary")] * 4
mock_connector.vector_store_config.name = "table_name"
mock_connector.similar_search_with_scores.return_value = [
Chunk(content="Table summary")
] * 4
return mock_connector
@pytest.fixture
def db_struct_retriever(mock_db_connection, mock_vector_store_connector):
def mock_field_vector_store_connector():
mock_connector = MagicMock()
mock_connector.similar_search_with_scores.return_value = [
Chunk(content="Field summary")
] * 4
return mock_connector
@pytest.fixture
def dbstruct_retriever(
mock_db_connection,
mock_table_vector_store_connector,
mock_field_vector_store_connector,
):
return DBSchemaRetriever(
connector=mock_db_connection,
index_store=mock_vector_store_connector,
table_vector_store_connector=mock_table_vector_store_connector,
field_vector_store_connector=mock_field_vector_store_connector,
)
def mock_parse_db_summary(conn) -> List[str]:
def mock_parse_db_summary() -> str:
"""Patch _parse_db_summary method."""
return ["Table summary"]
return "Table summary"
# Mocking the _parse_db_summary method in your test function
@patch.object(
dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary
)
def test_retrieve_with_mocked_summary(db_struct_retriever):
def test_retrieve_with_mocked_summary(dbstruct_retriever):
query = "Table summary"
chunks: List[Chunk] = db_struct_retriever._retrieve(query)
chunks: List[Chunk] = dbstruct_retriever._retrieve(query)
assert isinstance(chunks[0], Chunk)
assert chunks[0].content == "Table summary"
@pytest.mark.asyncio
@patch.object(
dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary
)
async def test_aretrieve_with_mocked_summary(db_struct_retriever):
query = "Table summary"
chunks: List[Chunk] = await db_struct_retriever._aretrieve(query)
assert isinstance(chunks[0], Chunk)
assert chunks[0].content == "Table summary"
async def async_mock_parse_db_summary() -> str:
"""Asynchronous patch for _parse_db_summary method."""
return "Table summary"

View File

@ -2,13 +2,15 @@
import logging
import traceback
from typing import List
from dbgpt._private.config import Config
from dbgpt.component import SystemApp
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
from dbgpt.rag import ChunkParameters
from dbgpt.rag.summary.gdbms_db_summary import GdbmsSummary
from dbgpt.rag.summary.rdbms_db_summary import RdbmsSummary
from dbgpt.rag.text_splitter.text_splitter import RDBTextSplitter
from dbgpt.serve.rag.connector import VectorStoreConnector
logger = logging.getLogger(__name__)
@ -47,22 +49,26 @@ class DBSummaryClient:
logger.info("db summary embedding success")
def get_db_summary(self, dbname, query, topk) -> List[str]:
def get_db_summary(self, dbname, query, topk):
"""Get user query related tables info."""
from dbgpt.serve.rag.connector import VectorStoreConnector
from dbgpt.storage.vector_store.base import VectorStoreConfig
vector_store_config = VectorStoreConfig(name=dbname + "_profile")
vector_connector = VectorStoreConnector.from_default(
vector_store_name = dbname + "_profile"
table_vector_store_config = VectorStoreConfig(name=vector_store_name)
table_vector_connector = VectorStoreConnector.from_default(
CFG.VECTOR_STORE_TYPE,
embedding_fn=self.embeddings,
vector_store_config=vector_store_config,
self.embeddings,
vector_store_config=table_vector_store_config,
)
from dbgpt.rag.retriever.db_schema import DBSchemaRetriever
retriever = DBSchemaRetriever(
top_k=topk, index_store=vector_connector.index_client
top_k=topk,
table_vector_store_connector=table_vector_connector,
separator="--table-field-separator--",
)
table_docs = retriever.retrieve(query)
ans = [d.content for d in table_docs]
return ans
@ -92,18 +98,23 @@ class DBSummaryClient:
from dbgpt.serve.rag.connector import VectorStoreConnector
from dbgpt.storage.vector_store.base import VectorStoreConfig
vector_store_config = VectorStoreConfig(name=vector_store_name)
vector_connector = VectorStoreConnector.from_default(
table_vector_store_config = VectorStoreConfig(name=vector_store_name)
table_vector_connector = VectorStoreConnector.from_default(
CFG.VECTOR_STORE_TYPE,
self.embeddings,
vector_store_config=vector_store_config,
vector_store_config=table_vector_store_config,
)
if not vector_connector.vector_name_exists():
if not table_vector_connector.vector_name_exists():
from dbgpt.rag.assembler.db_schema import DBSchemaAssembler
chunk_parameters = ChunkParameters(
text_splitter=RDBTextSplitter(separator="--table-field-separator--")
)
db_assembler = DBSchemaAssembler.load_from_connection(
connector=db_summary_client.db,
index_store=vector_connector.index_client,
table_vector_store_connector=table_vector_connector,
chunk_parameters=chunk_parameters,
max_seq_length=CFG.EMBEDDING_MODEL_MAX_SEQ_LEN,
)
if len(db_assembler.get_chunks()) > 0:
@ -115,16 +126,26 @@ class DBSummaryClient:
def delete_db_profile(self, dbname):
"""Delete db profile."""
vector_store_name = dbname + "_profile"
table_vector_store_name = dbname + "_profile"
field_vector_store_name = dbname + "_profile_field"
from dbgpt.serve.rag.connector import VectorStoreConnector
from dbgpt.storage.vector_store.base import VectorStoreConfig
vector_store_config = VectorStoreConfig(name=vector_store_name)
vector_connector = VectorStoreConnector.from_default(
table_vector_store_config = VectorStoreConfig(name=vector_store_name)
field_vector_store_config = VectorStoreConfig(name=field_vector_store_name)
table_vector_connector = VectorStoreConnector.from_default(
CFG.VECTOR_STORE_TYPE,
self.embeddings,
vector_store_config=vector_store_config,
vector_store_config=table_vector_store_config,
)
vector_connector.delete_vector_name(vector_store_name)
field_vector_connector = VectorStoreConnector.from_default(
CFG.VECTOR_STORE_TYPE,
self.embeddings,
vector_store_config=field_vector_store_config,
)
table_vector_connector.delete_vector_name(table_vector_store_name)
field_vector_connector.delete_vector_name(field_vector_store_name)
logger.info(f"delete db profile {dbname} success")
@staticmethod

View File

@ -1,6 +1,6 @@
"""Summary for rdbms database."""
import re
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from dbgpt._private.config import Config
from dbgpt.datasource import BaseConnector
@ -80,6 +80,134 @@ def _parse_db_summary(
return table_info_summaries
def _parse_db_summary_with_metadata(
conn: BaseConnector,
summary_template: str = "table_name: {table_name}",
separator: str = "--table-field-separator--",
model_dimension: int = 512,
) -> List[Tuple[str, Dict[str, Any]]]:
"""Get db summary for database.
Args:
conn (BaseConnector): database connection
summary_template (str): summary template
separator(str, optional): separator used to separate table's
basic info and fields. defaults to `-- table-field-separator--`
model_dimension(int, optional): The threshold for splitting field string
"""
tables = conn.get_table_names()
table_info_summaries = [
_parse_table_summary_with_metadata(
conn, summary_template, separator, table_name, model_dimension
)
for table_name in tables
]
return table_info_summaries
def _split_columns_str(columns: List[str], model_dimension: int):
"""Split columns str.
Args:
columns (List[str]): fields string
model_dimension (int, optional): The threshold for splitting field string.
"""
result = []
current_string = ""
current_length = 0
for element_str in columns:
element_length = len(element_str)
# If adding the current element's length would exceed the threshold,
# add the current string to results and reset
if current_length + element_length > model_dimension:
result.append(current_string.strip()) # Remove trailing spaces
current_string = element_str
current_length = element_length
else:
# If current string is empty, add element directly
if current_string:
current_string += "," + element_str
else:
current_string = element_str
current_length += element_length + 1 # Add length of space
# Handle the last string segment
if current_string:
result.append(current_string.strip())
return result
def _parse_table_summary_with_metadata(
conn: BaseConnector,
summary_template: str,
separator,
table_name: str,
model_dimension=512,
) -> Tuple[str, Dict[str, Any]]:
"""Get table summary for table.
Args:
conn (BaseConnector): database connection
summary_template (str): summary template
separator(str, optional): separator used to separate table's
basic info and fields. defaults to `-- table-field-separator--`
model_dimension(int, optional): The threshold for splitting field string
Examples:
metadata: {'table_name': 'asd', 'separated': 0/1}
table_name: table1
table_comment: comment
index_keys: keys
--table-field-separator--
(column1,comment), (column2, comment), (column3, comment)
(column4,comment), (column5, comment), (column6, comment)
"""
columns = []
metadata = {"table_name": table_name, "separated": 0}
for column in conn.get_columns(table_name):
if column.get("comment"):
columns.append(f"{column['name']} ({column.get('comment')})")
else:
columns.append(f"{column['name']}")
metadata.update({"field_num": len(columns)})
separated_columns = _split_columns_str(columns, model_dimension=model_dimension)
if len(separated_columns) > 1:
metadata["separated"] = 1
column_str = "\n".join(separated_columns)
# Obtain index information
index_keys = []
raw_indexes = conn.get_indexes(table_name)
for index in raw_indexes:
if isinstance(index, tuple): # Process tuple type index information
index_name, index_creation_command = index
# Extract column names using re
matched_columns = re.findall(r"\(([^)]+)\)", index_creation_command)
if matched_columns:
key_str = ", ".join(matched_columns)
index_keys.append(f"{index_name}(`{key_str}`) ")
else:
key_str = ", ".join(index["column_names"])
index_keys.append(f"{index['name']}(`{key_str}`) ")
table_str = summary_template.format(table_name=table_name)
try:
comment = conn.get_table_comment(table_name)
except Exception:
comment = dict(text=None)
if comment.get("text"):
table_str += f"\ntable_comment: {comment.get('text')}"
if len(index_keys) > 0:
index_key_str = ", ".join(index_keys)
table_str += f"\nindex_keys: {index_key_str}"
table_str += f"\n{separator}\n{column_str}"
return table_str, metadata
def _parse_table_summary(
conn: BaseConnector, summary_template: str, table_name: str
) -> str:

View File

@ -912,3 +912,42 @@ class PageTextSplitter(TextSplitter):
new_doc = Chunk(content=text, metadata=copy.deepcopy(_metadatas[i]))
chunks.append(new_doc)
return chunks
class RDBTextSplitter(TextSplitter):
"""Split relational database tables and fields."""
def __init__(self, **kwargs):
"""Create a new TextSplitter."""
super().__init__(**kwargs)
def split_text(self, text: str, **kwargs):
"""Split text into a couple of parts."""
pass
def split_documents(self, documents: Iterable[Document], **kwargs) -> List[Chunk]:
"""Split document into chunks."""
chunks = []
for doc in documents:
metadata = doc.metadata
content = doc.content
if metadata.get("separated"):
# separate table and field
parts = content.split(self._separator)
table_part, field_part = parts[0], parts[1]
table_metadata, field_metadata = copy.deepcopy(metadata), copy.deepcopy(
metadata
)
table_metadata["part"] = "table" # identify of table_chunk
field_metadata["part"] = "field" # identify of field_chunk
table_chunk = Chunk(content=table_part, metadata=table_metadata)
chunks.append(table_chunk)
field_parts = field_part.split("\n")
for i, sub_part in enumerate(field_parts):
sub_metadata = copy.deepcopy(field_metadata)
sub_metadata["part_index"] = i
field_chunk = Chunk(content=sub_part, metadata=sub_metadata)
chunks.append(field_chunk)
else:
chunks.append(Chunk(content=content, metadata=metadata))
return chunks

View File

@ -1,5 +1,6 @@
import asyncio
from typing import Any, Coroutine, List
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Callable, Coroutine, List
async def llm_chat_response_nostream(chat_scene: str, **chat_param):
@ -47,13 +48,34 @@ async def run_async_tasks(
def run_tasks(
tasks: List[Coroutine],
tasks: List[Callable],
concurrency_limit: int = None,
) -> List[Any]:
"""Run a list of async tasks."""
tasks_to_execute: List[Any] = tasks
"""
Run a list of tasks concurrently using a thread pool.
async def _gather() -> List[Any]:
return await asyncio.gather(*tasks_to_execute)
Args:
tasks: List of callable functions to execute
concurrency_limit: Maximum number of concurrent threads (optional)
outputs: List[Any] = asyncio.run(_gather())
return outputs
Returns:
List of results from all tasks in the order they were submitted
"""
max_workers = concurrency_limit if concurrency_limit else None
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all tasks and get futures
futures = [executor.submit(task) for task in tasks]
# Collect results in order, raising any exceptions
results = []
for future in futures:
try:
results.append(future.result())
except Exception as e:
# Cancel any pending futures
for f in futures:
f.cancel()
raise e
return results

View File

@ -0,0 +1,317 @@
CREATE TABLE order_wide_table (
-- order_base
order_id TEXT, -- 订单ID
order_no TEXT, -- 订单编号
parent_order_no TEXT, -- 父订单编号
order_type INTEGER, -- 订单类型1实物2虚拟3混合
order_status INTEGER, -- 订单状态
order_source TEXT, -- 订单来源
order_source_detail TEXT, -- 订单来源详情
create_time DATETIME, -- 创建时间
pay_time DATETIME, -- 支付时间
finish_time DATETIME, -- 完成时间
close_time DATETIME, -- 关闭时间
cancel_time DATETIME, -- 取消时间
cancel_reason TEXT, -- 取消原因
order_remark TEXT, -- 订单备注
seller_remark TEXT, -- 卖家备注
buyer_remark TEXT, -- 买家备注
is_deleted INTEGER, -- 是否删除
delete_time DATETIME, -- 删除时间
order_ip TEXT, -- 下单IP
order_platform TEXT, -- 下单平台
order_device TEXT, -- 下单设备
order_app_version TEXT, -- APP版本号
-- order_amount
currency TEXT, -- 货币类型
exchange_rate REAL, -- 汇率
original_amount REAL, -- 原始金额
discount_amount REAL, -- 优惠金额
coupon_amount REAL, -- 优惠券金额
points_amount REAL, -- 积分抵扣金额
shipping_amount REAL, -- 运费
insurance_amount REAL, -- 保价费
tax_amount REAL, -- 税费
tariff_amount REAL, -- 关税
payment_amount REAL, -- 实付金额
commission_amount REAL, -- 佣金金额
platform_fee REAL, -- 平台费用
seller_income REAL, -- 卖家实收
payment_currency TEXT, -- 支付货币
payment_exchange_rate REAL, -- 支付汇率
-- user_info
user_id TEXT, -- 用户ID
user_name TEXT, -- 用户名
user_nickname TEXT, -- 用户昵称
user_level INTEGER, -- 用户等级
user_type INTEGER, -- 用户类型
register_time DATETIME, -- 注册时间
register_source TEXT, -- 注册来源
mobile TEXT, -- 手机号
mobile_area TEXT, -- 手机号区号
email TEXT, -- 邮箱
is_vip INTEGER, -- 是否VIP
vip_level INTEGER, -- VIP等级
vip_expire_time DATETIME, -- VIP过期时间
user_age INTEGER, -- 用户年龄
user_gender INTEGER, -- 用户性别
user_birthday DATE, -- 用户生日
user_avatar TEXT, -- 用户头像
user_province TEXT, -- 用户所在省
user_city TEXT, -- 用户所在市
user_district TEXT, -- 用户所在区
last_login_time DATETIME, -- 最后登录时间
last_login_ip TEXT, -- 最后登录IP
user_credit_score INTEGER, -- 用户信用分
total_order_count INTEGER, -- 历史订单数
total_order_amount REAL, -- 历史订单金额
-- product_info
product_id TEXT, -- 商品ID
product_code TEXT, -- 商品编码
product_name TEXT, -- 商品名称
product_short_name TEXT, -- 商品短名称
product_type INTEGER, -- 商品类型
product_status INTEGER, -- 商品状态
category_id TEXT, -- 类目ID
category_name TEXT, -- 类目名称
category_path TEXT, -- 类目路径
brand_id TEXT, -- 品牌ID
brand_name TEXT, -- 品牌名称
brand_english_name TEXT, -- 品牌英文名
seller_id TEXT, -- 卖家ID
seller_name TEXT, -- 卖家名称
seller_type INTEGER, -- 卖家类型
shop_id TEXT, -- 店铺ID
shop_name TEXT, -- 店铺名称
product_price REAL, -- 商品价格
market_price REAL, -- 市场价
cost_price REAL, -- 成本价
wholesale_price REAL, -- 批发价
product_quantity INTEGER, -- 商品数量
product_unit TEXT, -- 商品单位
product_weight REAL, -- 商品重量(克)
product_volume REAL, -- 商品体积(cm³)
product_spec TEXT, -- 商品规格
product_color TEXT, -- 商品颜色
product_size TEXT, -- 商品尺寸
product_material TEXT, -- 商品材质
product_origin TEXT, -- 商品产地
product_shelf_life INTEGER, -- 保质期(天)
manufacture_date DATE, -- 生产日期
expiry_date DATE, -- 过期日期
batch_number TEXT, -- 批次号
product_barcode TEXT, -- 商品条码
warehouse_id TEXT, -- 发货仓库ID
warehouse_name TEXT, -- 发货仓库名称
-- address_info
receiver_name TEXT, -- 收货人姓名
receiver_mobile TEXT, -- 收货人手机
receiver_tel TEXT, -- 收货人电话
receiver_email TEXT, -- 收货人邮箱
receiver_country TEXT, -- 国家
receiver_province TEXT, -- 省份
receiver_city TEXT, -- 城市
receiver_district TEXT, -- 区县
receiver_street TEXT, -- 街道
receiver_address TEXT, -- 详细地址
receiver_zip TEXT, -- 邮编
address_type INTEGER, -- 地址类型
is_default INTEGER, -- 是否默认地址
longitude REAL, -- 经度
latitude REAL, -- 纬度
address_label TEXT, -- 地址标签
-- shipping_info
shipping_type INTEGER, -- 配送方式
shipping_method TEXT, -- 配送方式名称
shipping_company TEXT, -- 快递公司
shipping_company_code TEXT, -- 快递公司编码
shipping_no TEXT, -- 快递单号
shipping_time DATETIME, -- 发货时间
shipping_remark TEXT, -- 发货备注
expect_receive_time DATETIME, -- 预计送达时间
receive_time DATETIME, -- 收货时间
sign_type INTEGER, -- 签收类型
shipping_status INTEGER, -- 物流状态
tracking_url TEXT, -- 物流跟踪URL
is_free_shipping INTEGER, -- 是否包邮
shipping_insurance REAL, -- 运费险金额
shipping_distance REAL, -- 配送距离
delivered_time DATETIME, -- 送达时间
delivery_staff_id TEXT, -- 配送员ID
delivery_staff_name TEXT, -- 配送员姓名
delivery_staff_mobile TEXT, -- 配送员电话
-- payment_info
payment_id TEXT, -- 支付ID
payment_no TEXT, -- 支付单号
payment_type INTEGER, -- 支付方式
payment_method TEXT, -- 支付方式名称
payment_status INTEGER, -- 支付状态
payment_platform TEXT, -- 支付平台
transaction_id TEXT, -- 交易流水号
payment_time DATETIME, -- 支付时间
payment_account TEXT, -- 支付账号
payment_bank TEXT, -- 支付银行
payment_card_type TEXT, -- 支付卡类型
payment_card_no TEXT, -- 支付卡号
payment_scene TEXT, -- 支付场景
payment_client_ip TEXT, -- 支付IP
payment_device TEXT, -- 支付设备
payment_remark TEXT, -- 支付备注
payment_voucher TEXT, -- 支付凭证
-- promotion_info
promotion_id TEXT, -- 活动ID
promotion_name TEXT, -- 活动名称
promotion_type INTEGER, -- 活动类型
promotion_desc TEXT, -- 活动描述
promotion_start_time DATETIME, -- 活动开始时间
promotion_end_time DATETIME, -- 活动结束时间
coupon_id TEXT, -- 优惠券ID
coupon_code TEXT, -- 优惠券码
coupon_type INTEGER, -- 优惠券类型
coupon_name TEXT, -- 优惠券名称
coupon_desc TEXT, -- 优惠券描述
points_used INTEGER, -- 使用积分
points_gained INTEGER, -- 获得积分
points_multiple REAL, -- 积分倍率
is_first_order INTEGER, -- 是否首单
is_new_customer INTEGER, -- 是否新客
marketing_channel TEXT, -- 营销渠道
marketing_source TEXT, -- 营销来源
referral_code TEXT, -- 推荐码
referral_user_id TEXT, -- 推荐人ID
-- after_sale_info
refund_id TEXT, -- 退款ID
refund_no TEXT, -- 退款单号
refund_type INTEGER, -- 退款类型
refund_status INTEGER, -- 退款状态
refund_reason TEXT, -- 退款原因
refund_desc TEXT, -- 退款描述
refund_time DATETIME, -- 退款时间
refund_amount REAL, -- 退款金额
return_shipping_no TEXT, -- 退货快递单号
return_shipping_company TEXT, -- 退货快递公司
return_shipping_time DATETIME, -- 退货时间
refund_evidence TEXT, -- 退款凭证
complaint_id TEXT, -- 投诉ID
complaint_type INTEGER, -- 投诉类型
complaint_status INTEGER, -- 投诉状态
complaint_content TEXT, -- 投诉内容
complaint_time DATETIME, -- 投诉时间
complaint_handle_time DATETIME, -- 投诉处理时间
complaint_handle_result TEXT, -- 投诉处理结果
evaluation_score INTEGER, -- 评价分数
evaluation_content TEXT, -- 评价内容
evaluation_time DATETIME, -- 评价时间
evaluation_reply TEXT, -- 评价回复
evaluation_reply_time DATETIME, -- 评价回复时间
evaluation_images TEXT, -- 评价图片
evaluation_videos TEXT, -- 评价视频
is_anonymous INTEGER, -- 是否匿名评价
-- invoice_info
invoice_type INTEGER, -- 发票类型
invoice_title TEXT, -- 发票抬头
invoice_content TEXT, -- 发票内容
tax_no TEXT, -- 税号
invoice_amount REAL, -- 发票金额
invoice_status INTEGER, -- 发票状态
invoice_time DATETIME, -- 开票时间
invoice_number TEXT, -- 发票号码
invoice_code TEXT, -- 发票代码
company_name TEXT, -- 单位名称
company_address TEXT, -- 单位地址
company_tel TEXT, -- 单位电话
company_bank TEXT, -- 开户银行
company_account TEXT, -- 银行账号
-- delivery_time_info
expect_delivery_time DATETIME, -- 期望配送时间
delivery_period_type INTEGER, -- 配送时段类型
delivery_period_start TEXT, -- 配送时段开始
delivery_period_end TEXT, -- 配送时段结束
delivery_priority INTEGER, -- 配送优先级
-- tag_info
order_tags TEXT, -- 订单标签
user_tags TEXT, -- 用户标签
product_tags TEXT, -- 商品标签
risk_level INTEGER, -- 风险等级
risk_tags TEXT, -- 风险标签
business_tags TEXT, -- 业务标签
-- commercial_info
gross_profit REAL, -- 毛利
gross_profit_rate REAL, -- 毛利率
settlement_amount REAL, -- 结算金额
settlement_time DATETIME, -- 结算时间
settlement_cycle INTEGER, -- 结算周期
settlement_status INTEGER, -- 结算状态
commission_rate REAL, -- 佣金比例
platform_service_fee REAL, -- 平台服务费
ad_cost REAL, -- 广告费用
promotion_cost REAL -- 推广费用
);
-- 插入示例数据
INSERT INTO order_wide_table (
-- 基础订单信息
order_id, order_no, order_type, order_status, create_time, order_source,
-- 订单金额
original_amount, payment_amount, shipping_amount,
-- 用户信息
user_id, user_name, user_level, mobile,
-- 商品信息
product_id, product_name, product_quantity, product_price,
-- 收货信息
receiver_name, receiver_mobile, receiver_address,
-- 物流信息
shipping_no, shipping_status,
-- 支付信息
payment_type, payment_status,
-- 营销信息
promotion_id, coupon_amount,
-- 发票信息
invoice_type, invoice_title
) VALUES
(
'ORD20240101001', 'NO20240101001', 1, 2, '2024-01-01 10:00:00', 'APP',
199.99, 188.88, 10.00,
'USER001', '张三', 2, '13800138000',
'PRD001', 'iPhone 15 手机壳', 2, 89.99,
'李四', '13900139000', '北京市朝阳区XX路XX号',
'SF123456789', 1,
1, 1,
'PROM001', 20.00,
1, '个人'
),
(
'ORD20240101002', 'NO20240101002', 1, 1, '2024-01-01 11:00:00', 'H5',
299.99, 279.99, 0.00,
'USER002', '王五', 3, '13700137000',
'PRD002', 'AirPods Pro 保护套', 1, 299.99,
'赵六', '13600136000', '上海市浦东新区XX路XX号',
'YT987654321', 2,
2, 2,
'PROM002', 10.00,
2, '上海科技有限公司'
),
(
'ORD20240101003', 'NO20240101003', 2, 3, '2024-01-01 12:00:00', 'WEB',
1999.99, 1899.99, 0.00,
'USER003', '陈七', 4, '13500135000',
'PRD003', 'MacBook Pro 电脑包', 1, 1999.99,
'孙八', '13400134000', '广州市天河区XX路XX号',
'JD123123123', 3,
3, 1,
'PROM003', 100.00,
1, '个人'
);

View File

@ -4,7 +4,8 @@ from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
from dbgpt.rag.assembler import DBSchemaAssembler
from dbgpt.rag.embedding import DefaultEmbeddingFactory
from dbgpt.storage.vector_store.chroma_store import ChromaStore, ChromaVectorConfig
from dbgpt.serve.rag.connector import VectorStoreConnector
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
"""DB struct rag example.
pre-requirements:
@ -12,7 +13,7 @@ from dbgpt.storage.vector_store.chroma_store import ChromaStore, ChromaVectorCon
```
embedding_model_path = "{your_embedding_model_path}"
```
Examples:
..code-block:: shell
python examples/rag/db_schema_rag_example.py
@ -45,27 +46,26 @@ def _create_temporary_connection():
def _create_vector_connector():
"""Create vector connector."""
config = ChromaVectorConfig(
persist_path=PILOT_PATH,
name="dbschema_rag_test",
return VectorStoreConnector.from_default(
"Chroma",
vector_store_config=ChromaVectorConfig(
name="db_schema_vector_store_name",
persist_path=os.path.join(PILOT_PATH, "data"),
),
embedding_fn=DefaultEmbeddingFactory(
default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"),
).create(),
)
return ChromaStore(config)
if __name__ == "__main__":
connection = _create_temporary_connection()
index_store = _create_vector_connector()
vector_connector = _create_vector_connector()
assembler = DBSchemaAssembler.load_from_connection(
connector=connection,
index_store=index_store,
connector=connection, table_vector_store_connector=vector_connector
)
assembler.persist()
# get db schema retriever
retriever = assembler.as_retriever(top_k=1)
chunks = retriever.retrieve("show columns from user")
print(f"db schema rag example results:{[chunk.content for chunk in chunks]}")
index_store.delete_vector_name("dbschema_rag_test")

View File

@ -15,6 +15,7 @@ fi
DEFAULT_DB_FILE="DB-GPT/pilot/data/default_sqlite.db"
DEFAULT_SQL_FILE="DB-GPT/docker/examples/sqls/*_sqlite.sql"
DB_FILE="$WORK_DIR/pilot/data/default_sqlite.db"
WIDE_DB_FILE="$WORK_DIR/pilot/data/wide_sqlite.db"
SQL_FILE=""
usage () {
@ -61,6 +62,12 @@ if [ -n $SQL_FILE ];then
sqlite3 $DB_FILE < "$file"
done
for file in $WORK_DIR/docker/examples/sqls/*_sqlite_wide.sql
do
echo "execute sql file: $file"
sqlite3 $WIDE_DB_FILE < "$file"
done
else
echo "Execute SQL file ${SQL_FILE}"
sqlite3 $DB_FILE < $SQL_FILE