mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-25 04:49:17 +00:00
Community: LlamaCppEmbeddings embed_documents
and embed_query
(#28827)
- **Description:** `embed_documents` and `embed_query` was throwing off the error as stated in the issue. The issue was that `Llama` client is returning the embeddings in a nested list which is not being accounted for in the current implementation and therefore the stated error is being raised. - **Issue:** #28813 --------- Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
committed by
GitHub
parent
32917a0b98
commit
41b6a86bbe
40
libs/community/tests/unit_tests/embeddings/test_llamacpp.py
Normal file
40
libs/community/tests/unit_tests/embeddings/test_llamacpp.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from typing import Generator
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_community.embeddings.llamacpp import LlamaCppEmbeddings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llama_client() -> Generator[MagicMock, None, None]:
|
||||
with patch(
|
||||
"langchain_community.embeddings.llamacpp.LlamaCppEmbeddings"
|
||||
) as MockLlama:
|
||||
mock_client = MagicMock()
|
||||
MockLlama.return_value = mock_client
|
||||
yield mock_client
|
||||
|
||||
|
||||
def test_initialization(mock_llama_client: MagicMock) -> None:
|
||||
embeddings = LlamaCppEmbeddings(client=mock_llama_client) # type: ignore[call-arg]
|
||||
assert embeddings.client is not None
|
||||
|
||||
|
||||
def test_embed_documents(mock_llama_client: MagicMock) -> None:
|
||||
mock_llama_client.create_embedding.return_value = {
|
||||
"data": [{"embedding": [[0.1, 0.2, 0.3]]}, {"embedding": [[0.4, 0.5, 0.6]]}]
|
||||
}
|
||||
embeddings = LlamaCppEmbeddings(client=mock_llama_client) # type: ignore[call-arg]
|
||||
texts = ["Hello world", "Test document"]
|
||||
result = embeddings.embed_documents(texts)
|
||||
expected = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_embed_query(mock_llama_client: MagicMock) -> None:
|
||||
mock_llama_client.embed.return_value = [[0.1, 0.2, 0.3]]
|
||||
embeddings = LlamaCppEmbeddings(client=mock_llama_client) # type: ignore[call-arg]
|
||||
result = embeddings.embed_query("Sample query")
|
||||
expected = [0.1, 0.2, 0.3]
|
||||
assert result == expected
|
Reference in New Issue
Block a user