mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 13:23:35 +00:00
Community: LlamaCppEmbeddings embed_documents
and embed_query
(#28827)
- **Description:** `embed_documents` and `embed_query` was throwing off the error as stated in the issue. The issue was that `Llama` client is returning the embeddings in a nested list which is not being accounted for in the current implementation and therefore the stated error is being raised. - **Issue:** #28813 --------- Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
parent
32917a0b98
commit
41b6a86bbe
@ -20,7 +20,7 @@ class LlamaCppEmbeddings(BaseModel, Embeddings):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
client: Any = None #: :meta private:
|
client: Any = None #: :meta private:
|
||||||
model_path: str
|
model_path: str = Field(default="")
|
||||||
|
|
||||||
n_ctx: int = Field(512, alias="n_ctx")
|
n_ctx: int = Field(512, alias="n_ctx")
|
||||||
"""Token context window."""
|
"""Token context window."""
|
||||||
@ -88,21 +88,22 @@ class LlamaCppEmbeddings(BaseModel, Embeddings):
|
|||||||
if self.n_gpu_layers is not None:
|
if self.n_gpu_layers is not None:
|
||||||
model_params["n_gpu_layers"] = self.n_gpu_layers
|
model_params["n_gpu_layers"] = self.n_gpu_layers
|
||||||
|
|
||||||
try:
|
if not self.client:
|
||||||
from llama_cpp import Llama
|
try:
|
||||||
|
from llama_cpp import Llama
|
||||||
|
|
||||||
self.client = Llama(model_path, embedding=True, **model_params)
|
self.client = Llama(model_path, embedding=True, **model_params)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Could not import llama-cpp-python library. "
|
"Could not import llama-cpp-python library. "
|
||||||
"Please install the llama-cpp-python library to "
|
"Please install the llama-cpp-python library to "
|
||||||
"use this embedding model: pip install llama-cpp-python"
|
"use this embedding model: pip install llama-cpp-python"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Could not load Llama model from path: {model_path}. "
|
f"Could not load Llama model from path: {model_path}. "
|
||||||
f"Received error {e}"
|
f"Received error {e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@ -116,7 +117,17 @@ class LlamaCppEmbeddings(BaseModel, Embeddings):
|
|||||||
List of embeddings, one for each text.
|
List of embeddings, one for each text.
|
||||||
"""
|
"""
|
||||||
embeddings = self.client.create_embedding(texts)
|
embeddings = self.client.create_embedding(texts)
|
||||||
return [list(map(float, e["embedding"])) for e in embeddings["data"]]
|
final_embeddings = []
|
||||||
|
for e in embeddings["data"]:
|
||||||
|
try:
|
||||||
|
if isinstance(e["embedding"][0], list):
|
||||||
|
for data in e["embedding"]:
|
||||||
|
final_embeddings.append(list(map(float, data)))
|
||||||
|
else:
|
||||||
|
final_embeddings.append(list(map(float, e["embedding"])))
|
||||||
|
except (IndexError, TypeError):
|
||||||
|
final_embeddings.append(list(map(float, e["embedding"])))
|
||||||
|
return final_embeddings
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
def embed_query(self, text: str) -> List[float]:
|
||||||
"""Embed a query using the Llama model.
|
"""Embed a query using the Llama model.
|
||||||
@ -128,4 +139,7 @@ class LlamaCppEmbeddings(BaseModel, Embeddings):
|
|||||||
Embeddings for the text.
|
Embeddings for the text.
|
||||||
"""
|
"""
|
||||||
embedding = self.client.embed(text)
|
embedding = self.client.embed(text)
|
||||||
return list(map(float, embedding))
|
if not isinstance(embedding, list):
|
||||||
|
return list(map(float, embedding))
|
||||||
|
else:
|
||||||
|
return list(map(float, embedding[0]))
|
||||||
|
40
libs/community/tests/unit_tests/embeddings/test_llamacpp.py
Normal file
40
libs/community/tests/unit_tests/embeddings/test_llamacpp.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
from typing import Generator
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain_community.embeddings.llamacpp import LlamaCppEmbeddings
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_llama_client() -> Generator[MagicMock, None, None]:
|
||||||
|
with patch(
|
||||||
|
"langchain_community.embeddings.llamacpp.LlamaCppEmbeddings"
|
||||||
|
) as MockLlama:
|
||||||
|
mock_client = MagicMock()
|
||||||
|
MockLlama.return_value = mock_client
|
||||||
|
yield mock_client
|
||||||
|
|
||||||
|
|
||||||
|
def test_initialization(mock_llama_client: MagicMock) -> None:
|
||||||
|
embeddings = LlamaCppEmbeddings(client=mock_llama_client) # type: ignore[call-arg]
|
||||||
|
assert embeddings.client is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed_documents(mock_llama_client: MagicMock) -> None:
|
||||||
|
mock_llama_client.create_embedding.return_value = {
|
||||||
|
"data": [{"embedding": [[0.1, 0.2, 0.3]]}, {"embedding": [[0.4, 0.5, 0.6]]}]
|
||||||
|
}
|
||||||
|
embeddings = LlamaCppEmbeddings(client=mock_llama_client) # type: ignore[call-arg]
|
||||||
|
texts = ["Hello world", "Test document"]
|
||||||
|
result = embeddings.embed_documents(texts)
|
||||||
|
expected = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed_query(mock_llama_client: MagicMock) -> None:
|
||||||
|
mock_llama_client.embed.return_value = [[0.1, 0.2, 0.3]]
|
||||||
|
embeddings = LlamaCppEmbeddings(client=mock_llama_client) # type: ignore[call-arg]
|
||||||
|
result = embeddings.embed_query("Sample query")
|
||||||
|
expected = [0.1, 0.2, 0.3]
|
||||||
|
assert result == expected
|
Loading…
Reference in New Issue
Block a user