mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-23 20:26:15 +00:00
57 lines
1.6 KiB
Python
57 lines
1.6 KiB
Python
from typing import List
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
import dbgpt
|
|
from dbgpt.core import Chunk
|
|
from dbgpt.rag.retriever.db_schema import DBSchemaRetriever
|
|
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_db_connection():
|
|
return MagicMock()
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_vector_store_connector():
|
|
mock_connector = MagicMock()
|
|
mock_connector.similar_search.return_value = [Chunk(content="Table summary")] * 4
|
|
return mock_connector
|
|
|
|
|
|
@pytest.fixture
|
|
def db_struct_retriever(mock_db_connection, mock_vector_store_connector):
|
|
return DBSchemaRetriever(
|
|
connector=mock_db_connection,
|
|
index_store=mock_vector_store_connector,
|
|
)
|
|
|
|
|
|
def mock_parse_db_summary(conn) -> List[str]:
|
|
"""Patch _parse_db_summary method."""
|
|
return ["Table summary"]
|
|
|
|
|
|
# Mocking the _parse_db_summary method in your test function
|
|
@patch.object(
|
|
dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary
|
|
)
|
|
def test_retrieve_with_mocked_summary(db_struct_retriever):
|
|
query = "Table summary"
|
|
chunks: List[Chunk] = db_struct_retriever._retrieve(query)
|
|
assert isinstance(chunks[0], Chunk)
|
|
assert chunks[0].content == "Table summary"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@patch.object(
|
|
dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary
|
|
)
|
|
async def test_aretrieve_with_mocked_summary(db_struct_retriever):
|
|
query = "Table summary"
|
|
chunks: List[Chunk] = await db_struct_retriever._aretrieve(query)
|
|
assert isinstance(chunks[0], Chunk)
|
|
assert chunks[0].content == "Table summary"
|