mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-14 05:31:40 +00:00
refactor: RAG Refactor (#985)
Co-authored-by: Aralhi <xiaoping0501@gmail.com> Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
0
dbgpt/rag/retriever/tests/__init__.py
Normal file
0
dbgpt/rag/retriever/tests/__init__.py
Normal file
49
dbgpt/rag/retriever/tests/test_db_struct.py
Normal file
49
dbgpt/rag/retriever/tests/test_db_struct.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
from typing import List
|
||||
|
||||
import dbgpt
|
||||
from dbgpt.rag.chunk import Chunk
|
||||
from dbgpt.rag.retriever.db_struct import DBStructRetriever
|
||||
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_connection():
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_store_connector():
|
||||
mock_connector = MagicMock()
|
||||
mock_connector.similar_search.return_value = [Chunk(content="Table summary")] * 4
|
||||
return mock_connector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dbstruct_retriever(mock_db_connection, mock_vector_store_connector):
|
||||
return DBStructRetriever(
|
||||
connection=mock_db_connection,
|
||||
vector_store_connector=mock_vector_store_connector,
|
||||
)
|
||||
|
||||
|
||||
def mock_parse_db_summary() -> str:
|
||||
"""Patch _parse_db_summary method."""
|
||||
return "Table summary"
|
||||
|
||||
|
||||
# Mocking the _parse_db_summary method in your test function
|
||||
@patch.object(
|
||||
dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary
|
||||
)
|
||||
def test_retrieve_with_mocked_summary(dbstruct_retriever):
|
||||
query = "Table summary"
|
||||
chunks: List[Chunk] = dbstruct_retriever._retrieve(query)
|
||||
assert isinstance(chunks[0], Chunk)
|
||||
assert chunks[0].content == "Table summary"
|
||||
|
||||
|
||||
async def async_mock_parse_db_summary() -> str:
|
||||
"""Asynchronous patch for _parse_db_summary method."""
|
||||
return "Table summary"
|
39
dbgpt/rag/retriever/tests/test_embedding.py
Normal file
39
dbgpt/rag/retriever/tests/test_embedding.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from unittest.mock import MagicMock
|
||||
import pytest
|
||||
|
||||
from dbgpt.rag.chunk import Chunk
|
||||
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def top_k():
|
||||
return 4
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def query():
|
||||
return "test query"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_store_connector():
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def embedding_retriever(top_k, mock_vector_store_connector):
|
||||
return EmbeddingRetriever(
|
||||
top_k=top_k,
|
||||
query_rewrite=False,
|
||||
vector_store_connector=mock_vector_store_connector,
|
||||
)
|
||||
|
||||
|
||||
def test_retrieve(query, top_k, mock_vector_store_connector, embedding_retriever):
|
||||
expected_chunks = [Chunk() for _ in range(top_k)]
|
||||
mock_vector_store_connector.similar_search.return_value = expected_chunks
|
||||
|
||||
retrieved_chunks = embedding_retriever._retrieve(query)
|
||||
|
||||
mock_vector_store_connector.similar_search.assert_called_once_with(query, top_k)
|
||||
assert len(retrieved_chunks) == top_k
|
Reference in New Issue
Block a user