refactor: RAG Refactor (#985)

Co-authored-by: Aralhi <xiaoping0501@gmail.com>
Co-authored-by: csunny <cfqsunny@163.com>
This commit is contained in:
Aries-ckt
2024-01-03 09:45:26 +08:00
committed by GitHub
parent 90775aad50
commit 9ad70a2961
206 changed files with 5766 additions and 2419 deletions

View File

View File

@@ -0,0 +1,49 @@
import pytest
from unittest.mock import MagicMock, patch, AsyncMock
from typing import List
import dbgpt
from dbgpt.rag.chunk import Chunk
from dbgpt.rag.retriever.db_struct import DBStructRetriever
from dbgpt.rag.summary.rdbms_db_summary import _parse_db_summary
@pytest.fixture
def mock_db_connection():
return MagicMock()
@pytest.fixture
def mock_vector_store_connector():
mock_connector = MagicMock()
mock_connector.similar_search.return_value = [Chunk(content="Table summary")] * 4
return mock_connector
@pytest.fixture
def dbstruct_retriever(mock_db_connection, mock_vector_store_connector):
return DBStructRetriever(
connection=mock_db_connection,
vector_store_connector=mock_vector_store_connector,
)
def mock_parse_db_summary() -> str:
"""Patch _parse_db_summary method."""
return "Table summary"
# Mocking the _parse_db_summary method in your test function
@patch.object(
dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary
)
def test_retrieve_with_mocked_summary(dbstruct_retriever):
query = "Table summary"
chunks: List[Chunk] = dbstruct_retriever._retrieve(query)
assert isinstance(chunks[0], Chunk)
assert chunks[0].content == "Table summary"
async def async_mock_parse_db_summary() -> str:
"""Asynchronous patch for _parse_db_summary method."""
return "Table summary"

View File

@@ -0,0 +1,39 @@
from unittest.mock import MagicMock
import pytest
from dbgpt.rag.chunk import Chunk
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
@pytest.fixture
def top_k():
return 4
@pytest.fixture
def query():
return "test query"
@pytest.fixture
def mock_vector_store_connector():
return MagicMock()
@pytest.fixture
def embedding_retriever(top_k, mock_vector_store_connector):
return EmbeddingRetriever(
top_k=top_k,
query_rewrite=False,
vector_store_connector=mock_vector_store_connector,
)
def test_retrieve(query, top_k, mock_vector_store_connector, embedding_retriever):
expected_chunks = [Chunk() for _ in range(top_k)]
mock_vector_store_connector.similar_search.return_value = expected_chunks
retrieved_chunks = embedding_retriever._retrieve(query)
mock_vector_store_connector.similar_search.assert_called_once_with(query, top_k)
assert len(retrieved_chunks) == top_k