feat(perplexity): add PerplexityEmbeddings (#37082)

## Description

This PR adds a new `PerplexityEmbeddings` class to the
`langchain-perplexity` partner package, providing first-class support
for the Perplexity Embeddings API alongside the existing
`ChatPerplexity`, `PerplexitySearchRetriever`, and
`PerplexitySearchResults` integrations.

### What was added

- `langchain_perplexity/embeddings.py` — `PerplexityEmbeddings` class
implementing `langchain_core.embeddings.Embeddings` with sync
(`embed_documents`, `embed_query`) and async (`aembed_documents`,
`aembed_query`) methods. Defaults to model `pplx-embed-v1-4b` and reuses
the existing `_utils.initialize_client` helper for API key resolution
(`PPLX_API_KEY` / `PERPLEXITY_API_KEY`).
- `__init__.py` exports `PerplexityEmbeddings` and adds it to `__all__`.
- Unit tests under `tests/unit_tests/test_embeddings.py` covering
sync/async paths with mocked clients (no network).
- Integration tests under `tests/integration_tests/test_embeddings.py`,
gated on `PPLX_API_KEY` (matches the pattern in `test_search_api.py`).
- README updated to advertise the new component.

### Why

LangChain users already get chat, search, and tool wrappers from
`langchain-perplexity`, but had to drop down to the raw Perplexity SDK
to use embeddings. This closes that gap.

### References

- Perplexity Embeddings docs: https://docs.perplexity.ai/docs/embeddings
- Perplexity Embeddings API reference:
https://docs.perplexity.ai/api-reference/embeddings-post

### Issue

Closes #36726

## Testing

- `cd libs/partners/perplexity && make lint` — passes (ruff, format,
mypy).
- `cd libs/partners/perplexity && make test` — all unit tests pass (59
passed, 1 skipped).
- Integration tests will run in CI with secrets; they exercise real
`embed_documents` / `embed_query` / async variants against the live API
and assert vector dimensionality consistency.

---------

Co-authored-by: Claude Agent <agent@anthropic.com>
Co-authored-by: Mason Daugherty <github@mdrxy.com>
This commit is contained in:
James Liounis
2026-04-29 17:51:50 -04:00
committed by GitHub
parent 90b0047270
commit 28f5448dd4
11 changed files with 550 additions and 41 deletions

View File

@@ -0,0 +1,56 @@
"""Integration tests for Perplexity Embeddings API."""
import os
import pytest
from langchain_perplexity import PerplexityEmbeddings
@pytest.mark.skipif(
not (os.environ.get("PPLX_API_KEY") or os.environ.get("PERPLEXITY_API_KEY")),
reason="PPLX_API_KEY/PERPLEXITY_API_KEY not set",
)
class TestPerplexityEmbeddings:
def test_embed_documents(self) -> None:
"""Test embedding a list of documents."""
embeddings = PerplexityEmbeddings()
texts = ["hello world", "goodbye world"]
vectors = embeddings.embed_documents(texts)
assert len(vectors) == len(texts)
assert all(isinstance(v, list) for v in vectors)
assert all(len(v) > 0 for v in vectors)
# All vectors should have the same dimensionality.
assert len({len(v) for v in vectors}) == 1
assert all(isinstance(x, float) for x in vectors[0])
def test_embed_query(self) -> None:
"""Test embedding a single query."""
embeddings = PerplexityEmbeddings()
vector = embeddings.embed_query("What is the capital of France?")
assert isinstance(vector, list)
assert len(vector) > 0
assert all(isinstance(x, float) for x in vector)
def test_embed_query_matches_documents_dim(self) -> None:
"""Embeddings from query and documents should share dimensionality."""
embeddings = PerplexityEmbeddings()
query_vec = embeddings.embed_query("hello")
doc_vecs = embeddings.embed_documents(["hello"])
assert len(query_vec) == len(doc_vecs[0])
async def test_aembed_documents(self) -> None:
"""Test async embedding a list of documents."""
embeddings = PerplexityEmbeddings()
vectors = await embeddings.aembed_documents(["hello", "world"])
assert len(vectors) == 2
assert all(len(v) > 0 for v in vectors)
async def test_aembed_query(self) -> None:
"""Test async embedding a single query."""
embeddings = PerplexityEmbeddings()
vector = await embeddings.aembed_query("hello")
assert isinstance(vector, list)
assert len(vector) > 0

View File

@@ -0,0 +1,23 @@
"""Standard integration tests for `PerplexityEmbeddings`."""
import os
import pytest
from langchain_core.embeddings import Embeddings
from langchain_tests.integration_tests import EmbeddingsIntegrationTests
from langchain_perplexity import PerplexityEmbeddings
@pytest.mark.skipif(
not (os.environ.get("PPLX_API_KEY") or os.environ.get("PERPLEXITY_API_KEY")),
reason="PPLX_API_KEY/PERPLEXITY_API_KEY not set",
)
class TestPerplexityEmbeddingsIntegration(EmbeddingsIntegrationTests):
@property
def embeddings_class(self) -> type[Embeddings]:
return PerplexityEmbeddings
@property
def embedding_model_params(self) -> dict:
return {}

View File

@@ -0,0 +1,203 @@
"""Unit tests for `PerplexityEmbeddings`."""
import base64
import struct
from unittest.mock import AsyncMock, MagicMock
import pytest
from pydantic import SecretStr
from langchain_perplexity import PerplexityEmbeddings
def _encode_int8(values: list[int]) -> str:
"""Encode signed int8 values as base64 (matches Perplexity's wire format)."""
raw = struct.pack(f"<{len(values)}b", *values)
return base64.b64encode(raw).decode("ascii")
def _make_response(int8_vectors: list[list[int]]) -> MagicMock:
"""Build a stand-in for `EmbeddingCreateResponse` with base64_int8 payloads."""
response = MagicMock()
response.data = []
for values in int8_vectors:
item = MagicMock()
item.embedding = _encode_int8(values)
response.data.append(item)
return response
def test_embeddings_initialization() -> None:
embeddings = PerplexityEmbeddings(pplx_api_key="test")
assert embeddings.pplx_api_key is not None
assert embeddings.pplx_api_key.get_secret_value() == "test"
assert embeddings.model == "pplx-embed-v1-4b"
assert embeddings.client is not None
assert embeddings.async_client is not None
def test_embeddings_custom_model() -> None:
embeddings = PerplexityEmbeddings(pplx_api_key="test", model="custom-model")
assert embeddings.model == "custom-model"
def test_api_key_alias() -> None:
"""`api_key=` should be accepted via populate_by_name alias."""
embeddings = PerplexityEmbeddings(api_key="aliased")
assert embeddings.pplx_api_key is not None
assert embeddings.pplx_api_key.get_secret_value() == "aliased"
def test_api_key_accepts_secret_str() -> None:
embeddings = PerplexityEmbeddings(pplx_api_key=SecretStr("typed"))
assert embeddings.pplx_api_key is not None
assert embeddings.pplx_api_key.get_secret_value() == "typed"
def test_lc_secrets() -> None:
embeddings = PerplexityEmbeddings(pplx_api_key="test")
assert embeddings.lc_secrets == {"pplx_api_key": "PPLX_API_KEY"}
def test_pplx_api_key_env_fallback(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.delenv("PERPLEXITY_API_KEY", raising=False)
monkeypatch.setenv("PPLX_API_KEY", "from_pplx_env")
embeddings = PerplexityEmbeddings()
assert embeddings.pplx_api_key is not None
assert embeddings.pplx_api_key.get_secret_value() == "from_pplx_env"
def test_perplexity_api_key_env_fallback(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.delenv("PPLX_API_KEY", raising=False)
monkeypatch.setenv("PERPLEXITY_API_KEY", "from_perp_env")
embeddings = PerplexityEmbeddings()
assert embeddings.pplx_api_key is not None
assert embeddings.pplx_api_key.get_secret_value() == "from_perp_env"
def test_pplx_takes_precedence_over_perplexity(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setenv("PPLX_API_KEY", "primary")
monkeypatch.setenv("PERPLEXITY_API_KEY", "secondary")
embeddings = PerplexityEmbeddings()
assert embeddings.pplx_api_key is not None
assert embeddings.pplx_api_key.get_secret_value() == "primary"
def test_explicit_kwarg_overrides_env(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("PPLX_API_KEY", "from_env")
embeddings = PerplexityEmbeddings(pplx_api_key="explicit")
assert embeddings.pplx_api_key is not None
assert embeddings.pplx_api_key.get_secret_value() == "explicit"
def test_missing_api_key_raises(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.delenv("PPLX_API_KEY", raising=False)
monkeypatch.delenv("PERPLEXITY_API_KEY", raising=False)
with pytest.raises(ValueError, match="Perplexity API key not provided"):
PerplexityEmbeddings()
def test_embed_documents() -> None:
mock_client = MagicMock()
mock_client.embeddings.create.return_value = _make_response(
[[1, -2, 3], [4, 5, -6]]
)
embeddings = PerplexityEmbeddings(pplx_api_key="test", client=mock_client)
result = embeddings.embed_documents(["hello", "world"])
assert result == [[1.0, -2.0, 3.0], [4.0, 5.0, -6.0]]
mock_client.embeddings.create.assert_called_once_with(
model="pplx-embed-v1-4b", input=["hello", "world"]
)
def test_embed_documents_empty_short_circuits() -> None:
mock_client = MagicMock()
embeddings = PerplexityEmbeddings(pplx_api_key="test", client=mock_client)
assert embeddings.embed_documents([]) == []
mock_client.embeddings.create.assert_not_called()
def test_embed_documents_propagates_errors() -> None:
mock_client = MagicMock()
mock_client.embeddings.create.side_effect = RuntimeError("boom")
embeddings = PerplexityEmbeddings(pplx_api_key="test", client=mock_client)
with pytest.raises(RuntimeError, match="boom"):
embeddings.embed_documents(["x"])
def test_embed_query() -> None:
mock_client = MagicMock()
mock_client.embeddings.create.return_value = _make_response([[7, 8, 9]])
embeddings = PerplexityEmbeddings(pplx_api_key="test", client=mock_client)
result = embeddings.embed_query("hello")
assert result == [7.0, 8.0, 9.0]
mock_client.embeddings.create.assert_called_once_with(
model="pplx-embed-v1-4b", input=["hello"]
)
def test_embed_documents_uses_custom_model() -> None:
mock_client = MagicMock()
mock_client.embeddings.create.return_value = _make_response([[0]])
embeddings = PerplexityEmbeddings(
pplx_api_key="test", model="custom-model", client=mock_client
)
embeddings.embed_documents(["x"])
mock_client.embeddings.create.assert_called_once_with(
model="custom-model", input=["x"]
)
async def test_aembed_documents() -> None:
mock_async_client = MagicMock()
mock_async_client.embeddings.create = AsyncMock(
return_value=_make_response([[1, 2], [3, 4]])
)
embeddings = PerplexityEmbeddings(
pplx_api_key="test", async_client=mock_async_client
)
result = await embeddings.aembed_documents(["a", "b"])
assert result == [[1.0, 2.0], [3.0, 4.0]]
mock_async_client.embeddings.create.assert_awaited_once_with(
model="pplx-embed-v1-4b", input=["a", "b"]
)
async def test_aembed_documents_empty_short_circuits() -> None:
mock_async_client = MagicMock()
mock_async_client.embeddings.create = AsyncMock()
embeddings = PerplexityEmbeddings(
pplx_api_key="test", async_client=mock_async_client
)
assert await embeddings.aembed_documents([]) == []
mock_async_client.embeddings.create.assert_not_awaited()
async def test_aembed_query() -> None:
mock_async_client = MagicMock()
mock_async_client.embeddings.create = AsyncMock(
return_value=_make_response([[5, 6]])
)
embeddings = PerplexityEmbeddings(
pplx_api_key="test", async_client=mock_async_client
)
result = await embeddings.aembed_query("hi")
assert result == [5.0, 6.0]
mock_async_client.embeddings.create.assert_awaited_once_with(
model="pplx-embed-v1-4b", input=["hi"]
)

View File

@@ -0,0 +1,20 @@
"""Standard unit tests for `PerplexityEmbeddings`."""
from langchain_core.embeddings import Embeddings
from langchain_tests.unit_tests import EmbeddingsUnitTests
from langchain_perplexity import PerplexityEmbeddings
class TestPerplexityEmbeddingsStandard(EmbeddingsUnitTests):
@property
def embeddings_class(self) -> type[Embeddings]:
return PerplexityEmbeddings
@property
def embedding_model_params(self) -> dict:
return {"pplx_api_key": "test"}
@property
def init_from_env_params(self) -> tuple[dict, dict, dict]:
return ({"PPLX_API_KEY": "api_key"}, {}, {"pplx_api_key": "api_key"})

View File

@@ -2,6 +2,7 @@ from langchain_perplexity import __all__
EXPECTED_ALL = [
"ChatPerplexity",
"PerplexityEmbeddings",
"PerplexitySearchRetriever",
"PerplexitySearchResults",
"UserLocation",