langchain/libs/community/tests/unit_tests/embeddings/test_llamacpp.py
Julien Elkaim e586bffe51
community: Repair embeddings/llamacpp's embed_query method (#29935)
**Description:** As commented on the commit
[41b6a86](41b6a86bbe)
it introduced a bug for when we do an embedding request and the model
returns a non-nested list. Typically it's the case for model
**_nomic-embed-text_**.

- I added the unit test, and ran `make format`, `make lint` and `make
test` from the `community` package.
- No new dependency.

---------

Co-authored-by: Chester Curme <chester.curme@gmail.com>
2025-02-23 19:32:17 +00:00

49 lines
1.7 KiB
Python

from typing import Generator
from unittest.mock import MagicMock, patch
import pytest
from langchain_community.embeddings.llamacpp import LlamaCppEmbeddings
@pytest.fixture
def mock_llama_client() -> Generator[MagicMock, None, None]:
with patch(
"langchain_community.embeddings.llamacpp.LlamaCppEmbeddings"
) as MockLlama:
mock_client = MagicMock()
MockLlama.return_value = mock_client
yield mock_client
def test_initialization(mock_llama_client: MagicMock) -> None:
embeddings = LlamaCppEmbeddings(client=mock_llama_client) # type: ignore[call-arg]
assert embeddings.client is not None
def test_embed_documents(mock_llama_client: MagicMock) -> None:
mock_llama_client.create_embedding.return_value = {
"data": [{"embedding": [[0.1, 0.2, 0.3]]}, {"embedding": [[0.4, 0.5, 0.6]]}]
}
embeddings = LlamaCppEmbeddings(client=mock_llama_client) # type: ignore[call-arg]
texts = ["Hello world", "Test document"]
result = embeddings.embed_documents(texts)
expected = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
assert result == expected
def test_embed_query(mock_llama_client: MagicMock) -> None:
mock_llama_client.embed.return_value = [[0.1, 0.2, 0.3]]
embeddings = LlamaCppEmbeddings(client=mock_llama_client) # type: ignore[call-arg]
result = embeddings.embed_query("Sample query")
expected = [0.1, 0.2, 0.3]
assert result == expected
def test_embed_query_not_nested(mock_llama_client: MagicMock) -> None:
mock_llama_client.embed.return_value = [0.1, 0.2, 0.3]
embeddings = LlamaCppEmbeddings(client=mock_llama_client) # type: ignore[call-arg]
result = embeddings.embed_query("Sample query")
expected = [0.1, 0.2, 0.3]
assert result == expected