From e586bffe512789192f93f7a450e463a804bbebf9 Mon Sep 17 00:00:00 2001 From: Julien Elkaim <32833946+JulienElkaim@users.noreply.github.com> Date: Mon, 24 Feb 2025 03:32:17 +0800 Subject: [PATCH] community: Repair embeddings/llamacpp's embed_query method (#29935) **Description:** As commented on the commit [41b6a86](https://github.com/langchain-ai/langchain/commit/41b6a86bbe030291cf8ee284ed0cd70dd493152b) it introduced a bug for when we do an embedding request and the model returns a non-nested list. Typically it's the case for model **_nomic-embed-text_**. - I added the unit test, and ran `make format`, `make lint` and `make test` from the `community` package. - No new dependency. --------- Co-authored-by: Chester Curme --- libs/community/langchain_community/embeddings/llamacpp.py | 6 +++--- .../tests/unit_tests/embeddings/test_llamacpp.py | 8 ++++++++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/libs/community/langchain_community/embeddings/llamacpp.py b/libs/community/langchain_community/embeddings/llamacpp.py index 4adfeb0e527..e4ebe33b33c 100644 --- a/libs/community/langchain_community/embeddings/llamacpp.py +++ b/libs/community/langchain_community/embeddings/llamacpp.py @@ -139,7 +139,7 @@ class LlamaCppEmbeddings(BaseModel, Embeddings): Embeddings for the text. """ embedding = self.client.embed(text) - if not isinstance(embedding, list): - return list(map(float, embedding)) - else: + if embedding and isinstance(embedding, list) and isinstance(embedding[0], list): return list(map(float, embedding[0])) + else: + return list(map(float, embedding)) diff --git a/libs/community/tests/unit_tests/embeddings/test_llamacpp.py b/libs/community/tests/unit_tests/embeddings/test_llamacpp.py index ca2bd758216..552ccd87d06 100644 --- a/libs/community/tests/unit_tests/embeddings/test_llamacpp.py +++ b/libs/community/tests/unit_tests/embeddings/test_llamacpp.py @@ -38,3 +38,11 @@ def test_embed_query(mock_llama_client: MagicMock) -> None: result = embeddings.embed_query("Sample query") expected = [0.1, 0.2, 0.3] assert result == expected + + +def test_embed_query_not_nested(mock_llama_client: MagicMock) -> None: + mock_llama_client.embed.return_value = [0.1, 0.2, 0.3] + embeddings = LlamaCppEmbeddings(client=mock_llama_client) # type: ignore[call-arg] + result = embeddings.embed_query("Sample query") + expected = [0.1, 0.2, 0.3] + assert result == expected