From 28f5448dd467bb7c88d39f627134b8448e27b477 Mon Sep 17 00:00:00 2001 From: James Liounis Date: Wed, 29 Apr 2026 17:51:50 -0400 Subject: [PATCH] feat(perplexity): add `PerplexityEmbeddings` (#37082) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Description This PR adds a new `PerplexityEmbeddings` class to the `langchain-perplexity` partner package, providing first-class support for the Perplexity Embeddings API alongside the existing `ChatPerplexity`, `PerplexitySearchRetriever`, and `PerplexitySearchResults` integrations. ### What was added - `langchain_perplexity/embeddings.py` — `PerplexityEmbeddings` class implementing `langchain_core.embeddings.Embeddings` with sync (`embed_documents`, `embed_query`) and async (`aembed_documents`, `aembed_query`) methods. Defaults to model `pplx-embed-v1-4b` and reuses the existing `_utils.initialize_client` helper for API key resolution (`PPLX_API_KEY` / `PERPLEXITY_API_KEY`). - `__init__.py` exports `PerplexityEmbeddings` and adds it to `__all__`. - Unit tests under `tests/unit_tests/test_embeddings.py` covering sync/async paths with mocked clients (no network). - Integration tests under `tests/integration_tests/test_embeddings.py`, gated on `PPLX_API_KEY` (matches the pattern in `test_search_api.py`). - README updated to advertise the new component. ### Why LangChain users already get chat, search, and tool wrappers from `langchain-perplexity`, but had to drop down to the raw Perplexity SDK to use embeddings. This closes that gap. ### References - Perplexity Embeddings docs: https://docs.perplexity.ai/docs/embeddings - Perplexity Embeddings API reference: https://docs.perplexity.ai/api-reference/embeddings-post ### Issue Closes #36726 ## Testing - `cd libs/partners/perplexity && make lint` — passes (ruff, format, mypy). - `cd libs/partners/perplexity && make test` — all unit tests pass (59 passed, 1 skipped). - Integration tests will run in CI with secrets; they exercise real `embed_documents` / `embed_query` / async variants against the live API and assert vector dimensionality consistency. --------- Co-authored-by: Claude Agent Co-authored-by: Mason Daugherty --- libs/partners/perplexity/Makefile | 3 +- .../langchain_perplexity/__init__.py | 2 + .../langchain_perplexity/chat_models.py | 87 +++++--- .../langchain_perplexity/embeddings.py | 184 ++++++++++++++++ libs/partners/perplexity/pyproject.toml | 2 +- .../integration_tests/test_embeddings.py | 56 +++++ .../test_embeddings_standard.py | 23 ++ .../tests/unit_tests/test_embeddings.py | 203 ++++++++++++++++++ .../unit_tests/test_embeddings_standard.py | 20 ++ .../tests/unit_tests/test_imports.py | 1 + libs/partners/perplexity/uv.lock | 10 +- 11 files changed, 550 insertions(+), 41 deletions(-) create mode 100644 libs/partners/perplexity/langchain_perplexity/embeddings.py create mode 100644 libs/partners/perplexity/tests/integration_tests/test_embeddings.py create mode 100644 libs/partners/perplexity/tests/integration_tests/test_embeddings_standard.py create mode 100644 libs/partners/perplexity/tests/unit_tests/test_embeddings.py create mode 100644 libs/partners/perplexity/tests/unit_tests/test_embeddings_standard.py diff --git a/libs/partners/perplexity/Makefile b/libs/partners/perplexity/Makefile index ca7b025aacc..f0dcbbc9cae 100644 --- a/libs/partners/perplexity/Makefile +++ b/libs/partners/perplexity/Makefile @@ -19,7 +19,8 @@ test_watch: uv run --group test ptw --snapshot-update --now . -- -vv $(TEST_FILE) integration_test integration_tests: - uv run --group test --group test_integration pytest -v --tb=short -n auto $(TEST_FILE) + uv run --group test --group test_integration pytest -v --tb=short -n 4 \ + --retries 3 --retry-delay 5 $(TEST_FILE) ###################### # LINTING AND FORMATTING diff --git a/libs/partners/perplexity/langchain_perplexity/__init__.py b/libs/partners/perplexity/langchain_perplexity/__init__.py index 5db46f6bc40..22447cfde40 100644 --- a/libs/partners/perplexity/langchain_perplexity/__init__.py +++ b/libs/partners/perplexity/langchain_perplexity/__init__.py @@ -1,6 +1,7 @@ """Perplexity AI integration for LangChain.""" from langchain_perplexity.chat_models import ChatPerplexity +from langchain_perplexity.embeddings import PerplexityEmbeddings from langchain_perplexity.output_parsers import ( ReasoningJsonOutputParser, ReasoningStructuredOutputParser, @@ -17,6 +18,7 @@ from langchain_perplexity.types import ( __all__ = [ "ChatPerplexity", + "PerplexityEmbeddings", "PerplexitySearchRetriever", "PerplexitySearchResults", "UserLocation", diff --git a/libs/partners/perplexity/langchain_perplexity/chat_models.py b/libs/partners/perplexity/langchain_perplexity/chat_models.py index 99e27ede01d..ef9970b359b 100644 --- a/libs/partners/perplexity/langchain_perplexity/chat_models.py +++ b/libs/partners/perplexity/langchain_perplexity/chat_models.py @@ -297,11 +297,18 @@ class ChatPerplexity(BaseChatModel): self.pplx_api_key.get_secret_value() if self.pplx_api_key else None ) + client_params: dict[str, Any] = { + "api_key": pplx_api_key, + "max_retries": self.max_retries, + } + if self.request_timeout is not None: + client_params["timeout"] = self.request_timeout + if not self.client: - self.client = Perplexity(api_key=pplx_api_key) + self.client = Perplexity(**client_params) if not self.async_client: - self.async_client = AsyncPerplexity(api_key=pplx_api_key) + self.async_client = AsyncPerplexity(**client_params) return self @@ -445,9 +452,30 @@ class ChatPerplexity(BaseChatModel): prev_total_usage = lc_total_usage else: usage_metadata = None - if len(chunk["choices"]) == 0: + generation_info = {} + if (model_name := chunk.get("model")) and not added_model_name: + generation_info["model_name"] = model_name + added_model_name = True + if total_usage := chunk.get("usage"): + if num_search_queries := total_usage.get("num_search_queries"): + if not added_search_queries: + generation_info["num_search_queries"] = num_search_queries + added_search_queries = True + if not added_search_context_size: + if search_context_size := total_usage.get("search_context_size"): + generation_info["search_context_size"] = search_context_size + added_search_context_size = True + + choices = chunk.get("choices") or [] + if len(choices) == 0: + # Usage-only or otherwise empty chunk: still yield so the stream + # is never empty and downstream callers receive usage metadata. + message = AIMessageChunk(content="", usage_metadata=usage_metadata) + yield ChatGenerationChunk( + message=message, generation_info=generation_info or None + ) continue - choice = chunk["choices"][0] + choice = choices[0] additional_kwargs = {} if first_chunk: @@ -462,21 +490,6 @@ class ChatPerplexity(BaseChatModel): if chunk.get("reasoning_steps"): additional_kwargs["reasoning_steps"] = chunk["reasoning_steps"] - generation_info = {} - if (model_name := chunk.get("model")) and not added_model_name: - generation_info["model_name"] = model_name - added_model_name = True - # Add num_search_queries to generation_info if present - if total_usage := chunk.get("usage"): - if num_search_queries := total_usage.get("num_search_queries"): - if not added_search_queries: - generation_info["num_search_queries"] = num_search_queries - added_search_queries = True - if not added_search_context_size: - if search_context_size := total_usage.get("search_context_size"): - generation_info["search_context_size"] = search_context_size - added_search_context_size = True - chunk = self._convert_delta_to_message_chunk( choice["delta"], default_chunk_class ) @@ -532,9 +545,28 @@ class ChatPerplexity(BaseChatModel): prev_total_usage = lc_total_usage else: usage_metadata = None - if len(chunk["choices"]) == 0: + generation_info = {} + if (model_name := chunk.get("model")) and not added_model_name: + generation_info["model_name"] = model_name + added_model_name = True + if total_usage := chunk.get("usage"): + if num_search_queries := total_usage.get("num_search_queries"): + if not added_search_queries: + generation_info["num_search_queries"] = num_search_queries + added_search_queries = True + if search_context_size := total_usage.get("search_context_size"): + generation_info["search_context_size"] = search_context_size + + choices = chunk.get("choices") or [] + if len(choices) == 0: + # Usage-only or otherwise empty chunk: still yield so the stream + # is never empty and downstream callers receive usage metadata. + message = AIMessageChunk(content="", usage_metadata=usage_metadata) + yield ChatGenerationChunk( + message=message, generation_info=generation_info or None + ) continue - choice = chunk["choices"][0] + choice = choices[0] additional_kwargs = {} if first_chunk: @@ -549,19 +581,6 @@ class ChatPerplexity(BaseChatModel): if chunk.get("reasoning_steps"): additional_kwargs["reasoning_steps"] = chunk["reasoning_steps"] - generation_info = {} - if (model_name := chunk.get("model")) and not added_model_name: - generation_info["model_name"] = model_name - added_model_name = True - - if total_usage := chunk.get("usage"): - if num_search_queries := total_usage.get("num_search_queries"): - if not added_search_queries: - generation_info["num_search_queries"] = num_search_queries - added_search_queries = True - if search_context_size := total_usage.get("search_context_size"): - generation_info["search_context_size"] = search_context_size - chunk = self._convert_delta_to_message_chunk( choice["delta"], default_chunk_class ) diff --git a/libs/partners/perplexity/langchain_perplexity/embeddings.py b/libs/partners/perplexity/langchain_perplexity/embeddings.py new file mode 100644 index 00000000000..8ee2dde7d1c --- /dev/null +++ b/libs/partners/perplexity/langchain_perplexity/embeddings.py @@ -0,0 +1,184 @@ +"""Wrapper around Perplexity Embeddings API.""" + +from __future__ import annotations + +import base64 +import struct +from typing import Any + +from langchain_core.embeddings import Embeddings +from langchain_core.utils import secret_from_env +from perplexity import AsyncPerplexity, Perplexity +from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator +from typing_extensions import Self + + +def _decode_int8_embedding(b64: str) -> list[float]: + """Decode a `base64_int8`-encoded Perplexity embedding into a list of floats.""" + raw = base64.b64decode(b64) + return [float(v) for v in struct.unpack(f"<{len(raw)}b", raw)] + + +class PerplexityEmbeddings(BaseModel, Embeddings): + """`Perplexity AI` embeddings. + + Setup: + Install the `perplexityai` package and set the `PPLX_API_KEY` + (or `PERPLEXITY_API_KEY`) environment variable, or pass the key as + the `pplx_api_key`/`api_key` argument. + + ```bash + pip install -U langchain-perplexity + export PPLX_API_KEY=your_api_key + ``` + + See the Perplexity Embeddings API reference: + https://docs.perplexity.ai/api-reference/embeddings-post + + Instantiate: + + ```python + from langchain_perplexity import PerplexityEmbeddings + + embeddings = PerplexityEmbeddings() + ``` + + Embed a single query: + + ```python + query_vector = embeddings.embed_query("hello world") + ``` + + Embed documents: + + ```python + doc_vectors = embeddings.embed_documents(["hello", "world"]) + ``` + + Select a specific model: + + ```python + embeddings = PerplexityEmbeddings(model="pplx-embed-v1-0.6b") + ``` + + !!! note + Perplexity returns base64-encoded signed int8 embeddings. This class + decodes them into `list[float]` values in the range [-128, 127]. The + magnitude is preserved from the API's quantized output; cosine + similarity is unaffected by the lack of unit-length normalization. + """ + + client: Any = Field(default=None, exclude=True) + """Perplexity SDK client (set automatically).""" + + async_client: Any = Field(default=None, exclude=True) + """Async Perplexity SDK client (set automatically).""" + + model: str = "pplx-embed-v1-4b" + """Name of the Perplexity embedding model to use. + + See the API reference for available identifiers, including + `pplx-embed-v1-0.6b` and `pplx-embed-v1-4b`. Contextualized variants are + served through a separate endpoint and are not exposed by this class. + """ + + pplx_api_key: SecretStr | None = Field( + default_factory=secret_from_env( + ["PPLX_API_KEY", "PERPLEXITY_API_KEY"], default=None + ), + alias="api_key", + ) + """Perplexity API key. Reads from `PPLX_API_KEY` or `PERPLEXITY_API_KEY`.""" + + request_timeout: float | tuple[float, float] | None = Field(None, alias="timeout") + """Timeout for requests to the Perplexity embeddings API.""" + + max_retries: int = 6 + """Maximum number of retries to make when calling the embeddings API.""" + + model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True) + + @property + def lc_secrets(self) -> dict[str, str]: + """Map secret field names to their environment variable names.""" + return {"pplx_api_key": "PPLX_API_KEY"} + + @model_validator(mode="after") + def validate_environment(self) -> Self: + """Initialize the Perplexity SDK clients.""" + if not self.pplx_api_key: + msg = ( + "Perplexity API key not provided. Pass `pplx_api_key` (or " + "`api_key`) to PerplexityEmbeddings, or set the `PPLX_API_KEY` " + "or `PERPLEXITY_API_KEY` environment variable." + ) + raise ValueError(msg) + + api_key = self.pplx_api_key.get_secret_value() + client_params: dict[str, Any] = { + "api_key": api_key, + "max_retries": self.max_retries, + } + if self.request_timeout is not None: + client_params["timeout"] = self.request_timeout + + if self.client is None: + self.client = Perplexity(**client_params) + if self.async_client is None: + self.async_client = AsyncPerplexity(**client_params) + return self + + def embed_documents(self, texts: list[str]) -> list[list[float]]: + """Embed a list of documents using the Perplexity embeddings API. + + Args: + texts: The list of texts to embed. + + Returns: + A list of embeddings, one per input text. An empty list is returned + when `texts` is empty. + """ + if not texts: + return [] + response = self.client.embeddings.create(model=self.model, input=texts) + return [_decode_int8_embedding(item.embedding) for item in response.data] + + def embed_query(self, text: str) -> list[float]: + """Embed a single query string using the Perplexity embeddings API. + + Args: + text: The text to embed. + + Returns: + The embedding vector for the input text. + """ + return self.embed_documents([text])[0] + + async def aembed_documents(self, texts: list[str]) -> list[list[float]]: + """Asynchronously embed a list of documents. + + Args: + texts: The list of texts to embed. + + Returns: + A list of embeddings, one per input text. An empty list is returned + when `texts` is empty. + """ + if not texts: + return [] + response = await self.async_client.embeddings.create( + model=self.model, input=texts + ) + return [_decode_int8_embedding(item.embedding) for item in response.data] + + async def aembed_query(self, text: str) -> list[float]: + """Asynchronously embed a single query string. + + Args: + text: The text to embed. + + Returns: + The embedding vector for the input text. + """ + result = await self.aembed_documents([text]) + return result[0] diff --git a/libs/partners/perplexity/pyproject.toml b/libs/partners/perplexity/pyproject.toml index 85e07406797..9304a9e2047 100644 --- a/libs/partners/perplexity/pyproject.toml +++ b/libs/partners/perplexity/pyproject.toml @@ -24,7 +24,7 @@ version = "1.1.0" requires-python = ">=3.10.0,<4.0.0" dependencies = [ "langchain-core>=1.3.2,<2.0.0", - "perplexityai>=0.22.0", + "perplexityai>=0.32.0,<1.0.0", ] [project.urls] diff --git a/libs/partners/perplexity/tests/integration_tests/test_embeddings.py b/libs/partners/perplexity/tests/integration_tests/test_embeddings.py new file mode 100644 index 00000000000..9423f1eafe5 --- /dev/null +++ b/libs/partners/perplexity/tests/integration_tests/test_embeddings.py @@ -0,0 +1,56 @@ +"""Integration tests for Perplexity Embeddings API.""" + +import os + +import pytest + +from langchain_perplexity import PerplexityEmbeddings + + +@pytest.mark.skipif( + not (os.environ.get("PPLX_API_KEY") or os.environ.get("PERPLEXITY_API_KEY")), + reason="PPLX_API_KEY/PERPLEXITY_API_KEY not set", +) +class TestPerplexityEmbeddings: + def test_embed_documents(self) -> None: + """Test embedding a list of documents.""" + embeddings = PerplexityEmbeddings() + texts = ["hello world", "goodbye world"] + vectors = embeddings.embed_documents(texts) + + assert len(vectors) == len(texts) + assert all(isinstance(v, list) for v in vectors) + assert all(len(v) > 0 for v in vectors) + # All vectors should have the same dimensionality. + assert len({len(v) for v in vectors}) == 1 + assert all(isinstance(x, float) for x in vectors[0]) + + def test_embed_query(self) -> None: + """Test embedding a single query.""" + embeddings = PerplexityEmbeddings() + vector = embeddings.embed_query("What is the capital of France?") + + assert isinstance(vector, list) + assert len(vector) > 0 + assert all(isinstance(x, float) for x in vector) + + def test_embed_query_matches_documents_dim(self) -> None: + """Embeddings from query and documents should share dimensionality.""" + embeddings = PerplexityEmbeddings() + query_vec = embeddings.embed_query("hello") + doc_vecs = embeddings.embed_documents(["hello"]) + assert len(query_vec) == len(doc_vecs[0]) + + async def test_aembed_documents(self) -> None: + """Test async embedding a list of documents.""" + embeddings = PerplexityEmbeddings() + vectors = await embeddings.aembed_documents(["hello", "world"]) + assert len(vectors) == 2 + assert all(len(v) > 0 for v in vectors) + + async def test_aembed_query(self) -> None: + """Test async embedding a single query.""" + embeddings = PerplexityEmbeddings() + vector = await embeddings.aembed_query("hello") + assert isinstance(vector, list) + assert len(vector) > 0 diff --git a/libs/partners/perplexity/tests/integration_tests/test_embeddings_standard.py b/libs/partners/perplexity/tests/integration_tests/test_embeddings_standard.py new file mode 100644 index 00000000000..c243b5f4f6e --- /dev/null +++ b/libs/partners/perplexity/tests/integration_tests/test_embeddings_standard.py @@ -0,0 +1,23 @@ +"""Standard integration tests for `PerplexityEmbeddings`.""" + +import os + +import pytest +from langchain_core.embeddings import Embeddings +from langchain_tests.integration_tests import EmbeddingsIntegrationTests + +from langchain_perplexity import PerplexityEmbeddings + + +@pytest.mark.skipif( + not (os.environ.get("PPLX_API_KEY") or os.environ.get("PERPLEXITY_API_KEY")), + reason="PPLX_API_KEY/PERPLEXITY_API_KEY not set", +) +class TestPerplexityEmbeddingsIntegration(EmbeddingsIntegrationTests): + @property + def embeddings_class(self) -> type[Embeddings]: + return PerplexityEmbeddings + + @property + def embedding_model_params(self) -> dict: + return {} diff --git a/libs/partners/perplexity/tests/unit_tests/test_embeddings.py b/libs/partners/perplexity/tests/unit_tests/test_embeddings.py new file mode 100644 index 00000000000..eec9725efc2 --- /dev/null +++ b/libs/partners/perplexity/tests/unit_tests/test_embeddings.py @@ -0,0 +1,203 @@ +"""Unit tests for `PerplexityEmbeddings`.""" + +import base64 +import struct +from unittest.mock import AsyncMock, MagicMock + +import pytest +from pydantic import SecretStr + +from langchain_perplexity import PerplexityEmbeddings + + +def _encode_int8(values: list[int]) -> str: + """Encode signed int8 values as base64 (matches Perplexity's wire format).""" + raw = struct.pack(f"<{len(values)}b", *values) + return base64.b64encode(raw).decode("ascii") + + +def _make_response(int8_vectors: list[list[int]]) -> MagicMock: + """Build a stand-in for `EmbeddingCreateResponse` with base64_int8 payloads.""" + response = MagicMock() + response.data = [] + for values in int8_vectors: + item = MagicMock() + item.embedding = _encode_int8(values) + response.data.append(item) + return response + + +def test_embeddings_initialization() -> None: + embeddings = PerplexityEmbeddings(pplx_api_key="test") + assert embeddings.pplx_api_key is not None + assert embeddings.pplx_api_key.get_secret_value() == "test" + assert embeddings.model == "pplx-embed-v1-4b" + assert embeddings.client is not None + assert embeddings.async_client is not None + + +def test_embeddings_custom_model() -> None: + embeddings = PerplexityEmbeddings(pplx_api_key="test", model="custom-model") + assert embeddings.model == "custom-model" + + +def test_api_key_alias() -> None: + """`api_key=` should be accepted via populate_by_name alias.""" + embeddings = PerplexityEmbeddings(api_key="aliased") + assert embeddings.pplx_api_key is not None + assert embeddings.pplx_api_key.get_secret_value() == "aliased" + + +def test_api_key_accepts_secret_str() -> None: + embeddings = PerplexityEmbeddings(pplx_api_key=SecretStr("typed")) + assert embeddings.pplx_api_key is not None + assert embeddings.pplx_api_key.get_secret_value() == "typed" + + +def test_lc_secrets() -> None: + embeddings = PerplexityEmbeddings(pplx_api_key="test") + assert embeddings.lc_secrets == {"pplx_api_key": "PPLX_API_KEY"} + + +def test_pplx_api_key_env_fallback(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("PERPLEXITY_API_KEY", raising=False) + monkeypatch.setenv("PPLX_API_KEY", "from_pplx_env") + embeddings = PerplexityEmbeddings() + assert embeddings.pplx_api_key is not None + assert embeddings.pplx_api_key.get_secret_value() == "from_pplx_env" + + +def test_perplexity_api_key_env_fallback(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("PPLX_API_KEY", raising=False) + monkeypatch.setenv("PERPLEXITY_API_KEY", "from_perp_env") + embeddings = PerplexityEmbeddings() + assert embeddings.pplx_api_key is not None + assert embeddings.pplx_api_key.get_secret_value() == "from_perp_env" + + +def test_pplx_takes_precedence_over_perplexity( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setenv("PPLX_API_KEY", "primary") + monkeypatch.setenv("PERPLEXITY_API_KEY", "secondary") + embeddings = PerplexityEmbeddings() + assert embeddings.pplx_api_key is not None + assert embeddings.pplx_api_key.get_secret_value() == "primary" + + +def test_explicit_kwarg_overrides_env(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("PPLX_API_KEY", "from_env") + embeddings = PerplexityEmbeddings(pplx_api_key="explicit") + assert embeddings.pplx_api_key is not None + assert embeddings.pplx_api_key.get_secret_value() == "explicit" + + +def test_missing_api_key_raises(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("PPLX_API_KEY", raising=False) + monkeypatch.delenv("PERPLEXITY_API_KEY", raising=False) + with pytest.raises(ValueError, match="Perplexity API key not provided"): + PerplexityEmbeddings() + + +def test_embed_documents() -> None: + mock_client = MagicMock() + mock_client.embeddings.create.return_value = _make_response( + [[1, -2, 3], [4, 5, -6]] + ) + embeddings = PerplexityEmbeddings(pplx_api_key="test", client=mock_client) + + result = embeddings.embed_documents(["hello", "world"]) + + assert result == [[1.0, -2.0, 3.0], [4.0, 5.0, -6.0]] + mock_client.embeddings.create.assert_called_once_with( + model="pplx-embed-v1-4b", input=["hello", "world"] + ) + + +def test_embed_documents_empty_short_circuits() -> None: + mock_client = MagicMock() + embeddings = PerplexityEmbeddings(pplx_api_key="test", client=mock_client) + + assert embeddings.embed_documents([]) == [] + mock_client.embeddings.create.assert_not_called() + + +def test_embed_documents_propagates_errors() -> None: + mock_client = MagicMock() + mock_client.embeddings.create.side_effect = RuntimeError("boom") + embeddings = PerplexityEmbeddings(pplx_api_key="test", client=mock_client) + + with pytest.raises(RuntimeError, match="boom"): + embeddings.embed_documents(["x"]) + + +def test_embed_query() -> None: + mock_client = MagicMock() + mock_client.embeddings.create.return_value = _make_response([[7, 8, 9]]) + embeddings = PerplexityEmbeddings(pplx_api_key="test", client=mock_client) + + result = embeddings.embed_query("hello") + + assert result == [7.0, 8.0, 9.0] + mock_client.embeddings.create.assert_called_once_with( + model="pplx-embed-v1-4b", input=["hello"] + ) + + +def test_embed_documents_uses_custom_model() -> None: + mock_client = MagicMock() + mock_client.embeddings.create.return_value = _make_response([[0]]) + embeddings = PerplexityEmbeddings( + pplx_api_key="test", model="custom-model", client=mock_client + ) + + embeddings.embed_documents(["x"]) + + mock_client.embeddings.create.assert_called_once_with( + model="custom-model", input=["x"] + ) + + +async def test_aembed_documents() -> None: + mock_async_client = MagicMock() + mock_async_client.embeddings.create = AsyncMock( + return_value=_make_response([[1, 2], [3, 4]]) + ) + embeddings = PerplexityEmbeddings( + pplx_api_key="test", async_client=mock_async_client + ) + + result = await embeddings.aembed_documents(["a", "b"]) + + assert result == [[1.0, 2.0], [3.0, 4.0]] + mock_async_client.embeddings.create.assert_awaited_once_with( + model="pplx-embed-v1-4b", input=["a", "b"] + ) + + +async def test_aembed_documents_empty_short_circuits() -> None: + mock_async_client = MagicMock() + mock_async_client.embeddings.create = AsyncMock() + embeddings = PerplexityEmbeddings( + pplx_api_key="test", async_client=mock_async_client + ) + + assert await embeddings.aembed_documents([]) == [] + mock_async_client.embeddings.create.assert_not_awaited() + + +async def test_aembed_query() -> None: + mock_async_client = MagicMock() + mock_async_client.embeddings.create = AsyncMock( + return_value=_make_response([[5, 6]]) + ) + embeddings = PerplexityEmbeddings( + pplx_api_key="test", async_client=mock_async_client + ) + + result = await embeddings.aembed_query("hi") + + assert result == [5.0, 6.0] + mock_async_client.embeddings.create.assert_awaited_once_with( + model="pplx-embed-v1-4b", input=["hi"] + ) diff --git a/libs/partners/perplexity/tests/unit_tests/test_embeddings_standard.py b/libs/partners/perplexity/tests/unit_tests/test_embeddings_standard.py new file mode 100644 index 00000000000..ba66e954064 --- /dev/null +++ b/libs/partners/perplexity/tests/unit_tests/test_embeddings_standard.py @@ -0,0 +1,20 @@ +"""Standard unit tests for `PerplexityEmbeddings`.""" + +from langchain_core.embeddings import Embeddings +from langchain_tests.unit_tests import EmbeddingsUnitTests + +from langchain_perplexity import PerplexityEmbeddings + + +class TestPerplexityEmbeddingsStandard(EmbeddingsUnitTests): + @property + def embeddings_class(self) -> type[Embeddings]: + return PerplexityEmbeddings + + @property + def embedding_model_params(self) -> dict: + return {"pplx_api_key": "test"} + + @property + def init_from_env_params(self) -> tuple[dict, dict, dict]: + return ({"PPLX_API_KEY": "api_key"}, {}, {"pplx_api_key": "api_key"}) diff --git a/libs/partners/perplexity/tests/unit_tests/test_imports.py b/libs/partners/perplexity/tests/unit_tests/test_imports.py index 67793501c22..57168f50e48 100644 --- a/libs/partners/perplexity/tests/unit_tests/test_imports.py +++ b/libs/partners/perplexity/tests/unit_tests/test_imports.py @@ -2,6 +2,7 @@ from langchain_perplexity import __all__ EXPECTED_ALL = [ "ChatPerplexity", + "PerplexityEmbeddings", "PerplexitySearchRetriever", "PerplexitySearchResults", "UserLocation", diff --git a/libs/partners/perplexity/uv.lock b/libs/partners/perplexity/uv.lock index 2b7a96185de..351d13dd8d3 100644 --- a/libs/partners/perplexity/uv.lock +++ b/libs/partners/perplexity/uv.lock @@ -537,7 +537,7 @@ typing = [ [package.metadata] requires-dist = [ { name = "langchain-core", editable = "../../core" }, - { name = "perplexityai", specifier = ">=0.22.0" }, + { name = "perplexityai", specifier = ">=0.32.0,<1.0.0" }, ] [package.metadata.requires-dev] @@ -581,7 +581,7 @@ wheels = [ [[package]] name = "langchain-tests" -version = "1.1.6" +version = "1.1.7" source = { editable = "../../standard-tests" } dependencies = [ { name = "httpx" }, @@ -973,7 +973,7 @@ wheels = [ [[package]] name = "perplexityai" -version = "0.22.2" +version = "0.32.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, @@ -983,9 +983,9 @@ dependencies = [ { name = "sniffio" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/47/f0/4ead48afbe4c9d7b57e9d03e253bed2986ac20d2257ed918cac276949018/perplexityai-0.22.2.tar.gz", hash = "sha256:9c3cad307c95aa5e8967358547e548d58793d350318d8d1d4aa33a933cbed844", size = 113014, upload-time = "2025-12-17T19:05:25.572Z" } +sdist = { url = "https://files.pythonhosted.org/packages/09/02/73f460c85a5ec533a97fd1ff34fa729a009b4a217a4a87d8da946b6e1c52/perplexityai-0.32.1.tar.gz", hash = "sha256:b03503498591d06c4d50b666f7f7469875d3586f664c29416aae9012ae7a64d1", size = 135741, upload-time = "2026-04-21T04:35:40.345Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/cb/aa/fc3337fdb014b1584297fc212552e6365d22a6fb77850a56c9038cd47173/perplexityai-0.22.2-py3-none-any.whl", hash = "sha256:92d3dc7f4e110c879ac5009daf7263a04f413523f7d76fba871176516c253890", size = 96860, upload-time = "2025-12-17T19:05:24.292Z" }, + { url = "https://files.pythonhosted.org/packages/d6/11/5c164f114311bc2e2350202393e7c5bd25bb156b5230a1edf5a2b2f4ba04/perplexityai-0.32.1-py3-none-any.whl", hash = "sha256:e5017d245fd8966cf79657edc03a93078d867708542b491b38152618f91e369b", size = 130223, upload-time = "2026-04-21T04:35:38.786Z" }, ] [[package]]