mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-14 21:51:25 +00:00
fix(rag): Fix db schema aretriever bug (#1755)
This commit is contained in:
@@ -22,29 +22,35 @@ def mock_vector_store_connector():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dbstruct_retriever(mock_db_connection, mock_vector_store_connector):
|
||||
def db_struct_retriever(mock_db_connection, mock_vector_store_connector):
|
||||
return DBSchemaRetriever(
|
||||
connector=mock_db_connection,
|
||||
index_store=mock_vector_store_connector,
|
||||
)
|
||||
|
||||
|
||||
def mock_parse_db_summary() -> str:
|
||||
def mock_parse_db_summary(conn) -> List[str]:
|
||||
"""Patch _parse_db_summary method."""
|
||||
return "Table summary"
|
||||
return ["Table summary"]
|
||||
|
||||
|
||||
# Mocking the _parse_db_summary method in your test function
|
||||
@patch.object(
|
||||
dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary
|
||||
)
|
||||
def test_retrieve_with_mocked_summary(dbstruct_retriever):
|
||||
def test_retrieve_with_mocked_summary(db_struct_retriever):
|
||||
query = "Table summary"
|
||||
chunks: List[Chunk] = dbstruct_retriever._retrieve(query)
|
||||
chunks: List[Chunk] = db_struct_retriever._retrieve(query)
|
||||
assert isinstance(chunks[0], Chunk)
|
||||
assert chunks[0].content == "Table summary"
|
||||
|
||||
|
||||
async def async_mock_parse_db_summary() -> str:
|
||||
"""Asynchronous patch for _parse_db_summary method."""
|
||||
return "Table summary"
|
||||
@pytest.mark.asyncio
|
||||
@patch.object(
|
||||
dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary
|
||||
)
|
||||
async def test_aretrieve_with_mocked_summary(db_struct_retriever):
|
||||
query = "Table summary"
|
||||
chunks: List[Chunk] = await db_struct_retriever._aretrieve(query)
|
||||
assert isinstance(chunks[0], Chunk)
|
||||
assert chunks[0].content == "Table summary"
|
||||
|
Reference in New Issue
Block a user