mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-23 04:12:13 +00:00
40 lines
852 B
Python
40 lines
852 B
Python
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
|
|
from dbgpt.core import Chunk
|
|
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
|
|
|
|
|
|
@pytest.fixture
|
|
def top_k():
|
|
return 4
|
|
|
|
|
|
@pytest.fixture
|
|
def query():
|
|
return "test query"
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_vector_store_connector():
|
|
return MagicMock()
|
|
|
|
|
|
@pytest.fixture
|
|
def embedding_retriever(top_k, mock_vector_store_connector):
|
|
return EmbeddingRetriever(
|
|
top_k=top_k,
|
|
query_rewrite=None,
|
|
index_store=mock_vector_store_connector,
|
|
)
|
|
|
|
|
|
def test_retrieve(query, top_k, mock_vector_store_connector, embedding_retriever):
|
|
expected_chunks = [Chunk() for _ in range(top_k)]
|
|
mock_vector_store_connector.similar_search.return_value = expected_chunks
|
|
|
|
retrieved_chunks = embedding_retriever._retrieve(query)
|
|
|
|
assert len(retrieved_chunks) == top_k
|