mirror of
https://github.com/hwchase17/langchain.git
synced 2025-04-28 20:05:58 +00:00
community: Repair embeddings/llamacpp's embed_query method (#29935)
**Description:** As commented on the commit [41b6a86
](41b6a86bbe
) it introduced a bug for when we do an embedding request and the model returns a non-nested list. Typically it's the case for model **_nomic-embed-text_**. - I added the unit test, and ran `make format`, `make lint` and `make test` from the `community` package. - No new dependency. --------- Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
parent
5ca4933b9d
commit
e586bffe51
@ -139,7 +139,7 @@ class LlamaCppEmbeddings(BaseModel, Embeddings):
|
|||||||
Embeddings for the text.
|
Embeddings for the text.
|
||||||
"""
|
"""
|
||||||
embedding = self.client.embed(text)
|
embedding = self.client.embed(text)
|
||||||
if not isinstance(embedding, list):
|
if embedding and isinstance(embedding, list) and isinstance(embedding[0], list):
|
||||||
return list(map(float, embedding))
|
|
||||||
else:
|
|
||||||
return list(map(float, embedding[0]))
|
return list(map(float, embedding[0]))
|
||||||
|
else:
|
||||||
|
return list(map(float, embedding))
|
||||||
|
@ -38,3 +38,11 @@ def test_embed_query(mock_llama_client: MagicMock) -> None:
|
|||||||
result = embeddings.embed_query("Sample query")
|
result = embeddings.embed_query("Sample query")
|
||||||
expected = [0.1, 0.2, 0.3]
|
expected = [0.1, 0.2, 0.3]
|
||||||
assert result == expected
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed_query_not_nested(mock_llama_client: MagicMock) -> None:
|
||||||
|
mock_llama_client.embed.return_value = [0.1, 0.2, 0.3]
|
||||||
|
embeddings = LlamaCppEmbeddings(client=mock_llama_client) # type: ignore[call-arg]
|
||||||
|
result = embeddings.embed_query("Sample query")
|
||||||
|
expected = [0.1, 0.2, 0.3]
|
||||||
|
assert result == expected
|
||||||
|
Loading…
Reference in New Issue
Block a user