mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-23 20:26:15 +00:00
Feat rdb summary wide table (#2035)
Co-authored-by: dongzhancai1 <dongzhancai1@jd.com> Co-authored-by: dong <dongzhancai@iie2.com>
This commit is contained in:
parent
7f4b5e79cf
commit
9b0161e521
@ -66,6 +66,7 @@ QUANTIZE_8bit=True
|
|||||||
#** EMBEDDING SETTINGS **#
|
#** EMBEDDING SETTINGS **#
|
||||||
#*******************************************************************#
|
#*******************************************************************#
|
||||||
EMBEDDING_MODEL=text2vec
|
EMBEDDING_MODEL=text2vec
|
||||||
|
EMBEDDING_MODEL_MAX_SEQ_LEN=512
|
||||||
#EMBEDDING_MODEL=m3e-large
|
#EMBEDDING_MODEL=m3e-large
|
||||||
#EMBEDDING_MODEL=bge-large-en
|
#EMBEDDING_MODEL=bge-large-en
|
||||||
#EMBEDDING_MODEL=bge-large-zh
|
#EMBEDDING_MODEL=bge-large-zh
|
||||||
|
@ -264,6 +264,9 @@ class Config(metaclass=Singleton):
|
|||||||
|
|
||||||
# EMBEDDING Configuration
|
# EMBEDDING Configuration
|
||||||
self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec")
|
self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec")
|
||||||
|
self.EMBEDDING_MODEL_MAX_SEQ_LEN = int(
|
||||||
|
os.getenv("MEMBEDDING_MODEL_MAX_SEQ_LEN", 512)
|
||||||
|
)
|
||||||
# Rerank model configuration
|
# Rerank model configuration
|
||||||
self.RERANK_MODEL = os.getenv("RERANK_MODEL")
|
self.RERANK_MODEL = os.getenv("RERANK_MODEL")
|
||||||
self.RERANK_MODEL_PATH = os.getenv("RERANK_MODEL_PATH")
|
self.RERANK_MODEL_PATH = os.getenv("RERANK_MODEL_PATH")
|
||||||
|
@ -55,9 +55,6 @@ class ChatWithDbQA(BaseChat):
|
|||||||
if self.db_name:
|
if self.db_name:
|
||||||
client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
|
client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
|
||||||
try:
|
try:
|
||||||
# table_infos = client.get_db_summary(
|
|
||||||
# dbname=self.db_name, query=self.current_user_input, topk=self.top_k
|
|
||||||
# )
|
|
||||||
table_infos = await blocking_func_to_async(
|
table_infos = await blocking_func_to_async(
|
||||||
self._executor,
|
self._executor,
|
||||||
client.get_db_summary,
|
client.get_db_summary,
|
||||||
|
@ -1,12 +1,15 @@
|
|||||||
"""DBSchemaAssembler."""
|
"""DBSchemaAssembler."""
|
||||||
|
import os
|
||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
from dbgpt.core import Chunk
|
from dbgpt.core import Chunk, Embeddings
|
||||||
from dbgpt.datasource.base import BaseConnector
|
from dbgpt.datasource.base import BaseConnector
|
||||||
|
|
||||||
|
from ...serve.rag.connector import VectorStoreConnector
|
||||||
|
from ...storage.vector_store.base import VectorStoreConfig
|
||||||
from ..assembler.base import BaseAssembler
|
from ..assembler.base import BaseAssembler
|
||||||
from ..chunk_manager import ChunkParameters
|
from ..chunk_manager import ChunkParameters
|
||||||
from ..index.base import IndexStoreBase
|
from ..embedding.embedding_factory import DefaultEmbeddingFactory
|
||||||
from ..knowledge.datasource import DatasourceKnowledge
|
from ..knowledge.datasource import DatasourceKnowledge
|
||||||
from ..retriever.db_schema import DBSchemaRetriever
|
from ..retriever.db_schema import DBSchemaRetriever
|
||||||
|
|
||||||
@ -35,23 +38,64 @@ class DBSchemaAssembler(BaseAssembler):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
connector: BaseConnector,
|
connector: BaseConnector,
|
||||||
index_store: IndexStoreBase,
|
table_vector_store_connector: VectorStoreConnector,
|
||||||
|
field_vector_store_connector: VectorStoreConnector = None,
|
||||||
chunk_parameters: Optional[ChunkParameters] = None,
|
chunk_parameters: Optional[ChunkParameters] = None,
|
||||||
|
embedding_model: Optional[str] = None,
|
||||||
|
embeddings: Optional[Embeddings] = None,
|
||||||
|
max_seq_length: int = 512,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize with Embedding Assembler arguments.
|
"""Initialize with Embedding Assembler arguments.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
connector: (BaseConnector) BaseConnector connection.
|
connector: (BaseConnector) BaseConnector connection.
|
||||||
index_store: (IndexStoreBase) IndexStoreBase to use.
|
table_vector_store_connector: VectorStoreConnector to load
|
||||||
|
and retrieve table info.
|
||||||
|
field_vector_store_connector: VectorStoreConnector to load
|
||||||
|
and retrieve field info.
|
||||||
chunk_manager: (Optional[ChunkManager]) ChunkManager to use for chunking.
|
chunk_manager: (Optional[ChunkManager]) ChunkManager to use for chunking.
|
||||||
embedding_model: (Optional[str]) Embedding model to use.
|
embedding_model: (Optional[str]) Embedding model to use.
|
||||||
embeddings: (Optional[Embeddings]) Embeddings to use.
|
embeddings: (Optional[Embeddings]) Embeddings to use.
|
||||||
"""
|
"""
|
||||||
knowledge = DatasourceKnowledge(connector)
|
|
||||||
self._connector = connector
|
self._connector = connector
|
||||||
self._index_store = index_store
|
self._table_vector_store_connector = table_vector_store_connector
|
||||||
|
|
||||||
|
field_vector_store_config = VectorStoreConfig(
|
||||||
|
name=table_vector_store_connector.vector_store_config.name + "_field"
|
||||||
|
)
|
||||||
|
self._field_vector_store_connector = (
|
||||||
|
field_vector_store_connector
|
||||||
|
or VectorStoreConnector.from_default(
|
||||||
|
os.getenv("VECTOR_STORE_TYPE", "Chroma"),
|
||||||
|
self._table_vector_store_connector.current_embeddings,
|
||||||
|
vector_store_config=field_vector_store_config,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self._embedding_model = embedding_model
|
||||||
|
if self._embedding_model and not embeddings:
|
||||||
|
embeddings = DefaultEmbeddingFactory(
|
||||||
|
default_model_name=self._embedding_model
|
||||||
|
).create(self._embedding_model)
|
||||||
|
|
||||||
|
if (
|
||||||
|
embeddings
|
||||||
|
and self._table_vector_store_connector.vector_store_config.embedding_fn
|
||||||
|
is None
|
||||||
|
):
|
||||||
|
self._table_vector_store_connector.vector_store_config.embedding_fn = (
|
||||||
|
embeddings
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
embeddings
|
||||||
|
and self._field_vector_store_connector.vector_store_config.embedding_fn
|
||||||
|
is None
|
||||||
|
):
|
||||||
|
self._field_vector_store_connector.vector_store_config.embedding_fn = (
|
||||||
|
embeddings
|
||||||
|
)
|
||||||
|
knowledge = DatasourceKnowledge(connector, model_dimension=max_seq_length)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
knowledge=knowledge,
|
knowledge=knowledge,
|
||||||
chunk_parameters=chunk_parameters,
|
chunk_parameters=chunk_parameters,
|
||||||
@ -62,23 +106,36 @@ class DBSchemaAssembler(BaseAssembler):
|
|||||||
def load_from_connection(
|
def load_from_connection(
|
||||||
cls,
|
cls,
|
||||||
connector: BaseConnector,
|
connector: BaseConnector,
|
||||||
index_store: IndexStoreBase,
|
table_vector_store_connector: VectorStoreConnector,
|
||||||
|
field_vector_store_connector: VectorStoreConnector = None,
|
||||||
chunk_parameters: Optional[ChunkParameters] = None,
|
chunk_parameters: Optional[ChunkParameters] = None,
|
||||||
|
embedding_model: Optional[str] = None,
|
||||||
|
embeddings: Optional[Embeddings] = None,
|
||||||
|
max_seq_length: int = 512,
|
||||||
) -> "DBSchemaAssembler":
|
) -> "DBSchemaAssembler":
|
||||||
"""Load document embedding into vector store from path.
|
"""Load document embedding into vector store from path.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
connector: (BaseConnector) BaseConnector connection.
|
connector: (BaseConnector) BaseConnector connection.
|
||||||
index_store: (IndexStoreBase) IndexStoreBase to use.
|
table_vector_store_connector: used to load table chunks.
|
||||||
|
field_vector_store_connector: used to load field chunks
|
||||||
|
if field in table is too much.
|
||||||
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for
|
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for
|
||||||
chunking.
|
chunking.
|
||||||
|
embedding_model: (Optional[str]) Embedding model to use.
|
||||||
|
embeddings: (Optional[Embeddings]) Embeddings to use.
|
||||||
|
max_seq_length: Embedding model max sequence length
|
||||||
Returns:
|
Returns:
|
||||||
DBSchemaAssembler
|
DBSchemaAssembler
|
||||||
"""
|
"""
|
||||||
return cls(
|
return cls(
|
||||||
connector=connector,
|
connector=connector,
|
||||||
index_store=index_store,
|
table_vector_store_connector=table_vector_store_connector,
|
||||||
|
field_vector_store_connector=field_vector_store_connector,
|
||||||
|
embedding_model=embedding_model,
|
||||||
chunk_parameters=chunk_parameters,
|
chunk_parameters=chunk_parameters,
|
||||||
|
embeddings=embeddings,
|
||||||
|
max_seq_length=max_seq_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_chunks(self) -> List[Chunk]:
|
def get_chunks(self) -> List[Chunk]:
|
||||||
@ -91,7 +148,19 @@ class DBSchemaAssembler(BaseAssembler):
|
|||||||
Returns:
|
Returns:
|
||||||
List[str]: List of chunk ids.
|
List[str]: List of chunk ids.
|
||||||
"""
|
"""
|
||||||
return self._index_store.load_document(self._chunks)
|
table_chunks, field_chunks = [], []
|
||||||
|
for chunk in self._chunks:
|
||||||
|
metadata = chunk.metadata
|
||||||
|
if metadata.get("separated"):
|
||||||
|
if metadata.get("part") == "table":
|
||||||
|
table_chunks.append(chunk)
|
||||||
|
else:
|
||||||
|
field_chunks.append(chunk)
|
||||||
|
else:
|
||||||
|
table_chunks.append(chunk)
|
||||||
|
|
||||||
|
self._field_vector_store_connector.load_document(field_chunks)
|
||||||
|
return self._table_vector_store_connector.load_document(table_chunks)
|
||||||
|
|
||||||
def _extract_info(self, chunks) -> List[Chunk]:
|
def _extract_info(self, chunks) -> List[Chunk]:
|
||||||
"""Extract info from chunks."""
|
"""Extract info from chunks."""
|
||||||
@ -110,5 +179,6 @@ class DBSchemaAssembler(BaseAssembler):
|
|||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
connector=self._connector,
|
connector=self._connector,
|
||||||
is_embeddings=True,
|
is_embeddings=True,
|
||||||
index_store=self._index_store,
|
table_vector_store_connector=self._table_vector_store_connector,
|
||||||
|
field_vector_store_connector=self._field_vector_store_connector,
|
||||||
)
|
)
|
||||||
|
@ -1,76 +1,117 @@
|
|||||||
from unittest.mock import MagicMock
|
from typing import List
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
|
import dbgpt
|
||||||
from dbgpt.rag.assembler.embedding import EmbeddingAssembler
|
from dbgpt.core import Chunk
|
||||||
from dbgpt.rag.chunk_manager import ChunkParameters, SplitterType
|
from dbgpt.rag.retriever.db_schema import DBSchemaRetriever
|
||||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||||
from dbgpt.rag.knowledge.base import Knowledge
|
|
||||||
from dbgpt.rag.text_splitter.text_splitter import CharacterTextSplitter
|
|
||||||
from dbgpt.storage.vector_store.chroma_store import ChromaStore
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_db_connection():
|
def mock_db_connection():
|
||||||
"""Create a temporary database connection for testing."""
|
return MagicMock()
|
||||||
connect = SQLiteTempConnector.create_temporary_db()
|
|
||||||
connect.create_temp_tables(
|
|
||||||
{
|
@pytest.fixture
|
||||||
"user": {
|
def mock_table_vector_store_connector():
|
||||||
"columns": {
|
mock_connector = MagicMock()
|
||||||
"id": "INTEGER PRIMARY KEY",
|
mock_connector.vector_store_config.name = "table_name"
|
||||||
"name": "TEXT",
|
chunk = Chunk(
|
||||||
"age": "INTEGER",
|
content="table_name: user\ncomment: user about dbgpt",
|
||||||
},
|
metadata={
|
||||||
"data": [
|
"field_num": 6,
|
||||||
(1, "Tom", 10),
|
"part": "table",
|
||||||
(2, "Jerry", 16),
|
"separated": 1,
|
||||||
(3, "Jack", 18),
|
"table_name": "user",
|
||||||
(4, "Alice", 20),
|
},
|
||||||
(5, "Bob", 22),
|
|
||||||
],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
return connect
|
mock_connector.similar_search_with_scores = MagicMock(return_value=[chunk])
|
||||||
|
return mock_connector
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_chunk_parameters():
|
def mock_field_vector_store_connector():
|
||||||
return MagicMock(spec=ChunkParameters)
|
mock_connector = MagicMock()
|
||||||
|
chunk1 = Chunk(
|
||||||
|
content="name,age",
|
||||||
|
metadata={
|
||||||
|
"field_num": 6,
|
||||||
|
"part": "field",
|
||||||
|
"part_index": 0,
|
||||||
|
"separated": 1,
|
||||||
|
"table_name": "user",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
chunk2 = Chunk(
|
||||||
|
content="address,gender",
|
||||||
|
metadata={
|
||||||
|
"field_num": 6,
|
||||||
|
"part": "field",
|
||||||
|
"part_index": 1,
|
||||||
|
"separated": 1,
|
||||||
|
"table_name": "user",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
chunk3 = Chunk(
|
||||||
|
content="mail,phone",
|
||||||
|
metadata={
|
||||||
|
"field_num": 6,
|
||||||
|
"part": "field",
|
||||||
|
"part_index": 2,
|
||||||
|
"separated": 1,
|
||||||
|
"table_name": "user",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
mock_connector.similar_search_with_scores = MagicMock(
|
||||||
|
return_value=[chunk1, chunk2, chunk3]
|
||||||
|
)
|
||||||
|
return mock_connector
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_embedding_factory():
|
def dbstruct_retriever(
|
||||||
return MagicMock(spec=EmbeddingFactory)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_vector_store_connector():
|
|
||||||
return MagicMock(spec=ChromaStore)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_knowledge():
|
|
||||||
return MagicMock(spec=Knowledge)
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_knowledge(
|
|
||||||
mock_db_connection,
|
mock_db_connection,
|
||||||
mock_knowledge,
|
mock_table_vector_store_connector,
|
||||||
mock_chunk_parameters,
|
mock_field_vector_store_connector,
|
||||||
mock_embedding_factory,
|
|
||||||
mock_vector_store_connector,
|
|
||||||
):
|
):
|
||||||
mock_chunk_parameters.chunk_strategy = "CHUNK_BY_SIZE"
|
return DBSchemaRetriever(
|
||||||
mock_chunk_parameters.text_splitter = CharacterTextSplitter()
|
connector=mock_db_connection,
|
||||||
mock_chunk_parameters.splitter_type = SplitterType.USER_DEFINE
|
table_vector_store_connector=mock_table_vector_store_connector,
|
||||||
assembler = EmbeddingAssembler(
|
field_vector_store_connector=mock_field_vector_store_connector,
|
||||||
knowledge=mock_knowledge,
|
separator="--table-field-separator--",
|
||||||
chunk_parameters=mock_chunk_parameters,
|
)
|
||||||
embeddings=mock_embedding_factory.create(),
|
|
||||||
index_store=mock_vector_store_connector,
|
|
||||||
|
def mock_parse_db_summary() -> str:
|
||||||
|
"""Patch _parse_db_summary method."""
|
||||||
|
return (
|
||||||
|
"table_name: user\ncomment: user about dbgpt\n"
|
||||||
|
"--table-field-separator--\n"
|
||||||
|
"name,age\naddress,gender\nmail,phone"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Mocking the _parse_db_summary method in your test function
|
||||||
|
@patch.object(
|
||||||
|
dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary
|
||||||
|
)
|
||||||
|
def test_retrieve_with_mocked_summary(dbstruct_retriever):
|
||||||
|
query = "Table summary"
|
||||||
|
chunks: List[Chunk] = dbstruct_retriever._retrieve(query)
|
||||||
|
assert isinstance(chunks[0], Chunk)
|
||||||
|
assert chunks[0].content == (
|
||||||
|
"table_name: user\ncomment: user about dbgpt\n"
|
||||||
|
"--table-field-separator--\n"
|
||||||
|
"name,age\naddress,gender\nmail,phone"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def async_mock_parse_db_summary() -> str:
|
||||||
|
"""Asynchronous patch for _parse_db_summary method."""
|
||||||
|
return (
|
||||||
|
"table_name: user\ncomment: user about dbgpt\n"
|
||||||
|
"--table-field-separator--\n"
|
||||||
|
"name,age\naddress,gender\nmail,phone"
|
||||||
)
|
)
|
||||||
assembler.load_knowledge(knowledge=mock_knowledge)
|
|
||||||
assert len(assembler._chunks) == 0
|
|
||||||
|
@ -5,9 +5,9 @@ import pytest
|
|||||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
|
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
|
||||||
from dbgpt.rag.assembler.db_schema import DBSchemaAssembler
|
from dbgpt.rag.assembler.db_schema import DBSchemaAssembler
|
||||||
from dbgpt.rag.chunk_manager import ChunkParameters, SplitterType
|
from dbgpt.rag.chunk_manager import ChunkParameters, SplitterType
|
||||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddings, EmbeddingFactory
|
||||||
from dbgpt.rag.text_splitter.text_splitter import CharacterTextSplitter
|
from dbgpt.rag.text_splitter.text_splitter import RDBTextSplitter
|
||||||
from dbgpt.storage.vector_store.chroma_store import ChromaStore
|
from dbgpt.serve.rag.connector import VectorStoreConnector
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -21,14 +21,22 @@ def mock_db_connection():
|
|||||||
"id": "INTEGER PRIMARY KEY",
|
"id": "INTEGER PRIMARY KEY",
|
||||||
"name": "TEXT",
|
"name": "TEXT",
|
||||||
"age": "INTEGER",
|
"age": "INTEGER",
|
||||||
},
|
"address": "TEXT",
|
||||||
"data": [
|
"phone": "TEXT",
|
||||||
(1, "Tom", 10),
|
"email": "TEXT",
|
||||||
(2, "Jerry", 16),
|
"gender": "TEXT",
|
||||||
(3, "Jack", 18),
|
"birthdate": "TEXT",
|
||||||
(4, "Alice", 20),
|
"occupation": "TEXT",
|
||||||
(5, "Bob", 22),
|
"education": "TEXT",
|
||||||
],
|
"marital_status": "TEXT",
|
||||||
|
"nationality": "TEXT",
|
||||||
|
"height": "REAL",
|
||||||
|
"weight": "REAL",
|
||||||
|
"blood_type": "TEXT",
|
||||||
|
"emergency_contact": "TEXT",
|
||||||
|
"created_at": "TEXT",
|
||||||
|
"updated_at": "TEXT",
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -46,23 +54,29 @@ def mock_embedding_factory():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_vector_store_connector():
|
def mock_table_vector_store_connector():
|
||||||
return MagicMock(spec=ChromaStore)
|
mock_connector = MagicMock(spec=VectorStoreConnector)
|
||||||
|
mock_connector.vector_store_config.name = "table_vector_store_name"
|
||||||
|
mock_connector.current_embeddings = DefaultEmbeddings()
|
||||||
|
return mock_connector
|
||||||
|
|
||||||
|
|
||||||
def test_load_knowledge(
|
def test_load_knowledge(
|
||||||
mock_db_connection,
|
mock_db_connection,
|
||||||
mock_chunk_parameters,
|
mock_chunk_parameters,
|
||||||
mock_embedding_factory,
|
mock_embedding_factory,
|
||||||
mock_vector_store_connector,
|
mock_table_vector_store_connector,
|
||||||
):
|
):
|
||||||
mock_chunk_parameters.chunk_strategy = "CHUNK_BY_SIZE"
|
mock_chunk_parameters.chunk_strategy = "CHUNK_BY_SIZE"
|
||||||
mock_chunk_parameters.text_splitter = CharacterTextSplitter()
|
mock_chunk_parameters.text_splitter = RDBTextSplitter(
|
||||||
|
separator="--table-field-separator--"
|
||||||
|
)
|
||||||
mock_chunk_parameters.splitter_type = SplitterType.USER_DEFINE
|
mock_chunk_parameters.splitter_type = SplitterType.USER_DEFINE
|
||||||
assembler = DBSchemaAssembler(
|
assembler = DBSchemaAssembler(
|
||||||
connector=mock_db_connection,
|
connector=mock_db_connection,
|
||||||
chunk_parameters=mock_chunk_parameters,
|
chunk_parameters=mock_chunk_parameters,
|
||||||
embeddings=mock_embedding_factory.create(),
|
embeddings=mock_embedding_factory.create(),
|
||||||
index_store=mock_vector_store_connector,
|
table_vector_store_connector=mock_table_vector_store_connector,
|
||||||
|
max_seq_length=10,
|
||||||
)
|
)
|
||||||
assert len(assembler._chunks) == 1
|
assert len(assembler._chunks) > 1
|
||||||
|
@ -5,7 +5,7 @@ from dbgpt.core import Document
|
|||||||
from dbgpt.datasource import BaseConnector
|
from dbgpt.datasource import BaseConnector
|
||||||
|
|
||||||
from ..summary.gdbms_db_summary import _parse_db_summary as _parse_gdb_summary
|
from ..summary.gdbms_db_summary import _parse_db_summary as _parse_gdb_summary
|
||||||
from ..summary.rdbms_db_summary import _parse_db_summary
|
from ..summary.rdbms_db_summary import _parse_db_summary_with_metadata
|
||||||
from .base import ChunkStrategy, DocumentType, Knowledge, KnowledgeType
|
from .base import ChunkStrategy, DocumentType, Knowledge, KnowledgeType
|
||||||
|
|
||||||
|
|
||||||
@ -15,9 +15,11 @@ class DatasourceKnowledge(Knowledge):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
connector: BaseConnector,
|
connector: BaseConnector,
|
||||||
summary_template: str = "{table_name}({columns})",
|
summary_template: str = "table_name: {table_name}",
|
||||||
|
separator: str = "--table-field-separator--",
|
||||||
knowledge_type: Optional[KnowledgeType] = KnowledgeType.DOCUMENT,
|
knowledge_type: Optional[KnowledgeType] = KnowledgeType.DOCUMENT,
|
||||||
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
|
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
|
||||||
|
model_dimension: int = 512,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Create Datasource Knowledge with Knowledge arguments.
|
"""Create Datasource Knowledge with Knowledge arguments.
|
||||||
@ -25,11 +27,17 @@ class DatasourceKnowledge(Knowledge):
|
|||||||
Args:
|
Args:
|
||||||
connector(BaseConnector): connector
|
connector(BaseConnector): connector
|
||||||
summary_template(str, optional): summary template
|
summary_template(str, optional): summary template
|
||||||
|
separator(str, optional): separator used to separate
|
||||||
|
table's basic info and fields.
|
||||||
|
defaults `-- table-field-separator--`
|
||||||
knowledge_type(KnowledgeType, optional): knowledge type
|
knowledge_type(KnowledgeType, optional): knowledge type
|
||||||
metadata(Dict[str, Union[str, List[str]], optional): metadata
|
metadata(Dict[str, Union[str, List[str]], optional): metadata
|
||||||
|
model_dimension(int, optional): The threshold for splitting field string
|
||||||
"""
|
"""
|
||||||
|
self._separator = separator
|
||||||
self._connector = connector
|
self._connector = connector
|
||||||
self._summary_template = summary_template
|
self._summary_template = summary_template
|
||||||
|
self._model_dimension = model_dimension
|
||||||
super().__init__(knowledge_type=knowledge_type, metadata=metadata, **kwargs)
|
super().__init__(knowledge_type=knowledge_type, metadata=metadata, **kwargs)
|
||||||
|
|
||||||
def _load(self) -> List[Document]:
|
def _load(self) -> List[Document]:
|
||||||
@ -37,13 +45,23 @@ class DatasourceKnowledge(Knowledge):
|
|||||||
docs = []
|
docs = []
|
||||||
if self._connector.is_graph_type():
|
if self._connector.is_graph_type():
|
||||||
db_summary = _parse_gdb_summary(self._connector, self._summary_template)
|
db_summary = _parse_gdb_summary(self._connector, self._summary_template)
|
||||||
|
for table_summary in db_summary:
|
||||||
|
metadata = {"source": "database"}
|
||||||
|
docs.append(Document(content=table_summary, metadata=metadata))
|
||||||
else:
|
else:
|
||||||
db_summary = _parse_db_summary(self._connector, self._summary_template)
|
db_summary_with_metadata = _parse_db_summary_with_metadata(
|
||||||
for table_summary in db_summary:
|
self._connector,
|
||||||
metadata = {"source": "database"}
|
self._summary_template,
|
||||||
if self._metadata:
|
self._separator,
|
||||||
metadata.update(self._metadata) # type: ignore
|
self._model_dimension,
|
||||||
docs.append(Document(content=table_summary, metadata=metadata))
|
)
|
||||||
|
for summary, table_metadata in db_summary_with_metadata:
|
||||||
|
metadata = {"source": "database"}
|
||||||
|
|
||||||
|
if self._metadata:
|
||||||
|
metadata.update(self._metadata) # type: ignore
|
||||||
|
table_metadata.update(metadata)
|
||||||
|
docs.append(Document(content=summary, metadata=table_metadata))
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -1,14 +1,15 @@
|
|||||||
"""The DBSchema Retriever Operator."""
|
"""The DBSchema Retriever Operator."""
|
||||||
|
import os
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from dbgpt.core import Chunk
|
from dbgpt.core import Chunk
|
||||||
from dbgpt.core.interface.operators.retriever import RetrieverOperator
|
from dbgpt.core.interface.operators.retriever import RetrieverOperator
|
||||||
from dbgpt.datasource.base import BaseConnector
|
from dbgpt.datasource.base import BaseConnector
|
||||||
|
from dbgpt.serve.rag.connector import VectorStoreConnector
|
||||||
|
|
||||||
|
from ...storage.vector_store.base import VectorStoreConfig
|
||||||
from ..assembler.db_schema import DBSchemaAssembler
|
from ..assembler.db_schema import DBSchemaAssembler
|
||||||
from ..chunk_manager import ChunkParameters
|
from ..chunk_manager import ChunkParameters
|
||||||
from ..index.base import IndexStoreBase
|
|
||||||
from ..retriever.db_schema import DBSchemaRetriever
|
from ..retriever.db_schema import DBSchemaRetriever
|
||||||
from .assembler import AssemblerOperator
|
from .assembler import AssemblerOperator
|
||||||
|
|
||||||
@ -19,13 +20,14 @@ class DBSchemaRetrieverOperator(RetrieverOperator[str, List[Chunk]]):
|
|||||||
Args:
|
Args:
|
||||||
connector (BaseConnector): The connection.
|
connector (BaseConnector): The connection.
|
||||||
top_k (int, optional): The top k. Defaults to 4.
|
top_k (int, optional): The top k. Defaults to 4.
|
||||||
index_store (IndexStoreBase, optional): The vector store
|
vector_store_connector (VectorStoreConnector, optional): The vector store
|
||||||
connector. Defaults to None.
|
connector. Defaults to None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
index_store: IndexStoreBase,
|
table_vector_store_connector: VectorStoreConnector,
|
||||||
|
field_vector_store_connector: VectorStoreConnector,
|
||||||
top_k: int = 4,
|
top_k: int = 4,
|
||||||
connector: Optional[BaseConnector] = None,
|
connector: Optional[BaseConnector] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
@ -35,7 +37,8 @@ class DBSchemaRetrieverOperator(RetrieverOperator[str, List[Chunk]]):
|
|||||||
self._retriever = DBSchemaRetriever(
|
self._retriever = DBSchemaRetriever(
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
connector=connector,
|
connector=connector,
|
||||||
index_store=index_store,
|
table_vector_store_connector=table_vector_store_connector,
|
||||||
|
field_vector_store_connector=field_vector_store_connector,
|
||||||
)
|
)
|
||||||
|
|
||||||
def retrieve(self, query: str) -> List[Chunk]:
|
def retrieve(self, query: str) -> List[Chunk]:
|
||||||
@ -53,7 +56,8 @@ class DBSchemaAssemblerOperator(AssemblerOperator[BaseConnector, List[Chunk]]):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
connector: BaseConnector,
|
connector: BaseConnector,
|
||||||
index_store: IndexStoreBase,
|
table_vector_store_connector: VectorStoreConnector,
|
||||||
|
field_vector_store_connector: VectorStoreConnector = None,
|
||||||
chunk_parameters: Optional[ChunkParameters] = None,
|
chunk_parameters: Optional[ChunkParameters] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
@ -61,14 +65,26 @@ class DBSchemaAssemblerOperator(AssemblerOperator[BaseConnector, List[Chunk]]):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
connector (BaseConnector): The connection.
|
connector (BaseConnector): The connection.
|
||||||
index_store (IndexStoreBase): The Storage IndexStoreBase.
|
vector_store_connector (VectorStoreConnector): The vector store connector.
|
||||||
chunk_parameters (Optional[ChunkParameters], optional): The chunk
|
chunk_parameters (Optional[ChunkParameters], optional): The chunk
|
||||||
parameters.
|
parameters.
|
||||||
"""
|
"""
|
||||||
if not chunk_parameters:
|
if not chunk_parameters:
|
||||||
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")
|
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")
|
||||||
self._chunk_parameters = chunk_parameters
|
self._chunk_parameters = chunk_parameters
|
||||||
self._index_store = index_store
|
self._table_vector_store_connector = table_vector_store_connector
|
||||||
|
|
||||||
|
field_vector_store_config = VectorStoreConfig(
|
||||||
|
name=table_vector_store_connector.vector_store_config.name + "_field"
|
||||||
|
)
|
||||||
|
self._field_vector_store_connector = (
|
||||||
|
field_vector_store_connector
|
||||||
|
or VectorStoreConnector.from_default(
|
||||||
|
os.getenv("VECTOR_STORE_TYPE", "Chroma"),
|
||||||
|
self._table_vector_store_connector.current_embeddings,
|
||||||
|
vector_store_config=field_vector_store_config,
|
||||||
|
)
|
||||||
|
)
|
||||||
self._connector = connector
|
self._connector = connector
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
@ -84,7 +100,8 @@ class DBSchemaAssemblerOperator(AssemblerOperator[BaseConnector, List[Chunk]]):
|
|||||||
assembler = DBSchemaAssembler.load_from_connection(
|
assembler = DBSchemaAssembler.load_from_connection(
|
||||||
connector=self._connector,
|
connector=self._connector,
|
||||||
chunk_parameters=self._chunk_parameters,
|
chunk_parameters=self._chunk_parameters,
|
||||||
index_store=self._index_store,
|
table_vector_store_connector=self._table_vector_store_connector,
|
||||||
|
field_vector_store_connector=self._field_vector_store_connector,
|
||||||
)
|
)
|
||||||
assembler.persist()
|
assembler.persist()
|
||||||
return assembler.get_chunks()
|
return assembler.get_chunks()
|
||||||
|
@ -1,18 +1,23 @@
|
|||||||
"""DBSchema retriever."""
|
"""DBSchema retriever."""
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
from functools import reduce
|
from dbgpt._private.config import Config
|
||||||
from typing import List, Optional, cast
|
|
||||||
|
|
||||||
from dbgpt.core import Chunk
|
from dbgpt.core import Chunk
|
||||||
from dbgpt.datasource.base import BaseConnector
|
from dbgpt.datasource.base import BaseConnector
|
||||||
from dbgpt.rag.index.base import IndexStoreBase
|
|
||||||
from dbgpt.rag.retriever.base import BaseRetriever
|
from dbgpt.rag.retriever.base import BaseRetriever
|
||||||
from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
|
from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
|
||||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
from dbgpt.rag.summary.gdbms_db_summary import _parse_db_summary
|
||||||
from dbgpt.storage.vector_store.filters import MetadataFilters
|
from dbgpt.serve.rag.connector import VectorStoreConnector
|
||||||
from dbgpt.util.chat_util import run_async_tasks
|
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||||
|
from dbgpt.storage.vector_store.filters import MetadataFilter, MetadataFilters
|
||||||
|
from dbgpt.util.chat_util import run_tasks
|
||||||
from dbgpt.util.executor_utils import blocking_func_to_async_no_executor
|
from dbgpt.util.executor_utils import blocking_func_to_async_no_executor
|
||||||
from dbgpt.util.tracer import root_tracer
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
|
||||||
|
|
||||||
class DBSchemaRetriever(BaseRetriever):
|
class DBSchemaRetriever(BaseRetriever):
|
||||||
@ -20,7 +25,9 @@ class DBSchemaRetriever(BaseRetriever):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
index_store: IndexStoreBase,
|
table_vector_store_connector: VectorStoreConnector,
|
||||||
|
field_vector_store_connector: VectorStoreConnector = None,
|
||||||
|
separator: str = "--table-field-separator--",
|
||||||
top_k: int = 4,
|
top_k: int = 4,
|
||||||
connector: Optional[BaseConnector] = None,
|
connector: Optional[BaseConnector] = None,
|
||||||
query_rewrite: bool = False,
|
query_rewrite: bool = False,
|
||||||
@ -30,7 +37,11 @@ class DBSchemaRetriever(BaseRetriever):
|
|||||||
"""Create DBSchemaRetriever.
|
"""Create DBSchemaRetriever.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
index_store(IndexStore): index connector
|
table_vector_store_connector: VectorStoreConnector
|
||||||
|
to load and retrieve table info.
|
||||||
|
field_vector_store_connector: VectorStoreConnector
|
||||||
|
to load and retrieve field info.
|
||||||
|
separator: field/table separator
|
||||||
top_k (int): top k
|
top_k (int): top k
|
||||||
connector (Optional[BaseConnector]): RDBMSConnector.
|
connector (Optional[BaseConnector]): RDBMSConnector.
|
||||||
query_rewrite (bool): query rewrite
|
query_rewrite (bool): query rewrite
|
||||||
@ -70,34 +81,42 @@ class DBSchemaRetriever(BaseRetriever):
|
|||||||
|
|
||||||
|
|
||||||
connector = _create_temporary_connection()
|
connector = _create_temporary_connection()
|
||||||
|
vector_store_config = ChromaVectorConfig(name="vector_store_name")
|
||||||
|
embedding_model_path = "{your_embedding_model_path}"
|
||||||
embedding_fn = embedding_factory.create(model_name=embedding_model_path)
|
embedding_fn = embedding_factory.create(model_name=embedding_model_path)
|
||||||
config = ChromaVectorConfig(
|
vector_connector = VectorStoreConnector.from_default(
|
||||||
persist_path=PILOT_PATH,
|
"Chroma",
|
||||||
name="dbschema_rag_test",
|
vector_store_config=vector_store_config,
|
||||||
embedding_fn=DefaultEmbeddingFactory(
|
embedding_fn=embedding_fn,
|
||||||
default_model_name=os.path.join(
|
|
||||||
MODEL_PATH, "text2vec-large-chinese"
|
|
||||||
),
|
|
||||||
).create(),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
vector_store = ChromaStore(config)
|
|
||||||
# get db struct retriever
|
# get db struct retriever
|
||||||
retriever = DBSchemaRetriever(
|
retriever = DBSchemaRetriever(
|
||||||
top_k=3,
|
top_k=3,
|
||||||
index_store=vector_store,
|
vector_store_connector=vector_connector,
|
||||||
connector=connector,
|
connector=connector,
|
||||||
)
|
)
|
||||||
chunks = retriever.retrieve("show columns from table")
|
chunks = retriever.retrieve("show columns from table")
|
||||||
result = [chunk.content for chunk in chunks]
|
result = [chunk.content for chunk in chunks]
|
||||||
print(f"db struct rag example results:{result}")
|
print(f"db struct rag example results:{result}")
|
||||||
"""
|
"""
|
||||||
|
self._separator = separator
|
||||||
self._top_k = top_k
|
self._top_k = top_k
|
||||||
self._connector = connector
|
self._connector = connector
|
||||||
self._query_rewrite = query_rewrite
|
self._query_rewrite = query_rewrite
|
||||||
self._index_store = index_store
|
self._table_vector_store_connector = table_vector_store_connector
|
||||||
|
field_vector_store_config = VectorStoreConfig(
|
||||||
|
name=table_vector_store_connector.vector_store_config.name + "_field"
|
||||||
|
)
|
||||||
|
self._field_vector_store_connector = (
|
||||||
|
field_vector_store_connector
|
||||||
|
or VectorStoreConnector.from_default(
|
||||||
|
os.getenv("VECTOR_STORE_TYPE", "Chroma"),
|
||||||
|
self._table_vector_store_connector.current_embeddings,
|
||||||
|
vector_store_config=field_vector_store_config,
|
||||||
|
)
|
||||||
|
)
|
||||||
self._need_embeddings = False
|
self._need_embeddings = False
|
||||||
if self._index_store:
|
if self._table_vector_store_connector:
|
||||||
self._need_embeddings = True
|
self._need_embeddings = True
|
||||||
self._rerank = rerank or DefaultRanker(self._top_k)
|
self._rerank = rerank or DefaultRanker(self._top_k)
|
||||||
|
|
||||||
@ -114,15 +133,8 @@ class DBSchemaRetriever(BaseRetriever):
|
|||||||
List[Chunk]: list of chunks
|
List[Chunk]: list of chunks
|
||||||
"""
|
"""
|
||||||
if self._need_embeddings:
|
if self._need_embeddings:
|
||||||
queries = [query]
|
return self._similarity_search(query, filters)
|
||||||
candidates = [
|
|
||||||
self._index_store.similar_search(query, self._top_k, filters)
|
|
||||||
for query in queries
|
|
||||||
]
|
|
||||||
return cast(List[Chunk], reduce(lambda x, y: x + y, candidates))
|
|
||||||
else:
|
else:
|
||||||
if not self._connector:
|
|
||||||
raise RuntimeError("RDBMSConnector connection is required.")
|
|
||||||
table_summaries = _parse_db_summary(self._connector)
|
table_summaries = _parse_db_summary(self._connector)
|
||||||
return [Chunk(content=table_summary) for table_summary in table_summaries]
|
return [Chunk(content=table_summary) for table_summary in table_summaries]
|
||||||
|
|
||||||
@ -156,30 +168,11 @@ class DBSchemaRetriever(BaseRetriever):
|
|||||||
Returns:
|
Returns:
|
||||||
List[Chunk]: list of chunks
|
List[Chunk]: list of chunks
|
||||||
"""
|
"""
|
||||||
if self._need_embeddings:
|
return await blocking_func_to_async_no_executor(
|
||||||
queries = [query]
|
func=self._retrieve,
|
||||||
candidates = [
|
query=query,
|
||||||
self._similarity_search(
|
filters=filters,
|
||||||
query, filters, root_tracer.get_current_span_id()
|
)
|
||||||
)
|
|
||||||
for query in queries
|
|
||||||
]
|
|
||||||
result_candidates = await run_async_tasks(
|
|
||||||
tasks=candidates, concurrency_limit=1
|
|
||||||
)
|
|
||||||
return cast(List[Chunk], reduce(lambda x, y: x + y, result_candidates))
|
|
||||||
else:
|
|
||||||
from dbgpt.rag.summary.rdbms_db_summary import ( # noqa: F401
|
|
||||||
_parse_db_summary,
|
|
||||||
)
|
|
||||||
|
|
||||||
table_summaries = await run_async_tasks(
|
|
||||||
tasks=[self._aparse_db_summary(root_tracer.get_current_span_id())],
|
|
||||||
concurrency_limit=1,
|
|
||||||
)
|
|
||||||
return [
|
|
||||||
Chunk(content=table_summary) for table_summary in table_summaries[0]
|
|
||||||
]
|
|
||||||
|
|
||||||
async def _aretrieve_with_score(
|
async def _aretrieve_with_score(
|
||||||
self,
|
self,
|
||||||
@ -196,34 +189,40 @@ class DBSchemaRetriever(BaseRetriever):
|
|||||||
"""
|
"""
|
||||||
return await self._aretrieve(query, filters)
|
return await self._aretrieve(query, filters)
|
||||||
|
|
||||||
async def _similarity_search(
|
def _retrieve_field(self, table_chunk: Chunk, query) -> Chunk:
|
||||||
self,
|
metadata = table_chunk.metadata
|
||||||
query,
|
metadata["part"] = "field"
|
||||||
filters: Optional[MetadataFilters] = None,
|
filters = [MetadataFilter(key=k, value=v) for k, v in metadata.items()]
|
||||||
parent_span_id: Optional[str] = None,
|
field_chunks = self._field_vector_store_connector.similar_search_with_scores(
|
||||||
|
query, self._top_k, 0, MetadataFilters(filters=filters)
|
||||||
|
)
|
||||||
|
field_contents = [chunk.content for chunk in field_chunks]
|
||||||
|
table_chunk.content += "\n" + self._separator + "\n" + "\n".join(field_contents)
|
||||||
|
return table_chunk
|
||||||
|
|
||||||
|
def _similarity_search(
|
||||||
|
self, query, filters: Optional[MetadataFilters] = None
|
||||||
) -> List[Chunk]:
|
) -> List[Chunk]:
|
||||||
"""Similar search."""
|
"""Similar search."""
|
||||||
with root_tracer.start_span(
|
table_chunks = self._table_vector_store_connector.similar_search_with_scores(
|
||||||
"dbgpt.rag.retriever.db_schema._similarity_search",
|
query, self._top_k, 0, filters
|
||||||
parent_span_id,
|
)
|
||||||
metadata={"query": query},
|
|
||||||
):
|
|
||||||
return await blocking_func_to_async_no_executor(
|
|
||||||
self._index_store.similar_search, query, self._top_k, filters
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _aparse_db_summary(
|
not_sep_chunks = [
|
||||||
self, parent_span_id: Optional[str] = None
|
chunk for chunk in table_chunks if not chunk.metadata.get("separated")
|
||||||
) -> List[str]:
|
]
|
||||||
"""Similar search."""
|
separated_chunks = [
|
||||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
chunk for chunk in table_chunks if chunk.metadata.get("separated")
|
||||||
|
]
|
||||||
|
if not separated_chunks:
|
||||||
|
return not_sep_chunks
|
||||||
|
|
||||||
if not self._connector:
|
# Create tasks list
|
||||||
raise RuntimeError("RDBMSConnector connection is required.")
|
tasks = [
|
||||||
with root_tracer.start_span(
|
lambda c=chunk: self._retrieve_field(c, query) for chunk in separated_chunks
|
||||||
"dbgpt.rag.retriever.db_schema._aparse_db_summary",
|
]
|
||||||
parent_span_id,
|
# Run tasks concurrently
|
||||||
):
|
separated_result = run_tasks(tasks, concurrency_limit=3)
|
||||||
return await blocking_func_to_async_no_executor(
|
|
||||||
_parse_db_summary, self._connector
|
# Combine and return results
|
||||||
)
|
return not_sep_chunks + separated_result
|
||||||
|
@ -15,42 +15,53 @@ def mock_db_connection():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_vector_store_connector():
|
def mock_table_vector_store_connector():
|
||||||
mock_connector = MagicMock()
|
mock_connector = MagicMock()
|
||||||
mock_connector.similar_search.return_value = [Chunk(content="Table summary")] * 4
|
mock_connector.vector_store_config.name = "table_name"
|
||||||
|
mock_connector.similar_search_with_scores.return_value = [
|
||||||
|
Chunk(content="Table summary")
|
||||||
|
] * 4
|
||||||
return mock_connector
|
return mock_connector
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def db_struct_retriever(mock_db_connection, mock_vector_store_connector):
|
def mock_field_vector_store_connector():
|
||||||
|
mock_connector = MagicMock()
|
||||||
|
mock_connector.similar_search_with_scores.return_value = [
|
||||||
|
Chunk(content="Field summary")
|
||||||
|
] * 4
|
||||||
|
return mock_connector
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dbstruct_retriever(
|
||||||
|
mock_db_connection,
|
||||||
|
mock_table_vector_store_connector,
|
||||||
|
mock_field_vector_store_connector,
|
||||||
|
):
|
||||||
return DBSchemaRetriever(
|
return DBSchemaRetriever(
|
||||||
connector=mock_db_connection,
|
connector=mock_db_connection,
|
||||||
index_store=mock_vector_store_connector,
|
table_vector_store_connector=mock_table_vector_store_connector,
|
||||||
|
field_vector_store_connector=mock_field_vector_store_connector,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def mock_parse_db_summary(conn) -> List[str]:
|
def mock_parse_db_summary() -> str:
|
||||||
"""Patch _parse_db_summary method."""
|
"""Patch _parse_db_summary method."""
|
||||||
return ["Table summary"]
|
return "Table summary"
|
||||||
|
|
||||||
|
|
||||||
# Mocking the _parse_db_summary method in your test function
|
# Mocking the _parse_db_summary method in your test function
|
||||||
@patch.object(
|
@patch.object(
|
||||||
dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary
|
dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary
|
||||||
)
|
)
|
||||||
def test_retrieve_with_mocked_summary(db_struct_retriever):
|
def test_retrieve_with_mocked_summary(dbstruct_retriever):
|
||||||
query = "Table summary"
|
query = "Table summary"
|
||||||
chunks: List[Chunk] = db_struct_retriever._retrieve(query)
|
chunks: List[Chunk] = dbstruct_retriever._retrieve(query)
|
||||||
assert isinstance(chunks[0], Chunk)
|
assert isinstance(chunks[0], Chunk)
|
||||||
assert chunks[0].content == "Table summary"
|
assert chunks[0].content == "Table summary"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
async def async_mock_parse_db_summary() -> str:
|
||||||
@patch.object(
|
"""Asynchronous patch for _parse_db_summary method."""
|
||||||
dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary
|
return "Table summary"
|
||||||
)
|
|
||||||
async def test_aretrieve_with_mocked_summary(db_struct_retriever):
|
|
||||||
query = "Table summary"
|
|
||||||
chunks: List[Chunk] = await db_struct_retriever._aretrieve(query)
|
|
||||||
assert isinstance(chunks[0], Chunk)
|
|
||||||
assert chunks[0].content == "Table summary"
|
|
||||||
|
@ -2,13 +2,15 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import traceback
|
import traceback
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from dbgpt._private.config import Config
|
from dbgpt._private.config import Config
|
||||||
from dbgpt.component import SystemApp
|
from dbgpt.component import SystemApp
|
||||||
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
|
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
|
||||||
|
from dbgpt.rag import ChunkParameters
|
||||||
from dbgpt.rag.summary.gdbms_db_summary import GdbmsSummary
|
from dbgpt.rag.summary.gdbms_db_summary import GdbmsSummary
|
||||||
from dbgpt.rag.summary.rdbms_db_summary import RdbmsSummary
|
from dbgpt.rag.summary.rdbms_db_summary import RdbmsSummary
|
||||||
|
from dbgpt.rag.text_splitter.text_splitter import RDBTextSplitter
|
||||||
|
from dbgpt.serve.rag.connector import VectorStoreConnector
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -47,22 +49,26 @@ class DBSummaryClient:
|
|||||||
|
|
||||||
logger.info("db summary embedding success")
|
logger.info("db summary embedding success")
|
||||||
|
|
||||||
def get_db_summary(self, dbname, query, topk) -> List[str]:
|
def get_db_summary(self, dbname, query, topk):
|
||||||
"""Get user query related tables info."""
|
"""Get user query related tables info."""
|
||||||
from dbgpt.serve.rag.connector import VectorStoreConnector
|
|
||||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||||
|
|
||||||
vector_store_config = VectorStoreConfig(name=dbname + "_profile")
|
vector_store_name = dbname + "_profile"
|
||||||
vector_connector = VectorStoreConnector.from_default(
|
table_vector_store_config = VectorStoreConfig(name=vector_store_name)
|
||||||
|
table_vector_connector = VectorStoreConnector.from_default(
|
||||||
CFG.VECTOR_STORE_TYPE,
|
CFG.VECTOR_STORE_TYPE,
|
||||||
embedding_fn=self.embeddings,
|
self.embeddings,
|
||||||
vector_store_config=vector_store_config,
|
vector_store_config=table_vector_store_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
from dbgpt.rag.retriever.db_schema import DBSchemaRetriever
|
from dbgpt.rag.retriever.db_schema import DBSchemaRetriever
|
||||||
|
|
||||||
retriever = DBSchemaRetriever(
|
retriever = DBSchemaRetriever(
|
||||||
top_k=topk, index_store=vector_connector.index_client
|
top_k=topk,
|
||||||
|
table_vector_store_connector=table_vector_connector,
|
||||||
|
separator="--table-field-separator--",
|
||||||
)
|
)
|
||||||
|
|
||||||
table_docs = retriever.retrieve(query)
|
table_docs = retriever.retrieve(query)
|
||||||
ans = [d.content for d in table_docs]
|
ans = [d.content for d in table_docs]
|
||||||
return ans
|
return ans
|
||||||
@ -92,18 +98,23 @@ class DBSummaryClient:
|
|||||||
from dbgpt.serve.rag.connector import VectorStoreConnector
|
from dbgpt.serve.rag.connector import VectorStoreConnector
|
||||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||||
|
|
||||||
vector_store_config = VectorStoreConfig(name=vector_store_name)
|
table_vector_store_config = VectorStoreConfig(name=vector_store_name)
|
||||||
vector_connector = VectorStoreConnector.from_default(
|
table_vector_connector = VectorStoreConnector.from_default(
|
||||||
CFG.VECTOR_STORE_TYPE,
|
CFG.VECTOR_STORE_TYPE,
|
||||||
self.embeddings,
|
self.embeddings,
|
||||||
vector_store_config=vector_store_config,
|
vector_store_config=table_vector_store_config,
|
||||||
)
|
)
|
||||||
if not vector_connector.vector_name_exists():
|
if not table_vector_connector.vector_name_exists():
|
||||||
from dbgpt.rag.assembler.db_schema import DBSchemaAssembler
|
from dbgpt.rag.assembler.db_schema import DBSchemaAssembler
|
||||||
|
|
||||||
|
chunk_parameters = ChunkParameters(
|
||||||
|
text_splitter=RDBTextSplitter(separator="--table-field-separator--")
|
||||||
|
)
|
||||||
db_assembler = DBSchemaAssembler.load_from_connection(
|
db_assembler = DBSchemaAssembler.load_from_connection(
|
||||||
connector=db_summary_client.db,
|
connector=db_summary_client.db,
|
||||||
index_store=vector_connector.index_client,
|
table_vector_store_connector=table_vector_connector,
|
||||||
|
chunk_parameters=chunk_parameters,
|
||||||
|
max_seq_length=CFG.EMBEDDING_MODEL_MAX_SEQ_LEN,
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(db_assembler.get_chunks()) > 0:
|
if len(db_assembler.get_chunks()) > 0:
|
||||||
@ -115,16 +126,26 @@ class DBSummaryClient:
|
|||||||
def delete_db_profile(self, dbname):
|
def delete_db_profile(self, dbname):
|
||||||
"""Delete db profile."""
|
"""Delete db profile."""
|
||||||
vector_store_name = dbname + "_profile"
|
vector_store_name = dbname + "_profile"
|
||||||
|
table_vector_store_name = dbname + "_profile"
|
||||||
|
field_vector_store_name = dbname + "_profile_field"
|
||||||
from dbgpt.serve.rag.connector import VectorStoreConnector
|
from dbgpt.serve.rag.connector import VectorStoreConnector
|
||||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||||
|
|
||||||
vector_store_config = VectorStoreConfig(name=vector_store_name)
|
table_vector_store_config = VectorStoreConfig(name=vector_store_name)
|
||||||
vector_connector = VectorStoreConnector.from_default(
|
field_vector_store_config = VectorStoreConfig(name=field_vector_store_name)
|
||||||
|
table_vector_connector = VectorStoreConnector.from_default(
|
||||||
CFG.VECTOR_STORE_TYPE,
|
CFG.VECTOR_STORE_TYPE,
|
||||||
self.embeddings,
|
self.embeddings,
|
||||||
vector_store_config=vector_store_config,
|
vector_store_config=table_vector_store_config,
|
||||||
)
|
)
|
||||||
vector_connector.delete_vector_name(vector_store_name)
|
field_vector_connector = VectorStoreConnector.from_default(
|
||||||
|
CFG.VECTOR_STORE_TYPE,
|
||||||
|
self.embeddings,
|
||||||
|
vector_store_config=field_vector_store_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
table_vector_connector.delete_vector_name(table_vector_store_name)
|
||||||
|
field_vector_connector.delete_vector_name(field_vector_store_name)
|
||||||
logger.info(f"delete db profile {dbname} success")
|
logger.info(f"delete db profile {dbname} success")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
"""Summary for rdbms database."""
|
"""Summary for rdbms database."""
|
||||||
import re
|
import re
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from dbgpt._private.config import Config
|
from dbgpt._private.config import Config
|
||||||
from dbgpt.datasource import BaseConnector
|
from dbgpt.datasource import BaseConnector
|
||||||
@ -80,6 +80,134 @@ def _parse_db_summary(
|
|||||||
return table_info_summaries
|
return table_info_summaries
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_db_summary_with_metadata(
|
||||||
|
conn: BaseConnector,
|
||||||
|
summary_template: str = "table_name: {table_name}",
|
||||||
|
separator: str = "--table-field-separator--",
|
||||||
|
model_dimension: int = 512,
|
||||||
|
) -> List[Tuple[str, Dict[str, Any]]]:
|
||||||
|
"""Get db summary for database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conn (BaseConnector): database connection
|
||||||
|
summary_template (str): summary template
|
||||||
|
separator(str, optional): separator used to separate table's
|
||||||
|
basic info and fields. defaults to `-- table-field-separator--`
|
||||||
|
model_dimension(int, optional): The threshold for splitting field string
|
||||||
|
"""
|
||||||
|
tables = conn.get_table_names()
|
||||||
|
table_info_summaries = [
|
||||||
|
_parse_table_summary_with_metadata(
|
||||||
|
conn, summary_template, separator, table_name, model_dimension
|
||||||
|
)
|
||||||
|
for table_name in tables
|
||||||
|
]
|
||||||
|
return table_info_summaries
|
||||||
|
|
||||||
|
|
||||||
|
def _split_columns_str(columns: List[str], model_dimension: int):
|
||||||
|
"""Split columns str.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
columns (List[str]): fields string
|
||||||
|
model_dimension (int, optional): The threshold for splitting field string.
|
||||||
|
"""
|
||||||
|
result = []
|
||||||
|
current_string = ""
|
||||||
|
current_length = 0
|
||||||
|
|
||||||
|
for element_str in columns:
|
||||||
|
element_length = len(element_str)
|
||||||
|
|
||||||
|
# If adding the current element's length would exceed the threshold,
|
||||||
|
# add the current string to results and reset
|
||||||
|
if current_length + element_length > model_dimension:
|
||||||
|
result.append(current_string.strip()) # Remove trailing spaces
|
||||||
|
current_string = element_str
|
||||||
|
current_length = element_length
|
||||||
|
else:
|
||||||
|
# If current string is empty, add element directly
|
||||||
|
if current_string:
|
||||||
|
current_string += "," + element_str
|
||||||
|
else:
|
||||||
|
current_string = element_str
|
||||||
|
current_length += element_length + 1 # Add length of space
|
||||||
|
|
||||||
|
# Handle the last string segment
|
||||||
|
if current_string:
|
||||||
|
result.append(current_string.strip())
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_table_summary_with_metadata(
|
||||||
|
conn: BaseConnector,
|
||||||
|
summary_template: str,
|
||||||
|
separator,
|
||||||
|
table_name: str,
|
||||||
|
model_dimension=512,
|
||||||
|
) -> Tuple[str, Dict[str, Any]]:
|
||||||
|
"""Get table summary for table.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conn (BaseConnector): database connection
|
||||||
|
summary_template (str): summary template
|
||||||
|
separator(str, optional): separator used to separate table's
|
||||||
|
basic info and fields. defaults to `-- table-field-separator--`
|
||||||
|
model_dimension(int, optional): The threshold for splitting field string
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
metadata: {'table_name': 'asd', 'separated': 0/1}
|
||||||
|
|
||||||
|
table_name: table1
|
||||||
|
table_comment: comment
|
||||||
|
index_keys: keys
|
||||||
|
--table-field-separator--
|
||||||
|
(column1,comment), (column2, comment), (column3, comment)
|
||||||
|
(column4,comment), (column5, comment), (column6, comment)
|
||||||
|
"""
|
||||||
|
columns = []
|
||||||
|
metadata = {"table_name": table_name, "separated": 0}
|
||||||
|
for column in conn.get_columns(table_name):
|
||||||
|
if column.get("comment"):
|
||||||
|
columns.append(f"{column['name']} ({column.get('comment')})")
|
||||||
|
else:
|
||||||
|
columns.append(f"{column['name']}")
|
||||||
|
metadata.update({"field_num": len(columns)})
|
||||||
|
separated_columns = _split_columns_str(columns, model_dimension=model_dimension)
|
||||||
|
if len(separated_columns) > 1:
|
||||||
|
metadata["separated"] = 1
|
||||||
|
column_str = "\n".join(separated_columns)
|
||||||
|
# Obtain index information
|
||||||
|
index_keys = []
|
||||||
|
raw_indexes = conn.get_indexes(table_name)
|
||||||
|
for index in raw_indexes:
|
||||||
|
if isinstance(index, tuple): # Process tuple type index information
|
||||||
|
index_name, index_creation_command = index
|
||||||
|
# Extract column names using re
|
||||||
|
matched_columns = re.findall(r"\(([^)]+)\)", index_creation_command)
|
||||||
|
if matched_columns:
|
||||||
|
key_str = ", ".join(matched_columns)
|
||||||
|
index_keys.append(f"{index_name}(`{key_str}`) ")
|
||||||
|
else:
|
||||||
|
key_str = ", ".join(index["column_names"])
|
||||||
|
index_keys.append(f"{index['name']}(`{key_str}`) ")
|
||||||
|
table_str = summary_template.format(table_name=table_name)
|
||||||
|
|
||||||
|
try:
|
||||||
|
comment = conn.get_table_comment(table_name)
|
||||||
|
except Exception:
|
||||||
|
comment = dict(text=None)
|
||||||
|
if comment.get("text"):
|
||||||
|
table_str += f"\ntable_comment: {comment.get('text')}"
|
||||||
|
|
||||||
|
if len(index_keys) > 0:
|
||||||
|
index_key_str = ", ".join(index_keys)
|
||||||
|
table_str += f"\nindex_keys: {index_key_str}"
|
||||||
|
table_str += f"\n{separator}\n{column_str}"
|
||||||
|
return table_str, metadata
|
||||||
|
|
||||||
|
|
||||||
def _parse_table_summary(
|
def _parse_table_summary(
|
||||||
conn: BaseConnector, summary_template: str, table_name: str
|
conn: BaseConnector, summary_template: str, table_name: str
|
||||||
) -> str:
|
) -> str:
|
||||||
|
@ -912,3 +912,42 @@ class PageTextSplitter(TextSplitter):
|
|||||||
new_doc = Chunk(content=text, metadata=copy.deepcopy(_metadatas[i]))
|
new_doc = Chunk(content=text, metadata=copy.deepcopy(_metadatas[i]))
|
||||||
chunks.append(new_doc)
|
chunks.append(new_doc)
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
class RDBTextSplitter(TextSplitter):
|
||||||
|
"""Split relational database tables and fields."""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
"""Create a new TextSplitter."""
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def split_text(self, text: str, **kwargs):
|
||||||
|
"""Split text into a couple of parts."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def split_documents(self, documents: Iterable[Document], **kwargs) -> List[Chunk]:
|
||||||
|
"""Split document into chunks."""
|
||||||
|
chunks = []
|
||||||
|
for doc in documents:
|
||||||
|
metadata = doc.metadata
|
||||||
|
content = doc.content
|
||||||
|
if metadata.get("separated"):
|
||||||
|
# separate table and field
|
||||||
|
parts = content.split(self._separator)
|
||||||
|
table_part, field_part = parts[0], parts[1]
|
||||||
|
table_metadata, field_metadata = copy.deepcopy(metadata), copy.deepcopy(
|
||||||
|
metadata
|
||||||
|
)
|
||||||
|
table_metadata["part"] = "table" # identify of table_chunk
|
||||||
|
field_metadata["part"] = "field" # identify of field_chunk
|
||||||
|
table_chunk = Chunk(content=table_part, metadata=table_metadata)
|
||||||
|
chunks.append(table_chunk)
|
||||||
|
field_parts = field_part.split("\n")
|
||||||
|
for i, sub_part in enumerate(field_parts):
|
||||||
|
sub_metadata = copy.deepcopy(field_metadata)
|
||||||
|
sub_metadata["part_index"] = i
|
||||||
|
field_chunk = Chunk(content=sub_part, metadata=sub_metadata)
|
||||||
|
chunks.append(field_chunk)
|
||||||
|
else:
|
||||||
|
chunks.append(Chunk(content=content, metadata=metadata))
|
||||||
|
return chunks
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from typing import Any, Coroutine, List
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
from typing import Any, Callable, Coroutine, List
|
||||||
|
|
||||||
|
|
||||||
async def llm_chat_response_nostream(chat_scene: str, **chat_param):
|
async def llm_chat_response_nostream(chat_scene: str, **chat_param):
|
||||||
@ -47,13 +48,34 @@ async def run_async_tasks(
|
|||||||
|
|
||||||
|
|
||||||
def run_tasks(
|
def run_tasks(
|
||||||
tasks: List[Coroutine],
|
tasks: List[Callable],
|
||||||
|
concurrency_limit: int = None,
|
||||||
) -> List[Any]:
|
) -> List[Any]:
|
||||||
"""Run a list of async tasks."""
|
"""
|
||||||
tasks_to_execute: List[Any] = tasks
|
Run a list of tasks concurrently using a thread pool.
|
||||||
|
|
||||||
async def _gather() -> List[Any]:
|
Args:
|
||||||
return await asyncio.gather(*tasks_to_execute)
|
tasks: List of callable functions to execute
|
||||||
|
concurrency_limit: Maximum number of concurrent threads (optional)
|
||||||
|
|
||||||
outputs: List[Any] = asyncio.run(_gather())
|
Returns:
|
||||||
return outputs
|
List of results from all tasks in the order they were submitted
|
||||||
|
"""
|
||||||
|
max_workers = concurrency_limit if concurrency_limit else None
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
|
# Submit all tasks and get futures
|
||||||
|
futures = [executor.submit(task) for task in tasks]
|
||||||
|
|
||||||
|
# Collect results in order, raising any exceptions
|
||||||
|
results = []
|
||||||
|
for future in futures:
|
||||||
|
try:
|
||||||
|
results.append(future.result())
|
||||||
|
except Exception as e:
|
||||||
|
# Cancel any pending futures
|
||||||
|
for f in futures:
|
||||||
|
f.cancel()
|
||||||
|
raise e
|
||||||
|
|
||||||
|
return results
|
||||||
|
317
docker/examples/sqls/case_3_order_wide_table_sqlite_wide.sql
Normal file
317
docker/examples/sqls/case_3_order_wide_table_sqlite_wide.sql
Normal file
@ -0,0 +1,317 @@
|
|||||||
|
CREATE TABLE order_wide_table (
|
||||||
|
|
||||||
|
-- order_base
|
||||||
|
order_id TEXT, -- 订单ID
|
||||||
|
order_no TEXT, -- 订单编号
|
||||||
|
parent_order_no TEXT, -- 父订单编号
|
||||||
|
order_type INTEGER, -- 订单类型:1实物2虚拟3混合
|
||||||
|
order_status INTEGER, -- 订单状态
|
||||||
|
order_source TEXT, -- 订单来源
|
||||||
|
order_source_detail TEXT, -- 订单来源详情
|
||||||
|
create_time DATETIME, -- 创建时间
|
||||||
|
pay_time DATETIME, -- 支付时间
|
||||||
|
finish_time DATETIME, -- 完成时间
|
||||||
|
close_time DATETIME, -- 关闭时间
|
||||||
|
cancel_time DATETIME, -- 取消时间
|
||||||
|
cancel_reason TEXT, -- 取消原因
|
||||||
|
order_remark TEXT, -- 订单备注
|
||||||
|
seller_remark TEXT, -- 卖家备注
|
||||||
|
buyer_remark TEXT, -- 买家备注
|
||||||
|
is_deleted INTEGER, -- 是否删除
|
||||||
|
delete_time DATETIME, -- 删除时间
|
||||||
|
order_ip TEXT, -- 下单IP
|
||||||
|
order_platform TEXT, -- 下单平台
|
||||||
|
order_device TEXT, -- 下单设备
|
||||||
|
order_app_version TEXT, -- APP版本号
|
||||||
|
|
||||||
|
-- order_amount
|
||||||
|
currency TEXT, -- 货币类型
|
||||||
|
exchange_rate REAL, -- 汇率
|
||||||
|
original_amount REAL, -- 原始金额
|
||||||
|
discount_amount REAL, -- 优惠金额
|
||||||
|
coupon_amount REAL, -- 优惠券金额
|
||||||
|
points_amount REAL, -- 积分抵扣金额
|
||||||
|
shipping_amount REAL, -- 运费
|
||||||
|
insurance_amount REAL, -- 保价费
|
||||||
|
tax_amount REAL, -- 税费
|
||||||
|
tariff_amount REAL, -- 关税
|
||||||
|
payment_amount REAL, -- 实付金额
|
||||||
|
commission_amount REAL, -- 佣金金额
|
||||||
|
platform_fee REAL, -- 平台费用
|
||||||
|
seller_income REAL, -- 卖家实收
|
||||||
|
payment_currency TEXT, -- 支付货币
|
||||||
|
payment_exchange_rate REAL, -- 支付汇率
|
||||||
|
|
||||||
|
-- user_info
|
||||||
|
user_id TEXT, -- 用户ID
|
||||||
|
user_name TEXT, -- 用户名
|
||||||
|
user_nickname TEXT, -- 用户昵称
|
||||||
|
user_level INTEGER, -- 用户等级
|
||||||
|
user_type INTEGER, -- 用户类型
|
||||||
|
register_time DATETIME, -- 注册时间
|
||||||
|
register_source TEXT, -- 注册来源
|
||||||
|
mobile TEXT, -- 手机号
|
||||||
|
mobile_area TEXT, -- 手机号区号
|
||||||
|
email TEXT, -- 邮箱
|
||||||
|
is_vip INTEGER, -- 是否VIP
|
||||||
|
vip_level INTEGER, -- VIP等级
|
||||||
|
vip_expire_time DATETIME, -- VIP过期时间
|
||||||
|
user_age INTEGER, -- 用户年龄
|
||||||
|
user_gender INTEGER, -- 用户性别
|
||||||
|
user_birthday DATE, -- 用户生日
|
||||||
|
user_avatar TEXT, -- 用户头像
|
||||||
|
user_province TEXT, -- 用户所在省
|
||||||
|
user_city TEXT, -- 用户所在市
|
||||||
|
user_district TEXT, -- 用户所在区
|
||||||
|
last_login_time DATETIME, -- 最后登录时间
|
||||||
|
last_login_ip TEXT, -- 最后登录IP
|
||||||
|
user_credit_score INTEGER, -- 用户信用分
|
||||||
|
total_order_count INTEGER, -- 历史订单数
|
||||||
|
total_order_amount REAL, -- 历史订单金额
|
||||||
|
|
||||||
|
-- product_info
|
||||||
|
product_id TEXT, -- 商品ID
|
||||||
|
product_code TEXT, -- 商品编码
|
||||||
|
product_name TEXT, -- 商品名称
|
||||||
|
product_short_name TEXT, -- 商品短名称
|
||||||
|
product_type INTEGER, -- 商品类型
|
||||||
|
product_status INTEGER, -- 商品状态
|
||||||
|
category_id TEXT, -- 类目ID
|
||||||
|
category_name TEXT, -- 类目名称
|
||||||
|
category_path TEXT, -- 类目路径
|
||||||
|
brand_id TEXT, -- 品牌ID
|
||||||
|
brand_name TEXT, -- 品牌名称
|
||||||
|
brand_english_name TEXT, -- 品牌英文名
|
||||||
|
seller_id TEXT, -- 卖家ID
|
||||||
|
seller_name TEXT, -- 卖家名称
|
||||||
|
seller_type INTEGER, -- 卖家类型
|
||||||
|
shop_id TEXT, -- 店铺ID
|
||||||
|
shop_name TEXT, -- 店铺名称
|
||||||
|
product_price REAL, -- 商品价格
|
||||||
|
market_price REAL, -- 市场价
|
||||||
|
cost_price REAL, -- 成本价
|
||||||
|
wholesale_price REAL, -- 批发价
|
||||||
|
product_quantity INTEGER, -- 商品数量
|
||||||
|
product_unit TEXT, -- 商品单位
|
||||||
|
product_weight REAL, -- 商品重量(克)
|
||||||
|
product_volume REAL, -- 商品体积(cm³)
|
||||||
|
product_spec TEXT, -- 商品规格
|
||||||
|
product_color TEXT, -- 商品颜色
|
||||||
|
product_size TEXT, -- 商品尺寸
|
||||||
|
product_material TEXT, -- 商品材质
|
||||||
|
product_origin TEXT, -- 商品产地
|
||||||
|
product_shelf_life INTEGER, -- 保质期(天)
|
||||||
|
manufacture_date DATE, -- 生产日期
|
||||||
|
expiry_date DATE, -- 过期日期
|
||||||
|
batch_number TEXT, -- 批次号
|
||||||
|
product_barcode TEXT, -- 商品条码
|
||||||
|
warehouse_id TEXT, -- 发货仓库ID
|
||||||
|
warehouse_name TEXT, -- 发货仓库名称
|
||||||
|
|
||||||
|
-- address_info
|
||||||
|
receiver_name TEXT, -- 收货人姓名
|
||||||
|
receiver_mobile TEXT, -- 收货人手机
|
||||||
|
receiver_tel TEXT, -- 收货人电话
|
||||||
|
receiver_email TEXT, -- 收货人邮箱
|
||||||
|
receiver_country TEXT, -- 国家
|
||||||
|
receiver_province TEXT, -- 省份
|
||||||
|
receiver_city TEXT, -- 城市
|
||||||
|
receiver_district TEXT, -- 区县
|
||||||
|
receiver_street TEXT, -- 街道
|
||||||
|
receiver_address TEXT, -- 详细地址
|
||||||
|
receiver_zip TEXT, -- 邮编
|
||||||
|
address_type INTEGER, -- 地址类型
|
||||||
|
is_default INTEGER, -- 是否默认地址
|
||||||
|
longitude REAL, -- 经度
|
||||||
|
latitude REAL, -- 纬度
|
||||||
|
address_label TEXT, -- 地址标签
|
||||||
|
|
||||||
|
-- shipping_info
|
||||||
|
shipping_type INTEGER, -- 配送方式
|
||||||
|
shipping_method TEXT, -- 配送方式名称
|
||||||
|
shipping_company TEXT, -- 快递公司
|
||||||
|
shipping_company_code TEXT, -- 快递公司编码
|
||||||
|
shipping_no TEXT, -- 快递单号
|
||||||
|
shipping_time DATETIME, -- 发货时间
|
||||||
|
shipping_remark TEXT, -- 发货备注
|
||||||
|
expect_receive_time DATETIME, -- 预计送达时间
|
||||||
|
receive_time DATETIME, -- 收货时间
|
||||||
|
sign_type INTEGER, -- 签收类型
|
||||||
|
shipping_status INTEGER, -- 物流状态
|
||||||
|
tracking_url TEXT, -- 物流跟踪URL
|
||||||
|
is_free_shipping INTEGER, -- 是否包邮
|
||||||
|
shipping_insurance REAL, -- 运费险金额
|
||||||
|
shipping_distance REAL, -- 配送距离
|
||||||
|
delivered_time DATETIME, -- 送达时间
|
||||||
|
delivery_staff_id TEXT, -- 配送员ID
|
||||||
|
delivery_staff_name TEXT, -- 配送员姓名
|
||||||
|
delivery_staff_mobile TEXT, -- 配送员电话
|
||||||
|
|
||||||
|
-- payment_info
|
||||||
|
payment_id TEXT, -- 支付ID
|
||||||
|
payment_no TEXT, -- 支付单号
|
||||||
|
payment_type INTEGER, -- 支付方式
|
||||||
|
payment_method TEXT, -- 支付方式名称
|
||||||
|
payment_status INTEGER, -- 支付状态
|
||||||
|
payment_platform TEXT, -- 支付平台
|
||||||
|
transaction_id TEXT, -- 交易流水号
|
||||||
|
payment_time DATETIME, -- 支付时间
|
||||||
|
payment_account TEXT, -- 支付账号
|
||||||
|
payment_bank TEXT, -- 支付银行
|
||||||
|
payment_card_type TEXT, -- 支付卡类型
|
||||||
|
payment_card_no TEXT, -- 支付卡号
|
||||||
|
payment_scene TEXT, -- 支付场景
|
||||||
|
payment_client_ip TEXT, -- 支付IP
|
||||||
|
payment_device TEXT, -- 支付设备
|
||||||
|
payment_remark TEXT, -- 支付备注
|
||||||
|
payment_voucher TEXT, -- 支付凭证
|
||||||
|
|
||||||
|
-- promotion_info
|
||||||
|
promotion_id TEXT, -- 活动ID
|
||||||
|
promotion_name TEXT, -- 活动名称
|
||||||
|
promotion_type INTEGER, -- 活动类型
|
||||||
|
promotion_desc TEXT, -- 活动描述
|
||||||
|
promotion_start_time DATETIME, -- 活动开始时间
|
||||||
|
promotion_end_time DATETIME, -- 活动结束时间
|
||||||
|
coupon_id TEXT, -- 优惠券ID
|
||||||
|
coupon_code TEXT, -- 优惠券码
|
||||||
|
coupon_type INTEGER, -- 优惠券类型
|
||||||
|
coupon_name TEXT, -- 优惠券名称
|
||||||
|
coupon_desc TEXT, -- 优惠券描述
|
||||||
|
points_used INTEGER, -- 使用积分
|
||||||
|
points_gained INTEGER, -- 获得积分
|
||||||
|
points_multiple REAL, -- 积分倍率
|
||||||
|
is_first_order INTEGER, -- 是否首单
|
||||||
|
is_new_customer INTEGER, -- 是否新客
|
||||||
|
marketing_channel TEXT, -- 营销渠道
|
||||||
|
marketing_source TEXT, -- 营销来源
|
||||||
|
referral_code TEXT, -- 推荐码
|
||||||
|
referral_user_id TEXT, -- 推荐人ID
|
||||||
|
|
||||||
|
-- after_sale_info
|
||||||
|
refund_id TEXT, -- 退款ID
|
||||||
|
refund_no TEXT, -- 退款单号
|
||||||
|
refund_type INTEGER, -- 退款类型
|
||||||
|
refund_status INTEGER, -- 退款状态
|
||||||
|
refund_reason TEXT, -- 退款原因
|
||||||
|
refund_desc TEXT, -- 退款描述
|
||||||
|
refund_time DATETIME, -- 退款时间
|
||||||
|
refund_amount REAL, -- 退款金额
|
||||||
|
return_shipping_no TEXT, -- 退货快递单号
|
||||||
|
return_shipping_company TEXT, -- 退货快递公司
|
||||||
|
return_shipping_time DATETIME, -- 退货时间
|
||||||
|
refund_evidence TEXT, -- 退款凭证
|
||||||
|
complaint_id TEXT, -- 投诉ID
|
||||||
|
complaint_type INTEGER, -- 投诉类型
|
||||||
|
complaint_status INTEGER, -- 投诉状态
|
||||||
|
complaint_content TEXT, -- 投诉内容
|
||||||
|
complaint_time DATETIME, -- 投诉时间
|
||||||
|
complaint_handle_time DATETIME, -- 投诉处理时间
|
||||||
|
complaint_handle_result TEXT, -- 投诉处理结果
|
||||||
|
evaluation_score INTEGER, -- 评价分数
|
||||||
|
evaluation_content TEXT, -- 评价内容
|
||||||
|
evaluation_time DATETIME, -- 评价时间
|
||||||
|
evaluation_reply TEXT, -- 评价回复
|
||||||
|
evaluation_reply_time DATETIME, -- 评价回复时间
|
||||||
|
evaluation_images TEXT, -- 评价图片
|
||||||
|
evaluation_videos TEXT, -- 评价视频
|
||||||
|
is_anonymous INTEGER, -- 是否匿名评价
|
||||||
|
|
||||||
|
-- invoice_info
|
||||||
|
invoice_type INTEGER, -- 发票类型
|
||||||
|
invoice_title TEXT, -- 发票抬头
|
||||||
|
invoice_content TEXT, -- 发票内容
|
||||||
|
tax_no TEXT, -- 税号
|
||||||
|
invoice_amount REAL, -- 发票金额
|
||||||
|
invoice_status INTEGER, -- 发票状态
|
||||||
|
invoice_time DATETIME, -- 开票时间
|
||||||
|
invoice_number TEXT, -- 发票号码
|
||||||
|
invoice_code TEXT, -- 发票代码
|
||||||
|
company_name TEXT, -- 单位名称
|
||||||
|
company_address TEXT, -- 单位地址
|
||||||
|
company_tel TEXT, -- 单位电话
|
||||||
|
company_bank TEXT, -- 开户银行
|
||||||
|
company_account TEXT, -- 银行账号
|
||||||
|
|
||||||
|
-- delivery_time_info
|
||||||
|
expect_delivery_time DATETIME, -- 期望配送时间
|
||||||
|
delivery_period_type INTEGER, -- 配送时段类型
|
||||||
|
delivery_period_start TEXT, -- 配送时段开始
|
||||||
|
delivery_period_end TEXT, -- 配送时段结束
|
||||||
|
delivery_priority INTEGER, -- 配送优先级
|
||||||
|
|
||||||
|
-- tag_info
|
||||||
|
order_tags TEXT, -- 订单标签
|
||||||
|
user_tags TEXT, -- 用户标签
|
||||||
|
product_tags TEXT, -- 商品标签
|
||||||
|
risk_level INTEGER, -- 风险等级
|
||||||
|
risk_tags TEXT, -- 风险标签
|
||||||
|
business_tags TEXT, -- 业务标签
|
||||||
|
|
||||||
|
-- commercial_info
|
||||||
|
gross_profit REAL, -- 毛利
|
||||||
|
gross_profit_rate REAL, -- 毛利率
|
||||||
|
settlement_amount REAL, -- 结算金额
|
||||||
|
settlement_time DATETIME, -- 结算时间
|
||||||
|
settlement_cycle INTEGER, -- 结算周期
|
||||||
|
settlement_status INTEGER, -- 结算状态
|
||||||
|
commission_rate REAL, -- 佣金比例
|
||||||
|
platform_service_fee REAL, -- 平台服务费
|
||||||
|
ad_cost REAL, -- 广告费用
|
||||||
|
promotion_cost REAL -- 推广费用
|
||||||
|
);
|
||||||
|
|
||||||
|
-- 插入示例数据
|
||||||
|
INSERT INTO order_wide_table (
|
||||||
|
-- 基础订单信息
|
||||||
|
order_id, order_no, order_type, order_status, create_time, order_source,
|
||||||
|
-- 订单金额
|
||||||
|
original_amount, payment_amount, shipping_amount,
|
||||||
|
-- 用户信息
|
||||||
|
user_id, user_name, user_level, mobile,
|
||||||
|
-- 商品信息
|
||||||
|
product_id, product_name, product_quantity, product_price,
|
||||||
|
-- 收货信息
|
||||||
|
receiver_name, receiver_mobile, receiver_address,
|
||||||
|
-- 物流信息
|
||||||
|
shipping_no, shipping_status,
|
||||||
|
-- 支付信息
|
||||||
|
payment_type, payment_status,
|
||||||
|
-- 营销信息
|
||||||
|
promotion_id, coupon_amount,
|
||||||
|
-- 发票信息
|
||||||
|
invoice_type, invoice_title
|
||||||
|
) VALUES
|
||||||
|
(
|
||||||
|
'ORD20240101001', 'NO20240101001', 1, 2, '2024-01-01 10:00:00', 'APP',
|
||||||
|
199.99, 188.88, 10.00,
|
||||||
|
'USER001', '张三', 2, '13800138000',
|
||||||
|
'PRD001', 'iPhone 15 手机壳', 2, 89.99,
|
||||||
|
'李四', '13900139000', '北京市朝阳区XX路XX号',
|
||||||
|
'SF123456789', 1,
|
||||||
|
1, 1,
|
||||||
|
'PROM001', 20.00,
|
||||||
|
1, '个人'
|
||||||
|
),
|
||||||
|
(
|
||||||
|
'ORD20240101002', 'NO20240101002', 1, 1, '2024-01-01 11:00:00', 'H5',
|
||||||
|
299.99, 279.99, 0.00,
|
||||||
|
'USER002', '王五', 3, '13700137000',
|
||||||
|
'PRD002', 'AirPods Pro 保护套', 1, 299.99,
|
||||||
|
'赵六', '13600136000', '上海市浦东新区XX路XX号',
|
||||||
|
'YT987654321', 2,
|
||||||
|
2, 2,
|
||||||
|
'PROM002', 10.00,
|
||||||
|
2, '上海科技有限公司'
|
||||||
|
),
|
||||||
|
(
|
||||||
|
'ORD20240101003', 'NO20240101003', 2, 3, '2024-01-01 12:00:00', 'WEB',
|
||||||
|
1999.99, 1899.99, 0.00,
|
||||||
|
'USER003', '陈七', 4, '13500135000',
|
||||||
|
'PRD003', 'MacBook Pro 电脑包', 1, 1999.99,
|
||||||
|
'孙八', '13400134000', '广州市天河区XX路XX号',
|
||||||
|
'JD123123123', 3,
|
||||||
|
3, 1,
|
||||||
|
'PROM003', 100.00,
|
||||||
|
1, '个人'
|
||||||
|
);
|
@ -4,7 +4,8 @@ from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
|
|||||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
|
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
|
||||||
from dbgpt.rag.assembler import DBSchemaAssembler
|
from dbgpt.rag.assembler import DBSchemaAssembler
|
||||||
from dbgpt.rag.embedding import DefaultEmbeddingFactory
|
from dbgpt.rag.embedding import DefaultEmbeddingFactory
|
||||||
from dbgpt.storage.vector_store.chroma_store import ChromaStore, ChromaVectorConfig
|
from dbgpt.serve.rag.connector import VectorStoreConnector
|
||||||
|
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||||
|
|
||||||
"""DB struct rag example.
|
"""DB struct rag example.
|
||||||
pre-requirements:
|
pre-requirements:
|
||||||
@ -12,7 +13,7 @@ from dbgpt.storage.vector_store.chroma_store import ChromaStore, ChromaVectorCon
|
|||||||
```
|
```
|
||||||
embedding_model_path = "{your_embedding_model_path}"
|
embedding_model_path = "{your_embedding_model_path}"
|
||||||
```
|
```
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
..code-block:: shell
|
..code-block:: shell
|
||||||
python examples/rag/db_schema_rag_example.py
|
python examples/rag/db_schema_rag_example.py
|
||||||
@ -45,27 +46,26 @@ def _create_temporary_connection():
|
|||||||
|
|
||||||
def _create_vector_connector():
|
def _create_vector_connector():
|
||||||
"""Create vector connector."""
|
"""Create vector connector."""
|
||||||
config = ChromaVectorConfig(
|
return VectorStoreConnector.from_default(
|
||||||
persist_path=PILOT_PATH,
|
"Chroma",
|
||||||
name="dbschema_rag_test",
|
vector_store_config=ChromaVectorConfig(
|
||||||
|
name="db_schema_vector_store_name",
|
||||||
|
persist_path=os.path.join(PILOT_PATH, "data"),
|
||||||
|
),
|
||||||
embedding_fn=DefaultEmbeddingFactory(
|
embedding_fn=DefaultEmbeddingFactory(
|
||||||
default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"),
|
default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"),
|
||||||
).create(),
|
).create(),
|
||||||
)
|
)
|
||||||
|
|
||||||
return ChromaStore(config)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
connection = _create_temporary_connection()
|
connection = _create_temporary_connection()
|
||||||
index_store = _create_vector_connector()
|
vector_connector = _create_vector_connector()
|
||||||
assembler = DBSchemaAssembler.load_from_connection(
|
assembler = DBSchemaAssembler.load_from_connection(
|
||||||
connector=connection,
|
connector=connection, table_vector_store_connector=vector_connector
|
||||||
index_store=index_store,
|
|
||||||
)
|
)
|
||||||
assembler.persist()
|
assembler.persist()
|
||||||
# get db schema retriever
|
# get db schema retriever
|
||||||
retriever = assembler.as_retriever(top_k=1)
|
retriever = assembler.as_retriever(top_k=1)
|
||||||
chunks = retriever.retrieve("show columns from user")
|
chunks = retriever.retrieve("show columns from user")
|
||||||
print(f"db schema rag example results:{[chunk.content for chunk in chunks]}")
|
print(f"db schema rag example results:{[chunk.content for chunk in chunks]}")
|
||||||
index_store.delete_vector_name("dbschema_rag_test")
|
|
||||||
|
@ -15,6 +15,7 @@ fi
|
|||||||
DEFAULT_DB_FILE="DB-GPT/pilot/data/default_sqlite.db"
|
DEFAULT_DB_FILE="DB-GPT/pilot/data/default_sqlite.db"
|
||||||
DEFAULT_SQL_FILE="DB-GPT/docker/examples/sqls/*_sqlite.sql"
|
DEFAULT_SQL_FILE="DB-GPT/docker/examples/sqls/*_sqlite.sql"
|
||||||
DB_FILE="$WORK_DIR/pilot/data/default_sqlite.db"
|
DB_FILE="$WORK_DIR/pilot/data/default_sqlite.db"
|
||||||
|
WIDE_DB_FILE="$WORK_DIR/pilot/data/wide_sqlite.db"
|
||||||
SQL_FILE=""
|
SQL_FILE=""
|
||||||
|
|
||||||
usage () {
|
usage () {
|
||||||
@ -61,6 +62,12 @@ if [ -n $SQL_FILE ];then
|
|||||||
sqlite3 $DB_FILE < "$file"
|
sqlite3 $DB_FILE < "$file"
|
||||||
done
|
done
|
||||||
|
|
||||||
|
for file in $WORK_DIR/docker/examples/sqls/*_sqlite_wide.sql
|
||||||
|
do
|
||||||
|
echo "execute sql file: $file"
|
||||||
|
sqlite3 $WIDE_DB_FILE < "$file"
|
||||||
|
done
|
||||||
|
|
||||||
else
|
else
|
||||||
echo "Execute SQL file ${SQL_FILE}"
|
echo "Execute SQL file ${SQL_FILE}"
|
||||||
sqlite3 $DB_FILE < $SQL_FILE
|
sqlite3 $DB_FILE < $SQL_FILE
|
||||||
|
Loading…
Reference in New Issue
Block a user