mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-13 21:21:08 +00:00
Feat rdb summary wide table (#2035)
Co-authored-by: dongzhancai1 <dongzhancai1@jd.com> Co-authored-by: dong <dongzhancai@iie2.com>
This commit is contained in:
@@ -264,6 +264,9 @@ class Config(metaclass=Singleton):
|
||||
|
||||
# EMBEDDING Configuration
|
||||
self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec")
|
||||
self.EMBEDDING_MODEL_MAX_SEQ_LEN = int(
|
||||
os.getenv("MEMBEDDING_MODEL_MAX_SEQ_LEN", 512)
|
||||
)
|
||||
# Rerank model configuration
|
||||
self.RERANK_MODEL = os.getenv("RERANK_MODEL")
|
||||
self.RERANK_MODEL_PATH = os.getenv("RERANK_MODEL_PATH")
|
||||
|
@@ -55,9 +55,6 @@ class ChatWithDbQA(BaseChat):
|
||||
if self.db_name:
|
||||
client = DBSummaryClient(system_app=CFG.SYSTEM_APP)
|
||||
try:
|
||||
# table_infos = client.get_db_summary(
|
||||
# dbname=self.db_name, query=self.current_user_input, topk=self.top_k
|
||||
# )
|
||||
table_infos = await blocking_func_to_async(
|
||||
self._executor,
|
||||
client.get_db_summary,
|
||||
|
@@ -1,12 +1,15 @@
|
||||
"""DBSchemaAssembler."""
|
||||
import os
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.core import Chunk, Embeddings
|
||||
from dbgpt.datasource.base import BaseConnector
|
||||
|
||||
from ...serve.rag.connector import VectorStoreConnector
|
||||
from ...storage.vector_store.base import VectorStoreConfig
|
||||
from ..assembler.base import BaseAssembler
|
||||
from ..chunk_manager import ChunkParameters
|
||||
from ..index.base import IndexStoreBase
|
||||
from ..embedding.embedding_factory import DefaultEmbeddingFactory
|
||||
from ..knowledge.datasource import DatasourceKnowledge
|
||||
from ..retriever.db_schema import DBSchemaRetriever
|
||||
|
||||
@@ -35,23 +38,64 @@ class DBSchemaAssembler(BaseAssembler):
|
||||
def __init__(
|
||||
self,
|
||||
connector: BaseConnector,
|
||||
index_store: IndexStoreBase,
|
||||
table_vector_store_connector: VectorStoreConnector,
|
||||
field_vector_store_connector: VectorStoreConnector = None,
|
||||
chunk_parameters: Optional[ChunkParameters] = None,
|
||||
embedding_model: Optional[str] = None,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
max_seq_length: int = 512,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize with Embedding Assembler arguments.
|
||||
|
||||
Args:
|
||||
connector: (BaseConnector) BaseConnector connection.
|
||||
index_store: (IndexStoreBase) IndexStoreBase to use.
|
||||
table_vector_store_connector: VectorStoreConnector to load
|
||||
and retrieve table info.
|
||||
field_vector_store_connector: VectorStoreConnector to load
|
||||
and retrieve field info.
|
||||
chunk_manager: (Optional[ChunkManager]) ChunkManager to use for chunking.
|
||||
embedding_model: (Optional[str]) Embedding model to use.
|
||||
embeddings: (Optional[Embeddings]) Embeddings to use.
|
||||
"""
|
||||
knowledge = DatasourceKnowledge(connector)
|
||||
self._connector = connector
|
||||
self._index_store = index_store
|
||||
self._table_vector_store_connector = table_vector_store_connector
|
||||
|
||||
field_vector_store_config = VectorStoreConfig(
|
||||
name=table_vector_store_connector.vector_store_config.name + "_field"
|
||||
)
|
||||
self._field_vector_store_connector = (
|
||||
field_vector_store_connector
|
||||
or VectorStoreConnector.from_default(
|
||||
os.getenv("VECTOR_STORE_TYPE", "Chroma"),
|
||||
self._table_vector_store_connector.current_embeddings,
|
||||
vector_store_config=field_vector_store_config,
|
||||
)
|
||||
)
|
||||
|
||||
self._embedding_model = embedding_model
|
||||
if self._embedding_model and not embeddings:
|
||||
embeddings = DefaultEmbeddingFactory(
|
||||
default_model_name=self._embedding_model
|
||||
).create(self._embedding_model)
|
||||
|
||||
if (
|
||||
embeddings
|
||||
and self._table_vector_store_connector.vector_store_config.embedding_fn
|
||||
is None
|
||||
):
|
||||
self._table_vector_store_connector.vector_store_config.embedding_fn = (
|
||||
embeddings
|
||||
)
|
||||
if (
|
||||
embeddings
|
||||
and self._field_vector_store_connector.vector_store_config.embedding_fn
|
||||
is None
|
||||
):
|
||||
self._field_vector_store_connector.vector_store_config.embedding_fn = (
|
||||
embeddings
|
||||
)
|
||||
knowledge = DatasourceKnowledge(connector, model_dimension=max_seq_length)
|
||||
super().__init__(
|
||||
knowledge=knowledge,
|
||||
chunk_parameters=chunk_parameters,
|
||||
@@ -62,23 +106,36 @@ class DBSchemaAssembler(BaseAssembler):
|
||||
def load_from_connection(
|
||||
cls,
|
||||
connector: BaseConnector,
|
||||
index_store: IndexStoreBase,
|
||||
table_vector_store_connector: VectorStoreConnector,
|
||||
field_vector_store_connector: VectorStoreConnector = None,
|
||||
chunk_parameters: Optional[ChunkParameters] = None,
|
||||
embedding_model: Optional[str] = None,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
max_seq_length: int = 512,
|
||||
) -> "DBSchemaAssembler":
|
||||
"""Load document embedding into vector store from path.
|
||||
|
||||
Args:
|
||||
connector: (BaseConnector) BaseConnector connection.
|
||||
index_store: (IndexStoreBase) IndexStoreBase to use.
|
||||
table_vector_store_connector: used to load table chunks.
|
||||
field_vector_store_connector: used to load field chunks
|
||||
if field in table is too much.
|
||||
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for
|
||||
chunking.
|
||||
embedding_model: (Optional[str]) Embedding model to use.
|
||||
embeddings: (Optional[Embeddings]) Embeddings to use.
|
||||
max_seq_length: Embedding model max sequence length
|
||||
Returns:
|
||||
DBSchemaAssembler
|
||||
"""
|
||||
return cls(
|
||||
connector=connector,
|
||||
index_store=index_store,
|
||||
table_vector_store_connector=table_vector_store_connector,
|
||||
field_vector_store_connector=field_vector_store_connector,
|
||||
embedding_model=embedding_model,
|
||||
chunk_parameters=chunk_parameters,
|
||||
embeddings=embeddings,
|
||||
max_seq_length=max_seq_length,
|
||||
)
|
||||
|
||||
def get_chunks(self) -> List[Chunk]:
|
||||
@@ -91,7 +148,19 @@ class DBSchemaAssembler(BaseAssembler):
|
||||
Returns:
|
||||
List[str]: List of chunk ids.
|
||||
"""
|
||||
return self._index_store.load_document(self._chunks)
|
||||
table_chunks, field_chunks = [], []
|
||||
for chunk in self._chunks:
|
||||
metadata = chunk.metadata
|
||||
if metadata.get("separated"):
|
||||
if metadata.get("part") == "table":
|
||||
table_chunks.append(chunk)
|
||||
else:
|
||||
field_chunks.append(chunk)
|
||||
else:
|
||||
table_chunks.append(chunk)
|
||||
|
||||
self._field_vector_store_connector.load_document(field_chunks)
|
||||
return self._table_vector_store_connector.load_document(table_chunks)
|
||||
|
||||
def _extract_info(self, chunks) -> List[Chunk]:
|
||||
"""Extract info from chunks."""
|
||||
@@ -110,5 +179,6 @@ class DBSchemaAssembler(BaseAssembler):
|
||||
top_k=top_k,
|
||||
connector=self._connector,
|
||||
is_embeddings=True,
|
||||
index_store=self._index_store,
|
||||
table_vector_store_connector=self._table_vector_store_connector,
|
||||
field_vector_store_connector=self._field_vector_store_connector,
|
||||
)
|
||||
|
@@ -1,76 +1,117 @@
|
||||
from unittest.mock import MagicMock
|
||||
from typing import List
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
|
||||
from dbgpt.rag.assembler.embedding import EmbeddingAssembler
|
||||
from dbgpt.rag.chunk_manager import ChunkParameters, SplitterType
|
||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.rag.knowledge.base import Knowledge
|
||||
from dbgpt.rag.text_splitter.text_splitter import CharacterTextSplitter
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaStore
|
||||
import dbgpt
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.rag.retriever.db_schema import DBSchemaRetriever
|
||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_connection():
|
||||
"""Create a temporary database connection for testing."""
|
||||
connect = SQLiteTempConnector.create_temporary_db()
|
||||
connect.create_temp_tables(
|
||||
{
|
||||
"user": {
|
||||
"columns": {
|
||||
"id": "INTEGER PRIMARY KEY",
|
||||
"name": "TEXT",
|
||||
"age": "INTEGER",
|
||||
},
|
||||
"data": [
|
||||
(1, "Tom", 10),
|
||||
(2, "Jerry", 16),
|
||||
(3, "Jack", 18),
|
||||
(4, "Alice", 20),
|
||||
(5, "Bob", 22),
|
||||
],
|
||||
}
|
||||
}
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_table_vector_store_connector():
|
||||
mock_connector = MagicMock()
|
||||
mock_connector.vector_store_config.name = "table_name"
|
||||
chunk = Chunk(
|
||||
content="table_name: user\ncomment: user about dbgpt",
|
||||
metadata={
|
||||
"field_num": 6,
|
||||
"part": "table",
|
||||
"separated": 1,
|
||||
"table_name": "user",
|
||||
},
|
||||
)
|
||||
return connect
|
||||
mock_connector.similar_search_with_scores = MagicMock(return_value=[chunk])
|
||||
return mock_connector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_chunk_parameters():
|
||||
return MagicMock(spec=ChunkParameters)
|
||||
def mock_field_vector_store_connector():
|
||||
mock_connector = MagicMock()
|
||||
chunk1 = Chunk(
|
||||
content="name,age",
|
||||
metadata={
|
||||
"field_num": 6,
|
||||
"part": "field",
|
||||
"part_index": 0,
|
||||
"separated": 1,
|
||||
"table_name": "user",
|
||||
},
|
||||
)
|
||||
chunk2 = Chunk(
|
||||
content="address,gender",
|
||||
metadata={
|
||||
"field_num": 6,
|
||||
"part": "field",
|
||||
"part_index": 1,
|
||||
"separated": 1,
|
||||
"table_name": "user",
|
||||
},
|
||||
)
|
||||
chunk3 = Chunk(
|
||||
content="mail,phone",
|
||||
metadata={
|
||||
"field_num": 6,
|
||||
"part": "field",
|
||||
"part_index": 2,
|
||||
"separated": 1,
|
||||
"table_name": "user",
|
||||
},
|
||||
)
|
||||
mock_connector.similar_search_with_scores = MagicMock(
|
||||
return_value=[chunk1, chunk2, chunk3]
|
||||
)
|
||||
return mock_connector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embedding_factory():
|
||||
return MagicMock(spec=EmbeddingFactory)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_store_connector():
|
||||
return MagicMock(spec=ChromaStore)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_knowledge():
|
||||
return MagicMock(spec=Knowledge)
|
||||
|
||||
|
||||
def test_load_knowledge(
|
||||
def dbstruct_retriever(
|
||||
mock_db_connection,
|
||||
mock_knowledge,
|
||||
mock_chunk_parameters,
|
||||
mock_embedding_factory,
|
||||
mock_vector_store_connector,
|
||||
mock_table_vector_store_connector,
|
||||
mock_field_vector_store_connector,
|
||||
):
|
||||
mock_chunk_parameters.chunk_strategy = "CHUNK_BY_SIZE"
|
||||
mock_chunk_parameters.text_splitter = CharacterTextSplitter()
|
||||
mock_chunk_parameters.splitter_type = SplitterType.USER_DEFINE
|
||||
assembler = EmbeddingAssembler(
|
||||
knowledge=mock_knowledge,
|
||||
chunk_parameters=mock_chunk_parameters,
|
||||
embeddings=mock_embedding_factory.create(),
|
||||
index_store=mock_vector_store_connector,
|
||||
return DBSchemaRetriever(
|
||||
connector=mock_db_connection,
|
||||
table_vector_store_connector=mock_table_vector_store_connector,
|
||||
field_vector_store_connector=mock_field_vector_store_connector,
|
||||
separator="--table-field-separator--",
|
||||
)
|
||||
|
||||
|
||||
def mock_parse_db_summary() -> str:
|
||||
"""Patch _parse_db_summary method."""
|
||||
return (
|
||||
"table_name: user\ncomment: user about dbgpt\n"
|
||||
"--table-field-separator--\n"
|
||||
"name,age\naddress,gender\nmail,phone"
|
||||
)
|
||||
|
||||
|
||||
# Mocking the _parse_db_summary method in your test function
|
||||
@patch.object(
|
||||
dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary
|
||||
)
|
||||
def test_retrieve_with_mocked_summary(dbstruct_retriever):
|
||||
query = "Table summary"
|
||||
chunks: List[Chunk] = dbstruct_retriever._retrieve(query)
|
||||
assert isinstance(chunks[0], Chunk)
|
||||
assert chunks[0].content == (
|
||||
"table_name: user\ncomment: user about dbgpt\n"
|
||||
"--table-field-separator--\n"
|
||||
"name,age\naddress,gender\nmail,phone"
|
||||
)
|
||||
|
||||
|
||||
def async_mock_parse_db_summary() -> str:
|
||||
"""Asynchronous patch for _parse_db_summary method."""
|
||||
return (
|
||||
"table_name: user\ncomment: user about dbgpt\n"
|
||||
"--table-field-separator--\n"
|
||||
"name,age\naddress,gender\nmail,phone"
|
||||
)
|
||||
assembler.load_knowledge(knowledge=mock_knowledge)
|
||||
assert len(assembler._chunks) == 0
|
||||
|
@@ -5,9 +5,9 @@ import pytest
|
||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
|
||||
from dbgpt.rag.assembler.db_schema import DBSchemaAssembler
|
||||
from dbgpt.rag.chunk_manager import ChunkParameters, SplitterType
|
||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.rag.text_splitter.text_splitter import CharacterTextSplitter
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaStore
|
||||
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddings, EmbeddingFactory
|
||||
from dbgpt.rag.text_splitter.text_splitter import RDBTextSplitter
|
||||
from dbgpt.serve.rag.connector import VectorStoreConnector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -21,14 +21,22 @@ def mock_db_connection():
|
||||
"id": "INTEGER PRIMARY KEY",
|
||||
"name": "TEXT",
|
||||
"age": "INTEGER",
|
||||
},
|
||||
"data": [
|
||||
(1, "Tom", 10),
|
||||
(2, "Jerry", 16),
|
||||
(3, "Jack", 18),
|
||||
(4, "Alice", 20),
|
||||
(5, "Bob", 22),
|
||||
],
|
||||
"address": "TEXT",
|
||||
"phone": "TEXT",
|
||||
"email": "TEXT",
|
||||
"gender": "TEXT",
|
||||
"birthdate": "TEXT",
|
||||
"occupation": "TEXT",
|
||||
"education": "TEXT",
|
||||
"marital_status": "TEXT",
|
||||
"nationality": "TEXT",
|
||||
"height": "REAL",
|
||||
"weight": "REAL",
|
||||
"blood_type": "TEXT",
|
||||
"emergency_contact": "TEXT",
|
||||
"created_at": "TEXT",
|
||||
"updated_at": "TEXT",
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
@@ -46,23 +54,29 @@ def mock_embedding_factory():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_store_connector():
|
||||
return MagicMock(spec=ChromaStore)
|
||||
def mock_table_vector_store_connector():
|
||||
mock_connector = MagicMock(spec=VectorStoreConnector)
|
||||
mock_connector.vector_store_config.name = "table_vector_store_name"
|
||||
mock_connector.current_embeddings = DefaultEmbeddings()
|
||||
return mock_connector
|
||||
|
||||
|
||||
def test_load_knowledge(
|
||||
mock_db_connection,
|
||||
mock_chunk_parameters,
|
||||
mock_embedding_factory,
|
||||
mock_vector_store_connector,
|
||||
mock_table_vector_store_connector,
|
||||
):
|
||||
mock_chunk_parameters.chunk_strategy = "CHUNK_BY_SIZE"
|
||||
mock_chunk_parameters.text_splitter = CharacterTextSplitter()
|
||||
mock_chunk_parameters.text_splitter = RDBTextSplitter(
|
||||
separator="--table-field-separator--"
|
||||
)
|
||||
mock_chunk_parameters.splitter_type = SplitterType.USER_DEFINE
|
||||
assembler = DBSchemaAssembler(
|
||||
connector=mock_db_connection,
|
||||
chunk_parameters=mock_chunk_parameters,
|
||||
embeddings=mock_embedding_factory.create(),
|
||||
index_store=mock_vector_store_connector,
|
||||
table_vector_store_connector=mock_table_vector_store_connector,
|
||||
max_seq_length=10,
|
||||
)
|
||||
assert len(assembler._chunks) == 1
|
||||
assert len(assembler._chunks) > 1
|
||||
|
@@ -5,7 +5,7 @@ from dbgpt.core import Document
|
||||
from dbgpt.datasource import BaseConnector
|
||||
|
||||
from ..summary.gdbms_db_summary import _parse_db_summary as _parse_gdb_summary
|
||||
from ..summary.rdbms_db_summary import _parse_db_summary
|
||||
from ..summary.rdbms_db_summary import _parse_db_summary_with_metadata
|
||||
from .base import ChunkStrategy, DocumentType, Knowledge, KnowledgeType
|
||||
|
||||
|
||||
@@ -15,9 +15,11 @@ class DatasourceKnowledge(Knowledge):
|
||||
def __init__(
|
||||
self,
|
||||
connector: BaseConnector,
|
||||
summary_template: str = "{table_name}({columns})",
|
||||
summary_template: str = "table_name: {table_name}",
|
||||
separator: str = "--table-field-separator--",
|
||||
knowledge_type: Optional[KnowledgeType] = KnowledgeType.DOCUMENT,
|
||||
metadata: Optional[Dict[str, Union[str, List[str]]]] = None,
|
||||
model_dimension: int = 512,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create Datasource Knowledge with Knowledge arguments.
|
||||
@@ -25,11 +27,17 @@ class DatasourceKnowledge(Knowledge):
|
||||
Args:
|
||||
connector(BaseConnector): connector
|
||||
summary_template(str, optional): summary template
|
||||
separator(str, optional): separator used to separate
|
||||
table's basic info and fields.
|
||||
defaults `-- table-field-separator--`
|
||||
knowledge_type(KnowledgeType, optional): knowledge type
|
||||
metadata(Dict[str, Union[str, List[str]], optional): metadata
|
||||
model_dimension(int, optional): The threshold for splitting field string
|
||||
"""
|
||||
self._separator = separator
|
||||
self._connector = connector
|
||||
self._summary_template = summary_template
|
||||
self._model_dimension = model_dimension
|
||||
super().__init__(knowledge_type=knowledge_type, metadata=metadata, **kwargs)
|
||||
|
||||
def _load(self) -> List[Document]:
|
||||
@@ -37,13 +45,23 @@ class DatasourceKnowledge(Knowledge):
|
||||
docs = []
|
||||
if self._connector.is_graph_type():
|
||||
db_summary = _parse_gdb_summary(self._connector, self._summary_template)
|
||||
for table_summary in db_summary:
|
||||
metadata = {"source": "database"}
|
||||
docs.append(Document(content=table_summary, metadata=metadata))
|
||||
else:
|
||||
db_summary = _parse_db_summary(self._connector, self._summary_template)
|
||||
for table_summary in db_summary:
|
||||
metadata = {"source": "database"}
|
||||
if self._metadata:
|
||||
metadata.update(self._metadata) # type: ignore
|
||||
docs.append(Document(content=table_summary, metadata=metadata))
|
||||
db_summary_with_metadata = _parse_db_summary_with_metadata(
|
||||
self._connector,
|
||||
self._summary_template,
|
||||
self._separator,
|
||||
self._model_dimension,
|
||||
)
|
||||
for summary, table_metadata in db_summary_with_metadata:
|
||||
metadata = {"source": "database"}
|
||||
|
||||
if self._metadata:
|
||||
metadata.update(self._metadata) # type: ignore
|
||||
table_metadata.update(metadata)
|
||||
docs.append(Document(content=summary, metadata=table_metadata))
|
||||
return docs
|
||||
|
||||
@classmethod
|
||||
|
@@ -1,14 +1,15 @@
|
||||
"""The DBSchema Retriever Operator."""
|
||||
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.core.interface.operators.retriever import RetrieverOperator
|
||||
from dbgpt.datasource.base import BaseConnector
|
||||
from dbgpt.serve.rag.connector import VectorStoreConnector
|
||||
|
||||
from ...storage.vector_store.base import VectorStoreConfig
|
||||
from ..assembler.db_schema import DBSchemaAssembler
|
||||
from ..chunk_manager import ChunkParameters
|
||||
from ..index.base import IndexStoreBase
|
||||
from ..retriever.db_schema import DBSchemaRetriever
|
||||
from .assembler import AssemblerOperator
|
||||
|
||||
@@ -19,13 +20,14 @@ class DBSchemaRetrieverOperator(RetrieverOperator[str, List[Chunk]]):
|
||||
Args:
|
||||
connector (BaseConnector): The connection.
|
||||
top_k (int, optional): The top k. Defaults to 4.
|
||||
index_store (IndexStoreBase, optional): The vector store
|
||||
vector_store_connector (VectorStoreConnector, optional): The vector store
|
||||
connector. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index_store: IndexStoreBase,
|
||||
table_vector_store_connector: VectorStoreConnector,
|
||||
field_vector_store_connector: VectorStoreConnector,
|
||||
top_k: int = 4,
|
||||
connector: Optional[BaseConnector] = None,
|
||||
**kwargs
|
||||
@@ -35,7 +37,8 @@ class DBSchemaRetrieverOperator(RetrieverOperator[str, List[Chunk]]):
|
||||
self._retriever = DBSchemaRetriever(
|
||||
top_k=top_k,
|
||||
connector=connector,
|
||||
index_store=index_store,
|
||||
table_vector_store_connector=table_vector_store_connector,
|
||||
field_vector_store_connector=field_vector_store_connector,
|
||||
)
|
||||
|
||||
def retrieve(self, query: str) -> List[Chunk]:
|
||||
@@ -53,7 +56,8 @@ class DBSchemaAssemblerOperator(AssemblerOperator[BaseConnector, List[Chunk]]):
|
||||
def __init__(
|
||||
self,
|
||||
connector: BaseConnector,
|
||||
index_store: IndexStoreBase,
|
||||
table_vector_store_connector: VectorStoreConnector,
|
||||
field_vector_store_connector: VectorStoreConnector = None,
|
||||
chunk_parameters: Optional[ChunkParameters] = None,
|
||||
**kwargs
|
||||
):
|
||||
@@ -61,14 +65,26 @@ class DBSchemaAssemblerOperator(AssemblerOperator[BaseConnector, List[Chunk]]):
|
||||
|
||||
Args:
|
||||
connector (BaseConnector): The connection.
|
||||
index_store (IndexStoreBase): The Storage IndexStoreBase.
|
||||
vector_store_connector (VectorStoreConnector): The vector store connector.
|
||||
chunk_parameters (Optional[ChunkParameters], optional): The chunk
|
||||
parameters.
|
||||
"""
|
||||
if not chunk_parameters:
|
||||
chunk_parameters = ChunkParameters(chunk_strategy="CHUNK_BY_SIZE")
|
||||
self._chunk_parameters = chunk_parameters
|
||||
self._index_store = index_store
|
||||
self._table_vector_store_connector = table_vector_store_connector
|
||||
|
||||
field_vector_store_config = VectorStoreConfig(
|
||||
name=table_vector_store_connector.vector_store_config.name + "_field"
|
||||
)
|
||||
self._field_vector_store_connector = (
|
||||
field_vector_store_connector
|
||||
or VectorStoreConnector.from_default(
|
||||
os.getenv("VECTOR_STORE_TYPE", "Chroma"),
|
||||
self._table_vector_store_connector.current_embeddings,
|
||||
vector_store_config=field_vector_store_config,
|
||||
)
|
||||
)
|
||||
self._connector = connector
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -84,7 +100,8 @@ class DBSchemaAssemblerOperator(AssemblerOperator[BaseConnector, List[Chunk]]):
|
||||
assembler = DBSchemaAssembler.load_from_connection(
|
||||
connector=self._connector,
|
||||
chunk_parameters=self._chunk_parameters,
|
||||
index_store=self._index_store,
|
||||
table_vector_store_connector=self._table_vector_store_connector,
|
||||
field_vector_store_connector=self._field_vector_store_connector,
|
||||
)
|
||||
assembler.persist()
|
||||
return assembler.get_chunks()
|
||||
|
@@ -1,18 +1,23 @@
|
||||
"""DBSchema retriever."""
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from functools import reduce
|
||||
from typing import List, Optional, cast
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.datasource.base import BaseConnector
|
||||
from dbgpt.rag.index.base import IndexStoreBase
|
||||
from dbgpt.rag.retriever.base import BaseRetriever
|
||||
from dbgpt.rag.retriever.rerank import DefaultRanker, Ranker
|
||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||
from dbgpt.util.chat_util import run_async_tasks
|
||||
from dbgpt.rag.summary.gdbms_db_summary import _parse_db_summary
|
||||
from dbgpt.serve.rag.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
from dbgpt.storage.vector_store.filters import MetadataFilter, MetadataFilters
|
||||
from dbgpt.util.chat_util import run_tasks
|
||||
from dbgpt.util.executor_utils import blocking_func_to_async_no_executor
|
||||
from dbgpt.util.tracer import root_tracer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class DBSchemaRetriever(BaseRetriever):
|
||||
@@ -20,7 +25,9 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index_store: IndexStoreBase,
|
||||
table_vector_store_connector: VectorStoreConnector,
|
||||
field_vector_store_connector: VectorStoreConnector = None,
|
||||
separator: str = "--table-field-separator--",
|
||||
top_k: int = 4,
|
||||
connector: Optional[BaseConnector] = None,
|
||||
query_rewrite: bool = False,
|
||||
@@ -30,7 +37,11 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
"""Create DBSchemaRetriever.
|
||||
|
||||
Args:
|
||||
index_store(IndexStore): index connector
|
||||
table_vector_store_connector: VectorStoreConnector
|
||||
to load and retrieve table info.
|
||||
field_vector_store_connector: VectorStoreConnector
|
||||
to load and retrieve field info.
|
||||
separator: field/table separator
|
||||
top_k (int): top k
|
||||
connector (Optional[BaseConnector]): RDBMSConnector.
|
||||
query_rewrite (bool): query rewrite
|
||||
@@ -70,34 +81,42 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
|
||||
|
||||
connector = _create_temporary_connection()
|
||||
vector_store_config = ChromaVectorConfig(name="vector_store_name")
|
||||
embedding_model_path = "{your_embedding_model_path}"
|
||||
embedding_fn = embedding_factory.create(model_name=embedding_model_path)
|
||||
config = ChromaVectorConfig(
|
||||
persist_path=PILOT_PATH,
|
||||
name="dbschema_rag_test",
|
||||
embedding_fn=DefaultEmbeddingFactory(
|
||||
default_model_name=os.path.join(
|
||||
MODEL_PATH, "text2vec-large-chinese"
|
||||
),
|
||||
).create(),
|
||||
vector_connector = VectorStoreConnector.from_default(
|
||||
"Chroma",
|
||||
vector_store_config=vector_store_config,
|
||||
embedding_fn=embedding_fn,
|
||||
)
|
||||
|
||||
vector_store = ChromaStore(config)
|
||||
# get db struct retriever
|
||||
retriever = DBSchemaRetriever(
|
||||
top_k=3,
|
||||
index_store=vector_store,
|
||||
vector_store_connector=vector_connector,
|
||||
connector=connector,
|
||||
)
|
||||
chunks = retriever.retrieve("show columns from table")
|
||||
result = [chunk.content for chunk in chunks]
|
||||
print(f"db struct rag example results:{result}")
|
||||
"""
|
||||
self._separator = separator
|
||||
self._top_k = top_k
|
||||
self._connector = connector
|
||||
self._query_rewrite = query_rewrite
|
||||
self._index_store = index_store
|
||||
self._table_vector_store_connector = table_vector_store_connector
|
||||
field_vector_store_config = VectorStoreConfig(
|
||||
name=table_vector_store_connector.vector_store_config.name + "_field"
|
||||
)
|
||||
self._field_vector_store_connector = (
|
||||
field_vector_store_connector
|
||||
or VectorStoreConnector.from_default(
|
||||
os.getenv("VECTOR_STORE_TYPE", "Chroma"),
|
||||
self._table_vector_store_connector.current_embeddings,
|
||||
vector_store_config=field_vector_store_config,
|
||||
)
|
||||
)
|
||||
self._need_embeddings = False
|
||||
if self._index_store:
|
||||
if self._table_vector_store_connector:
|
||||
self._need_embeddings = True
|
||||
self._rerank = rerank or DefaultRanker(self._top_k)
|
||||
|
||||
@@ -114,15 +133,8 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
List[Chunk]: list of chunks
|
||||
"""
|
||||
if self._need_embeddings:
|
||||
queries = [query]
|
||||
candidates = [
|
||||
self._index_store.similar_search(query, self._top_k, filters)
|
||||
for query in queries
|
||||
]
|
||||
return cast(List[Chunk], reduce(lambda x, y: x + y, candidates))
|
||||
return self._similarity_search(query, filters)
|
||||
else:
|
||||
if not self._connector:
|
||||
raise RuntimeError("RDBMSConnector connection is required.")
|
||||
table_summaries = _parse_db_summary(self._connector)
|
||||
return [Chunk(content=table_summary) for table_summary in table_summaries]
|
||||
|
||||
@@ -156,30 +168,11 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
Returns:
|
||||
List[Chunk]: list of chunks
|
||||
"""
|
||||
if self._need_embeddings:
|
||||
queries = [query]
|
||||
candidates = [
|
||||
self._similarity_search(
|
||||
query, filters, root_tracer.get_current_span_id()
|
||||
)
|
||||
for query in queries
|
||||
]
|
||||
result_candidates = await run_async_tasks(
|
||||
tasks=candidates, concurrency_limit=1
|
||||
)
|
||||
return cast(List[Chunk], reduce(lambda x, y: x + y, result_candidates))
|
||||
else:
|
||||
from dbgpt.rag.summary.rdbms_db_summary import ( # noqa: F401
|
||||
_parse_db_summary,
|
||||
)
|
||||
|
||||
table_summaries = await run_async_tasks(
|
||||
tasks=[self._aparse_db_summary(root_tracer.get_current_span_id())],
|
||||
concurrency_limit=1,
|
||||
)
|
||||
return [
|
||||
Chunk(content=table_summary) for table_summary in table_summaries[0]
|
||||
]
|
||||
return await blocking_func_to_async_no_executor(
|
||||
func=self._retrieve,
|
||||
query=query,
|
||||
filters=filters,
|
||||
)
|
||||
|
||||
async def _aretrieve_with_score(
|
||||
self,
|
||||
@@ -196,34 +189,40 @@ class DBSchemaRetriever(BaseRetriever):
|
||||
"""
|
||||
return await self._aretrieve(query, filters)
|
||||
|
||||
async def _similarity_search(
|
||||
self,
|
||||
query,
|
||||
filters: Optional[MetadataFilters] = None,
|
||||
parent_span_id: Optional[str] = None,
|
||||
def _retrieve_field(self, table_chunk: Chunk, query) -> Chunk:
|
||||
metadata = table_chunk.metadata
|
||||
metadata["part"] = "field"
|
||||
filters = [MetadataFilter(key=k, value=v) for k, v in metadata.items()]
|
||||
field_chunks = self._field_vector_store_connector.similar_search_with_scores(
|
||||
query, self._top_k, 0, MetadataFilters(filters=filters)
|
||||
)
|
||||
field_contents = [chunk.content for chunk in field_chunks]
|
||||
table_chunk.content += "\n" + self._separator + "\n" + "\n".join(field_contents)
|
||||
return table_chunk
|
||||
|
||||
def _similarity_search(
|
||||
self, query, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
"""Similar search."""
|
||||
with root_tracer.start_span(
|
||||
"dbgpt.rag.retriever.db_schema._similarity_search",
|
||||
parent_span_id,
|
||||
metadata={"query": query},
|
||||
):
|
||||
return await blocking_func_to_async_no_executor(
|
||||
self._index_store.similar_search, query, self._top_k, filters
|
||||
)
|
||||
table_chunks = self._table_vector_store_connector.similar_search_with_scores(
|
||||
query, self._top_k, 0, filters
|
||||
)
|
||||
|
||||
async def _aparse_db_summary(
|
||||
self, parent_span_id: Optional[str] = None
|
||||
) -> List[str]:
|
||||
"""Similar search."""
|
||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||
not_sep_chunks = [
|
||||
chunk for chunk in table_chunks if not chunk.metadata.get("separated")
|
||||
]
|
||||
separated_chunks = [
|
||||
chunk for chunk in table_chunks if chunk.metadata.get("separated")
|
||||
]
|
||||
if not separated_chunks:
|
||||
return not_sep_chunks
|
||||
|
||||
if not self._connector:
|
||||
raise RuntimeError("RDBMSConnector connection is required.")
|
||||
with root_tracer.start_span(
|
||||
"dbgpt.rag.retriever.db_schema._aparse_db_summary",
|
||||
parent_span_id,
|
||||
):
|
||||
return await blocking_func_to_async_no_executor(
|
||||
_parse_db_summary, self._connector
|
||||
)
|
||||
# Create tasks list
|
||||
tasks = [
|
||||
lambda c=chunk: self._retrieve_field(c, query) for chunk in separated_chunks
|
||||
]
|
||||
# Run tasks concurrently
|
||||
separated_result = run_tasks(tasks, concurrency_limit=3)
|
||||
|
||||
# Combine and return results
|
||||
return not_sep_chunks + separated_result
|
||||
|
@@ -15,42 +15,53 @@ def mock_db_connection():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_store_connector():
|
||||
def mock_table_vector_store_connector():
|
||||
mock_connector = MagicMock()
|
||||
mock_connector.similar_search.return_value = [Chunk(content="Table summary")] * 4
|
||||
mock_connector.vector_store_config.name = "table_name"
|
||||
mock_connector.similar_search_with_scores.return_value = [
|
||||
Chunk(content="Table summary")
|
||||
] * 4
|
||||
return mock_connector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_struct_retriever(mock_db_connection, mock_vector_store_connector):
|
||||
def mock_field_vector_store_connector():
|
||||
mock_connector = MagicMock()
|
||||
mock_connector.similar_search_with_scores.return_value = [
|
||||
Chunk(content="Field summary")
|
||||
] * 4
|
||||
return mock_connector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dbstruct_retriever(
|
||||
mock_db_connection,
|
||||
mock_table_vector_store_connector,
|
||||
mock_field_vector_store_connector,
|
||||
):
|
||||
return DBSchemaRetriever(
|
||||
connector=mock_db_connection,
|
||||
index_store=mock_vector_store_connector,
|
||||
table_vector_store_connector=mock_table_vector_store_connector,
|
||||
field_vector_store_connector=mock_field_vector_store_connector,
|
||||
)
|
||||
|
||||
|
||||
def mock_parse_db_summary(conn) -> List[str]:
|
||||
def mock_parse_db_summary() -> str:
|
||||
"""Patch _parse_db_summary method."""
|
||||
return ["Table summary"]
|
||||
return "Table summary"
|
||||
|
||||
|
||||
# Mocking the _parse_db_summary method in your test function
|
||||
@patch.object(
|
||||
dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary
|
||||
)
|
||||
def test_retrieve_with_mocked_summary(db_struct_retriever):
|
||||
def test_retrieve_with_mocked_summary(dbstruct_retriever):
|
||||
query = "Table summary"
|
||||
chunks: List[Chunk] = db_struct_retriever._retrieve(query)
|
||||
chunks: List[Chunk] = dbstruct_retriever._retrieve(query)
|
||||
assert isinstance(chunks[0], Chunk)
|
||||
assert chunks[0].content == "Table summary"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch.object(
|
||||
dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary
|
||||
)
|
||||
async def test_aretrieve_with_mocked_summary(db_struct_retriever):
|
||||
query = "Table summary"
|
||||
chunks: List[Chunk] = await db_struct_retriever._aretrieve(query)
|
||||
assert isinstance(chunks[0], Chunk)
|
||||
assert chunks[0].content == "Table summary"
|
||||
async def async_mock_parse_db_summary() -> str:
|
||||
"""Asynchronous patch for _parse_db_summary method."""
|
||||
return "Table summary"
|
||||
|
@@ -2,13 +2,15 @@
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
from typing import List
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
|
||||
from dbgpt.rag import ChunkParameters
|
||||
from dbgpt.rag.summary.gdbms_db_summary import GdbmsSummary
|
||||
from dbgpt.rag.summary.rdbms_db_summary import RdbmsSummary
|
||||
from dbgpt.rag.text_splitter.text_splitter import RDBTextSplitter
|
||||
from dbgpt.serve.rag.connector import VectorStoreConnector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -47,22 +49,26 @@ class DBSummaryClient:
|
||||
|
||||
logger.info("db summary embedding success")
|
||||
|
||||
def get_db_summary(self, dbname, query, topk) -> List[str]:
|
||||
def get_db_summary(self, dbname, query, topk):
|
||||
"""Get user query related tables info."""
|
||||
from dbgpt.serve.rag.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
|
||||
vector_store_config = VectorStoreConfig(name=dbname + "_profile")
|
||||
vector_connector = VectorStoreConnector.from_default(
|
||||
vector_store_name = dbname + "_profile"
|
||||
table_vector_store_config = VectorStoreConfig(name=vector_store_name)
|
||||
table_vector_connector = VectorStoreConnector.from_default(
|
||||
CFG.VECTOR_STORE_TYPE,
|
||||
embedding_fn=self.embeddings,
|
||||
vector_store_config=vector_store_config,
|
||||
self.embeddings,
|
||||
vector_store_config=table_vector_store_config,
|
||||
)
|
||||
|
||||
from dbgpt.rag.retriever.db_schema import DBSchemaRetriever
|
||||
|
||||
retriever = DBSchemaRetriever(
|
||||
top_k=topk, index_store=vector_connector.index_client
|
||||
top_k=topk,
|
||||
table_vector_store_connector=table_vector_connector,
|
||||
separator="--table-field-separator--",
|
||||
)
|
||||
|
||||
table_docs = retriever.retrieve(query)
|
||||
ans = [d.content for d in table_docs]
|
||||
return ans
|
||||
@@ -92,18 +98,23 @@ class DBSummaryClient:
|
||||
from dbgpt.serve.rag.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
|
||||
vector_store_config = VectorStoreConfig(name=vector_store_name)
|
||||
vector_connector = VectorStoreConnector.from_default(
|
||||
table_vector_store_config = VectorStoreConfig(name=vector_store_name)
|
||||
table_vector_connector = VectorStoreConnector.from_default(
|
||||
CFG.VECTOR_STORE_TYPE,
|
||||
self.embeddings,
|
||||
vector_store_config=vector_store_config,
|
||||
vector_store_config=table_vector_store_config,
|
||||
)
|
||||
if not vector_connector.vector_name_exists():
|
||||
if not table_vector_connector.vector_name_exists():
|
||||
from dbgpt.rag.assembler.db_schema import DBSchemaAssembler
|
||||
|
||||
chunk_parameters = ChunkParameters(
|
||||
text_splitter=RDBTextSplitter(separator="--table-field-separator--")
|
||||
)
|
||||
db_assembler = DBSchemaAssembler.load_from_connection(
|
||||
connector=db_summary_client.db,
|
||||
index_store=vector_connector.index_client,
|
||||
table_vector_store_connector=table_vector_connector,
|
||||
chunk_parameters=chunk_parameters,
|
||||
max_seq_length=CFG.EMBEDDING_MODEL_MAX_SEQ_LEN,
|
||||
)
|
||||
|
||||
if len(db_assembler.get_chunks()) > 0:
|
||||
@@ -115,16 +126,26 @@ class DBSummaryClient:
|
||||
def delete_db_profile(self, dbname):
|
||||
"""Delete db profile."""
|
||||
vector_store_name = dbname + "_profile"
|
||||
table_vector_store_name = dbname + "_profile"
|
||||
field_vector_store_name = dbname + "_profile_field"
|
||||
from dbgpt.serve.rag.connector import VectorStoreConnector
|
||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||
|
||||
vector_store_config = VectorStoreConfig(name=vector_store_name)
|
||||
vector_connector = VectorStoreConnector.from_default(
|
||||
table_vector_store_config = VectorStoreConfig(name=vector_store_name)
|
||||
field_vector_store_config = VectorStoreConfig(name=field_vector_store_name)
|
||||
table_vector_connector = VectorStoreConnector.from_default(
|
||||
CFG.VECTOR_STORE_TYPE,
|
||||
self.embeddings,
|
||||
vector_store_config=vector_store_config,
|
||||
vector_store_config=table_vector_store_config,
|
||||
)
|
||||
vector_connector.delete_vector_name(vector_store_name)
|
||||
field_vector_connector = VectorStoreConnector.from_default(
|
||||
CFG.VECTOR_STORE_TYPE,
|
||||
self.embeddings,
|
||||
vector_store_config=field_vector_store_config,
|
||||
)
|
||||
|
||||
table_vector_connector.delete_vector_name(table_vector_store_name)
|
||||
field_vector_connector.delete_vector_name(field_vector_store_name)
|
||||
logger.info(f"delete db profile {dbname} success")
|
||||
|
||||
@staticmethod
|
||||
|
@@ -1,6 +1,6 @@
|
||||
"""Summary for rdbms database."""
|
||||
import re
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.datasource import BaseConnector
|
||||
@@ -80,6 +80,134 @@ def _parse_db_summary(
|
||||
return table_info_summaries
|
||||
|
||||
|
||||
def _parse_db_summary_with_metadata(
|
||||
conn: BaseConnector,
|
||||
summary_template: str = "table_name: {table_name}",
|
||||
separator: str = "--table-field-separator--",
|
||||
model_dimension: int = 512,
|
||||
) -> List[Tuple[str, Dict[str, Any]]]:
|
||||
"""Get db summary for database.
|
||||
|
||||
Args:
|
||||
conn (BaseConnector): database connection
|
||||
summary_template (str): summary template
|
||||
separator(str, optional): separator used to separate table's
|
||||
basic info and fields. defaults to `-- table-field-separator--`
|
||||
model_dimension(int, optional): The threshold for splitting field string
|
||||
"""
|
||||
tables = conn.get_table_names()
|
||||
table_info_summaries = [
|
||||
_parse_table_summary_with_metadata(
|
||||
conn, summary_template, separator, table_name, model_dimension
|
||||
)
|
||||
for table_name in tables
|
||||
]
|
||||
return table_info_summaries
|
||||
|
||||
|
||||
def _split_columns_str(columns: List[str], model_dimension: int):
|
||||
"""Split columns str.
|
||||
|
||||
Args:
|
||||
columns (List[str]): fields string
|
||||
model_dimension (int, optional): The threshold for splitting field string.
|
||||
"""
|
||||
result = []
|
||||
current_string = ""
|
||||
current_length = 0
|
||||
|
||||
for element_str in columns:
|
||||
element_length = len(element_str)
|
||||
|
||||
# If adding the current element's length would exceed the threshold,
|
||||
# add the current string to results and reset
|
||||
if current_length + element_length > model_dimension:
|
||||
result.append(current_string.strip()) # Remove trailing spaces
|
||||
current_string = element_str
|
||||
current_length = element_length
|
||||
else:
|
||||
# If current string is empty, add element directly
|
||||
if current_string:
|
||||
current_string += "," + element_str
|
||||
else:
|
||||
current_string = element_str
|
||||
current_length += element_length + 1 # Add length of space
|
||||
|
||||
# Handle the last string segment
|
||||
if current_string:
|
||||
result.append(current_string.strip())
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _parse_table_summary_with_metadata(
|
||||
conn: BaseConnector,
|
||||
summary_template: str,
|
||||
separator,
|
||||
table_name: str,
|
||||
model_dimension=512,
|
||||
) -> Tuple[str, Dict[str, Any]]:
|
||||
"""Get table summary for table.
|
||||
|
||||
Args:
|
||||
conn (BaseConnector): database connection
|
||||
summary_template (str): summary template
|
||||
separator(str, optional): separator used to separate table's
|
||||
basic info and fields. defaults to `-- table-field-separator--`
|
||||
model_dimension(int, optional): The threshold for splitting field string
|
||||
|
||||
Examples:
|
||||
metadata: {'table_name': 'asd', 'separated': 0/1}
|
||||
|
||||
table_name: table1
|
||||
table_comment: comment
|
||||
index_keys: keys
|
||||
--table-field-separator--
|
||||
(column1,comment), (column2, comment), (column3, comment)
|
||||
(column4,comment), (column5, comment), (column6, comment)
|
||||
"""
|
||||
columns = []
|
||||
metadata = {"table_name": table_name, "separated": 0}
|
||||
for column in conn.get_columns(table_name):
|
||||
if column.get("comment"):
|
||||
columns.append(f"{column['name']} ({column.get('comment')})")
|
||||
else:
|
||||
columns.append(f"{column['name']}")
|
||||
metadata.update({"field_num": len(columns)})
|
||||
separated_columns = _split_columns_str(columns, model_dimension=model_dimension)
|
||||
if len(separated_columns) > 1:
|
||||
metadata["separated"] = 1
|
||||
column_str = "\n".join(separated_columns)
|
||||
# Obtain index information
|
||||
index_keys = []
|
||||
raw_indexes = conn.get_indexes(table_name)
|
||||
for index in raw_indexes:
|
||||
if isinstance(index, tuple): # Process tuple type index information
|
||||
index_name, index_creation_command = index
|
||||
# Extract column names using re
|
||||
matched_columns = re.findall(r"\(([^)]+)\)", index_creation_command)
|
||||
if matched_columns:
|
||||
key_str = ", ".join(matched_columns)
|
||||
index_keys.append(f"{index_name}(`{key_str}`) ")
|
||||
else:
|
||||
key_str = ", ".join(index["column_names"])
|
||||
index_keys.append(f"{index['name']}(`{key_str}`) ")
|
||||
table_str = summary_template.format(table_name=table_name)
|
||||
|
||||
try:
|
||||
comment = conn.get_table_comment(table_name)
|
||||
except Exception:
|
||||
comment = dict(text=None)
|
||||
if comment.get("text"):
|
||||
table_str += f"\ntable_comment: {comment.get('text')}"
|
||||
|
||||
if len(index_keys) > 0:
|
||||
index_key_str = ", ".join(index_keys)
|
||||
table_str += f"\nindex_keys: {index_key_str}"
|
||||
table_str += f"\n{separator}\n{column_str}"
|
||||
return table_str, metadata
|
||||
|
||||
|
||||
def _parse_table_summary(
|
||||
conn: BaseConnector, summary_template: str, table_name: str
|
||||
) -> str:
|
||||
|
@@ -912,3 +912,42 @@ class PageTextSplitter(TextSplitter):
|
||||
new_doc = Chunk(content=text, metadata=copy.deepcopy(_metadatas[i]))
|
||||
chunks.append(new_doc)
|
||||
return chunks
|
||||
|
||||
|
||||
class RDBTextSplitter(TextSplitter):
|
||||
"""Split relational database tables and fields."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Create a new TextSplitter."""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def split_text(self, text: str, **kwargs):
|
||||
"""Split text into a couple of parts."""
|
||||
pass
|
||||
|
||||
def split_documents(self, documents: Iterable[Document], **kwargs) -> List[Chunk]:
|
||||
"""Split document into chunks."""
|
||||
chunks = []
|
||||
for doc in documents:
|
||||
metadata = doc.metadata
|
||||
content = doc.content
|
||||
if metadata.get("separated"):
|
||||
# separate table and field
|
||||
parts = content.split(self._separator)
|
||||
table_part, field_part = parts[0], parts[1]
|
||||
table_metadata, field_metadata = copy.deepcopy(metadata), copy.deepcopy(
|
||||
metadata
|
||||
)
|
||||
table_metadata["part"] = "table" # identify of table_chunk
|
||||
field_metadata["part"] = "field" # identify of field_chunk
|
||||
table_chunk = Chunk(content=table_part, metadata=table_metadata)
|
||||
chunks.append(table_chunk)
|
||||
field_parts = field_part.split("\n")
|
||||
for i, sub_part in enumerate(field_parts):
|
||||
sub_metadata = copy.deepcopy(field_metadata)
|
||||
sub_metadata["part_index"] = i
|
||||
field_chunk = Chunk(content=sub_part, metadata=sub_metadata)
|
||||
chunks.append(field_chunk)
|
||||
else:
|
||||
chunks.append(Chunk(content=content, metadata=metadata))
|
||||
return chunks
|
||||
|
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
from typing import Any, Coroutine, List
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Any, Callable, Coroutine, List
|
||||
|
||||
|
||||
async def llm_chat_response_nostream(chat_scene: str, **chat_param):
|
||||
@@ -47,13 +48,34 @@ async def run_async_tasks(
|
||||
|
||||
|
||||
def run_tasks(
|
||||
tasks: List[Coroutine],
|
||||
tasks: List[Callable],
|
||||
concurrency_limit: int = None,
|
||||
) -> List[Any]:
|
||||
"""Run a list of async tasks."""
|
||||
tasks_to_execute: List[Any] = tasks
|
||||
"""
|
||||
Run a list of tasks concurrently using a thread pool.
|
||||
|
||||
async def _gather() -> List[Any]:
|
||||
return await asyncio.gather(*tasks_to_execute)
|
||||
Args:
|
||||
tasks: List of callable functions to execute
|
||||
concurrency_limit: Maximum number of concurrent threads (optional)
|
||||
|
||||
outputs: List[Any] = asyncio.run(_gather())
|
||||
return outputs
|
||||
Returns:
|
||||
List of results from all tasks in the order they were submitted
|
||||
"""
|
||||
max_workers = concurrency_limit if concurrency_limit else None
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# Submit all tasks and get futures
|
||||
futures = [executor.submit(task) for task in tasks]
|
||||
|
||||
# Collect results in order, raising any exceptions
|
||||
results = []
|
||||
for future in futures:
|
||||
try:
|
||||
results.append(future.result())
|
||||
except Exception as e:
|
||||
# Cancel any pending futures
|
||||
for f in futures:
|
||||
f.cancel()
|
||||
raise e
|
||||
|
||||
return results
|
||||
|
Reference in New Issue
Block a user