mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-11 13:55:03 +00:00
**Description:** As commented on the commit [41b6a86
](41b6a86bbe
) it introduced a bug for when we do an embedding request and the model returns a non-nested list. Typically it's the case for model **_nomic-embed-text_**. - I added the unit test, and ran `make format`, `make lint` and `make test` from the `community` package. - No new dependency. --------- Co-authored-by: Chester Curme <chester.curme@gmail.com>
49 lines
1.7 KiB
Python
49 lines
1.7 KiB
Python
from typing import Generator
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from langchain_community.embeddings.llamacpp import LlamaCppEmbeddings
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_llama_client() -> Generator[MagicMock, None, None]:
|
|
with patch(
|
|
"langchain_community.embeddings.llamacpp.LlamaCppEmbeddings"
|
|
) as MockLlama:
|
|
mock_client = MagicMock()
|
|
MockLlama.return_value = mock_client
|
|
yield mock_client
|
|
|
|
|
|
def test_initialization(mock_llama_client: MagicMock) -> None:
|
|
embeddings = LlamaCppEmbeddings(client=mock_llama_client) # type: ignore[call-arg]
|
|
assert embeddings.client is not None
|
|
|
|
|
|
def test_embed_documents(mock_llama_client: MagicMock) -> None:
|
|
mock_llama_client.create_embedding.return_value = {
|
|
"data": [{"embedding": [[0.1, 0.2, 0.3]]}, {"embedding": [[0.4, 0.5, 0.6]]}]
|
|
}
|
|
embeddings = LlamaCppEmbeddings(client=mock_llama_client) # type: ignore[call-arg]
|
|
texts = ["Hello world", "Test document"]
|
|
result = embeddings.embed_documents(texts)
|
|
expected = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
|
assert result == expected
|
|
|
|
|
|
def test_embed_query(mock_llama_client: MagicMock) -> None:
|
|
mock_llama_client.embed.return_value = [[0.1, 0.2, 0.3]]
|
|
embeddings = LlamaCppEmbeddings(client=mock_llama_client) # type: ignore[call-arg]
|
|
result = embeddings.embed_query("Sample query")
|
|
expected = [0.1, 0.2, 0.3]
|
|
assert result == expected
|
|
|
|
|
|
def test_embed_query_not_nested(mock_llama_client: MagicMock) -> None:
|
|
mock_llama_client.embed.return_value = [0.1, 0.2, 0.3]
|
|
embeddings = LlamaCppEmbeddings(client=mock_llama_client) # type: ignore[call-arg]
|
|
result = embeddings.embed_query("Sample query")
|
|
expected = [0.1, 0.2, 0.3]
|
|
assert result == expected
|