mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-23 20:26:15 +00:00
Feat rdb summary wide table (#2035)
Co-authored-by: dongzhancai1 <dongzhancai1@jd.com> Co-authored-by: dong <dongzhancai@iie2.com>
This commit is contained in:
parent
7f4b5e79cf
commit
9b0161e521
@ -66,6 +66,7 @@ QUANTIZE_8bit=True
|
||||
#** EMBEDDING SETTINGS **#
|
||||
#*******************************************************************#
|
||||
EMBEDDING_MODEL=text2vec
|
||||
EMBEDDING_MODEL_MAX_SEQ_LEN=512
|
||||
#EMBEDDING_MODEL=m3e-large
|
||||
#EMBEDDING_MODEL=bge-large-en
|
||||
#EMBEDDING_MODEL=bge-large-zh
|
||||
|
@ -264,6 +264,9 @@ class Config(metaclass=Singleton):
|
||||
|
||||
# EMBEDDING Configuration
|
||||
self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec")
|
||||
self.EMBEDDING_MODEL_MAX_SEQ_LEN = int(
|
||||
os.getenv("MEMBEDDING_MODEL_MAX_SEQ_LEN", 512)
|
||||
)
|
||||
# Rerank model configuration
|
||||
self.RERANK_MODEL = os.getenv("RERANK_MODEL")
|
||||
self.RERANK_MODEL_PATH = os.getenv("RERANK_MODEL_PATH")
|
||||
|
@ -55,9 +55,6 @@ class ChatWithDbQA(BaseChat):
|
||||
if self.db_name:
|
||||
client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
|
||||
try:
|
||||
# table_infos = client.get_db_summary(
|
||||
# dbname=self.db_name, query=self.current_user_input, topk=self.top_k
|
||||
# )
|
||||
table_infos = await blocking_func_to_async(
|
||||
self._executor,
|
||||
client.get_db_summary,
|
||||
|
@ -1,12 +1,15 @@
|
||||
"""DBSchemaAssembler."""
|
||||
import os
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.core import Chunk, Embeddings
|
||||
from dbgpt.datasource.base import BaseConnector
|
||||
|
||||
from ...serve.rag.connector import VectorStoreConnector
|
||||
from ...storage.vector_store.base import VectorStoreConfig
|
||||
from ..assembler.base import BaseAssembler
|
||||
from ..chunk_manager import ChunkParameters
|
||||
from ..index.base import IndexStoreBase
|
||||
from ..embedding.embedding_factory import DefaultEmbeddingFactory
|
||||
from ..knowledge.datasource import DatasourceKnowledge
|
||||
from ..retriever.db_schema import DBSchemaRetriever
|
||||
|
||||
@ -35,23 +38,64 @@ class DBSchemaAssembler(BaseAssembler):
|
||||
def __init__(
|
||||
self,
|
||||
connector: BaseConnector,
|
||||
index_store: IndexStoreBase,
|
||||
table_vector_store_connector: VectorStoreConnector,
|
||||
field_vector_store_connector: VectorStoreConnector = None,
|
||||
chunk_parameters: Optional[ChunkParameters] = None,
|
||||
embedding_model: Optional[str] = None,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
max_seq_length: int = 512,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize with Embedding Assembler arguments.
|
||||
|
||||
Args:
|
||||
connector: (BaseConnector) BaseConnector connection.
|
||||
index_store: (IndexStoreBase) IndexStoreBase to use.
|
||||
table_vector_store_connector: VectorStoreConnector to load
|
||||
and retrieve table info.
|
||||
field_vector_store_connector: VectorStoreConnector to load
|
||||
and retrieve field info.
|
||||
chunk_manager: (Optional[ChunkManager]) ChunkManager to use for chunking.
|
||||
embedding_model: (Optional[str]) Embedding model to use.
|
||||
embeddings: (Optional[Embeddings]) Embeddings to use.
|
||||
"""
|
||||
knowledge = DatasourceKnowledge(connector)
|
||||
self._connector = connector
|
||||
self._index_store = index_store
|
||||
self._table_vector_store_connector = table_vector_store_connector
|
||||
|
||||
field_vector_store_config = VectorStoreConfig(
|
||||
name=table_vector_store_connector.vector_store_config.name + "_field"
|
||||
)
|
||||
self._field_vector_store_connector = (
|
||||
field_vector_store_connector
|
||||
or VectorStoreConnector.from_default(
|
||||
os.getenv("VECTOR_STORE_TYPE", "Chroma"),
|
||||
self._table_vector_store_connector.current_embeddings,
|
||||
vector_store_config=field_vector_store_config,
|
||||
)
|
||||
)
|
||||
|
||||
self._embedding_model = embedding_model
|
||||
if self._embedding_model and not embeddings:
|
||||
embeddings = DefaultEmbeddingFactory(
|
||||
default_model_name=self._embedding_model
|
||||
).create(self._embedding_model)
|
||||
|
||||
if (
|
||||
embeddings
|
||||
and self._table_vector_store_connector.vector_store_config.embedding_fn
|
||||
is None
|
||||
):
|
||||
self._table_vector_store_connector.vector_store_config.embedding_fn = (
|
||||
embeddings
|
||||
)
|
||||
if (
|
||||
embeddings
|
||||
and self._field_vector_store_connector.vector_store_config.embedding_fn
|
||||
is None
|
||||
):
|
||||
self._field_vector_store_connector.vector_store_config.embedding_fn = (
|
||||
embeddings
|
||||
)
|
||||
knowledge = DatasourceKnowledge(connector, model_dimension=max_seq_length)
|
||||
super().__init__(
|
||||
knowledge=knowledge,
|
||||
chunk_parameters=chunk_parameters,
|
||||
@ -62,23 +106,36 @@ class DBSchemaAssembler(BaseAssembler):
|
||||
def load_from_connection(
|
||||
cls,
|
||||
connector: BaseConnector,
|
||||
index_store: IndexStoreBase,
|
||||
table_vector_store_connector: VectorStoreConnector,
|
||||
field_vector_store_connector: VectorStoreConnector = None,
|
||||
chunk_parameters: Optional[ChunkParameters] = None,
|
||||
embedding_model: Optional[str] = None,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
max_seq_length: int = 512,
|
||||
) -> "DBSchemaAssembler":
|
||||
"""Load document embedding into vector store from path.
|
||||
|
||||
Args:
|
||||
connector: (BaseConnector) BaseConnector connection.
|
||||
index_store: (IndexStoreBase) IndexStoreBase to use.
|
||||
table_vector_store_connector: used to load table chunks.
|
||||
field_vector_store_connector: used to load field chunks
|
||||
if field in table is too much.
|
||||
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for
|
||||
chunking.
|
||||
embedding_model: (Optional[str]) Embedding model to use.
|
||||
embeddings: (Optional[Embeddings]) Embeddings to use.
|
||||
max_seq_length: Embedding model max sequence length
|
||||
Returns:
|
||||
DBSchemaAssembler
|
||||
"""
|
||||
return cls(
|
||||
connector=connector,
|
||||
index_store=index_store,
|
||||
table_vector_store_connector=table_vector_store_connector,
|
||||
field_vector_store_connector=field_vector_store_connector,
|
||||
embedding_model=embedding_model,
|
||||
chunk_parameters=chunk_parameters,
|
||||
embeddings=embeddings,
|
||||
max_seq_length=max_seq_length,
|
||||
)
|
||||
|
||||
def get_chunks(self) -> List[Chunk]:
|
||||
@ -91,7 +148,19 @@ class DBSchemaAssembler(BaseAssembler):
|
||||
Returns:
|
||||
List[str]: List of chunk ids.
|
||||
"""
|
||||
return self._index_store.load_document(self._chunks)
|
||||
table_chunks, field_chunks = [], []
|
||||
for chunk in self._chunks:
|
||||
metadata = chunk.metadata
|
||||
if metadata.get("separated"):
|
||||
if metadata.get("part") == "table":
|
||||
table_chunks.append(chunk)
|
||||
else:
|
||||
field_chunks.append(chunk)
|
||||
else:
|
||||
table_chunks.append(chunk)
|
||||
|
||||
self._field_vector_store_connector.load_document(field_chunks)
|
||||
return self._table_vector_store_connector.load_document(table_chunks)
|
||||
|
||||
def _extract_info(self, chunks) -> List[Chunk]:
|
||||
"""Extract info from chunks."""
|
||||
@ -110,5 +179,6 @@ class DBSchemaAssembler(BaseAssembler):
|
||||
top_k=top_k,
|
||||
connector=self._connector,
|
||||
is_embeddings=True,
|
||||
index_store=self._index_store,
|
||||
table_vector_store_connector=self._table_vector_store_connector,
|
||||
field_vector_store_connector=self._field_vector_store_connector,
|
||||
)
|
||||
|
@ -1,76 +1,117 @@
|
||||
from unittest.mock import MagicMock
|
||||
from typing import List
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
|
||||
from dbgpt.rag.assembler.embedding import EmbeddingAssembler
|
||||
from dbgpt.rag.chunk_manager import ChunkParameters, SplitterType
|
||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.rag.knowledge.base import Knowledge
|
||||
from dbgpt.rag.text_splitter.text_splitter import CharacterTextSplitter
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaStore
|
||||
import dbgpt
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.rag.retriever.db_schema import DBSchemaRetriever
|
||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_connection():
|
||||
"""Create a temporary database connection for testing."""
|
||||
connect = SQLiteTempConnector.create_temporary_db()
|
||||
connect.create_temp_tables(
|
||||
{
|
||||
"user": {
|
||||
"columns": {
|
||||
"id": "INTEGER PRIMARY KEY",
|
||||
"name": "TEXT",
|
||||
"age": "INTEGER",
|
||||
},
|
||||
"data": [
|
||||
(1, "Tom", 10),
|
||||
(2, "Jerry", 16),
|
||||
(3, "Jack", 18),
|
||||
(4, "Alice", 20),
|
||||
(5, "Bob", 22),
|
||||
],
|
||||
}
|
||||
}
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_table_vector_store_connector():
|
||||
mock_connector = MagicMock()
|
||||
mock_connector.vector_store_config.name = "table_name"
|
||||
chunk = Chunk(
|
||||
content="table_name: user\ncomment: user about dbgpt",
|
||||
metadata={
|
||||
"field_num": 6,
|
||||
"part": "table",
|
||||
"separated": 1,
|
||||
"table_name": "user",
|
||||
},
|
||||
)
|
||||
return connect
|
||||
mock_connector.similar_search_with_scores = MagicMock(return_value=[chunk])
|
||||
return mock_connector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_chunk_parameters():
|
||||
return MagicMock(spec=ChunkParameters)
|
||||
def mock_field_vector_store_connector():
|
||||
mock_connector = MagicMock()
|
||||
chunk1 = Chunk(
|
||||
content="name,age",
|
||||
metadata={
|
||||
"field_num": 6,
|
||||
"part": "field",
|
||||
"part_index": 0,
|
||||
"separated": 1,
|
||||
"table_name": "user",
|
||||
},
|
||||
)
|
||||
chunk2 = Chunk(
|
||||
content="address,gender",
|
||||
metadata={
|
||||
"field_num": 6,
|
||||
"part": "field",
|
||||
"part_index": 1,
|
||||
"separated": 1,
|
||||
"table_name": "user",
|
||||
},
|
||||
)
|
||||
chunk3 = Chunk(
|
||||
content="mail,phone",
|
||||
metadata={
|
||||
"field_num": 6,
|
||||
"part": "field",
|
||||
"part_index": 2,
|
||||
"separated": 1,
|
||||
"table_name": "user",
|
||||
},
|
||||
)
|
||||
mock_connector.similar_search_with_scores = MagicMock(
|
||||
return_value=[chunk1, chunk2, chunk3]
|
||||
)
|
||||
return mock_connector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embedding_factory():
|
||||
return MagicMock(spec=EmbeddingFactory)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_store_connector():
|
||||
return MagicMock(spec=ChromaStore)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_knowledge():
|
||||
return MagicMock(spec=Knowledge)
|
||||
|
||||
|
||||
def test_load_knowledge(
|
||||
def dbstruct_retriever(
|
||||
mock_db_connection,
|
||||
mock_knowledge,
|
||||
mock_chunk_parameters,
|
||||
mock_embedding_factory,
|
||||
mock_vector_store_connector,
|
||||
mock_table_vector_store_connector,
|
||||
mock_field_vector_store_connector,
|
||||
):
|
||||
mock_chunk_parameters.chunk_strategy = "CHUNK_BY_SIZE"
|
||||
mock_chunk_parameters.text_splitter = CharacterTextSplitter()
|
||||
mock_chunk_parameters.splitter_type = SplitterType.USER_DEFINE
|
||||
assembler = EmbeddingAssembler(
|
||||
knowledge=mock_knowledge,
|
||||
chunk_parameters=mock_chunk_parameters,
|
||||
embeddings=mock_embedding_factory.create(),
|
||||
index_store=mock_vector_store_connector,
|
||||
return DBSchemaRetriever(
|
||||
connector=mock_db_connection,
|
||||
table_vector_store_connector=mock_table_vector_store_connector,
|
||||
field_vector_store_connector=mock_field_vector_store_connector,
|
||||
separator="--table-field-separator--",
|
||||
)
|
||||
|
||||
|
||||
def mock_parse_db_summary() -> str:
|
||||
"""Patch _parse_db_summary method."""
|
||||
return (
|
||||
"table_name: user\ncomment: user about dbgpt\n"
|
||||
"--table-field-separator--\n"
|
||||
"name,age\naddress,gender\nmail,phone"
|
||||
)
|
||||
|
||||
|
||||
# Mocking the _parse_db_summary method in your test function
|
||||
@patch.object(
|
||||
dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary
|
||||
)
|
||||
def test_retrieve_with_mocked_summary(dbstruct_retriever):
|
||||
query = "Table summary"
|
||||
chunks: List[Chunk] = dbstruct_retriever._retrieve(query)
|
||||
assert isinstance(chunks[0], Chunk)
|
||||
assert chunks[0].content == (
|
||||
"table_name: user\ncomment: user about dbgpt\n"
|
||||
"--table-field-separator--\n"
|
||||
"name,age\naddress,gender\nmail,phone"
|
||||
)
|
||||
|
||||
|
||||
def async_mock_parse_db_summary() -> str:
|
||||
"""Asynchronous patch for _parse_db_summary method."""
|
||||
return (
|
||||
"table_name: user\ncomment: user about dbgpt\n"
|
||||
"--table-field-separator--\n"
|
||||
"name,age\naddress,gender\nmail,phone"
|
||||
)
|
||||
assembler.load_knowledge(knowledge=mock_knowledge)
|
||||
assert len(assembler._chunks) == 0
|
||||
|
@ -5,9 +5,9 @@ import pytest
|
||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
|
||||
from dbgpt.rag.assembler.db_schema import DBSchemaAssembler
|
||||
from dbgpt.rag.chunk_manager import ChunkParameters, SplitterType
|
||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.rag.text_splitter.text_splitter import CharacterTextSplitter
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaStore
|
||||
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddings, EmbeddingFactory
|
||||
from dbgpt.rag.text_splitter.text_splitter import RDBTextSplitter
|
||||
from dbgpt.serve.rag.connector import VectorStoreConnector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -21,14 +21,22 @@ def mock_db_connection():
|
||||
"id": "INTEGER PRIMARY KEY",
|
||||
"name": "TEXT",
|
||||
"age": "INTEGER",
|
||||
},
|
||||
"data": [
|
||||
(1, "Tom", 10),
|
||||
(2, "Jerry", 16),
|
||||
(3, "Jack", 18),
|
||||
(4, "Alice", 20),
|
||||
(5, "Bob", 22),
|
||||
],
|
||||
"address": "TEXT",
|
||||
"phone": "TEXT",
|
||||
"email": "TEXT",
|
||||
"gender": "TEXT",
|
||||
"birthdate": "TEXT",
|
||||
"occupation": "TEXT",
|
||||
"education": "TEXT",
|
||||
"marital_status": "TEXT",
|
||||
"nationality": "TEXT",
|
||||
"height": "REAL",
|
||||
"weight": "REAL",
|
||||
"blood_type": "TEXT",
|
||||
"emergency_contact": "TEXT",
|
||||
"created_at": "TEXT",
|
||||
"updated_at": "TEXT",
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
@ -46,23 +54,29 @@ def mock_embedding_factory():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_store_connector():
|
||||
return MagicMock(spec=ChromaStore)
|
||||
def mock_table_vector_store_connector():
|
||||
mock_connector = MagicMock(spec=VectorStoreConnector)
|
||||
mock_connector.vector_store_config.name = "table_vector_store_name"
|
||||
mock_connector.current_embeddings = DefaultEmbeddings()
|
||||
return mock_connector
|
||||
|
||||
|
||||
def test_load_knowledge(
|
||||
mock_db_connection,
|
||||
mock_chunk_parameters,
|
||||
mock_embedding_factory,
|
||||
mock_vector_store_connector,
|
||||
mock_table_vector_store_connector,
|
||||
):
|
||||
mock_chunk_parameters.chunk_strategy = "CHUNK_BY_SIZE"
|
||||
mock_chunk_parameters.text_splitter = CharacterTextSplitter()
|
||||
mock_chunk_parameters.text_splitter = RDBTextSplitter(
|
||||
separator="--table-field-separator--"
|
||||
)
|
||||
mock_chunk_parameters.splitter_type = SplitterType.USER_DEFINE
|
||||
assembler = DBSchemaAssembler(
|
||||
connector=mock_db_connection,
|
||||
chunk_parameters=mock_chunk_parameters,
|
||||
embeddings=mock_embedding_factory.create(),
|
||||
index_store=mock_vector_store_connector,
|
||||
table_vector_store_connector=mock_table_vector_store_connector,
|
||||
max_seq_length=10,
|
||||
)
|
||||
assert len(assembler._chunks) == 1
|
||||
assert len(assembler._chunks) > 1
|
||||
|
@ -5,7 +5,7 @@ from dbgpt.core import Document
|
||||
from dbgpt.datasource import BaseConnector
|
||||
|
||||
from ..summary.gdbms_db_summary import _parse_db_summary as _parse_gdb_summary
|
||||
from ..summary.rdbms_db_summary import _parse_db_summary
|
||||
from ..summary.rdbms_db_summary import _parse_db_summary_with_metadata
|
||||
from .base import ChunkStrategy, DocumentType, Knowledge, KnowledgeType
|
||||
|
||||
|
||||
@ -15,9 +15,11 @@ class DatasourceKnowledge(Knowledge):
|
||||
def __init__(
|
||||
self,
|
||||
connector: BaseConnector,
|
||||
summary_template: str = "{table_name}({columns})",
|
||||
summary_template: str = "table_name: {table_name}",
|
||||
separator: str = "--table-field-separator--",
|
||||
knowledge_type: Optional[KnowledgeType] = KnowledgeType.DOCUMENT,
|
||||
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
|
||||
model_dimension: int = 512,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create Datasource Knowledge with Knowledge arguments.
|
||||
@ -25,11 +27,17 @@ class DatasourceKnowledge(Knowledge):
|
||||
Args:
|
||||
connector(BaseConnector): connector
|
||||
summary_template(str, optional): summary template
|
||||
separator(str, optional): separator used to separate
|
||||
table's basic info and fields.
|
||||
defaults `-- table-field-separator--`
|
||||
knowledge_type(KnowledgeType, optional): knowledge type
|
||||
metadata(Dict[str, Union[str, List[str]], optional): metadata
|
||||
model_dimension(int, optional): The threshold for splitting field string
|
||||
"""
|
||||
self._separator = separator
|
||||
self._connector = connector
|
||||
self._summary_template = summary_template
|
||||
self._model_dimension = model_dimension
|
||||
super().__init__(knowledge_type=knowledge_type, metadata=metadata, **kwargs)
|
||||
|
||||
def _load(self) -> List[Document]:
|
||||
@ -37,13 +45,23 @@ class DatasourceKnowledge(Knowledge):
|
||||
docs = []
|
||||
if self._connector.is_graph_type():
|
||||
db_summary = _parse_gdb_summary(self._connector, self._summary_template)
|
||||
for table_summary in db_summary:
|
||||
metadata = {"source": "database"}
|
||||
docs.append(Document(content=table_summary, metadata=metadata))
|
||||
else:
|
||||
db_summary = _parse_db_summary(self._connector, self._summary_template)
|
||||
for table_summary in db_summary:
|
||||
metadata = {"source": "database"}
|
||||
if self._metadata:
|
||||
metadata.update(self._metadata) # type: ignore
|
||||
docs.append(Document(content=table_summary, metadata=metadata))
|
||||
db_summary_with_metadata = _parse_db_summary_with_metadata(
|
||||
self._connector,
|
||||
self._summary_template,
|
||||
self._separator,
|
||||
self._model_dimension,
|
||||
)
|
||||
for summary, table_metadata in db_summary_with_metadata:
|
||||
metadata = {"source": "database"}
|
||||
|
||||
if self._metadata:
|
||||
metadata.update(self._metadata) # type: ignore
|
||||
table_metadata.update(metadata)
|
||||
docs.append(Document(content=summary, metadata=table_metadata))
|
||||
return docs
|
||||
|
||||
@classmethod
|
||||
|
@ -1,14 +1,15 @@
|
||||
"""The DBSchema Retriever Operator."""
|
||||
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.core.interface.operators.retriever import RetrieverOperator
|
||||
from dbgpt.datasource.base import BaseConnector
|
||||
from dbgpt.serve.rag.connector import VectorStoreConnector
|
||||
|
||||
from ...storage.vector_store.base import VectorStoreConfig
|
||||
from ..assembler.db_schema import DBSchemaAssembler
|
||||
from ..chunk_manager import ChunkParameters
|
||||
from ..index.base import IndexStoreBase
|
||||
from ..retriever.db_schema import DBSchemaRetriever
|
||||
from .assembler import AssemblerOperator
|
||||
|
||||
@ -19,13 +20,14 @@ class DBSchemaRetrieverOperator(RetrieverOperator[str, List[Chunk]]):
|
||||
Args:
|
||||
connector (BaseConnector): The connection.
|
||||
top_k (int, optional): The top k. Defaults to 4.
|
||||
index_store (IndexStoreBase, optional): The vector store
|
||||
vector_store_connector (VectorStoreConnector, optional): The vector store
|
||||
connector. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index_store: IndexStoreBase,
|
||||
table_vector_store_connector: VectorStoreConnector,
|
||||
field_vector_store_connector: VectorStoreConnector,
|
||||
top_k: int = 4,
|
||||
connector: Optional[BaseConnector] = None,
|
||||
**kwargs
|
||||
@ -35,7 +37,8 @@ class DBSchemaRetrieverOperator(RetrieverOperator[str, List[Chunk]]):
|
||||
self._retriever = DBSchemaRetriever(
|
||||
top_k=top_k,
|
||||
connector=connector,
|
||||
index_store=index_store,
|
||||
table_vector_store_connector=table_vector_store_connector,
|
||||
field_vector_store_connector=field_vector_store_connector,
|
||||
)
|
||||
|
||||
def retrieve(self, query: str) -> List[Chunk]:
|
||||
@ -53,7 +56,8 @@ class DBSchemaAssemblerOperator(AssemblerOperator[BaseConnector, List[Chunk]]):
|
||||
def __init__(
|
||||
self,
|
||||
connector: BaseConnector,
|
||||
index_store: IndexStoreBase,
|
||||
table_vector_store_connector: VectorStoreConnector,
|
||||
field_vector_store_connector: VectorStoreConnector = None,
|
||||
chunk_parameters: Optional[ChunkParameters] = None,
|
||||
**kwargs
|
||||
):
|
||||
@ -61,14 +65,26 @@ class DBSchemaAssemblerOperator(AssemblerOperator[BaseConnector, List[Chunk]]):
|
||||
|
||||
Args:
|
||||
connector (BaseConnector): The connection.
|
||||
index_store (IndexStoreBase): The Storage IndexStoreBase.
|
||||
vector_store_connector (VectorStoreConnector): The vector store connector.
|
||||
chunk_parameters (Optional[ChunkParameters], optional): The chunk
|
||||
parameters.
|
||||
"""
|
||||
if not chunk_parameters:
|
||||
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")
|
||||
self._chunk_parameters = chunk_parameters
|
||||
self._index_store = index_store
|
||||
self._table_vector_store_connector = table_vector_store_connector
|
||||
|
||||
field_vector_store_config = VectorStoreConfig(
|
||||
name=table_vector_store_connector.vector_store_config.name + "_field"
|
||||
)
|
||||
self._field_vector_store_connector = (
|
||||
field_vector_store_connector
|
||||
or VectorStoreConnector.from_default(
|
||||
os.getenv("VECTOR_STORE_TYPE", "Chroma"),
|
||||
self._table_vector_store_connector.current_embeddings,
|
||||
vector_store_config=field_vector_store_config,
|
||||
)
|
||||
)
|
||||
self._connector = connector
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@ -84,7 +100,8 @@ class DBSchemaAssemblerOperator(AssemblerOperator[BaseConnector, List[Chunk]]):
|
||||
assembler = DBSchemaAssembler.load_from_connection(
|
||||
connector=self._connector,
|
||||
chunk_parameters=self._chunk_parameters,
|
||||
index_store=self._index_store,
|
||||
table_vector_store_connector=self._table_vector_store_connector,
|
||||
field_vector_store_connector=self._field_vector_store_connector,
|
||||
)
|
||||
assembler.persist()
|
||||
return assembler.get_chunks()
|
||||
|
@ -1,18 +1,23 @@
|
||||
"""DBSchema retriever."""
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from functools import reduce
|
||||
from typing import List, Optional, cast
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.datasource.base import BaseConnector
|
||||
from dbgpt.rag.index.base import IndexStoreBase
|
||||
from dbgpt.rag.retriever.base import BaseRetriever
|
||||
from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
|
||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||
from dbgpt.util.chat_util import run_async_tasks
|
||||
from dbgpt.rag.summary.gdbms_db_summary import _parse_db_summary
|
||||
from dbgpt.serve.rag.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
from dbgpt.storage.vector_store.filters import MetadataFilter, MetadataFilters
|
||||
from dbgpt.util.chat_util import run_tasks
|
||||
from dbgpt.util.executor_utils import blocking_func_to_async_no_executor
|
||||
from dbgpt.util.tracer import root_tracer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class DBSchemaRetriever(BaseRetriever):
|
||||
@ -20,7 +25,9 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index_store: IndexStoreBase,
|
||||
table_vector_store_connector: VectorStoreConnector,
|
||||
field_vector_store_connector: VectorStoreConnector = None,
|
||||
separator: str = "--table-field-separator--",
|
||||
top_k: int = 4,
|
||||
connector: Optional[BaseConnector] = None,
|
||||
query_rewrite: bool = False,
|
||||
@ -30,7 +37,11 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
"""Create DBSchemaRetriever.
|
||||
|
||||
Args:
|
||||
index_store(IndexStore): index connector
|
||||
table_vector_store_connector: VectorStoreConnector
|
||||
to load and retrieve table info.
|
||||
field_vector_store_connector: VectorStoreConnector
|
||||
to load and retrieve field info.
|
||||
separator: field/table separator
|
||||
top_k (int): top k
|
||||
connector (Optional[BaseConnector]): RDBMSConnector.
|
||||
query_rewrite (bool): query rewrite
|
||||
@ -70,34 +81,42 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
|
||||
|
||||
connector = _create_temporary_connection()
|
||||
vector_store_config = ChromaVectorConfig(name="vector_store_name")
|
||||
embedding_model_path = "{your_embedding_model_path}"
|
||||
embedding_fn = embedding_factory.create(model_name=embedding_model_path)
|
||||
config = ChromaVectorConfig(
|
||||
persist_path=PILOT_PATH,
|
||||
name="dbschema_rag_test",
|
||||
embedding_fn=DefaultEmbeddingFactory(
|
||||
default_model_name=os.path.join(
|
||||
MODEL_PATH, "text2vec-large-chinese"
|
||||
),
|
||||
).create(),
|
||||
vector_connector = VectorStoreConnector.from_default(
|
||||
"Chroma",
|
||||
vector_store_config=vector_store_config,
|
||||
embedding_fn=embedding_fn,
|
||||
)
|
||||
|
||||
vector_store = ChromaStore(config)
|
||||
# get db struct retriever
|
||||
retriever = DBSchemaRetriever(
|
||||
top_k=3,
|
||||
index_store=vector_store,
|
||||
vector_store_connector=vector_connector,
|
||||
connector=connector,
|
||||
)
|
||||
chunks = retriever.retrieve("show columns from table")
|
||||
result = [chunk.content for chunk in chunks]
|
||||
print(f"db struct rag example results:{result}")
|
||||
"""
|
||||
self._separator = separator
|
||||
self._top_k = top_k
|
||||
self._connector = connector
|
||||
self._query_rewrite = query_rewrite
|
||||
self._index_store = index_store
|
||||
self._table_vector_store_connector = table_vector_store_connector
|
||||
field_vector_store_config = VectorStoreConfig(
|
||||
name=table_vector_store_connector.vector_store_config.name + "_field"
|
||||
)
|
||||
self._field_vector_store_connector = (
|
||||
field_vector_store_connector
|
||||
or VectorStoreConnector.from_default(
|
||||
os.getenv("VECTOR_STORE_TYPE", "Chroma"),
|
||||
self._table_vector_store_connector.current_embeddings,
|
||||
vector_store_config=field_vector_store_config,
|
||||
)
|
||||
)
|
||||
self._need_embeddings = False
|
||||
if self._index_store:
|
||||
if self._table_vector_store_connector:
|
||||
self._need_embeddings = True
|
||||
self._rerank = rerank or DefaultRanker(self._top_k)
|
||||
|
||||
@ -114,15 +133,8 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
List[Chunk]: list of chunks
|
||||
"""
|
||||
if self._need_embeddings:
|
||||
queries = [query]
|
||||
candidates = [
|
||||
self._index_store.similar_search(query, self._top_k, filters)
|
||||
for query in queries
|
||||
]
|
||||
return cast(List[Chunk], reduce(lambda x, y: x + y, candidates))
|
||||
return self._similarity_search(query, filters)
|
||||
else:
|
||||
if not self._connector:
|
||||
raise RuntimeError("RDBMSConnector connection is required.")
|
||||
table_summaries = _parse_db_summary(self._connector)
|
||||
return [Chunk(content=table_summary) for table_summary in table_summaries]
|
||||
|
||||
@ -156,30 +168,11 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
Returns:
|
||||
List[Chunk]: list of chunks
|
||||
"""
|
||||
if self._need_embeddings:
|
||||
queries = [query]
|
||||
candidates = [
|
||||
self._similarity_search(
|
||||
query, filters, root_tracer.get_current_span_id()
|
||||
)
|
||||
for query in queries
|
||||
]
|
||||
result_candidates = await run_async_tasks(
|
||||
tasks=candidates, concurrency_limit=1
|
||||
)
|
||||
return cast(List[Chunk], reduce(lambda x, y: x + y, result_candidates))
|
||||
else:
|
||||
from dbgpt.rag.summary.rdbms_db_summary import ( # noqa: F401
|
||||
_parse_db_summary,
|
||||
)
|
||||
|
||||
table_summaries = await run_async_tasks(
|
||||
tasks=[self._aparse_db_summary(root_tracer.get_current_span_id())],
|
||||
concurrency_limit=1,
|
||||
)
|
||||
return [
|
||||
Chunk(content=table_summary) for table_summary in table_summaries[0]
|
||||
]
|
||||
return await blocking_func_to_async_no_executor(
|
||||
func=self._retrieve,
|
||||
query=query,
|
||||
filters=filters,
|
||||
)
|
||||
|
||||
async def _aretrieve_with_score(
|
||||
self,
|
||||
@ -196,34 +189,40 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
"""
|
||||
return await self._aretrieve(query, filters)
|
||||
|
||||
async def _similarity_search(
|
||||
self,
|
||||
query,
|
||||
filters: Optional[MetadataFilters] = None,
|
||||
parent_span_id: Optional[str] = None,
|
||||
def _retrieve_field(self, table_chunk: Chunk, query) -> Chunk:
|
||||
metadata = table_chunk.metadata
|
||||
metadata["part"] = "field"
|
||||
filters = [MetadataFilter(key=k, value=v) for k, v in metadata.items()]
|
||||
field_chunks = self._field_vector_store_connector.similar_search_with_scores(
|
||||
query, self._top_k, 0, MetadataFilters(filters=filters)
|
||||
)
|
||||
field_contents = [chunk.content for chunk in field_chunks]
|
||||
table_chunk.content += "\n" + self._separator + "\n" + "\n".join(field_contents)
|
||||
return table_chunk
|
||||
|
||||
def _similarity_search(
|
||||
self, query, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
"""Similar search."""
|
||||
with root_tracer.start_span(
|
||||
"dbgpt.rag.retriever.db_schema._similarity_search",
|
||||
parent_span_id,
|
||||
metadata={"query": query},
|
||||
):
|
||||
return await blocking_func_to_async_no_executor(
|
||||
self._index_store.similar_search, query, self._top_k, filters
|
||||
)
|
||||
table_chunks = self._table_vector_store_connector.similar_search_with_scores(
|
||||
query, self._top_k, 0, filters
|
||||
)
|
||||
|
||||
async def _aparse_db_summary(
|
||||
self, parent_span_id: Optional[str] = None
|
||||
) -> List[str]:
|
||||
"""Similar search."""
|
||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||
not_sep_chunks = [
|
||||
chunk for chunk in table_chunks if not chunk.metadata.get("separated")
|
||||
]
|
||||
separated_chunks = [
|
||||
chunk for chunk in table_chunks if chunk.metadata.get("separated")
|
||||
]
|
||||
if not separated_chunks:
|
||||
return not_sep_chunks
|
||||
|
||||
if not self._connector:
|
||||
raise RuntimeError("RDBMSConnector connection is required.")
|
||||
with root_tracer.start_span(
|
||||
"dbgpt.rag.retriever.db_schema._aparse_db_summary",
|
||||
parent_span_id,
|
||||
):
|
||||
return await blocking_func_to_async_no_executor(
|
||||
_parse_db_summary, self._connector
|
||||
)
|
||||
# Create tasks list
|
||||
tasks = [
|
||||
lambda c=chunk: self._retrieve_field(c, query) for chunk in separated_chunks
|
||||
]
|
||||
# Run tasks concurrently
|
||||
separated_result = run_tasks(tasks, concurrency_limit=3)
|
||||
|
||||
# Combine and return results
|
||||
return not_sep_chunks + separated_result
|
||||
|
@ -15,42 +15,53 @@ def mock_db_connection():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_store_connector():
|
||||
def mock_table_vector_store_connector():
|
||||
mock_connector = MagicMock()
|
||||
mock_connector.similar_search.return_value = [Chunk(content="Table summary")] * 4
|
||||
mock_connector.vector_store_config.name = "table_name"
|
||||
mock_connector.similar_search_with_scores.return_value = [
|
||||
Chunk(content="Table summary")
|
||||
] * 4
|
||||
return mock_connector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_struct_retriever(mock_db_connection, mock_vector_store_connector):
|
||||
def mock_field_vector_store_connector():
|
||||
mock_connector = MagicMock()
|
||||
mock_connector.similar_search_with_scores.return_value = [
|
||||
Chunk(content="Field summary")
|
||||
] * 4
|
||||
return mock_connector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dbstruct_retriever(
|
||||
mock_db_connection,
|
||||
mock_table_vector_store_connector,
|
||||
mock_field_vector_store_connector,
|
||||
):
|
||||
return DBSchemaRetriever(
|
||||
connector=mock_db_connection,
|
||||
index_store=mock_vector_store_connector,
|
||||
table_vector_store_connector=mock_table_vector_store_connector,
|
||||
field_vector_store_connector=mock_field_vector_store_connector,
|
||||
)
|
||||
|
||||
|
||||
def mock_parse_db_summary(conn) -> List[str]:
|
||||
def mock_parse_db_summary() -> str:
|
||||
"""Patch _parse_db_summary method."""
|
||||
return ["Table summary"]
|
||||
return "Table summary"
|
||||
|
||||
|
||||
# Mocking the _parse_db_summary method in your test function
|
||||
@patch.object(
|
||||
dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary
|
||||
)
|
||||
def test_retrieve_with_mocked_summary(db_struct_retriever):
|
||||
def test_retrieve_with_mocked_summary(dbstruct_retriever):
|
||||
query = "Table summary"
|
||||
chunks: List[Chunk] = db_struct_retriever._retrieve(query)
|
||||
chunks: List[Chunk] = dbstruct_retriever._retrieve(query)
|
||||
assert isinstance(chunks[0], Chunk)
|
||||
assert chunks[0].content == "Table summary"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch.object(
|
||||
dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary
|
||||
)
|
||||
async def test_aretrieve_with_mocked_summary(db_struct_retriever):
|
||||
query = "Table summary"
|
||||
chunks: List[Chunk] = await db_struct_retriever._aretrieve(query)
|
||||
assert isinstance(chunks[0], Chunk)
|
||||
assert chunks[0].content == "Table summary"
|
||||
async def async_mock_parse_db_summary() -> str:
|
||||
"""Asynchronous patch for _parse_db_summary method."""
|
||||
return "Table summary"
|
||||
|
@ -2,13 +2,15 @@
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
from typing import List
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
|
||||
from dbgpt.rag import ChunkParameters
|
||||
from dbgpt.rag.summary.gdbms_db_summary import GdbmsSummary
|
||||
from dbgpt.rag.summary.rdbms_db_summary import RdbmsSummary
|
||||
from dbgpt.rag.text_splitter.text_splitter import RDBTextSplitter
|
||||
from dbgpt.serve.rag.connector import VectorStoreConnector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -47,22 +49,26 @@ class DBSummaryClient:
|
||||
|
||||
logger.info("db summary embedding success")
|
||||
|
||||
def get_db_summary(self, dbname, query, topk) -> List[str]:
|
||||
def get_db_summary(self, dbname, query, topk):
|
||||
"""Get user query related tables info."""
|
||||
from dbgpt.serve.rag.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
|
||||
vector_store_config = VectorStoreConfig(name=dbname + "_profile")
|
||||
vector_connector = VectorStoreConnector.from_default(
|
||||
vector_store_name = dbname + "_profile"
|
||||
table_vector_store_config = VectorStoreConfig(name=vector_store_name)
|
||||
table_vector_connector = VectorStoreConnector.from_default(
|
||||
CFG.VECTOR_STORE_TYPE,
|
||||
embedding_fn=self.embeddings,
|
||||
vector_store_config=vector_store_config,
|
||||
self.embeddings,
|
||||
vector_store_config=table_vector_store_config,
|
||||
)
|
||||
|
||||
from dbgpt.rag.retriever.db_schema import DBSchemaRetriever
|
||||
|
||||
retriever = DBSchemaRetriever(
|
||||
top_k=topk, index_store=vector_connector.index_client
|
||||
top_k=topk,
|
||||
table_vector_store_connector=table_vector_connector,
|
||||
separator="--table-field-separator--",
|
||||
)
|
||||
|
||||
table_docs = retriever.retrieve(query)
|
||||
ans = [d.content for d in table_docs]
|
||||
return ans
|
||||
@ -92,18 +98,23 @@ class DBSummaryClient:
|
||||
from dbgpt.serve.rag.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
|
||||
vector_store_config = VectorStoreConfig(name=vector_store_name)
|
||||
vector_connector = VectorStoreConnector.from_default(
|
||||
table_vector_store_config = VectorStoreConfig(name=vector_store_name)
|
||||
table_vector_connector = VectorStoreConnector.from_default(
|
||||
CFG.VECTOR_STORE_TYPE,
|
||||
self.embeddings,
|
||||
vector_store_config=vector_store_config,
|
||||
vector_store_config=table_vector_store_config,
|
||||
)
|
||||
if not vector_connector.vector_name_exists():
|
||||
if not table_vector_connector.vector_name_exists():
|
||||
from dbgpt.rag.assembler.db_schema import DBSchemaAssembler
|
||||
|
||||
chunk_parameters = ChunkParameters(
|
||||
text_splitter=RDBTextSplitter(separator="--table-field-separator--")
|
||||
)
|
||||
db_assembler = DBSchemaAssembler.load_from_connection(
|
||||
connector=db_summary_client.db,
|
||||
index_store=vector_connector.index_client,
|
||||
table_vector_store_connector=table_vector_connector,
|
||||
chunk_parameters=chunk_parameters,
|
||||
max_seq_length=CFG.EMBEDDING_MODEL_MAX_SEQ_LEN,
|
||||
)
|
||||
|
||||
if len(db_assembler.get_chunks()) > 0:
|
||||
@ -115,16 +126,26 @@ class DBSummaryClient:
|
||||
def delete_db_profile(self, dbname):
|
||||
"""Delete db profile."""
|
||||
vector_store_name = dbname + "_profile"
|
||||
table_vector_store_name = dbname + "_profile"
|
||||
field_vector_store_name = dbname + "_profile_field"
|
||||
from dbgpt.serve.rag.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
|
||||
vector_store_config = VectorStoreConfig(name=vector_store_name)
|
||||
vector_connector = VectorStoreConnector.from_default(
|
||||
table_vector_store_config = VectorStoreConfig(name=vector_store_name)
|
||||
field_vector_store_config = VectorStoreConfig(name=field_vector_store_name)
|
||||
table_vector_connector = VectorStoreConnector.from_default(
|
||||
CFG.VECTOR_STORE_TYPE,
|
||||
self.embeddings,
|
||||
vector_store_config=vector_store_config,
|
||||
vector_store_config=table_vector_store_config,
|
||||
)
|
||||
vector_connector.delete_vector_name(vector_store_name)
|
||||
field_vector_connector = VectorStoreConnector.from_default(
|
||||
CFG.VECTOR_STORE_TYPE,
|
||||
self.embeddings,
|
||||
vector_store_config=field_vector_store_config,
|
||||
)
|
||||
|
||||
table_vector_connector.delete_vector_name(table_vector_store_name)
|
||||
field_vector_connector.delete_vector_name(field_vector_store_name)
|
||||
logger.info(f"delete db profile {dbname} success")
|
||||
|
||||
@staticmethod
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Summary for rdbms database."""
|
||||
import re
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.datasource import BaseConnector
|
||||
@ -80,6 +80,134 @@ def _parse_db_summary(
|
||||
return table_info_summaries
|
||||
|
||||
|
||||
def _parse_db_summary_with_metadata(
|
||||
conn: BaseConnector,
|
||||
summary_template: str = "table_name: {table_name}",
|
||||
separator: str = "--table-field-separator--",
|
||||
model_dimension: int = 512,
|
||||
) -> List[Tuple[str, Dict[str, Any]]]:
|
||||
"""Get db summary for database.
|
||||
|
||||
Args:
|
||||
conn (BaseConnector): database connection
|
||||
summary_template (str): summary template
|
||||
separator(str, optional): separator used to separate table's
|
||||
basic info and fields. defaults to `-- table-field-separator--`
|
||||
model_dimension(int, optional): The threshold for splitting field string
|
||||
"""
|
||||
tables = conn.get_table_names()
|
||||
table_info_summaries = [
|
||||
_parse_table_summary_with_metadata(
|
||||
conn, summary_template, separator, table_name, model_dimension
|
||||
)
|
||||
for table_name in tables
|
||||
]
|
||||
return table_info_summaries
|
||||
|
||||
|
||||
def _split_columns_str(columns: List[str], model_dimension: int):
|
||||
"""Split columns str.
|
||||
|
||||
Args:
|
||||
columns (List[str]): fields string
|
||||
model_dimension (int, optional): The threshold for splitting field string.
|
||||
"""
|
||||
result = []
|
||||
current_string = ""
|
||||
current_length = 0
|
||||
|
||||
for element_str in columns:
|
||||
element_length = len(element_str)
|
||||
|
||||
# If adding the current element's length would exceed the threshold,
|
||||
# add the current string to results and reset
|
||||
if current_length + element_length > model_dimension:
|
||||
result.append(current_string.strip()) # Remove trailing spaces
|
||||
current_string = element_str
|
||||
current_length = element_length
|
||||
else:
|
||||
# If current string is empty, add element directly
|
||||
if current_string:
|
||||
current_string += "," + element_str
|
||||
else:
|
||||
current_string = element_str
|
||||
current_length += element_length + 1 # Add length of space
|
||||
|
||||
# Handle the last string segment
|
||||
if current_string:
|
||||
result.append(current_string.strip())
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _parse_table_summary_with_metadata(
|
||||
conn: BaseConnector,
|
||||
summary_template: str,
|
||||
separator,
|
||||
table_name: str,
|
||||
model_dimension=512,
|
||||
) -> Tuple[str, Dict[str, Any]]:
|
||||
"""Get table summary for table.
|
||||
|
||||
Args:
|
||||
conn (BaseConnector): database connection
|
||||
summary_template (str): summary template
|
||||
separator(str, optional): separator used to separate table's
|
||||
basic info and fields. defaults to `-- table-field-separator--`
|
||||
model_dimension(int, optional): The threshold for splitting field string
|
||||
|
||||
Examples:
|
||||
metadata: {'table_name': 'asd', 'separated': 0/1}
|
||||
|
||||
table_name: table1
|
||||
table_comment: comment
|
||||
index_keys: keys
|
||||
--table-field-separator--
|
||||
(column1,comment), (column2, comment), (column3, comment)
|
||||
(column4,comment), (column5, comment), (column6, comment)
|
||||
"""
|
||||
columns = []
|
||||
metadata = {"table_name": table_name, "separated": 0}
|
||||
for column in conn.get_columns(table_name):
|
||||
if column.get("comment"):
|
||||
columns.append(f"{column['name']} ({column.get('comment')})")
|
||||
else:
|
||||
columns.append(f"{column['name']}")
|
||||
metadata.update({"field_num": len(columns)})
|
||||
separated_columns = _split_columns_str(columns, model_dimension=model_dimension)
|
||||
if len(separated_columns) > 1:
|
||||
metadata["separated"] = 1
|
||||
column_str = "\n".join(separated_columns)
|
||||
# Obtain index information
|
||||
index_keys = []
|
||||
raw_indexes = conn.get_indexes(table_name)
|
||||
for index in raw_indexes:
|
||||
if isinstance(index, tuple): # Process tuple type index information
|
||||
index_name, index_creation_command = index
|
||||
# Extract column names using re
|
||||
matched_columns = re.findall(r"\(([^)]+)\)", index_creation_command)
|
||||
if matched_columns:
|
||||
key_str = ", ".join(matched_columns)
|
||||
index_keys.append(f"{index_name}(`{key_str}`) ")
|
||||
else:
|
||||
key_str = ", ".join(index["column_names"])
|
||||
index_keys.append(f"{index['name']}(`{key_str}`) ")
|
||||
table_str = summary_template.format(table_name=table_name)
|
||||
|
||||
try:
|
||||
comment = conn.get_table_comment(table_name)
|
||||
except Exception:
|
||||
comment = dict(text=None)
|
||||
if comment.get("text"):
|
||||
table_str += f"\ntable_comment: {comment.get('text')}"
|
||||
|
||||
if len(index_keys) > 0:
|
||||
index_key_str = ", ".join(index_keys)
|
||||
table_str += f"\nindex_keys: {index_key_str}"
|
||||
table_str += f"\n{separator}\n{column_str}"
|
||||
return table_str, metadata
|
||||
|
||||
|
||||
def _parse_table_summary(
|
||||
conn: BaseConnector, summary_template: str, table_name: str
|
||||
) -> str:
|
||||
|
@ -912,3 +912,42 @@ class PageTextSplitter(TextSplitter):
|
||||
new_doc = Chunk(content=text, metadata=copy.deepcopy(_metadatas[i]))
|
||||
chunks.append(new_doc)
|
||||
return chunks
|
||||
|
||||
|
||||
class RDBTextSplitter(TextSplitter):
|
||||
"""Split relational database tables and fields."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Create a new TextSplitter."""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def split_text(self, text: str, **kwargs):
|
||||
"""Split text into a couple of parts."""
|
||||
pass
|
||||
|
||||
def split_documents(self, documents: Iterable[Document], **kwargs) -> List[Chunk]:
|
||||
"""Split document into chunks."""
|
||||
chunks = []
|
||||
for doc in documents:
|
||||
metadata = doc.metadata
|
||||
content = doc.content
|
||||
if metadata.get("separated"):
|
||||
# separate table and field
|
||||
parts = content.split(self._separator)
|
||||
table_part, field_part = parts[0], parts[1]
|
||||
table_metadata, field_metadata = copy.deepcopy(metadata), copy.deepcopy(
|
||||
metadata
|
||||
)
|
||||
table_metadata["part"] = "table" # identify of table_chunk
|
||||
field_metadata["part"] = "field" # identify of field_chunk
|
||||
table_chunk = Chunk(content=table_part, metadata=table_metadata)
|
||||
chunks.append(table_chunk)
|
||||
field_parts = field_part.split("\n")
|
||||
for i, sub_part in enumerate(field_parts):
|
||||
sub_metadata = copy.deepcopy(field_metadata)
|
||||
sub_metadata["part_index"] = i
|
||||
field_chunk = Chunk(content=sub_part, metadata=sub_metadata)
|
||||
chunks.append(field_chunk)
|
||||
else:
|
||||
chunks.append(Chunk(content=content, metadata=metadata))
|
||||
return chunks
|
||||
|
@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
from typing import Any, Coroutine, List
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Any, Callable, Coroutine, List
|
||||
|
||||
|
||||
async def llm_chat_response_nostream(chat_scene: str, **chat_param):
|
||||
@ -47,13 +48,34 @@ async def run_async_tasks(
|
||||
|
||||
|
||||
def run_tasks(
|
||||
tasks: List[Coroutine],
|
||||
tasks: List[Callable],
|
||||
concurrency_limit: int = None,
|
||||
) -> List[Any]:
|
||||
"""Run a list of async tasks."""
|
||||
tasks_to_execute: List[Any] = tasks
|
||||
"""
|
||||
Run a list of tasks concurrently using a thread pool.
|
||||
|
||||
async def _gather() -> List[Any]:
|
||||
return await asyncio.gather(*tasks_to_execute)
|
||||
Args:
|
||||
tasks: List of callable functions to execute
|
||||
concurrency_limit: Maximum number of concurrent threads (optional)
|
||||
|
||||
outputs: List[Any] = asyncio.run(_gather())
|
||||
return outputs
|
||||
Returns:
|
||||
List of results from all tasks in the order they were submitted
|
||||
"""
|
||||
max_workers = concurrency_limit if concurrency_limit else None
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# Submit all tasks and get futures
|
||||
futures = [executor.submit(task) for task in tasks]
|
||||
|
||||
# Collect results in order, raising any exceptions
|
||||
results = []
|
||||
for future in futures:
|
||||
try:
|
||||
results.append(future.result())
|
||||
except Exception as e:
|
||||
# Cancel any pending futures
|
||||
for f in futures:
|
||||
f.cancel()
|
||||
raise e
|
||||
|
||||
return results
|
||||
|
317
docker/examples/sqls/case_3_order_wide_table_sqlite_wide.sql
Normal file
317
docker/examples/sqls/case_3_order_wide_table_sqlite_wide.sql
Normal file
@ -0,0 +1,317 @@
|
||||
CREATE TABLE order_wide_table (
|
||||
|
||||
-- order_base
|
||||
order_id TEXT, -- 订单ID
|
||||
order_no TEXT, -- 订单编号
|
||||
parent_order_no TEXT, -- 父订单编号
|
||||
order_type INTEGER, -- 订单类型:1实物2虚拟3混合
|
||||
order_status INTEGER, -- 订单状态
|
||||
order_source TEXT, -- 订单来源
|
||||
order_source_detail TEXT, -- 订单来源详情
|
||||
create_time DATETIME, -- 创建时间
|
||||
pay_time DATETIME, -- 支付时间
|
||||
finish_time DATETIME, -- 完成时间
|
||||
close_time DATETIME, -- 关闭时间
|
||||
cancel_time DATETIME, -- 取消时间
|
||||
cancel_reason TEXT, -- 取消原因
|
||||
order_remark TEXT, -- 订单备注
|
||||
seller_remark TEXT, -- 卖家备注
|
||||
buyer_remark TEXT, -- 买家备注
|
||||
is_deleted INTEGER, -- 是否删除
|
||||
delete_time DATETIME, -- 删除时间
|
||||
order_ip TEXT, -- 下单IP
|
||||
order_platform TEXT, -- 下单平台
|
||||
order_device TEXT, -- 下单设备
|
||||
order_app_version TEXT, -- APP版本号
|
||||
|
||||
-- order_amount
|
||||
currency TEXT, -- 货币类型
|
||||
exchange_rate REAL, -- 汇率
|
||||
original_amount REAL, -- 原始金额
|
||||
discount_amount REAL, -- 优惠金额
|
||||
coupon_amount REAL, -- 优惠券金额
|
||||
points_amount REAL, -- 积分抵扣金额
|
||||
shipping_amount REAL, -- 运费
|
||||
insurance_amount REAL, -- 保价费
|
||||
tax_amount REAL, -- 税费
|
||||
tariff_amount REAL, -- 关税
|
||||
payment_amount REAL, -- 实付金额
|
||||
commission_amount REAL, -- 佣金金额
|
||||
platform_fee REAL, -- 平台费用
|
||||
seller_income REAL, -- 卖家实收
|
||||
payment_currency TEXT, -- 支付货币
|
||||
payment_exchange_rate REAL, -- 支付汇率
|
||||
|
||||
-- user_info
|
||||
user_id TEXT, -- 用户ID
|
||||
user_name TEXT, -- 用户名
|
||||
user_nickname TEXT, -- 用户昵称
|
||||
user_level INTEGER, -- 用户等级
|
||||
user_type INTEGER, -- 用户类型
|
||||
register_time DATETIME, -- 注册时间
|
||||
register_source TEXT, -- 注册来源
|
||||
mobile TEXT, -- 手机号
|
||||
mobile_area TEXT, -- 手机号区号
|
||||
email TEXT, -- 邮箱
|
||||
is_vip INTEGER, -- 是否VIP
|
||||
vip_level INTEGER, -- VIP等级
|
||||
vip_expire_time DATETIME, -- VIP过期时间
|
||||
user_age INTEGER, -- 用户年龄
|
||||
user_gender INTEGER, -- 用户性别
|
||||
user_birthday DATE, -- 用户生日
|
||||
user_avatar TEXT, -- 用户头像
|
||||
user_province TEXT, -- 用户所在省
|
||||
user_city TEXT, -- 用户所在市
|
||||
user_district TEXT, -- 用户所在区
|
||||
last_login_time DATETIME, -- 最后登录时间
|
||||
last_login_ip TEXT, -- 最后登录IP
|
||||
user_credit_score INTEGER, -- 用户信用分
|
||||
total_order_count INTEGER, -- 历史订单数
|
||||
total_order_amount REAL, -- 历史订单金额
|
||||
|
||||
-- product_info
|
||||
product_id TEXT, -- 商品ID
|
||||
product_code TEXT, -- 商品编码
|
||||
product_name TEXT, -- 商品名称
|
||||
product_short_name TEXT, -- 商品短名称
|
||||
product_type INTEGER, -- 商品类型
|
||||
product_status INTEGER, -- 商品状态
|
||||
category_id TEXT, -- 类目ID
|
||||
category_name TEXT, -- 类目名称
|
||||
category_path TEXT, -- 类目路径
|
||||
brand_id TEXT, -- 品牌ID
|
||||
brand_name TEXT, -- 品牌名称
|
||||
brand_english_name TEXT, -- 品牌英文名
|
||||
seller_id TEXT, -- 卖家ID
|
||||
seller_name TEXT, -- 卖家名称
|
||||
seller_type INTEGER, -- 卖家类型
|
||||
shop_id TEXT, -- 店铺ID
|
||||
shop_name TEXT, -- 店铺名称
|
||||
product_price REAL, -- 商品价格
|
||||
market_price REAL, -- 市场价
|
||||
cost_price REAL, -- 成本价
|
||||
wholesale_price REAL, -- 批发价
|
||||
product_quantity INTEGER, -- 商品数量
|
||||
product_unit TEXT, -- 商品单位
|
||||
product_weight REAL, -- 商品重量(克)
|
||||
product_volume REAL, -- 商品体积(cm³)
|
||||
product_spec TEXT, -- 商品规格
|
||||
product_color TEXT, -- 商品颜色
|
||||
product_size TEXT, -- 商品尺寸
|
||||
product_material TEXT, -- 商品材质
|
||||
product_origin TEXT, -- 商品产地
|
||||
product_shelf_life INTEGER, -- 保质期(天)
|
||||
manufacture_date DATE, -- 生产日期
|
||||
expiry_date DATE, -- 过期日期
|
||||
batch_number TEXT, -- 批次号
|
||||
product_barcode TEXT, -- 商品条码
|
||||
warehouse_id TEXT, -- 发货仓库ID
|
||||
warehouse_name TEXT, -- 发货仓库名称
|
||||
|
||||
-- address_info
|
||||
receiver_name TEXT, -- 收货人姓名
|
||||
receiver_mobile TEXT, -- 收货人手机
|
||||
receiver_tel TEXT, -- 收货人电话
|
||||
receiver_email TEXT, -- 收货人邮箱
|
||||
receiver_country TEXT, -- 国家
|
||||
receiver_province TEXT, -- 省份
|
||||
receiver_city TEXT, -- 城市
|
||||
receiver_district TEXT, -- 区县
|
||||
receiver_street TEXT, -- 街道
|
||||
receiver_address TEXT, -- 详细地址
|
||||
receiver_zip TEXT, -- 邮编
|
||||
address_type INTEGER, -- 地址类型
|
||||
is_default INTEGER, -- 是否默认地址
|
||||
longitude REAL, -- 经度
|
||||
latitude REAL, -- 纬度
|
||||
address_label TEXT, -- 地址标签
|
||||
|
||||
-- shipping_info
|
||||
shipping_type INTEGER, -- 配送方式
|
||||
shipping_method TEXT, -- 配送方式名称
|
||||
shipping_company TEXT, -- 快递公司
|
||||
shipping_company_code TEXT, -- 快递公司编码
|
||||
shipping_no TEXT, -- 快递单号
|
||||
shipping_time DATETIME, -- 发货时间
|
||||
shipping_remark TEXT, -- 发货备注
|
||||
expect_receive_time DATETIME, -- 预计送达时间
|
||||
receive_time DATETIME, -- 收货时间
|
||||
sign_type INTEGER, -- 签收类型
|
||||
shipping_status INTEGER, -- 物流状态
|
||||
tracking_url TEXT, -- 物流跟踪URL
|
||||
is_free_shipping INTEGER, -- 是否包邮
|
||||
shipping_insurance REAL, -- 运费险金额
|
||||
shipping_distance REAL, -- 配送距离
|
||||
delivered_time DATETIME, -- 送达时间
|
||||
delivery_staff_id TEXT, -- 配送员ID
|
||||
delivery_staff_name TEXT, -- 配送员姓名
|
||||
delivery_staff_mobile TEXT, -- 配送员电话
|
||||
|
||||
-- payment_info
|
||||
payment_id TEXT, -- 支付ID
|
||||
payment_no TEXT, -- 支付单号
|
||||
payment_type INTEGER, -- 支付方式
|
||||
payment_method TEXT, -- 支付方式名称
|
||||
payment_status INTEGER, -- 支付状态
|
||||
payment_platform TEXT, -- 支付平台
|
||||
transaction_id TEXT, -- 交易流水号
|
||||
payment_time DATETIME, -- 支付时间
|
||||
payment_account TEXT, -- 支付账号
|
||||
payment_bank TEXT, -- 支付银行
|
||||
payment_card_type TEXT, -- 支付卡类型
|
||||
payment_card_no TEXT, -- 支付卡号
|
||||
payment_scene TEXT, -- 支付场景
|
||||
payment_client_ip TEXT, -- 支付IP
|
||||
payment_device TEXT, -- 支付设备
|
||||
payment_remark TEXT, -- 支付备注
|
||||
payment_voucher TEXT, -- 支付凭证
|
||||
|
||||
-- promotion_info
|
||||
promotion_id TEXT, -- 活动ID
|
||||
promotion_name TEXT, -- 活动名称
|
||||
promotion_type INTEGER, -- 活动类型
|
||||
promotion_desc TEXT, -- 活动描述
|
||||
promotion_start_time DATETIME, -- 活动开始时间
|
||||
promotion_end_time DATETIME, -- 活动结束时间
|
||||
coupon_id TEXT, -- 优惠券ID
|
||||
coupon_code TEXT, -- 优惠券码
|
||||
coupon_type INTEGER, -- 优惠券类型
|
||||
coupon_name TEXT, -- 优惠券名称
|
||||
coupon_desc TEXT, -- 优惠券描述
|
||||
points_used INTEGER, -- 使用积分
|
||||
points_gained INTEGER, -- 获得积分
|
||||
points_multiple REAL, -- 积分倍率
|
||||
is_first_order INTEGER, -- 是否首单
|
||||
is_new_customer INTEGER, -- 是否新客
|
||||
marketing_channel TEXT, -- 营销渠道
|
||||
marketing_source TEXT, -- 营销来源
|
||||
referral_code TEXT, -- 推荐码
|
||||
referral_user_id TEXT, -- 推荐人ID
|
||||
|
||||
-- after_sale_info
|
||||
refund_id TEXT, -- 退款ID
|
||||
refund_no TEXT, -- 退款单号
|
||||
refund_type INTEGER, -- 退款类型
|
||||
refund_status INTEGER, -- 退款状态
|
||||
refund_reason TEXT, -- 退款原因
|
||||
refund_desc TEXT, -- 退款描述
|
||||
refund_time DATETIME, -- 退款时间
|
||||
refund_amount REAL, -- 退款金额
|
||||
return_shipping_no TEXT, -- 退货快递单号
|
||||
return_shipping_company TEXT, -- 退货快递公司
|
||||
return_shipping_time DATETIME, -- 退货时间
|
||||
refund_evidence TEXT, -- 退款凭证
|
||||
complaint_id TEXT, -- 投诉ID
|
||||
complaint_type INTEGER, -- 投诉类型
|
||||
complaint_status INTEGER, -- 投诉状态
|
||||
complaint_content TEXT, -- 投诉内容
|
||||
complaint_time DATETIME, -- 投诉时间
|
||||
complaint_handle_time DATETIME, -- 投诉处理时间
|
||||
complaint_handle_result TEXT, -- 投诉处理结果
|
||||
evaluation_score INTEGER, -- 评价分数
|
||||
evaluation_content TEXT, -- 评价内容
|
||||
evaluation_time DATETIME, -- 评价时间
|
||||
evaluation_reply TEXT, -- 评价回复
|
||||
evaluation_reply_time DATETIME, -- 评价回复时间
|
||||
evaluation_images TEXT, -- 评价图片
|
||||
evaluation_videos TEXT, -- 评价视频
|
||||
is_anonymous INTEGER, -- 是否匿名评价
|
||||
|
||||
-- invoice_info
|
||||
invoice_type INTEGER, -- 发票类型
|
||||
invoice_title TEXT, -- 发票抬头
|
||||
invoice_content TEXT, -- 发票内容
|
||||
tax_no TEXT, -- 税号
|
||||
invoice_amount REAL, -- 发票金额
|
||||
invoice_status INTEGER, -- 发票状态
|
||||
invoice_time DATETIME, -- 开票时间
|
||||
invoice_number TEXT, -- 发票号码
|
||||
invoice_code TEXT, -- 发票代码
|
||||
company_name TEXT, -- 单位名称
|
||||
company_address TEXT, -- 单位地址
|
||||
company_tel TEXT, -- 单位电话
|
||||
company_bank TEXT, -- 开户银行
|
||||
company_account TEXT, -- 银行账号
|
||||
|
||||
-- delivery_time_info
|
||||
expect_delivery_time DATETIME, -- 期望配送时间
|
||||
delivery_period_type INTEGER, -- 配送时段类型
|
||||
delivery_period_start TEXT, -- 配送时段开始
|
||||
delivery_period_end TEXT, -- 配送时段结束
|
||||
delivery_priority INTEGER, -- 配送优先级
|
||||
|
||||
-- tag_info
|
||||
order_tags TEXT, -- 订单标签
|
||||
user_tags TEXT, -- 用户标签
|
||||
product_tags TEXT, -- 商品标签
|
||||
risk_level INTEGER, -- 风险等级
|
||||
risk_tags TEXT, -- 风险标签
|
||||
business_tags TEXT, -- 业务标签
|
||||
|
||||
-- commercial_info
|
||||
gross_profit REAL, -- 毛利
|
||||
gross_profit_rate REAL, -- 毛利率
|
||||
settlement_amount REAL, -- 结算金额
|
||||
settlement_time DATETIME, -- 结算时间
|
||||
settlement_cycle INTEGER, -- 结算周期
|
||||
settlement_status INTEGER, -- 结算状态
|
||||
commission_rate REAL, -- 佣金比例
|
||||
platform_service_fee REAL, -- 平台服务费
|
||||
ad_cost REAL, -- 广告费用
|
||||
promotion_cost REAL -- 推广费用
|
||||
);
|
||||
|
||||
-- 插入示例数据
|
||||
INSERT INTO order_wide_table (
|
||||
-- 基础订单信息
|
||||
order_id, order_no, order_type, order_status, create_time, order_source,
|
||||
-- 订单金额
|
||||
original_amount, payment_amount, shipping_amount,
|
||||
-- 用户信息
|
||||
user_id, user_name, user_level, mobile,
|
||||
-- 商品信息
|
||||
product_id, product_name, product_quantity, product_price,
|
||||
-- 收货信息
|
||||
receiver_name, receiver_mobile, receiver_address,
|
||||
-- 物流信息
|
||||
shipping_no, shipping_status,
|
||||
-- 支付信息
|
||||
payment_type, payment_status,
|
||||
-- 营销信息
|
||||
promotion_id, coupon_amount,
|
||||
-- 发票信息
|
||||
invoice_type, invoice_title
|
||||
) VALUES
|
||||
(
|
||||
'ORD20240101001', 'NO20240101001', 1, 2, '2024-01-01 10:00:00', 'APP',
|
||||
199.99, 188.88, 10.00,
|
||||
'USER001', '张三', 2, '13800138000',
|
||||
'PRD001', 'iPhone 15 手机壳', 2, 89.99,
|
||||
'李四', '13900139000', '北京市朝阳区XX路XX号',
|
||||
'SF123456789', 1,
|
||||
1, 1,
|
||||
'PROM001', 20.00,
|
||||
1, '个人'
|
||||
),
|
||||
(
|
||||
'ORD20240101002', 'NO20240101002', 1, 1, '2024-01-01 11:00:00', 'H5',
|
||||
299.99, 279.99, 0.00,
|
||||
'USER002', '王五', 3, '13700137000',
|
||||
'PRD002', 'AirPods Pro 保护套', 1, 299.99,
|
||||
'赵六', '13600136000', '上海市浦东新区XX路XX号',
|
||||
'YT987654321', 2,
|
||||
2, 2,
|
||||
'PROM002', 10.00,
|
||||
2, '上海科技有限公司'
|
||||
),
|
||||
(
|
||||
'ORD20240101003', 'NO20240101003', 2, 3, '2024-01-01 12:00:00', 'WEB',
|
||||
1999.99, 1899.99, 0.00,
|
||||
'USER003', '陈七', 4, '13500135000',
|
||||
'PRD003', 'MacBook Pro 电脑包', 1, 1999.99,
|
||||
'孙八', '13400134000', '广州市天河区XX路XX号',
|
||||
'JD123123123', 3,
|
||||
3, 1,
|
||||
'PROM003', 100.00,
|
||||
1, '个人'
|
||||
);
|
@ -4,7 +4,8 @@ from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH
|
||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
|
||||
from dbgpt.rag.assembler import DBSchemaAssembler
|
||||
from dbgpt.rag.embedding import DefaultEmbeddingFactory
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaStore, ChromaVectorConfig
|
||||
from dbgpt.serve.rag.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig
|
||||
|
||||
"""DB struct rag example.
|
||||
pre-requirements:
|
||||
@ -45,27 +46,26 @@ def _create_temporary_connection():
|
||||
|
||||
def _create_vector_connector():
|
||||
"""Create vector connector."""
|
||||
config = ChromaVectorConfig(
|
||||
persist_path=PILOT_PATH,
|
||||
name="dbschema_rag_test",
|
||||
return VectorStoreConnector.from_default(
|
||||
"Chroma",
|
||||
vector_store_config=ChromaVectorConfig(
|
||||
name="db_schema_vector_store_name",
|
||||
persist_path=os.path.join(PILOT_PATH, "data"),
|
||||
),
|
||||
embedding_fn=DefaultEmbeddingFactory(
|
||||
default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"),
|
||||
).create(),
|
||||
)
|
||||
|
||||
return ChromaStore(config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
connection = _create_temporary_connection()
|
||||
index_store = _create_vector_connector()
|
||||
vector_connector = _create_vector_connector()
|
||||
assembler = DBSchemaAssembler.load_from_connection(
|
||||
connector=connection,
|
||||
index_store=index_store,
|
||||
connector=connection, table_vector_store_connector=vector_connector
|
||||
)
|
||||
assembler.persist()
|
||||
# get db schema retriever
|
||||
retriever = assembler.as_retriever(top_k=1)
|
||||
chunks = retriever.retrieve("show columns from user")
|
||||
print(f"db schema rag example results:{[chunk.content for chunk in chunks]}")
|
||||
index_store.delete_vector_name("dbschema_rag_test")
|
||||
|
@ -15,6 +15,7 @@ fi
|
||||
DEFAULT_DB_FILE="DB-GPT/pilot/data/default_sqlite.db"
|
||||
DEFAULT_SQL_FILE="DB-GPT/docker/examples/sqls/*_sqlite.sql"
|
||||
DB_FILE="$WORK_DIR/pilot/data/default_sqlite.db"
|
||||
WIDE_DB_FILE="$WORK_DIR/pilot/data/wide_sqlite.db"
|
||||
SQL_FILE=""
|
||||
|
||||
usage () {
|
||||
@ -61,6 +62,12 @@ if [ -n $SQL_FILE ];then
|
||||
sqlite3 $DB_FILE < "$file"
|
||||
done
|
||||
|
||||
for file in $WORK_DIR/docker/examples/sqls/*_sqlite_wide.sql
|
||||
do
|
||||
echo "execute sql file: $file"
|
||||
sqlite3 $WIDE_DB_FILE < "$file"
|
||||
done
|
||||
|
||||
else
|
||||
echo "Execute SQL file ${SQL_FILE}"
|
||||
sqlite3 $DB_FILE < $SQL_FILE
|
||||
|
Loading…
Reference in New Issue
Block a user