mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 18:50:33 +00:00
feat(perplexity): add PerplexityEmbeddings (#37082)
## Description This PR adds a new `PerplexityEmbeddings` class to the `langchain-perplexity` partner package, providing first-class support for the Perplexity Embeddings API alongside the existing `ChatPerplexity`, `PerplexitySearchRetriever`, and `PerplexitySearchResults` integrations. ### What was added - `langchain_perplexity/embeddings.py` — `PerplexityEmbeddings` class implementing `langchain_core.embeddings.Embeddings` with sync (`embed_documents`, `embed_query`) and async (`aembed_documents`, `aembed_query`) methods. Defaults to model `pplx-embed-v1-4b` and reuses the existing `_utils.initialize_client` helper for API key resolution (`PPLX_API_KEY` / `PERPLEXITY_API_KEY`). - `__init__.py` exports `PerplexityEmbeddings` and adds it to `__all__`. - Unit tests under `tests/unit_tests/test_embeddings.py` covering sync/async paths with mocked clients (no network). - Integration tests under `tests/integration_tests/test_embeddings.py`, gated on `PPLX_API_KEY` (matches the pattern in `test_search_api.py`). - README updated to advertise the new component. ### Why LangChain users already get chat, search, and tool wrappers from `langchain-perplexity`, but had to drop down to the raw Perplexity SDK to use embeddings. This closes that gap. ### References - Perplexity Embeddings docs: https://docs.perplexity.ai/docs/embeddings - Perplexity Embeddings API reference: https://docs.perplexity.ai/api-reference/embeddings-post ### Issue Closes #36726 ## Testing - `cd libs/partners/perplexity && make lint` — passes (ruff, format, mypy). - `cd libs/partners/perplexity && make test` — all unit tests pass (59 passed, 1 skipped). - Integration tests will run in CI with secrets; they exercise real `embed_documents` / `embed_query` / async variants against the live API and assert vector dimensionality consistency. --------- Co-authored-by: Claude Agent <agent@anthropic.com> Co-authored-by: Mason Daugherty <github@mdrxy.com>
This commit is contained in:
@@ -19,7 +19,8 @@ test_watch:
|
|||||||
uv run --group test ptw --snapshot-update --now . -- -vv $(TEST_FILE)
|
uv run --group test ptw --snapshot-update --now . -- -vv $(TEST_FILE)
|
||||||
|
|
||||||
integration_test integration_tests:
|
integration_test integration_tests:
|
||||||
uv run --group test --group test_integration pytest -v --tb=short -n auto $(TEST_FILE)
|
uv run --group test --group test_integration pytest -v --tb=short -n 4 \
|
||||||
|
--retries 3 --retry-delay 5 $(TEST_FILE)
|
||||||
|
|
||||||
######################
|
######################
|
||||||
# LINTING AND FORMATTING
|
# LINTING AND FORMATTING
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Perplexity AI integration for LangChain."""
|
"""Perplexity AI integration for LangChain."""
|
||||||
|
|
||||||
from langchain_perplexity.chat_models import ChatPerplexity
|
from langchain_perplexity.chat_models import ChatPerplexity
|
||||||
|
from langchain_perplexity.embeddings import PerplexityEmbeddings
|
||||||
from langchain_perplexity.output_parsers import (
|
from langchain_perplexity.output_parsers import (
|
||||||
ReasoningJsonOutputParser,
|
ReasoningJsonOutputParser,
|
||||||
ReasoningStructuredOutputParser,
|
ReasoningStructuredOutputParser,
|
||||||
@@ -17,6 +18,7 @@ from langchain_perplexity.types import (
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ChatPerplexity",
|
"ChatPerplexity",
|
||||||
|
"PerplexityEmbeddings",
|
||||||
"PerplexitySearchRetriever",
|
"PerplexitySearchRetriever",
|
||||||
"PerplexitySearchResults",
|
"PerplexitySearchResults",
|
||||||
"UserLocation",
|
"UserLocation",
|
||||||
|
|||||||
@@ -297,11 +297,18 @@ class ChatPerplexity(BaseChatModel):
|
|||||||
self.pplx_api_key.get_secret_value() if self.pplx_api_key else None
|
self.pplx_api_key.get_secret_value() if self.pplx_api_key else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
client_params: dict[str, Any] = {
|
||||||
|
"api_key": pplx_api_key,
|
||||||
|
"max_retries": self.max_retries,
|
||||||
|
}
|
||||||
|
if self.request_timeout is not None:
|
||||||
|
client_params["timeout"] = self.request_timeout
|
||||||
|
|
||||||
if not self.client:
|
if not self.client:
|
||||||
self.client = Perplexity(api_key=pplx_api_key)
|
self.client = Perplexity(**client_params)
|
||||||
|
|
||||||
if not self.async_client:
|
if not self.async_client:
|
||||||
self.async_client = AsyncPerplexity(api_key=pplx_api_key)
|
self.async_client = AsyncPerplexity(**client_params)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@@ -445,9 +452,30 @@ class ChatPerplexity(BaseChatModel):
|
|||||||
prev_total_usage = lc_total_usage
|
prev_total_usage = lc_total_usage
|
||||||
else:
|
else:
|
||||||
usage_metadata = None
|
usage_metadata = None
|
||||||
if len(chunk["choices"]) == 0:
|
generation_info = {}
|
||||||
|
if (model_name := chunk.get("model")) and not added_model_name:
|
||||||
|
generation_info["model_name"] = model_name
|
||||||
|
added_model_name = True
|
||||||
|
if total_usage := chunk.get("usage"):
|
||||||
|
if num_search_queries := total_usage.get("num_search_queries"):
|
||||||
|
if not added_search_queries:
|
||||||
|
generation_info["num_search_queries"] = num_search_queries
|
||||||
|
added_search_queries = True
|
||||||
|
if not added_search_context_size:
|
||||||
|
if search_context_size := total_usage.get("search_context_size"):
|
||||||
|
generation_info["search_context_size"] = search_context_size
|
||||||
|
added_search_context_size = True
|
||||||
|
|
||||||
|
choices = chunk.get("choices") or []
|
||||||
|
if len(choices) == 0:
|
||||||
|
# Usage-only or otherwise empty chunk: still yield so the stream
|
||||||
|
# is never empty and downstream callers receive usage metadata.
|
||||||
|
message = AIMessageChunk(content="", usage_metadata=usage_metadata)
|
||||||
|
yield ChatGenerationChunk(
|
||||||
|
message=message, generation_info=generation_info or None
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
choice = chunk["choices"][0]
|
choice = choices[0]
|
||||||
|
|
||||||
additional_kwargs = {}
|
additional_kwargs = {}
|
||||||
if first_chunk:
|
if first_chunk:
|
||||||
@@ -462,21 +490,6 @@ class ChatPerplexity(BaseChatModel):
|
|||||||
if chunk.get("reasoning_steps"):
|
if chunk.get("reasoning_steps"):
|
||||||
additional_kwargs["reasoning_steps"] = chunk["reasoning_steps"]
|
additional_kwargs["reasoning_steps"] = chunk["reasoning_steps"]
|
||||||
|
|
||||||
generation_info = {}
|
|
||||||
if (model_name := chunk.get("model")) and not added_model_name:
|
|
||||||
generation_info["model_name"] = model_name
|
|
||||||
added_model_name = True
|
|
||||||
# Add num_search_queries to generation_info if present
|
|
||||||
if total_usage := chunk.get("usage"):
|
|
||||||
if num_search_queries := total_usage.get("num_search_queries"):
|
|
||||||
if not added_search_queries:
|
|
||||||
generation_info["num_search_queries"] = num_search_queries
|
|
||||||
added_search_queries = True
|
|
||||||
if not added_search_context_size:
|
|
||||||
if search_context_size := total_usage.get("search_context_size"):
|
|
||||||
generation_info["search_context_size"] = search_context_size
|
|
||||||
added_search_context_size = True
|
|
||||||
|
|
||||||
chunk = self._convert_delta_to_message_chunk(
|
chunk = self._convert_delta_to_message_chunk(
|
||||||
choice["delta"], default_chunk_class
|
choice["delta"], default_chunk_class
|
||||||
)
|
)
|
||||||
@@ -532,9 +545,28 @@ class ChatPerplexity(BaseChatModel):
|
|||||||
prev_total_usage = lc_total_usage
|
prev_total_usage = lc_total_usage
|
||||||
else:
|
else:
|
||||||
usage_metadata = None
|
usage_metadata = None
|
||||||
if len(chunk["choices"]) == 0:
|
generation_info = {}
|
||||||
|
if (model_name := chunk.get("model")) and not added_model_name:
|
||||||
|
generation_info["model_name"] = model_name
|
||||||
|
added_model_name = True
|
||||||
|
if total_usage := chunk.get("usage"):
|
||||||
|
if num_search_queries := total_usage.get("num_search_queries"):
|
||||||
|
if not added_search_queries:
|
||||||
|
generation_info["num_search_queries"] = num_search_queries
|
||||||
|
added_search_queries = True
|
||||||
|
if search_context_size := total_usage.get("search_context_size"):
|
||||||
|
generation_info["search_context_size"] = search_context_size
|
||||||
|
|
||||||
|
choices = chunk.get("choices") or []
|
||||||
|
if len(choices) == 0:
|
||||||
|
# Usage-only or otherwise empty chunk: still yield so the stream
|
||||||
|
# is never empty and downstream callers receive usage metadata.
|
||||||
|
message = AIMessageChunk(content="", usage_metadata=usage_metadata)
|
||||||
|
yield ChatGenerationChunk(
|
||||||
|
message=message, generation_info=generation_info or None
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
choice = chunk["choices"][0]
|
choice = choices[0]
|
||||||
|
|
||||||
additional_kwargs = {}
|
additional_kwargs = {}
|
||||||
if first_chunk:
|
if first_chunk:
|
||||||
@@ -549,19 +581,6 @@ class ChatPerplexity(BaseChatModel):
|
|||||||
if chunk.get("reasoning_steps"):
|
if chunk.get("reasoning_steps"):
|
||||||
additional_kwargs["reasoning_steps"] = chunk["reasoning_steps"]
|
additional_kwargs["reasoning_steps"] = chunk["reasoning_steps"]
|
||||||
|
|
||||||
generation_info = {}
|
|
||||||
if (model_name := chunk.get("model")) and not added_model_name:
|
|
||||||
generation_info["model_name"] = model_name
|
|
||||||
added_model_name = True
|
|
||||||
|
|
||||||
if total_usage := chunk.get("usage"):
|
|
||||||
if num_search_queries := total_usage.get("num_search_queries"):
|
|
||||||
if not added_search_queries:
|
|
||||||
generation_info["num_search_queries"] = num_search_queries
|
|
||||||
added_search_queries = True
|
|
||||||
if search_context_size := total_usage.get("search_context_size"):
|
|
||||||
generation_info["search_context_size"] = search_context_size
|
|
||||||
|
|
||||||
chunk = self._convert_delta_to_message_chunk(
|
chunk = self._convert_delta_to_message_chunk(
|
||||||
choice["delta"], default_chunk_class
|
choice["delta"], default_chunk_class
|
||||||
)
|
)
|
||||||
|
|||||||
184
libs/partners/perplexity/langchain_perplexity/embeddings.py
Normal file
184
libs/partners/perplexity/langchain_perplexity/embeddings.py
Normal file
@@ -0,0 +1,184 @@
|
|||||||
|
"""Wrapper around Perplexity Embeddings API."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import struct
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.embeddings import Embeddings
|
||||||
|
from langchain_core.utils import secret_from_env
|
||||||
|
from perplexity import AsyncPerplexity, Perplexity
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_int8_embedding(b64: str) -> list[float]:
|
||||||
|
"""Decode a `base64_int8`-encoded Perplexity embedding into a list of floats."""
|
||||||
|
raw = base64.b64decode(b64)
|
||||||
|
return [float(v) for v in struct.unpack(f"<{len(raw)}b", raw)]
|
||||||
|
|
||||||
|
|
||||||
|
class PerplexityEmbeddings(BaseModel, Embeddings):
|
||||||
|
"""`Perplexity AI` embeddings.
|
||||||
|
|
||||||
|
Setup:
|
||||||
|
Install the `perplexityai` package and set the `PPLX_API_KEY`
|
||||||
|
(or `PERPLEXITY_API_KEY`) environment variable, or pass the key as
|
||||||
|
the `pplx_api_key`/`api_key` argument.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -U langchain-perplexity
|
||||||
|
export PPLX_API_KEY=your_api_key
|
||||||
|
```
|
||||||
|
|
||||||
|
See the Perplexity Embeddings API reference:
|
||||||
|
https://docs.perplexity.ai/api-reference/embeddings-post
|
||||||
|
|
||||||
|
Instantiate:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from langchain_perplexity import PerplexityEmbeddings
|
||||||
|
|
||||||
|
embeddings = PerplexityEmbeddings()
|
||||||
|
```
|
||||||
|
|
||||||
|
Embed a single query:
|
||||||
|
|
||||||
|
```python
|
||||||
|
query_vector = embeddings.embed_query("hello world")
|
||||||
|
```
|
||||||
|
|
||||||
|
Embed documents:
|
||||||
|
|
||||||
|
```python
|
||||||
|
doc_vectors = embeddings.embed_documents(["hello", "world"])
|
||||||
|
```
|
||||||
|
|
||||||
|
Select a specific model:
|
||||||
|
|
||||||
|
```python
|
||||||
|
embeddings = PerplexityEmbeddings(model="pplx-embed-v1-0.6b")
|
||||||
|
```
|
||||||
|
|
||||||
|
!!! note
|
||||||
|
Perplexity returns base64-encoded signed int8 embeddings. This class
|
||||||
|
decodes them into `list[float]` values in the range [-128, 127]. The
|
||||||
|
magnitude is preserved from the API's quantized output; cosine
|
||||||
|
similarity is unaffected by the lack of unit-length normalization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
client: Any = Field(default=None, exclude=True)
|
||||||
|
"""Perplexity SDK client (set automatically)."""
|
||||||
|
|
||||||
|
async_client: Any = Field(default=None, exclude=True)
|
||||||
|
"""Async Perplexity SDK client (set automatically)."""
|
||||||
|
|
||||||
|
model: str = "pplx-embed-v1-4b"
|
||||||
|
"""Name of the Perplexity embedding model to use.
|
||||||
|
|
||||||
|
See the API reference for available identifiers, including
|
||||||
|
`pplx-embed-v1-0.6b` and `pplx-embed-v1-4b`. Contextualized variants are
|
||||||
|
served through a separate endpoint and are not exposed by this class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pplx_api_key: SecretStr | None = Field(
|
||||||
|
default_factory=secret_from_env(
|
||||||
|
["PPLX_API_KEY", "PERPLEXITY_API_KEY"], default=None
|
||||||
|
),
|
||||||
|
alias="api_key",
|
||||||
|
)
|
||||||
|
"""Perplexity API key. Reads from `PPLX_API_KEY` or `PERPLEXITY_API_KEY`."""
|
||||||
|
|
||||||
|
request_timeout: float | tuple[float, float] | None = Field(None, alias="timeout")
|
||||||
|
"""Timeout for requests to the Perplexity embeddings API."""
|
||||||
|
|
||||||
|
max_retries: int = 6
|
||||||
|
"""Maximum number of retries to make when calling the embeddings API."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_secrets(self) -> dict[str, str]:
|
||||||
|
"""Map secret field names to their environment variable names."""
|
||||||
|
return {"pplx_api_key": "PPLX_API_KEY"}
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_environment(self) -> Self:
|
||||||
|
"""Initialize the Perplexity SDK clients."""
|
||||||
|
if not self.pplx_api_key:
|
||||||
|
msg = (
|
||||||
|
"Perplexity API key not provided. Pass `pplx_api_key` (or "
|
||||||
|
"`api_key`) to PerplexityEmbeddings, or set the `PPLX_API_KEY` "
|
||||||
|
"or `PERPLEXITY_API_KEY` environment variable."
|
||||||
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
api_key = self.pplx_api_key.get_secret_value()
|
||||||
|
client_params: dict[str, Any] = {
|
||||||
|
"api_key": api_key,
|
||||||
|
"max_retries": self.max_retries,
|
||||||
|
}
|
||||||
|
if self.request_timeout is not None:
|
||||||
|
client_params["timeout"] = self.request_timeout
|
||||||
|
|
||||||
|
if self.client is None:
|
||||||
|
self.client = Perplexity(**client_params)
|
||||||
|
if self.async_client is None:
|
||||||
|
self.async_client = AsyncPerplexity(**client_params)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||||
|
"""Embed a list of documents using the Perplexity embeddings API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: The list of texts to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of embeddings, one per input text. An empty list is returned
|
||||||
|
when `texts` is empty.
|
||||||
|
"""
|
||||||
|
if not texts:
|
||||||
|
return []
|
||||||
|
response = self.client.embeddings.create(model=self.model, input=texts)
|
||||||
|
return [_decode_int8_embedding(item.embedding) for item in response.data]
|
||||||
|
|
||||||
|
def embed_query(self, text: str) -> list[float]:
|
||||||
|
"""Embed a single query string using the Perplexity embeddings API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The embedding vector for the input text.
|
||||||
|
"""
|
||||||
|
return self.embed_documents([text])[0]
|
||||||
|
|
||||||
|
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||||
|
"""Asynchronously embed a list of documents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: The list of texts to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of embeddings, one per input text. An empty list is returned
|
||||||
|
when `texts` is empty.
|
||||||
|
"""
|
||||||
|
if not texts:
|
||||||
|
return []
|
||||||
|
response = await self.async_client.embeddings.create(
|
||||||
|
model=self.model, input=texts
|
||||||
|
)
|
||||||
|
return [_decode_int8_embedding(item.embedding) for item in response.data]
|
||||||
|
|
||||||
|
async def aembed_query(self, text: str) -> list[float]:
|
||||||
|
"""Asynchronously embed a single query string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The embedding vector for the input text.
|
||||||
|
"""
|
||||||
|
result = await self.aembed_documents([text])
|
||||||
|
return result[0]
|
||||||
@@ -24,7 +24,7 @@ version = "1.1.0"
|
|||||||
requires-python = ">=3.10.0,<4.0.0"
|
requires-python = ">=3.10.0,<4.0.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"langchain-core>=1.3.2,<2.0.0",
|
"langchain-core>=1.3.2,<2.0.0",
|
||||||
"perplexityai>=0.22.0",
|
"perplexityai>=0.32.0,<1.0.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
|
|||||||
@@ -0,0 +1,56 @@
|
|||||||
|
"""Integration tests for Perplexity Embeddings API."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain_perplexity import PerplexityEmbeddings
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not (os.environ.get("PPLX_API_KEY") or os.environ.get("PERPLEXITY_API_KEY")),
|
||||||
|
reason="PPLX_API_KEY/PERPLEXITY_API_KEY not set",
|
||||||
|
)
|
||||||
|
class TestPerplexityEmbeddings:
|
||||||
|
def test_embed_documents(self) -> None:
|
||||||
|
"""Test embedding a list of documents."""
|
||||||
|
embeddings = PerplexityEmbeddings()
|
||||||
|
texts = ["hello world", "goodbye world"]
|
||||||
|
vectors = embeddings.embed_documents(texts)
|
||||||
|
|
||||||
|
assert len(vectors) == len(texts)
|
||||||
|
assert all(isinstance(v, list) for v in vectors)
|
||||||
|
assert all(len(v) > 0 for v in vectors)
|
||||||
|
# All vectors should have the same dimensionality.
|
||||||
|
assert len({len(v) for v in vectors}) == 1
|
||||||
|
assert all(isinstance(x, float) for x in vectors[0])
|
||||||
|
|
||||||
|
def test_embed_query(self) -> None:
|
||||||
|
"""Test embedding a single query."""
|
||||||
|
embeddings = PerplexityEmbeddings()
|
||||||
|
vector = embeddings.embed_query("What is the capital of France?")
|
||||||
|
|
||||||
|
assert isinstance(vector, list)
|
||||||
|
assert len(vector) > 0
|
||||||
|
assert all(isinstance(x, float) for x in vector)
|
||||||
|
|
||||||
|
def test_embed_query_matches_documents_dim(self) -> None:
|
||||||
|
"""Embeddings from query and documents should share dimensionality."""
|
||||||
|
embeddings = PerplexityEmbeddings()
|
||||||
|
query_vec = embeddings.embed_query("hello")
|
||||||
|
doc_vecs = embeddings.embed_documents(["hello"])
|
||||||
|
assert len(query_vec) == len(doc_vecs[0])
|
||||||
|
|
||||||
|
async def test_aembed_documents(self) -> None:
|
||||||
|
"""Test async embedding a list of documents."""
|
||||||
|
embeddings = PerplexityEmbeddings()
|
||||||
|
vectors = await embeddings.aembed_documents(["hello", "world"])
|
||||||
|
assert len(vectors) == 2
|
||||||
|
assert all(len(v) > 0 for v in vectors)
|
||||||
|
|
||||||
|
async def test_aembed_query(self) -> None:
|
||||||
|
"""Test async embedding a single query."""
|
||||||
|
embeddings = PerplexityEmbeddings()
|
||||||
|
vector = await embeddings.aembed_query("hello")
|
||||||
|
assert isinstance(vector, list)
|
||||||
|
assert len(vector) > 0
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
"""Standard integration tests for `PerplexityEmbeddings`."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.embeddings import Embeddings
|
||||||
|
from langchain_tests.integration_tests import EmbeddingsIntegrationTests
|
||||||
|
|
||||||
|
from langchain_perplexity import PerplexityEmbeddings
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not (os.environ.get("PPLX_API_KEY") or os.environ.get("PERPLEXITY_API_KEY")),
|
||||||
|
reason="PPLX_API_KEY/PERPLEXITY_API_KEY not set",
|
||||||
|
)
|
||||||
|
class TestPerplexityEmbeddingsIntegration(EmbeddingsIntegrationTests):
|
||||||
|
@property
|
||||||
|
def embeddings_class(self) -> type[Embeddings]:
|
||||||
|
return PerplexityEmbeddings
|
||||||
|
|
||||||
|
@property
|
||||||
|
def embedding_model_params(self) -> dict:
|
||||||
|
return {}
|
||||||
203
libs/partners/perplexity/tests/unit_tests/test_embeddings.py
Normal file
203
libs/partners/perplexity/tests/unit_tests/test_embeddings.py
Normal file
@@ -0,0 +1,203 @@
|
|||||||
|
"""Unit tests for `PerplexityEmbeddings`."""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import struct
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import SecretStr
|
||||||
|
|
||||||
|
from langchain_perplexity import PerplexityEmbeddings
|
||||||
|
|
||||||
|
|
||||||
|
def _encode_int8(values: list[int]) -> str:
|
||||||
|
"""Encode signed int8 values as base64 (matches Perplexity's wire format)."""
|
||||||
|
raw = struct.pack(f"<{len(values)}b", *values)
|
||||||
|
return base64.b64encode(raw).decode("ascii")
|
||||||
|
|
||||||
|
|
||||||
|
def _make_response(int8_vectors: list[list[int]]) -> MagicMock:
|
||||||
|
"""Build a stand-in for `EmbeddingCreateResponse` with base64_int8 payloads."""
|
||||||
|
response = MagicMock()
|
||||||
|
response.data = []
|
||||||
|
for values in int8_vectors:
|
||||||
|
item = MagicMock()
|
||||||
|
item.embedding = _encode_int8(values)
|
||||||
|
response.data.append(item)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def test_embeddings_initialization() -> None:
|
||||||
|
embeddings = PerplexityEmbeddings(pplx_api_key="test")
|
||||||
|
assert embeddings.pplx_api_key is not None
|
||||||
|
assert embeddings.pplx_api_key.get_secret_value() == "test"
|
||||||
|
assert embeddings.model == "pplx-embed-v1-4b"
|
||||||
|
assert embeddings.client is not None
|
||||||
|
assert embeddings.async_client is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_embeddings_custom_model() -> None:
|
||||||
|
embeddings = PerplexityEmbeddings(pplx_api_key="test", model="custom-model")
|
||||||
|
assert embeddings.model == "custom-model"
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_key_alias() -> None:
|
||||||
|
"""`api_key=` should be accepted via populate_by_name alias."""
|
||||||
|
embeddings = PerplexityEmbeddings(api_key="aliased")
|
||||||
|
assert embeddings.pplx_api_key is not None
|
||||||
|
assert embeddings.pplx_api_key.get_secret_value() == "aliased"
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_key_accepts_secret_str() -> None:
|
||||||
|
embeddings = PerplexityEmbeddings(pplx_api_key=SecretStr("typed"))
|
||||||
|
assert embeddings.pplx_api_key is not None
|
||||||
|
assert embeddings.pplx_api_key.get_secret_value() == "typed"
|
||||||
|
|
||||||
|
|
||||||
|
def test_lc_secrets() -> None:
|
||||||
|
embeddings = PerplexityEmbeddings(pplx_api_key="test")
|
||||||
|
assert embeddings.lc_secrets == {"pplx_api_key": "PPLX_API_KEY"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_pplx_api_key_env_fallback(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
monkeypatch.delenv("PERPLEXITY_API_KEY", raising=False)
|
||||||
|
monkeypatch.setenv("PPLX_API_KEY", "from_pplx_env")
|
||||||
|
embeddings = PerplexityEmbeddings()
|
||||||
|
assert embeddings.pplx_api_key is not None
|
||||||
|
assert embeddings.pplx_api_key.get_secret_value() == "from_pplx_env"
|
||||||
|
|
||||||
|
|
||||||
|
def test_perplexity_api_key_env_fallback(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
monkeypatch.delenv("PPLX_API_KEY", raising=False)
|
||||||
|
monkeypatch.setenv("PERPLEXITY_API_KEY", "from_perp_env")
|
||||||
|
embeddings = PerplexityEmbeddings()
|
||||||
|
assert embeddings.pplx_api_key is not None
|
||||||
|
assert embeddings.pplx_api_key.get_secret_value() == "from_perp_env"
|
||||||
|
|
||||||
|
|
||||||
|
def test_pplx_takes_precedence_over_perplexity(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
monkeypatch.setenv("PPLX_API_KEY", "primary")
|
||||||
|
monkeypatch.setenv("PERPLEXITY_API_KEY", "secondary")
|
||||||
|
embeddings = PerplexityEmbeddings()
|
||||||
|
assert embeddings.pplx_api_key is not None
|
||||||
|
assert embeddings.pplx_api_key.get_secret_value() == "primary"
|
||||||
|
|
||||||
|
|
||||||
|
def test_explicit_kwarg_overrides_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
monkeypatch.setenv("PPLX_API_KEY", "from_env")
|
||||||
|
embeddings = PerplexityEmbeddings(pplx_api_key="explicit")
|
||||||
|
assert embeddings.pplx_api_key is not None
|
||||||
|
assert embeddings.pplx_api_key.get_secret_value() == "explicit"
|
||||||
|
|
||||||
|
|
||||||
|
def test_missing_api_key_raises(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
monkeypatch.delenv("PPLX_API_KEY", raising=False)
|
||||||
|
monkeypatch.delenv("PERPLEXITY_API_KEY", raising=False)
|
||||||
|
with pytest.raises(ValueError, match="Perplexity API key not provided"):
|
||||||
|
PerplexityEmbeddings()
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed_documents() -> None:
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.embeddings.create.return_value = _make_response(
|
||||||
|
[[1, -2, 3], [4, 5, -6]]
|
||||||
|
)
|
||||||
|
embeddings = PerplexityEmbeddings(pplx_api_key="test", client=mock_client)
|
||||||
|
|
||||||
|
result = embeddings.embed_documents(["hello", "world"])
|
||||||
|
|
||||||
|
assert result == [[1.0, -2.0, 3.0], [4.0, 5.0, -6.0]]
|
||||||
|
mock_client.embeddings.create.assert_called_once_with(
|
||||||
|
model="pplx-embed-v1-4b", input=["hello", "world"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed_documents_empty_short_circuits() -> None:
|
||||||
|
mock_client = MagicMock()
|
||||||
|
embeddings = PerplexityEmbeddings(pplx_api_key="test", client=mock_client)
|
||||||
|
|
||||||
|
assert embeddings.embed_documents([]) == []
|
||||||
|
mock_client.embeddings.create.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed_documents_propagates_errors() -> None:
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.embeddings.create.side_effect = RuntimeError("boom")
|
||||||
|
embeddings = PerplexityEmbeddings(pplx_api_key="test", client=mock_client)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="boom"):
|
||||||
|
embeddings.embed_documents(["x"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed_query() -> None:
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.embeddings.create.return_value = _make_response([[7, 8, 9]])
|
||||||
|
embeddings = PerplexityEmbeddings(pplx_api_key="test", client=mock_client)
|
||||||
|
|
||||||
|
result = embeddings.embed_query("hello")
|
||||||
|
|
||||||
|
assert result == [7.0, 8.0, 9.0]
|
||||||
|
mock_client.embeddings.create.assert_called_once_with(
|
||||||
|
model="pplx-embed-v1-4b", input=["hello"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed_documents_uses_custom_model() -> None:
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.embeddings.create.return_value = _make_response([[0]])
|
||||||
|
embeddings = PerplexityEmbeddings(
|
||||||
|
pplx_api_key="test", model="custom-model", client=mock_client
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings.embed_documents(["x"])
|
||||||
|
|
||||||
|
mock_client.embeddings.create.assert_called_once_with(
|
||||||
|
model="custom-model", input=["x"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_aembed_documents() -> None:
|
||||||
|
mock_async_client = MagicMock()
|
||||||
|
mock_async_client.embeddings.create = AsyncMock(
|
||||||
|
return_value=_make_response([[1, 2], [3, 4]])
|
||||||
|
)
|
||||||
|
embeddings = PerplexityEmbeddings(
|
||||||
|
pplx_api_key="test", async_client=mock_async_client
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await embeddings.aembed_documents(["a", "b"])
|
||||||
|
|
||||||
|
assert result == [[1.0, 2.0], [3.0, 4.0]]
|
||||||
|
mock_async_client.embeddings.create.assert_awaited_once_with(
|
||||||
|
model="pplx-embed-v1-4b", input=["a", "b"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_aembed_documents_empty_short_circuits() -> None:
|
||||||
|
mock_async_client = MagicMock()
|
||||||
|
mock_async_client.embeddings.create = AsyncMock()
|
||||||
|
embeddings = PerplexityEmbeddings(
|
||||||
|
pplx_api_key="test", async_client=mock_async_client
|
||||||
|
)
|
||||||
|
|
||||||
|
assert await embeddings.aembed_documents([]) == []
|
||||||
|
mock_async_client.embeddings.create.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_aembed_query() -> None:
|
||||||
|
mock_async_client = MagicMock()
|
||||||
|
mock_async_client.embeddings.create = AsyncMock(
|
||||||
|
return_value=_make_response([[5, 6]])
|
||||||
|
)
|
||||||
|
embeddings = PerplexityEmbeddings(
|
||||||
|
pplx_api_key="test", async_client=mock_async_client
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await embeddings.aembed_query("hi")
|
||||||
|
|
||||||
|
assert result == [5.0, 6.0]
|
||||||
|
mock_async_client.embeddings.create.assert_awaited_once_with(
|
||||||
|
model="pplx-embed-v1-4b", input=["hi"]
|
||||||
|
)
|
||||||
@@ -0,0 +1,20 @@
|
|||||||
|
"""Standard unit tests for `PerplexityEmbeddings`."""
|
||||||
|
|
||||||
|
from langchain_core.embeddings import Embeddings
|
||||||
|
from langchain_tests.unit_tests import EmbeddingsUnitTests
|
||||||
|
|
||||||
|
from langchain_perplexity import PerplexityEmbeddings
|
||||||
|
|
||||||
|
|
||||||
|
class TestPerplexityEmbeddingsStandard(EmbeddingsUnitTests):
|
||||||
|
@property
|
||||||
|
def embeddings_class(self) -> type[Embeddings]:
|
||||||
|
return PerplexityEmbeddings
|
||||||
|
|
||||||
|
@property
|
||||||
|
def embedding_model_params(self) -> dict:
|
||||||
|
return {"pplx_api_key": "test"}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def init_from_env_params(self) -> tuple[dict, dict, dict]:
|
||||||
|
return ({"PPLX_API_KEY": "api_key"}, {}, {"pplx_api_key": "api_key"})
|
||||||
@@ -2,6 +2,7 @@ from langchain_perplexity import __all__
|
|||||||
|
|
||||||
EXPECTED_ALL = [
|
EXPECTED_ALL = [
|
||||||
"ChatPerplexity",
|
"ChatPerplexity",
|
||||||
|
"PerplexityEmbeddings",
|
||||||
"PerplexitySearchRetriever",
|
"PerplexitySearchRetriever",
|
||||||
"PerplexitySearchResults",
|
"PerplexitySearchResults",
|
||||||
"UserLocation",
|
"UserLocation",
|
||||||
|
|||||||
10
libs/partners/perplexity/uv.lock
generated
10
libs/partners/perplexity/uv.lock
generated
@@ -537,7 +537,7 @@ typing = [
|
|||||||
[package.metadata]
|
[package.metadata]
|
||||||
requires-dist = [
|
requires-dist = [
|
||||||
{ name = "langchain-core", editable = "../../core" },
|
{ name = "langchain-core", editable = "../../core" },
|
||||||
{ name = "perplexityai", specifier = ">=0.22.0" },
|
{ name = "perplexityai", specifier = ">=0.32.0,<1.0.0" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.metadata.requires-dev]
|
[package.metadata.requires-dev]
|
||||||
@@ -581,7 +581,7 @@ wheels = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "langchain-tests"
|
name = "langchain-tests"
|
||||||
version = "1.1.6"
|
version = "1.1.7"
|
||||||
source = { editable = "../../standard-tests" }
|
source = { editable = "../../standard-tests" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "httpx" },
|
{ name = "httpx" },
|
||||||
@@ -973,7 +973,7 @@ wheels = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "perplexityai"
|
name = "perplexityai"
|
||||||
version = "0.22.2"
|
version = "0.32.1"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "anyio" },
|
{ name = "anyio" },
|
||||||
@@ -983,9 +983,9 @@ dependencies = [
|
|||||||
{ name = "sniffio" },
|
{ name = "sniffio" },
|
||||||
{ name = "typing-extensions" },
|
{ name = "typing-extensions" },
|
||||||
]
|
]
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/47/f0/4ead48afbe4c9d7b57e9d03e253bed2986ac20d2257ed918cac276949018/perplexityai-0.22.2.tar.gz", hash = "sha256:9c3cad307c95aa5e8967358547e548d58793d350318d8d1d4aa33a933cbed844", size = 113014, upload-time = "2025-12-17T19:05:25.572Z" }
|
sdist = { url = "https://files.pythonhosted.org/packages/09/02/73f460c85a5ec533a97fd1ff34fa729a009b4a217a4a87d8da946b6e1c52/perplexityai-0.32.1.tar.gz", hash = "sha256:b03503498591d06c4d50b666f7f7469875d3586f664c29416aae9012ae7a64d1", size = 135741, upload-time = "2026-04-21T04:35:40.345Z" }
|
||||||
wheels = [
|
wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/cb/aa/fc3337fdb014b1584297fc212552e6365d22a6fb77850a56c9038cd47173/perplexityai-0.22.2-py3-none-any.whl", hash = "sha256:92d3dc7f4e110c879ac5009daf7263a04f413523f7d76fba871176516c253890", size = 96860, upload-time = "2025-12-17T19:05:24.292Z" },
|
{ url = "https://files.pythonhosted.org/packages/d6/11/5c164f114311bc2e2350202393e7c5bd25bb156b5230a1edf5a2b2f4ba04/perplexityai-0.32.1-py3-none-any.whl", hash = "sha256:e5017d245fd8966cf79657edc03a93078d867708542b491b38152618f91e369b", size = 130223, upload-time = "2026-04-21T04:35:38.786Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|||||||
Reference in New Issue
Block a user