fix(rag): Fix db schema aretriever bug (#1755)

This commit is contained in:
Fangyin Cheng
2024-07-30 10:30:42 +08:00
committed by GitHub
parent 55c8b39e2e
commit 25d7d94b89
3 changed files with 18 additions and 30 deletions

View File

@@ -22,29 +22,35 @@ def mock_vector_store_connector():
@pytest.fixture
def dbstruct_retriever(mock_db_connection, mock_vector_store_connector):
def db_struct_retriever(mock_db_connection, mock_vector_store_connector):
return DBSchemaRetriever(
connector=mock_db_connection,
index_store=mock_vector_store_connector,
)
def mock_parse_db_summary() -> str:
def mock_parse_db_summary(conn) -> List[str]:
"""Patch _parse_db_summary method."""
return "Table summary"
return ["Table summary"]
# Mocking the _parse_db_summary method in your test function
@patch.object(
dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary
)
def test_retrieve_with_mocked_summary(dbstruct_retriever):
def test_retrieve_with_mocked_summary(db_struct_retriever):
query = "Table summary"
chunks: List[Chunk] = dbstruct_retriever._retrieve(query)
chunks: List[Chunk] = db_struct_retriever._retrieve(query)
assert isinstance(chunks[0], Chunk)
assert chunks[0].content == "Table summary"
async def async_mock_parse_db_summary() -> str:
"""Asynchronous patch for _parse_db_summary method."""
return "Table summary"
@pytest.mark.asyncio
@patch.object(
dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary
)
async def test_aretrieve_with_mocked_summary(db_struct_retriever):
query = "Table summary"
chunks: List[Chunk] = await db_struct_retriever._aretrieve(query)
assert isinstance(chunks[0], Chunk)
assert chunks[0].content == "Table summary"