mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-15 05:59:59 +00:00
Feat rdb summary wide table (#2035)
Co-authored-by: dongzhancai1 <dongzhancai1@jd.com> Co-authored-by: dong <dongzhancai@iie2.com>
This commit is contained in:
@@ -1,12 +1,15 @@
|
||||
"""DBSchemaAssembler."""
|
||||
import os
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.core import Chunk, Embeddings
|
||||
from dbgpt.datasource.base import BaseConnector
|
||||
|
||||
from ...serve.rag.connector import VectorStoreConnector
|
||||
from ...storage.vector_store.base import VectorStoreConfig
|
||||
from ..assembler.base import BaseAssembler
|
||||
from ..chunk_manager import ChunkParameters
|
||||
from ..index.base import IndexStoreBase
|
||||
from ..embedding.embedding_factory import DefaultEmbeddingFactory
|
||||
from ..knowledge.datasource import DatasourceKnowledge
|
||||
from ..retriever.db_schema import DBSchemaRetriever
|
||||
|
||||
@@ -35,23 +38,64 @@ class DBSchemaAssembler(BaseAssembler):
|
||||
def __init__(
|
||||
self,
|
||||
connector: BaseConnector,
|
||||
index_store: IndexStoreBase,
|
||||
table_vector_store_connector: VectorStoreConnector,
|
||||
field_vector_store_connector: VectorStoreConnector = None,
|
||||
chunk_parameters: Optional[ChunkParameters] = None,
|
||||
embedding_model: Optional[str] = None,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
max_seq_length: int = 512,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize with Embedding Assembler arguments.
|
||||
|
||||
Args:
|
||||
connector: (BaseConnector) BaseConnector connection.
|
||||
index_store: (IndexStoreBase) IndexStoreBase to use.
|
||||
table_vector_store_connector: VectorStoreConnector to load
|
||||
and retrieve table info.
|
||||
field_vector_store_connector: VectorStoreConnector to load
|
||||
and retrieve field info.
|
||||
chunk_manager: (Optional[ChunkManager]) ChunkManager to use for chunking.
|
||||
embedding_model: (Optional[str]) Embedding model to use.
|
||||
embeddings: (Optional[Embeddings]) Embeddings to use.
|
||||
"""
|
||||
knowledge = DatasourceKnowledge(connector)
|
||||
self._connector = connector
|
||||
self._index_store = index_store
|
||||
self._table_vector_store_connector = table_vector_store_connector
|
||||
|
||||
field_vector_store_config = VectorStoreConfig(
|
||||
name=table_vector_store_connector.vector_store_config.name + "_field"
|
||||
)
|
||||
self._field_vector_store_connector = (
|
||||
field_vector_store_connector
|
||||
or VectorStoreConnector.from_default(
|
||||
os.getenv("VECTOR_STORE_TYPE", "Chroma"),
|
||||
self._table_vector_store_connector.current_embeddings,
|
||||
vector_store_config=field_vector_store_config,
|
||||
)
|
||||
)
|
||||
|
||||
self._embedding_model = embedding_model
|
||||
if self._embedding_model and not embeddings:
|
||||
embeddings = DefaultEmbeddingFactory(
|
||||
default_model_name=self._embedding_model
|
||||
).create(self._embedding_model)
|
||||
|
||||
if (
|
||||
embeddings
|
||||
and self._table_vector_store_connector.vector_store_config.embedding_fn
|
||||
is None
|
||||
):
|
||||
self._table_vector_store_connector.vector_store_config.embedding_fn = (
|
||||
embeddings
|
||||
)
|
||||
if (
|
||||
embeddings
|
||||
and self._field_vector_store_connector.vector_store_config.embedding_fn
|
||||
is None
|
||||
):
|
||||
self._field_vector_store_connector.vector_store_config.embedding_fn = (
|
||||
embeddings
|
||||
)
|
||||
knowledge = DatasourceKnowledge(connector, model_dimension=max_seq_length)
|
||||
super().__init__(
|
||||
knowledge=knowledge,
|
||||
chunk_parameters=chunk_parameters,
|
||||
@@ -62,23 +106,36 @@ class DBSchemaAssembler(BaseAssembler):
|
||||
def load_from_connection(
|
||||
cls,
|
||||
connector: BaseConnector,
|
||||
index_store: IndexStoreBase,
|
||||
table_vector_store_connector: VectorStoreConnector,
|
||||
field_vector_store_connector: VectorStoreConnector = None,
|
||||
chunk_parameters: Optional[ChunkParameters] = None,
|
||||
embedding_model: Optional[str] = None,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
max_seq_length: int = 512,
|
||||
) -> "DBSchemaAssembler":
|
||||
"""Load document embedding into vector store from path.
|
||||
|
||||
Args:
|
||||
connector: (BaseConnector) BaseConnector connection.
|
||||
index_store: (IndexStoreBase) IndexStoreBase to use.
|
||||
table_vector_store_connector: used to load table chunks.
|
||||
field_vector_store_connector: used to load field chunks
|
||||
if field in table is too much.
|
||||
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for
|
||||
chunking.
|
||||
embedding_model: (Optional[str]) Embedding model to use.
|
||||
embeddings: (Optional[Embeddings]) Embeddings to use.
|
||||
max_seq_length: Embedding model max sequence length
|
||||
Returns:
|
||||
DBSchemaAssembler
|
||||
"""
|
||||
return cls(
|
||||
connector=connector,
|
||||
index_store=index_store,
|
||||
table_vector_store_connector=table_vector_store_connector,
|
||||
field_vector_store_connector=field_vector_store_connector,
|
||||
embedding_model=embedding_model,
|
||||
chunk_parameters=chunk_parameters,
|
||||
embeddings=embeddings,
|
||||
max_seq_length=max_seq_length,
|
||||
)
|
||||
|
||||
def get_chunks(self) -> List[Chunk]:
|
||||
@@ -91,7 +148,19 @@ class DBSchemaAssembler(BaseAssembler):
|
||||
Returns:
|
||||
List[str]: List of chunk ids.
|
||||
"""
|
||||
return self._index_store.load_document(self._chunks)
|
||||
table_chunks, field_chunks = [], []
|
||||
for chunk in self._chunks:
|
||||
metadata = chunk.metadata
|
||||
if metadata.get("separated"):
|
||||
if metadata.get("part") == "table":
|
||||
table_chunks.append(chunk)
|
||||
else:
|
||||
field_chunks.append(chunk)
|
||||
else:
|
||||
table_chunks.append(chunk)
|
||||
|
||||
self._field_vector_store_connector.load_document(field_chunks)
|
||||
return self._table_vector_store_connector.load_document(table_chunks)
|
||||
|
||||
def _extract_info(self, chunks) -> List[Chunk]:
|
||||
"""Extract info from chunks."""
|
||||
@@ -110,5 +179,6 @@ class DBSchemaAssembler(BaseAssembler):
|
||||
top_k=top_k,
|
||||
connector=self._connector,
|
||||
is_embeddings=True,
|
||||
index_store=self._index_store,
|
||||
table_vector_store_connector=self._table_vector_store_connector,
|
||||
field_vector_store_connector=self._field_vector_store_connector,
|
||||
)
|
||||
|
@@ -1,76 +1,117 @@
|
||||
from unittest.mock import MagicMock
|
||||
from typing import List
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
|
||||
from dbgpt.rag.assembler.embedding import EmbeddingAssembler
|
||||
from dbgpt.rag.chunk_manager import ChunkParameters, SplitterType
|
||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.rag.knowledge.base import Knowledge
|
||||
from dbgpt.rag.text_splitter.text_splitter import CharacterTextSplitter
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaStore
|
||||
import dbgpt
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.rag.retriever.db_schema import DBSchemaRetriever
|
||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_connection():
|
||||
"""Create a temporary database connection for testing."""
|
||||
connect = SQLiteTempConnector.create_temporary_db()
|
||||
connect.create_temp_tables(
|
||||
{
|
||||
"user": {
|
||||
"columns": {
|
||||
"id": "INTEGER PRIMARY KEY",
|
||||
"name": "TEXT",
|
||||
"age": "INTEGER",
|
||||
},
|
||||
"data": [
|
||||
(1, "Tom", 10),
|
||||
(2, "Jerry", 16),
|
||||
(3, "Jack", 18),
|
||||
(4, "Alice", 20),
|
||||
(5, "Bob", 22),
|
||||
],
|
||||
}
|
||||
}
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_table_vector_store_connector():
|
||||
mock_connector = MagicMock()
|
||||
mock_connector.vector_store_config.name = "table_name"
|
||||
chunk = Chunk(
|
||||
content="table_name: user\ncomment: user about dbgpt",
|
||||
metadata={
|
||||
"field_num": 6,
|
||||
"part": "table",
|
||||
"separated": 1,
|
||||
"table_name": "user",
|
||||
},
|
||||
)
|
||||
return connect
|
||||
mock_connector.similar_search_with_scores = MagicMock(return_value=[chunk])
|
||||
return mock_connector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_chunk_parameters():
|
||||
return MagicMock(spec=ChunkParameters)
|
||||
def mock_field_vector_store_connector():
|
||||
mock_connector = MagicMock()
|
||||
chunk1 = Chunk(
|
||||
content="name,age",
|
||||
metadata={
|
||||
"field_num": 6,
|
||||
"part": "field",
|
||||
"part_index": 0,
|
||||
"separated": 1,
|
||||
"table_name": "user",
|
||||
},
|
||||
)
|
||||
chunk2 = Chunk(
|
||||
content="address,gender",
|
||||
metadata={
|
||||
"field_num": 6,
|
||||
"part": "field",
|
||||
"part_index": 1,
|
||||
"separated": 1,
|
||||
"table_name": "user",
|
||||
},
|
||||
)
|
||||
chunk3 = Chunk(
|
||||
content="mail,phone",
|
||||
metadata={
|
||||
"field_num": 6,
|
||||
"part": "field",
|
||||
"part_index": 2,
|
||||
"separated": 1,
|
||||
"table_name": "user",
|
||||
},
|
||||
)
|
||||
mock_connector.similar_search_with_scores = MagicMock(
|
||||
return_value=[chunk1, chunk2, chunk3]
|
||||
)
|
||||
return mock_connector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embedding_factory():
|
||||
return MagicMock(spec=EmbeddingFactory)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_store_connector():
|
||||
return MagicMock(spec=ChromaStore)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_knowledge():
|
||||
return MagicMock(spec=Knowledge)
|
||||
|
||||
|
||||
def test_load_knowledge(
|
||||
def dbstruct_retriever(
|
||||
mock_db_connection,
|
||||
mock_knowledge,
|
||||
mock_chunk_parameters,
|
||||
mock_embedding_factory,
|
||||
mock_vector_store_connector,
|
||||
mock_table_vector_store_connector,
|
||||
mock_field_vector_store_connector,
|
||||
):
|
||||
mock_chunk_parameters.chunk_strategy = "CHUNK_BY_SIZE"
|
||||
mock_chunk_parameters.text_splitter = CharacterTextSplitter()
|
||||
mock_chunk_parameters.splitter_type = SplitterType.USER_DEFINE
|
||||
assembler = EmbeddingAssembler(
|
||||
knowledge=mock_knowledge,
|
||||
chunk_parameters=mock_chunk_parameters,
|
||||
embeddings=mock_embedding_factory.create(),
|
||||
index_store=mock_vector_store_connector,
|
||||
return DBSchemaRetriever(
|
||||
connector=mock_db_connection,
|
||||
table_vector_store_connector=mock_table_vector_store_connector,
|
||||
field_vector_store_connector=mock_field_vector_store_connector,
|
||||
separator="--table-field-separator--",
|
||||
)
|
||||
|
||||
|
||||
def mock_parse_db_summary() -> str:
|
||||
"""Patch _parse_db_summary method."""
|
||||
return (
|
||||
"table_name: user\ncomment: user about dbgpt\n"
|
||||
"--table-field-separator--\n"
|
||||
"name,age\naddress,gender\nmail,phone"
|
||||
)
|
||||
|
||||
|
||||
# Mocking the _parse_db_summary method in your test function
|
||||
@patch.object(
|
||||
dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary
|
||||
)
|
||||
def test_retrieve_with_mocked_summary(dbstruct_retriever):
|
||||
query = "Table summary"
|
||||
chunks: List[Chunk] = dbstruct_retriever._retrieve(query)
|
||||
assert isinstance(chunks[0], Chunk)
|
||||
assert chunks[0].content == (
|
||||
"table_name: user\ncomment: user about dbgpt\n"
|
||||
"--table-field-separator--\n"
|
||||
"name,age\naddress,gender\nmail,phone"
|
||||
)
|
||||
|
||||
|
||||
def async_mock_parse_db_summary() -> str:
|
||||
"""Asynchronous patch for _parse_db_summary method."""
|
||||
return (
|
||||
"table_name: user\ncomment: user about dbgpt\n"
|
||||
"--table-field-separator--\n"
|
||||
"name,age\naddress,gender\nmail,phone"
|
||||
)
|
||||
assembler.load_knowledge(knowledge=mock_knowledge)
|
||||
assert len(assembler._chunks) == 0
|
||||
|
@@ -5,9 +5,9 @@ import pytest
|
||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector
|
||||
from dbgpt.rag.assembler.db_schema import DBSchemaAssembler
|
||||
from dbgpt.rag.chunk_manager import ChunkParameters, SplitterType
|
||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.rag.text_splitter.text_splitter import CharacterTextSplitter
|
||||
from dbgpt.storage.vector_store.chroma_store import ChromaStore
|
||||
from dbgpt.rag.embedding.embedding_factory import DefaultEmbeddings, EmbeddingFactory
|
||||
from dbgpt.rag.text_splitter.text_splitter import RDBTextSplitter
|
||||
from dbgpt.serve.rag.connector import VectorStoreConnector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -21,14 +21,22 @@ def mock_db_connection():
|
||||
"id": "INTEGER PRIMARY KEY",
|
||||
"name": "TEXT",
|
||||
"age": "INTEGER",
|
||||
},
|
||||
"data": [
|
||||
(1, "Tom", 10),
|
||||
(2, "Jerry", 16),
|
||||
(3, "Jack", 18),
|
||||
(4, "Alice", 20),
|
||||
(5, "Bob", 22),
|
||||
],
|
||||
"address": "TEXT",
|
||||
"phone": "TEXT",
|
||||
"email": "TEXT",
|
||||
"gender": "TEXT",
|
||||
"birthdate": "TEXT",
|
||||
"occupation": "TEXT",
|
||||
"education": "TEXT",
|
||||
"marital_status": "TEXT",
|
||||
"nationality": "TEXT",
|
||||
"height": "REAL",
|
||||
"weight": "REAL",
|
||||
"blood_type": "TEXT",
|
||||
"emergency_contact": "TEXT",
|
||||
"created_at": "TEXT",
|
||||
"updated_at": "TEXT",
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
@@ -46,23 +54,29 @@ def mock_embedding_factory():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_store_connector():
|
||||
return MagicMock(spec=ChromaStore)
|
||||
def mock_table_vector_store_connector():
|
||||
mock_connector = MagicMock(spec=VectorStoreConnector)
|
||||
mock_connector.vector_store_config.name = "table_vector_store_name"
|
||||
mock_connector.current_embeddings = DefaultEmbeddings()
|
||||
return mock_connector
|
||||
|
||||
|
||||
def test_load_knowledge(
|
||||
mock_db_connection,
|
||||
mock_chunk_parameters,
|
||||
mock_embedding_factory,
|
||||
mock_vector_store_connector,
|
||||
mock_table_vector_store_connector,
|
||||
):
|
||||
mock_chunk_parameters.chunk_strategy = "CHUNK_BY_SIZE"
|
||||
mock_chunk_parameters.text_splitter = CharacterTextSplitter()
|
||||
mock_chunk_parameters.text_splitter = RDBTextSplitter(
|
||||
separator="--table-field-separator--"
|
||||
)
|
||||
mock_chunk_parameters.splitter_type = SplitterType.USER_DEFINE
|
||||
assembler = DBSchemaAssembler(
|
||||
connector=mock_db_connection,
|
||||
chunk_parameters=mock_chunk_parameters,
|
||||
embeddings=mock_embedding_factory.create(),
|
||||
index_store=mock_vector_store_connector,
|
||||
table_vector_store_connector=mock_table_vector_store_connector,
|
||||
max_seq_length=10,
|
||||
)
|
||||
assert len(assembler._chunks) == 1
|
||||
assert len(assembler._chunks) > 1
|
||||
|
Reference in New Issue
Block a user