DB-GPT/dbgpt/rag/retriever/tests/test_embedding.py
Aries-ckt 58d08780d6
feat(ChatKnowledge): ChatKnowledge Support Keyword Retrieve (#1624)
Co-authored-by: Fangyin Cheng <staneyffer@gmail.com>
2024-06-13 13:49:17 +08:00

40 lines
852 B
Python

from unittest.mock import MagicMock
import pytest
from dbgpt.core import Chunk
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
@pytest.fixture
def top_k():
return 4
@pytest.fixture
def query():
return "test query"
@pytest.fixture
def mock_vector_store_connector():
return MagicMock()
@pytest.fixture
def embedding_retriever(top_k, mock_vector_store_connector):
return EmbeddingRetriever(
top_k=top_k,
query_rewrite=None,
index_store=mock_vector_store_connector,
)
def test_retrieve(query, top_k, mock_vector_store_connector, embedding_retriever):
expected_chunks = [Chunk() for _ in range(top_k)]
mock_vector_store_connector.similar_search.return_value = expected_chunks
retrieved_chunks = embedding_retriever._retrieve(query)
assert len(retrieved_chunks) == top_k