diff --git a/libs/community/langchain_community/embeddings/llamacpp.py b/libs/community/langchain_community/embeddings/llamacpp.py index 4adfeb0e527..e4ebe33b33c 100644 --- a/libs/community/langchain_community/embeddings/llamacpp.py +++ b/libs/community/langchain_community/embeddings/llamacpp.py @@ -139,7 +139,7 @@ class LlamaCppEmbeddings(BaseModel, Embeddings): Embeddings for the text. """ embedding = self.client.embed(text) - if not isinstance(embedding, list): - return list(map(float, embedding)) - else: + if embedding and isinstance(embedding, list) and isinstance(embedding[0], list): return list(map(float, embedding[0])) + else: + return list(map(float, embedding)) diff --git a/libs/community/tests/unit_tests/embeddings/test_llamacpp.py b/libs/community/tests/unit_tests/embeddings/test_llamacpp.py index ca2bd758216..552ccd87d06 100644 --- a/libs/community/tests/unit_tests/embeddings/test_llamacpp.py +++ b/libs/community/tests/unit_tests/embeddings/test_llamacpp.py @@ -38,3 +38,11 @@ def test_embed_query(mock_llama_client: MagicMock) -> None: result = embeddings.embed_query("Sample query") expected = [0.1, 0.2, 0.3] assert result == expected + + +def test_embed_query_not_nested(mock_llama_client: MagicMock) -> None: + mock_llama_client.embed.return_value = [0.1, 0.2, 0.3] + embeddings = LlamaCppEmbeddings(client=mock_llama_client) # type: ignore[call-arg] + result = embeddings.embed_query("Sample query") + expected = [0.1, 0.2, 0.3] + assert result == expected