Community: LlamaCppEmbeddings embed_documents and embed_query (#28827)

- **Description:** `embed_documents` and `embed_query` was throwing off
the error as stated in the issue. The issue was that `Llama` client is
returning the embeddings in a nested list which is not being accounted
for in the current implementation and therefore the stated error is
being raised.
- **Issue:** #28813

---------

Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
Mohammad Mohtashim 2024-12-23 19:50:22 +05:00 committed by GitHub
parent 32917a0b98
commit 41b6a86bbe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 71 additions and 17 deletions

View File

@ -20,7 +20,7 @@ class LlamaCppEmbeddings(BaseModel, Embeddings):
""" """
client: Any = None #: :meta private: client: Any = None #: :meta private:
model_path: str model_path: str = Field(default="")
n_ctx: int = Field(512, alias="n_ctx") n_ctx: int = Field(512, alias="n_ctx")
"""Token context window.""" """Token context window."""
@ -88,21 +88,22 @@ class LlamaCppEmbeddings(BaseModel, Embeddings):
if self.n_gpu_layers is not None: if self.n_gpu_layers is not None:
model_params["n_gpu_layers"] = self.n_gpu_layers model_params["n_gpu_layers"] = self.n_gpu_layers
try: if not self.client:
from llama_cpp import Llama try:
from llama_cpp import Llama
self.client = Llama(model_path, embedding=True, **model_params) self.client = Llama(model_path, embedding=True, **model_params)
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"Could not import llama-cpp-python library. " "Could not import llama-cpp-python library. "
"Please install the llama-cpp-python library to " "Please install the llama-cpp-python library to "
"use this embedding model: pip install llama-cpp-python" "use this embedding model: pip install llama-cpp-python"
) )
except Exception as e: except Exception as e:
raise ValueError( raise ValueError(
f"Could not load Llama model from path: {model_path}. " f"Could not load Llama model from path: {model_path}. "
f"Received error {e}" f"Received error {e}"
) )
return self return self
@ -116,7 +117,17 @@ class LlamaCppEmbeddings(BaseModel, Embeddings):
List of embeddings, one for each text. List of embeddings, one for each text.
""" """
embeddings = self.client.create_embedding(texts) embeddings = self.client.create_embedding(texts)
return [list(map(float, e["embedding"])) for e in embeddings["data"]] final_embeddings = []
for e in embeddings["data"]:
try:
if isinstance(e["embedding"][0], list):
for data in e["embedding"]:
final_embeddings.append(list(map(float, data)))
else:
final_embeddings.append(list(map(float, e["embedding"])))
except (IndexError, TypeError):
final_embeddings.append(list(map(float, e["embedding"])))
return final_embeddings
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> List[float]:
"""Embed a query using the Llama model. """Embed a query using the Llama model.
@ -128,4 +139,7 @@ class LlamaCppEmbeddings(BaseModel, Embeddings):
Embeddings for the text. Embeddings for the text.
""" """
embedding = self.client.embed(text) embedding = self.client.embed(text)
return list(map(float, embedding)) if not isinstance(embedding, list):
return list(map(float, embedding))
else:
return list(map(float, embedding[0]))

View File

@ -0,0 +1,40 @@
from typing import Generator
from unittest.mock import MagicMock, patch
import pytest
from langchain_community.embeddings.llamacpp import LlamaCppEmbeddings
@pytest.fixture
def mock_llama_client() -> Generator[MagicMock, None, None]:
with patch(
"langchain_community.embeddings.llamacpp.LlamaCppEmbeddings"
) as MockLlama:
mock_client = MagicMock()
MockLlama.return_value = mock_client
yield mock_client
def test_initialization(mock_llama_client: MagicMock) -> None:
embeddings = LlamaCppEmbeddings(client=mock_llama_client) # type: ignore[call-arg]
assert embeddings.client is not None
def test_embed_documents(mock_llama_client: MagicMock) -> None:
mock_llama_client.create_embedding.return_value = {
"data": [{"embedding": [[0.1, 0.2, 0.3]]}, {"embedding": [[0.4, 0.5, 0.6]]}]
}
embeddings = LlamaCppEmbeddings(client=mock_llama_client) # type: ignore[call-arg]
texts = ["Hello world", "Test document"]
result = embeddings.embed_documents(texts)
expected = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
assert result == expected
def test_embed_query(mock_llama_client: MagicMock) -> None:
mock_llama_client.embed.return_value = [[0.1, 0.2, 0.3]]
embeddings = LlamaCppEmbeddings(client=mock_llama_client) # type: ignore[call-arg]
result = embeddings.embed_query("Sample query")
expected = [0.1, 0.2, 0.3]
assert result == expected