Feat rdb summary wide table (#2035)

Co-authored-by: dongzhancai1 <dongzhancai1@jd.com>
Co-authored-by: dong <dongzhancai@iie2.com>
This commit is contained in:
Cooper
2024-12-18 20:34:21 +08:00
committed by GitHub
parent 7f4b5e79cf
commit 9b0161e521
17 changed files with 948 additions and 243 deletions

View File

@@ -15,42 +15,53 @@ def mock_db_connection():
@pytest.fixture
def mock_vector_store_connector():
def mock_table_vector_store_connector():
mock_connector = MagicMock()
mock_connector.similar_search.return_value = [Chunk(content="Table summary")] * 4
mock_connector.vector_store_config.name = "table_name"
mock_connector.similar_search_with_scores.return_value = [
Chunk(content="Table summary")
] * 4
return mock_connector
@pytest.fixture
def db_struct_retriever(mock_db_connection, mock_vector_store_connector):
def mock_field_vector_store_connector():
mock_connector = MagicMock()
mock_connector.similar_search_with_scores.return_value = [
Chunk(content="Field summary")
] * 4
return mock_connector
@pytest.fixture
def dbstruct_retriever(
mock_db_connection,
mock_table_vector_store_connector,
mock_field_vector_store_connector,
):
return DBSchemaRetriever(
connector=mock_db_connection,
index_store=mock_vector_store_connector,
table_vector_store_connector=mock_table_vector_store_connector,
field_vector_store_connector=mock_field_vector_store_connector,
)
def mock_parse_db_summary(conn) -> List[str]:
def mock_parse_db_summary() -> str:
"""Patch _parse_db_summary method."""
return ["Table summary"]
return "Table summary"
# Mocking the _parse_db_summary method in your test function
@patch.object(
dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary
)
def test_retrieve_with_mocked_summary(db_struct_retriever):
def test_retrieve_with_mocked_summary(dbstruct_retriever):
query = "Table summary"
chunks: List[Chunk] = db_struct_retriever._retrieve(query)
chunks: List[Chunk] = dbstruct_retriever._retrieve(query)
assert isinstance(chunks[0], Chunk)
assert chunks[0].content == "Table summary"
@pytest.mark.asyncio
@patch.object(
dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary
)
async def test_aretrieve_with_mocked_summary(db_struct_retriever):
query = "Table summary"
chunks: List[Chunk] = await db_struct_retriever._aretrieve(query)
assert isinstance(chunks[0], Chunk)
assert chunks[0].content == "Table summary"
async def async_mock_parse_db_summary() -> str:
"""Asynchronous patch for _parse_db_summary method."""
return "Table summary"