mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-01 04:29:09 +00:00
partners: Add Perplexity Chat Integration (#30618)
Perplexity's importance in the space has been growing, so we think it's time to add an official integration! Note: following the release of `langchain-perplexity` to `pypi`, we should be able to add `perplexity` as an extra in `libs/langchain/pyproject.toml`, but we're blocked by a circular import for now. --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com> Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
parent
87c02a1aff
commit
3814bd1ea7
1
.github/workflows/_integration_test.yml
vendored
1
.github/workflows/_integration_test.yml
vendored
@ -76,6 +76,7 @@ jobs:
|
|||||||
COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }}
|
COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }}
|
||||||
UPSTAGE_API_KEY: ${{ secrets.UPSTAGE_API_KEY }}
|
UPSTAGE_API_KEY: ${{ secrets.UPSTAGE_API_KEY }}
|
||||||
XAI_API_KEY: ${{ secrets.XAI_API_KEY }}
|
XAI_API_KEY: ${{ secrets.XAI_API_KEY }}
|
||||||
|
PPLX_API_KEY: ${{ secrets.PPLX_API_KEY }}
|
||||||
run: |
|
run: |
|
||||||
make integration_tests
|
make integration_tests
|
||||||
|
|
||||||
|
1
.github/workflows/_release.yml
vendored
1
.github/workflows/_release.yml
vendored
@ -327,6 +327,7 @@ jobs:
|
|||||||
FIREWORKS_API_KEY: ${{ secrets.FIREWORKS_API_KEY }}
|
FIREWORKS_API_KEY: ${{ secrets.FIREWORKS_API_KEY }}
|
||||||
XAI_API_KEY: ${{ secrets.XAI_API_KEY }}
|
XAI_API_KEY: ${{ secrets.XAI_API_KEY }}
|
||||||
DEEPSEEK_API_KEY: ${{ secrets.DEEPSEEK_API_KEY }}
|
DEEPSEEK_API_KEY: ${{ secrets.DEEPSEEK_API_KEY }}
|
||||||
|
PPLX_API_KEY: ${{ secrets.PPLX_API_KEY }}
|
||||||
run: make integration_tests
|
run: make integration_tests
|
||||||
working-directory: ${{ inputs.working-directory }}
|
working-directory: ${{ inputs.working-directory }}
|
||||||
|
|
||||||
|
1
.github/workflows/scheduled_test.yml
vendored
1
.github/workflows/scheduled_test.yml
vendored
@ -145,6 +145,7 @@ jobs:
|
|||||||
GOOGLE_API_KEY: ${{ secrets.GOOGLE_API_KEY }}
|
GOOGLE_API_KEY: ${{ secrets.GOOGLE_API_KEY }}
|
||||||
GOOGLE_SEARCH_API_KEY: ${{ secrets.GOOGLE_SEARCH_API_KEY }}
|
GOOGLE_SEARCH_API_KEY: ${{ secrets.GOOGLE_SEARCH_API_KEY }}
|
||||||
GOOGLE_CSE_ID: ${{ secrets.GOOGLE_CSE_ID }}
|
GOOGLE_CSE_ID: ${{ secrets.GOOGLE_CSE_ID }}
|
||||||
|
PPLX_API_KEY: ${{ secrets.PPLX_API_KEY }}
|
||||||
run: |
|
run: |
|
||||||
cd langchain/${{ matrix.working-directory }}
|
cd langchain/${{ matrix.working-directory }}
|
||||||
make integration_tests
|
make integration_tests
|
||||||
|
@ -218,6 +218,13 @@ ${llmVarName} = ChatWatsonx(
|
|||||||
apiKeyName: "XAI_API_KEY",
|
apiKeyName: "XAI_API_KEY",
|
||||||
packageName: "langchain-xai",
|
packageName: "langchain-xai",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
value: "perplexity",
|
||||||
|
label: "Perplexity",
|
||||||
|
model: "llama-3.1-sonar-small-128k-online",
|
||||||
|
apiKeyName: "PPLX_API_KEY",
|
||||||
|
packageName: "langchain-perplexity",
|
||||||
|
}
|
||||||
].map((item) => ({
|
].map((item) => ({
|
||||||
...item,
|
...item,
|
||||||
...overrideParams?.[item.value],
|
...overrideParams?.[item.value],
|
||||||
|
@ -237,6 +237,17 @@ const FEATURE_TABLES = {
|
|||||||
"local": false,
|
"local": false,
|
||||||
"apiLink": "https://python.langchain.com/api_reference/xai/chat_models/langchain_xai.chat_models.ChatXAI.html"
|
"apiLink": "https://python.langchain.com/api_reference/xai/chat_models/langchain_xai.chat_models.ChatXAI.html"
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"name": "ChatPerplexity",
|
||||||
|
"package": "langchain-perplexity",
|
||||||
|
"link": "perplexity",
|
||||||
|
"structured_output": true,
|
||||||
|
"tool_calling": false,
|
||||||
|
"json_mode": true,
|
||||||
|
"multimodal": true,
|
||||||
|
"local": false,
|
||||||
|
"apiLink": "https://python.langchain.com/api_reference/perplexity/chat_models/langchain_perplexity.chat_models.ChatPerplexity.html"
|
||||||
|
}
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
llms: {
|
llms: {
|
||||||
|
@ -28,6 +28,7 @@ DEFAULT_NAMESPACES = [
|
|||||||
"langchain_fireworks",
|
"langchain_fireworks",
|
||||||
"langchain_xai",
|
"langchain_xai",
|
||||||
"langchain_sambanova",
|
"langchain_sambanova",
|
||||||
|
"langchain_perplexity",
|
||||||
]
|
]
|
||||||
# Namespaces for which only deserializing via the SERIALIZABLE_MAPPING is allowed.
|
# Namespaces for which only deserializing via the SERIALIZABLE_MAPPING is allowed.
|
||||||
# Load by path is not allowed.
|
# Load by path is not allowed.
|
||||||
|
@ -125,6 +125,7 @@ def init_chat_model(
|
|||||||
- 'ibm' -> langchain-ibm
|
- 'ibm' -> langchain-ibm
|
||||||
- 'nvidia' -> langchain-nvidia-ai-endpoints
|
- 'nvidia' -> langchain-nvidia-ai-endpoints
|
||||||
- 'xai' -> langchain-xai
|
- 'xai' -> langchain-xai
|
||||||
|
- 'perplexity' -> langchain-perplexity
|
||||||
|
|
||||||
Will attempt to infer model_provider from model if not specified. The
|
Will attempt to infer model_provider from model if not specified. The
|
||||||
following providers will be inferred based on these model prefixes:
|
following providers will be inferred based on these model prefixes:
|
||||||
@ -453,6 +454,11 @@ def _init_chat_model_helper(
|
|||||||
from langchain_xai import ChatXAI
|
from langchain_xai import ChatXAI
|
||||||
|
|
||||||
return ChatXAI(model=model, **kwargs)
|
return ChatXAI(model=model, **kwargs)
|
||||||
|
elif model_provider == "perplexity":
|
||||||
|
_check_pkg("langchain_perplexity")
|
||||||
|
from langchain_perplexity import ChatPerplexity
|
||||||
|
|
||||||
|
return ChatPerplexity(model=model, **kwargs)
|
||||||
else:
|
else:
|
||||||
supported = ", ".join(_SUPPORTED_PROVIDERS)
|
supported = ", ".join(_SUPPORTED_PROVIDERS)
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -481,6 +487,7 @@ _SUPPORTED_PROVIDERS = {
|
|||||||
"deepseek",
|
"deepseek",
|
||||||
"ibm",
|
"ibm",
|
||||||
"xai",
|
"xai",
|
||||||
|
"perplexity",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
1
libs/partners/perplexity/.gitignore
vendored
Normal file
1
libs/partners/perplexity/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
__pycache__
|
21
libs/partners/perplexity/LICENSE
Normal file
21
libs/partners/perplexity/LICENSE
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2023 LangChain, Inc.
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
65
libs/partners/perplexity/Makefile
Normal file
65
libs/partners/perplexity/Makefile
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
.PHONY: all format lint test tests integration_tests docker_tests help extended_tests
|
||||||
|
|
||||||
|
# Default target executed when no arguments are given to make.
|
||||||
|
all: help
|
||||||
|
|
||||||
|
.EXPORT_ALL_VARIABLES:
|
||||||
|
UV_FROZEN = true
|
||||||
|
|
||||||
|
# Define a variable for the test file path.
|
||||||
|
TEST_FILE ?= tests/unit_tests/
|
||||||
|
|
||||||
|
integration_test integration_tests: TEST_FILE=tests/integration_tests/
|
||||||
|
|
||||||
|
test tests:
|
||||||
|
uv run --group test pytest --disable-socket --allow-unix-socket $(TEST_FILE)
|
||||||
|
|
||||||
|
test_watch:
|
||||||
|
uv run --group test ptw --snapshot-update --now . -- -vv $(TEST_FILE)
|
||||||
|
|
||||||
|
integration_test integration_tests:
|
||||||
|
uv run --group test --group test_integration pytest $(TEST_FILE)
|
||||||
|
|
||||||
|
######################
|
||||||
|
# LINTING AND FORMATTING
|
||||||
|
######################
|
||||||
|
|
||||||
|
# Define a variable for Python and notebook files.
|
||||||
|
PYTHON_FILES=.
|
||||||
|
MYPY_CACHE=.mypy_cache
|
||||||
|
lint format: PYTHON_FILES=.
|
||||||
|
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/perplexity --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
|
||||||
|
lint_package: PYTHON_FILES=langchain_perplexity
|
||||||
|
lint_tests: PYTHON_FILES=tests
|
||||||
|
lint_tests: MYPY_CACHE=.mypy_cache_test
|
||||||
|
|
||||||
|
lint lint_diff lint_package lint_tests:
|
||||||
|
[ "$(PYTHON_FILES)" = "" ] || uv run --all-groups ruff check $(PYTHON_FILES)
|
||||||
|
[ "$(PYTHON_FILES)" = "" ] || uv run --all-groups ruff format $(PYTHON_FILES) --diff
|
||||||
|
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && uv run --all-groups mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
|
||||||
|
|
||||||
|
format format_diff:
|
||||||
|
[ "$(PYTHON_FILES)" = "" ] || uv run --all-groups ruff format $(PYTHON_FILES)
|
||||||
|
[ "$(PYTHON_FILES)" = "" ] || uv run --all-groups ruff check --select I --fix $(PYTHON_FILES)
|
||||||
|
|
||||||
|
spell_check:
|
||||||
|
uv run --all-groups codespell --toml pyproject.toml
|
||||||
|
|
||||||
|
spell_fix:
|
||||||
|
uv run --all-groups codespell --toml pyproject.toml -w
|
||||||
|
|
||||||
|
check_imports: $(shell find langchain_perplexity -name '*.py')
|
||||||
|
uv run --all-groups python ./scripts/check_imports.py $^
|
||||||
|
|
||||||
|
######################
|
||||||
|
# HELP
|
||||||
|
######################
|
||||||
|
|
||||||
|
help:
|
||||||
|
@echo '----'
|
||||||
|
@echo 'check_imports - check imports'
|
||||||
|
@echo 'format - run code formatters'
|
||||||
|
@echo 'lint - run linters'
|
||||||
|
@echo 'test - run unit tests'
|
||||||
|
@echo 'tests - run unit tests'
|
||||||
|
@echo 'test TEST_FILE=<test_file> - run all tests in file'
|
29
libs/partners/perplexity/README.md
Normal file
29
libs/partners/perplexity/README.md
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
# langchain-perplexity
|
||||||
|
|
||||||
|
This package contains the LangChain integration with Perplexity.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -U langchain-perplexity
|
||||||
|
```
|
||||||
|
|
||||||
|
And you should [configure your perplexity credentials](https://docs.perplexity.ai/guides/getting-started)
|
||||||
|
and then set the `PPLX_API_KEY` environment variable.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
This package contains the `ChatPerplexity` class, which is the recommended way to interface with Perplexity chat models.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import getpass
|
||||||
|
import os
|
||||||
|
|
||||||
|
if not os.environ.get("PPLX_API_KEY"):
|
||||||
|
os.environ["PPLX_API_KEY"] = getpass.getpass("Enter API key for Perplexity: ")
|
||||||
|
|
||||||
|
from langchain.chat_models import init_chat_model
|
||||||
|
|
||||||
|
llm = init_chat_model("llama-3.1-sonar-small-128k-online", model_provider="perplexity")
|
||||||
|
llm.invoke("Hello, world!")
|
||||||
|
```
|
@ -0,0 +1,5 @@
|
|||||||
|
"""This package provides the Perplexity integration for LangChain."""
|
||||||
|
|
||||||
|
from langchain_perplexity.chat_models import ChatPerplexity
|
||||||
|
|
||||||
|
__all__ = ["ChatPerplexity"]
|
504
libs/partners/perplexity/langchain_perplexity/chat_models.py
Normal file
504
libs/partners/perplexity/langchain_perplexity/chat_models.py
Normal file
@ -0,0 +1,504 @@
|
|||||||
|
"""Wrapper around Perplexity APIs."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from operator import itemgetter
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Dict,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Mapping,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
import openai
|
||||||
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||||
|
from langchain_core.language_models import LanguageModelInput
|
||||||
|
from langchain_core.language_models.chat_models import (
|
||||||
|
BaseChatModel,
|
||||||
|
generate_from_stream,
|
||||||
|
)
|
||||||
|
from langchain_core.messages import (
|
||||||
|
AIMessage,
|
||||||
|
AIMessageChunk,
|
||||||
|
BaseMessage,
|
||||||
|
BaseMessageChunk,
|
||||||
|
ChatMessage,
|
||||||
|
ChatMessageChunk,
|
||||||
|
FunctionMessageChunk,
|
||||||
|
HumanMessage,
|
||||||
|
HumanMessageChunk,
|
||||||
|
SystemMessage,
|
||||||
|
SystemMessageChunk,
|
||||||
|
ToolMessageChunk,
|
||||||
|
)
|
||||||
|
from langchain_core.messages.ai import UsageMetadata, subtract_usage
|
||||||
|
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
|
||||||
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
|
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
||||||
|
from langchain_core.utils import get_pydantic_field_names, secret_from_env
|
||||||
|
from langchain_core.utils.function_calling import convert_to_json_schema
|
||||||
|
from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
_BM = TypeVar("_BM", bound=BaseModel)
|
||||||
|
_DictOrPydanticClass = Union[Dict[str, Any], Type[_BM], Type]
|
||||||
|
_DictOrPydantic = Union[Dict, _BM]
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_pydantic_class(obj: Any) -> bool:
|
||||||
|
return isinstance(obj, type) and is_basemodel_subclass(obj)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_usage_metadata(token_usage: dict) -> UsageMetadata:
|
||||||
|
input_tokens = token_usage.get("prompt_tokens", 0)
|
||||||
|
output_tokens = token_usage.get("completion_tokens", 0)
|
||||||
|
total_tokens = token_usage.get("total_tokens", input_tokens + output_tokens)
|
||||||
|
return UsageMetadata(
|
||||||
|
input_tokens=input_tokens,
|
||||||
|
output_tokens=output_tokens,
|
||||||
|
total_tokens=total_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatPerplexity(BaseChatModel):
|
||||||
|
"""`Perplexity AI` Chat models API.
|
||||||
|
|
||||||
|
Setup:
|
||||||
|
To use, you should have the ``openai`` python package installed, and the
|
||||||
|
environment variable ``PPLX_API_KEY`` set to your API key.
|
||||||
|
Any parameters that are valid to be passed to the openai.create call
|
||||||
|
can be passed in, even if not explicitly saved on this class.
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
pip install openai
|
||||||
|
export PPLX_API_KEY=your_api_key
|
||||||
|
|
||||||
|
Key init args - completion params:
|
||||||
|
model: str
|
||||||
|
Name of the model to use. e.g. "llama-3.1-sonar-small-128k-online"
|
||||||
|
temperature: float
|
||||||
|
Sampling temperature to use. Default is 0.7
|
||||||
|
max_tokens: Optional[int]
|
||||||
|
Maximum number of tokens to generate.
|
||||||
|
streaming: bool
|
||||||
|
Whether to stream the results or not.
|
||||||
|
|
||||||
|
Key init args - client params:
|
||||||
|
pplx_api_key: Optional[str]
|
||||||
|
API key for PerplexityChat API. Default is None.
|
||||||
|
request_timeout: Optional[Union[float, Tuple[float, float]]]
|
||||||
|
Timeout for requests to PerplexityChat completion API. Default is None.
|
||||||
|
max_retries: int
|
||||||
|
Maximum number of retries to make when generating.
|
||||||
|
|
||||||
|
See full list of supported init args and their descriptions in the params section.
|
||||||
|
|
||||||
|
Instantiate:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain_community.chat_models import ChatPerplexity
|
||||||
|
|
||||||
|
llm = ChatPerplexity(
|
||||||
|
model="llama-3.1-sonar-small-128k-online", temperature=0.7
|
||||||
|
)
|
||||||
|
|
||||||
|
Invoke:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
messages = [("system", "You are a chatbot."), ("user", "Hello!")]
|
||||||
|
llm.invoke(messages)
|
||||||
|
|
||||||
|
Invoke with structured output:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class StructuredOutput(BaseModel):
|
||||||
|
role: str
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
llm.with_structured_output(StructuredOutput)
|
||||||
|
llm.invoke(messages)
|
||||||
|
|
||||||
|
Invoke with perplexity-specific params:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
llm.invoke(messages, extra_body={"search_recency_filter": "week"})
|
||||||
|
|
||||||
|
Stream:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
for chunk in llm.stream(messages):
|
||||||
|
print(chunk.content)
|
||||||
|
|
||||||
|
Token usage:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
response = llm.invoke(messages)
|
||||||
|
response.usage_metadata
|
||||||
|
|
||||||
|
Response metadata:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
response = llm.invoke(messages)
|
||||||
|
response.response_metadata
|
||||||
|
|
||||||
|
""" # noqa: E501
|
||||||
|
|
||||||
|
client: Any = None #: :meta private:
|
||||||
|
model: str = "llama-3.1-sonar-small-128k-online"
|
||||||
|
"""Model name."""
|
||||||
|
temperature: float = 0.7
|
||||||
|
"""What sampling temperature to use."""
|
||||||
|
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||||
|
pplx_api_key: Optional[SecretStr] = Field(
|
||||||
|
default_factory=secret_from_env("PPLX_API_KEY", default=None), alias="api_key"
|
||||||
|
)
|
||||||
|
"""Base URL path for API requests,
|
||||||
|
leave blank if not using a proxy or service emulator."""
|
||||||
|
request_timeout: Optional[Union[float, Tuple[float, float]]] = Field(
|
||||||
|
None, alias="timeout"
|
||||||
|
)
|
||||||
|
"""Timeout for requests to PerplexityChat completion API. Default is None."""
|
||||||
|
max_retries: int = 6
|
||||||
|
"""Maximum number of retries to make when generating."""
|
||||||
|
streaming: bool = False
|
||||||
|
"""Whether to stream the results or not."""
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
"""Maximum number of tokens to generate."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(populate_by_name=True)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_secrets(self) -> Dict[str, str]:
|
||||||
|
return {"pplx_api_key": "PPLX_API_KEY"}
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def build_extra(cls, values: Dict[str, Any]) -> Any:
|
||||||
|
"""Build extra kwargs from additional params that were passed in."""
|
||||||
|
all_required_field_names = get_pydantic_field_names(cls)
|
||||||
|
extra = values.get("model_kwargs", {})
|
||||||
|
for field_name in list(values):
|
||||||
|
if field_name in extra:
|
||||||
|
raise ValueError(f"Found {field_name} supplied twice.")
|
||||||
|
if field_name not in all_required_field_names:
|
||||||
|
logger.warning(
|
||||||
|
f"""WARNING! {field_name} is not a default parameter.
|
||||||
|
{field_name} was transferred to model_kwargs.
|
||||||
|
Please confirm that {field_name} is what you intended."""
|
||||||
|
)
|
||||||
|
extra[field_name] = values.pop(field_name)
|
||||||
|
|
||||||
|
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
|
||||||
|
if invalid_model_kwargs:
|
||||||
|
raise ValueError(
|
||||||
|
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
||||||
|
f"Instead they were passed in as part of `model_kwargs` parameter."
|
||||||
|
)
|
||||||
|
|
||||||
|
values["model_kwargs"] = extra
|
||||||
|
return values
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_environment(self) -> Self:
|
||||||
|
"""Validate that api key and python package exists in environment."""
|
||||||
|
try:
|
||||||
|
self.client = openai.OpenAI(
|
||||||
|
api_key=self.pplx_api_key.get_secret_value()
|
||||||
|
if self.pplx_api_key
|
||||||
|
else None,
|
||||||
|
base_url="https://api.perplexity.ai",
|
||||||
|
)
|
||||||
|
except AttributeError:
|
||||||
|
raise ValueError(
|
||||||
|
"`openai` has no `ChatCompletion` attribute, this is likely "
|
||||||
|
"due to an old version of the openai package. Try upgrading it "
|
||||||
|
"with `pip install --upgrade openai`."
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _default_params(self) -> Dict[str, Any]:
|
||||||
|
"""Get the default parameters for calling PerplexityChat API."""
|
||||||
|
return {
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
"stream": self.streaming,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
**self.model_kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _convert_message_to_dict(self, message: BaseMessage) -> Dict[str, Any]:
|
||||||
|
if isinstance(message, ChatMessage):
|
||||||
|
message_dict = {"role": message.role, "content": message.content}
|
||||||
|
elif isinstance(message, SystemMessage):
|
||||||
|
message_dict = {"role": "system", "content": message.content}
|
||||||
|
elif isinstance(message, HumanMessage):
|
||||||
|
message_dict = {"role": "user", "content": message.content}
|
||||||
|
elif isinstance(message, AIMessage):
|
||||||
|
message_dict = {"role": "assistant", "content": message.content}
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Got unknown type {message}")
|
||||||
|
return message_dict
|
||||||
|
|
||||||
|
def _create_message_dicts(
|
||||||
|
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||||
|
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||||
|
params = dict(self._invocation_params)
|
||||||
|
if stop is not None:
|
||||||
|
if "stop" in params:
|
||||||
|
raise ValueError("`stop` found in both the input and default params.")
|
||||||
|
params["stop"] = stop
|
||||||
|
message_dicts = [self._convert_message_to_dict(m) for m in messages]
|
||||||
|
return message_dicts, params
|
||||||
|
|
||||||
|
def _convert_delta_to_message_chunk(
|
||||||
|
self, _dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
|
||||||
|
) -> BaseMessageChunk:
|
||||||
|
role = _dict.get("role")
|
||||||
|
content = _dict.get("content") or ""
|
||||||
|
additional_kwargs: Dict = {}
|
||||||
|
if _dict.get("function_call"):
|
||||||
|
function_call = dict(_dict["function_call"])
|
||||||
|
if "name" in function_call and function_call["name"] is None:
|
||||||
|
function_call["name"] = ""
|
||||||
|
additional_kwargs["function_call"] = function_call
|
||||||
|
if _dict.get("tool_calls"):
|
||||||
|
additional_kwargs["tool_calls"] = _dict["tool_calls"]
|
||||||
|
|
||||||
|
if role == "user" or default_class == HumanMessageChunk:
|
||||||
|
return HumanMessageChunk(content=content)
|
||||||
|
elif role == "assistant" or default_class == AIMessageChunk:
|
||||||
|
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
|
||||||
|
elif role == "system" or default_class == SystemMessageChunk:
|
||||||
|
return SystemMessageChunk(content=content)
|
||||||
|
elif role == "function" or default_class == FunctionMessageChunk:
|
||||||
|
return FunctionMessageChunk(content=content, name=_dict["name"])
|
||||||
|
elif role == "tool" or default_class == ToolMessageChunk:
|
||||||
|
return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"])
|
||||||
|
elif role or default_class == ChatMessageChunk:
|
||||||
|
return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type]
|
||||||
|
else:
|
||||||
|
return default_class(content=content) # type: ignore[call-arg]
|
||||||
|
|
||||||
|
def _stream(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
|
params = {**params, **kwargs}
|
||||||
|
default_chunk_class = AIMessageChunk
|
||||||
|
params.pop("stream", None)
|
||||||
|
if stop:
|
||||||
|
params["stop_sequences"] = stop
|
||||||
|
stream_resp = self.client.chat.completions.create(
|
||||||
|
messages=message_dicts, stream=True, **params
|
||||||
|
)
|
||||||
|
first_chunk = True
|
||||||
|
prev_total_usage: Optional[UsageMetadata] = None
|
||||||
|
|
||||||
|
added_model_name: bool = False
|
||||||
|
for chunk in stream_resp:
|
||||||
|
if not isinstance(chunk, dict):
|
||||||
|
chunk = chunk.model_dump()
|
||||||
|
# Collect standard usage metadata (transform from aggregate to delta)
|
||||||
|
if total_usage := chunk.get("usage"):
|
||||||
|
lc_total_usage = _create_usage_metadata(total_usage)
|
||||||
|
if prev_total_usage:
|
||||||
|
usage_metadata: Optional[UsageMetadata] = subtract_usage(
|
||||||
|
lc_total_usage, prev_total_usage
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
usage_metadata = lc_total_usage
|
||||||
|
prev_total_usage = lc_total_usage
|
||||||
|
else:
|
||||||
|
usage_metadata = None
|
||||||
|
if len(chunk["choices"]) == 0:
|
||||||
|
continue
|
||||||
|
choice = chunk["choices"][0]
|
||||||
|
|
||||||
|
additional_kwargs = {}
|
||||||
|
if first_chunk:
|
||||||
|
additional_kwargs["citations"] = chunk.get("citations", [])
|
||||||
|
for attr in ["images", "related_questions"]:
|
||||||
|
if attr in chunk:
|
||||||
|
additional_kwargs[attr] = chunk[attr]
|
||||||
|
|
||||||
|
generation_info = {}
|
||||||
|
if (model_name := chunk.get("model")) and not added_model_name:
|
||||||
|
generation_info["model_name"] = model_name
|
||||||
|
added_model_name = True
|
||||||
|
|
||||||
|
chunk = self._convert_delta_to_message_chunk(
|
||||||
|
choice["delta"], default_chunk_class
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(chunk, AIMessageChunk) and usage_metadata:
|
||||||
|
chunk.usage_metadata = usage_metadata
|
||||||
|
|
||||||
|
if first_chunk:
|
||||||
|
chunk.additional_kwargs |= additional_kwargs
|
||||||
|
first_chunk = False
|
||||||
|
|
||||||
|
if finish_reason := choice.get("finish_reason"):
|
||||||
|
generation_info["finish_reason"] = finish_reason
|
||||||
|
|
||||||
|
default_chunk_class = chunk.__class__
|
||||||
|
chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info)
|
||||||
|
if run_manager:
|
||||||
|
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
if self.streaming:
|
||||||
|
stream_iter = self._stream(
|
||||||
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
|
if stream_iter:
|
||||||
|
return generate_from_stream(stream_iter)
|
||||||
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
|
params = {**params, **kwargs}
|
||||||
|
response = self.client.chat.completions.create(messages=message_dicts, **params)
|
||||||
|
if usage := getattr(response, "usage", None):
|
||||||
|
usage_metadata = _create_usage_metadata(usage.model_dump())
|
||||||
|
else:
|
||||||
|
usage_metadata = None
|
||||||
|
|
||||||
|
additional_kwargs = {"citations": response.citations}
|
||||||
|
for attr in ["images", "related_questions"]:
|
||||||
|
if hasattr(response, attr):
|
||||||
|
additional_kwargs[attr] = getattr(response, attr)
|
||||||
|
|
||||||
|
message = AIMessage(
|
||||||
|
content=response.choices[0].message.content,
|
||||||
|
additional_kwargs=additional_kwargs,
|
||||||
|
usage_metadata=usage_metadata,
|
||||||
|
response_metadata={"model_name": getattr(response, "model", self.model)},
|
||||||
|
)
|
||||||
|
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _invocation_params(self) -> Mapping[str, Any]:
|
||||||
|
"""Get the parameters used to invoke the model."""
|
||||||
|
pplx_creds: Dict[str, Any] = {"model": self.model}
|
||||||
|
return {**pplx_creds, **self._default_params}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
"""Return type of chat model."""
|
||||||
|
return "perplexitychat"
|
||||||
|
|
||||||
|
def with_structured_output(
|
||||||
|
self,
|
||||||
|
schema: Optional[_DictOrPydanticClass] = None,
|
||||||
|
*,
|
||||||
|
method: Literal["json_schema"] = "json_schema",
|
||||||
|
include_raw: bool = False,
|
||||||
|
strict: Optional[bool] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Runnable[LanguageModelInput, _DictOrPydantic]:
|
||||||
|
"""Model wrapper that returns outputs formatted to match the given schema for Preplexity.
|
||||||
|
Currently, Preplexity only supports "json_schema" method for structured output
|
||||||
|
as per their official documentation: https://docs.perplexity.ai/guides/structured-outputs
|
||||||
|
|
||||||
|
Args:
|
||||||
|
schema:
|
||||||
|
The output schema. Can be passed in as:
|
||||||
|
|
||||||
|
- a JSON Schema,
|
||||||
|
- a TypedDict class,
|
||||||
|
- or a Pydantic class
|
||||||
|
|
||||||
|
method: The method for steering model generation, currently only support:
|
||||||
|
|
||||||
|
- "json_schema": Use the JSON Schema to parse the model output
|
||||||
|
|
||||||
|
|
||||||
|
include_raw:
|
||||||
|
If False then only the parsed structured output is returned. If
|
||||||
|
an error occurs during model output parsing it will be raised. If True
|
||||||
|
then both the raw model response (a BaseMessage) and the parsed model
|
||||||
|
response will be returned. If an error occurs during output parsing it
|
||||||
|
will be caught and returned as well. The final output is always a dict
|
||||||
|
with keys "raw", "parsed", and "parsing_error".
|
||||||
|
|
||||||
|
kwargs: Additional keyword args aren't supported.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`.
|
||||||
|
|
||||||
|
| If ``include_raw`` is False and ``schema`` is a Pydantic class, Runnable outputs an instance of ``schema`` (i.e., a Pydantic object). Otherwise, if ``include_raw`` is False then Runnable outputs a dict.
|
||||||
|
|
||||||
|
| If ``include_raw`` is True, then Runnable outputs a dict with keys:
|
||||||
|
|
||||||
|
- "raw": BaseMessage
|
||||||
|
- "parsed": None if there was a parsing error, otherwise the type depends on the ``schema`` as described above.
|
||||||
|
- "parsing_error": Optional[BaseException]
|
||||||
|
|
||||||
|
""" # noqa: E501
|
||||||
|
if method in ("function_calling", "json_mode"):
|
||||||
|
method = "json_schema"
|
||||||
|
if method == "json_schema":
|
||||||
|
if schema is None:
|
||||||
|
raise ValueError(
|
||||||
|
"schema must be specified when method is not 'json_schema'. "
|
||||||
|
"Received None."
|
||||||
|
)
|
||||||
|
is_pydantic_schema = _is_pydantic_class(schema)
|
||||||
|
response_format = convert_to_json_schema(schema)
|
||||||
|
llm = self.bind(
|
||||||
|
response_format={
|
||||||
|
"type": "json_schema",
|
||||||
|
"json_schema": {"schema": response_format},
|
||||||
|
},
|
||||||
|
ls_structured_output_format={
|
||||||
|
"kwargs": {"method": method},
|
||||||
|
"schema": response_format,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
output_parser = (
|
||||||
|
PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type]
|
||||||
|
if is_pydantic_schema
|
||||||
|
else JsonOutputParser()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unrecognized method argument. Expected 'json_schema' Received:\
|
||||||
|
'{method}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
if include_raw:
|
||||||
|
parser_assign = RunnablePassthrough.assign(
|
||||||
|
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
|
||||||
|
)
|
||||||
|
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
|
||||||
|
parser_with_fallback = parser_assign.with_fallbacks(
|
||||||
|
[parser_none], exception_key="parsing_error"
|
||||||
|
)
|
||||||
|
return RunnableMap(raw=llm) | parser_with_fallback
|
||||||
|
else:
|
||||||
|
return llm | output_parser
|
78
libs/partners/perplexity/pyproject.toml
Normal file
78
libs/partners/perplexity/pyproject.toml
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
[build-system]
|
||||||
|
requires = ["pdm-backend"]
|
||||||
|
build-backend = "pdm.backend"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
authors = []
|
||||||
|
license = { text = "MIT" }
|
||||||
|
requires-python = "<4.0,>=3.9"
|
||||||
|
dependencies = [
|
||||||
|
"langchain-core<1.0.0,>=0.3.49",
|
||||||
|
"openai<2.0.0,>=1.68.2",
|
||||||
|
]
|
||||||
|
name = "langchain-perplexity"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "An integration package connecting Perplexity and LangChain"
|
||||||
|
readme = "README.md"
|
||||||
|
|
||||||
|
[project.urls]
|
||||||
|
"Source Code" = "https://github.com/langchain-ai/langchain/tree/master/libs/partners/perplexity"
|
||||||
|
"Release Notes" = "https://github.com/langchain-ai/langchain/releases?q=tag%3A%22langchain-perplexity%3D%3D0%22&expanded=true"
|
||||||
|
repository = "https://github.com/langchain-ai/langchain"
|
||||||
|
|
||||||
|
[dependency-groups]
|
||||||
|
test = [
|
||||||
|
"pytest<8.0.0,>=7.3.0",
|
||||||
|
"freezegun<2.0.0,>=1.2.2",
|
||||||
|
"pytest-mock<4.0.0,>=3.10.0",
|
||||||
|
"syrupy<5.0.0,>=4.0.2",
|
||||||
|
"pytest-watcher<1.0.0,>=0.3.4",
|
||||||
|
"pytest-asyncio<1.0.0,>=0.21.1",
|
||||||
|
"pytest-cov<5.0.0,>=4.1.0",
|
||||||
|
"pytest-retry<1.8.0,>=1.7.0",
|
||||||
|
"pytest-socket<1.0.0,>=0.6.0",
|
||||||
|
"pytest-xdist<4.0.0,>=3.6.1",
|
||||||
|
"langchain-core",
|
||||||
|
"langchain-tests",
|
||||||
|
]
|
||||||
|
codespell = ["codespell<3.0.0,>=2.2.0"]
|
||||||
|
lint = ["ruff<1.0,>=0.5"]
|
||||||
|
dev = ["langchain-core"]
|
||||||
|
test_integration = [
|
||||||
|
"httpx<1.0.0,>=0.27.0",
|
||||||
|
"pillow<11.0.0,>=10.3.0",
|
||||||
|
]
|
||||||
|
typing = ["mypy<2.0,>=1.10", "types-tqdm<5.0.0.0,>=4.66.0.5", "langchain-core"]
|
||||||
|
|
||||||
|
[tool.uv.sources]
|
||||||
|
langchain-core = { path = "../../core", editable = true }
|
||||||
|
langchain-tests = { path = "../../standard-tests", editable = true }
|
||||||
|
|
||||||
|
[tool.mypy]
|
||||||
|
disallow_untyped_defs = "True"
|
||||||
|
plugins = ['pydantic.mypy']
|
||||||
|
[[tool.mypy.overrides]]
|
||||||
|
module = "transformers"
|
||||||
|
ignore_missing_imports = true
|
||||||
|
|
||||||
|
[tool.ruff.lint]
|
||||||
|
select = ["E", "F", "I", "T201"]
|
||||||
|
|
||||||
|
[tool.ruff.format]
|
||||||
|
docstring-code-format = true
|
||||||
|
skip-magic-trailing-comma = true
|
||||||
|
|
||||||
|
[tool.coverage.run]
|
||||||
|
omit = ["tests/*"]
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5 --cov=langchain_perplexity"
|
||||||
|
markers = [
|
||||||
|
"requires: mark tests as requiring a specific library",
|
||||||
|
"compile: mark placeholder test used to compile integration tests without running them",
|
||||||
|
"scheduled: mark tests to run in scheduled testing",
|
||||||
|
]
|
||||||
|
asyncio_mode = "auto"
|
||||||
|
filterwarnings = [
|
||||||
|
"ignore::langchain_core._api.beta_decorator.LangChainBetaWarning",
|
||||||
|
]
|
17
libs/partners/perplexity/scripts/check_imports.py
Normal file
17
libs/partners/perplexity/scripts/check_imports.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
from importlib.machinery import SourceFileLoader
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
files = sys.argv[1:]
|
||||||
|
has_failure = False
|
||||||
|
for file in files:
|
||||||
|
try:
|
||||||
|
SourceFileLoader("x", file).load_module()
|
||||||
|
except Exception:
|
||||||
|
has_failure = True
|
||||||
|
print(file) # noqa: T201
|
||||||
|
traceback.print_exc()
|
||||||
|
print() # noqa: T201
|
||||||
|
|
||||||
|
sys.exit(1 if has_failure else 0)
|
17
libs/partners/perplexity/scripts/lint_imports.sh
Executable file
17
libs/partners/perplexity/scripts/lint_imports.sh
Executable file
@ -0,0 +1,17 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -eu
|
||||||
|
|
||||||
|
# Initialize a variable to keep track of errors
|
||||||
|
errors=0
|
||||||
|
|
||||||
|
# make sure not importing from langchain or langchain_experimental
|
||||||
|
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
|
||||||
|
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
|
||||||
|
|
||||||
|
# Decide on an exit status based on the errors
|
||||||
|
if [ "$errors" -gt 0 ]; then
|
||||||
|
exit 1
|
||||||
|
else
|
||||||
|
exit 0
|
||||||
|
fi
|
0
libs/partners/perplexity/tests/__init__.py
Normal file
0
libs/partners/perplexity/tests/__init__.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
"""Standard LangChain interface tests."""
|
||||||
|
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.language_models import BaseChatModel
|
||||||
|
from langchain_tests.integration_tests import ChatModelIntegrationTests
|
||||||
|
|
||||||
|
from langchain_perplexity import ChatPerplexity
|
||||||
|
|
||||||
|
|
||||||
|
class TestPerplexityStandard(ChatModelIntegrationTests):
|
||||||
|
@property
|
||||||
|
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||||
|
return ChatPerplexity
|
||||||
|
|
||||||
|
@property
|
||||||
|
def chat_model_params(self) -> dict:
|
||||||
|
return {"model": "sonar"}
|
||||||
|
|
||||||
|
@pytest.mark.xfail(reason="TODO: handle in integration.")
|
||||||
|
def test_double_messages_conversation(self, model: BaseChatModel) -> None:
|
||||||
|
super().test_double_messages_conversation(model)
|
||||||
|
|
||||||
|
@pytest.mark.xfail(reason="Raises 400: Custom stop words not supported.")
|
||||||
|
def test_stop_sequence(self, model: BaseChatModel) -> None:
|
||||||
|
super().test_stop_sequence(model)
|
@ -0,0 +1,7 @@
|
|||||||
|
import pytest # type: ignore[import-not-found]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.compile
|
||||||
|
def test_placeholder() -> None:
|
||||||
|
"""Used for compiling integration tests without running any real tests."""
|
||||||
|
pass
|
3
libs/partners/perplexity/tests/unit_tests/__init__.py
Normal file
3
libs/partners/perplexity/tests/unit_tests/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
os.environ["PPLX_API_KEY"] = "test"
|
195
libs/partners/perplexity/tests/unit_tests/test_chat_models.py
Normal file
195
libs/partners/perplexity/tests/unit_tests/test_chat_models.py
Normal file
@ -0,0 +1,195 @@
|
|||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from langchain_core.messages import AIMessageChunk, BaseMessageChunk
|
||||||
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
|
from langchain_perplexity import ChatPerplexity
|
||||||
|
|
||||||
|
|
||||||
|
def test_perplexity_model_name_param() -> None:
|
||||||
|
llm = ChatPerplexity(model="foo")
|
||||||
|
assert llm.model == "foo"
|
||||||
|
|
||||||
|
|
||||||
|
def test_perplexity_model_kwargs() -> None:
|
||||||
|
llm = ChatPerplexity(model="test", model_kwargs={"foo": "bar"})
|
||||||
|
assert llm.model_kwargs == {"foo": "bar"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_perplexity_initialization() -> None:
|
||||||
|
"""Test perplexity initialization."""
|
||||||
|
# Verify that chat perplexity can be initialized using a secret key provided
|
||||||
|
# as a parameter rather than an environment variable.
|
||||||
|
for model in [
|
||||||
|
ChatPerplexity(
|
||||||
|
model="test", timeout=1, api_key="test", temperature=0.7, verbose=True
|
||||||
|
),
|
||||||
|
ChatPerplexity(
|
||||||
|
model="test",
|
||||||
|
request_timeout=1,
|
||||||
|
pplx_api_key="test",
|
||||||
|
temperature=0.7,
|
||||||
|
verbose=True,
|
||||||
|
),
|
||||||
|
]:
|
||||||
|
assert model.request_timeout == 1
|
||||||
|
assert (
|
||||||
|
model.pplx_api_key is not None
|
||||||
|
and model.pplx_api_key.get_secret_value() == "test"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_perplexity_stream_includes_citations(mocker: MockerFixture) -> None:
|
||||||
|
"""Test that the stream method includes citations in the additional_kwargs."""
|
||||||
|
llm = ChatPerplexity(model="test", timeout=30, verbose=True)
|
||||||
|
mock_chunk_0 = {
|
||||||
|
"choices": [{"delta": {"content": "Hello "}, "finish_reason": None}],
|
||||||
|
"citations": ["example.com", "example2.com"],
|
||||||
|
}
|
||||||
|
mock_chunk_1 = {
|
||||||
|
"choices": [{"delta": {"content": "Perplexity"}, "finish_reason": None}],
|
||||||
|
"citations": ["example.com", "example2.com"],
|
||||||
|
}
|
||||||
|
mock_chunks: List[Dict[str, Any]] = [mock_chunk_0, mock_chunk_1]
|
||||||
|
mock_stream = MagicMock()
|
||||||
|
mock_stream.__iter__.return_value = mock_chunks
|
||||||
|
patcher = mocker.patch.object(
|
||||||
|
llm.client.chat.completions, "create", return_value=mock_stream
|
||||||
|
)
|
||||||
|
stream = llm.stream("Hello langchain")
|
||||||
|
full: Optional[BaseMessageChunk] = None
|
||||||
|
for i, chunk in enumerate(stream):
|
||||||
|
full = chunk if full is None else full + chunk
|
||||||
|
assert chunk.content == mock_chunks[i]["choices"][0]["delta"]["content"]
|
||||||
|
if i == 0:
|
||||||
|
assert chunk.additional_kwargs["citations"] == [
|
||||||
|
"example.com",
|
||||||
|
"example2.com",
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
assert "citations" not in chunk.additional_kwargs
|
||||||
|
assert isinstance(full, AIMessageChunk)
|
||||||
|
assert full.content == "Hello Perplexity"
|
||||||
|
assert full.additional_kwargs == {"citations": ["example.com", "example2.com"]}
|
||||||
|
|
||||||
|
patcher.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_perplexity_stream_includes_citations_and_images(mocker: MockerFixture) -> None:
|
||||||
|
"""Test that the stream method includes citations in the additional_kwargs."""
|
||||||
|
llm = ChatPerplexity(model="test", timeout=30, verbose=True)
|
||||||
|
mock_chunk_0 = {
|
||||||
|
"choices": [{"delta": {"content": "Hello "}, "finish_reason": None}],
|
||||||
|
"citations": ["example.com", "example2.com"],
|
||||||
|
"images": [
|
||||||
|
{
|
||||||
|
"image_url": "mock_image_url",
|
||||||
|
"origin_url": "mock_origin_url",
|
||||||
|
"height": 100,
|
||||||
|
"width": 100,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
mock_chunk_1 = {
|
||||||
|
"choices": [{"delta": {"content": "Perplexity"}, "finish_reason": None}],
|
||||||
|
"citations": ["example.com", "example2.com"],
|
||||||
|
"images": [
|
||||||
|
{
|
||||||
|
"image_url": "mock_image_url",
|
||||||
|
"origin_url": "mock_origin_url",
|
||||||
|
"height": 100,
|
||||||
|
"width": 100,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
mock_chunks: List[Dict[str, Any]] = [mock_chunk_0, mock_chunk_1]
|
||||||
|
mock_stream = MagicMock()
|
||||||
|
mock_stream.__iter__.return_value = mock_chunks
|
||||||
|
patcher = mocker.patch.object(
|
||||||
|
llm.client.chat.completions, "create", return_value=mock_stream
|
||||||
|
)
|
||||||
|
stream = llm.stream("Hello langchain")
|
||||||
|
full: Optional[BaseMessageChunk] = None
|
||||||
|
for i, chunk in enumerate(stream):
|
||||||
|
full = chunk if full is None else full + chunk
|
||||||
|
assert chunk.content == mock_chunks[i]["choices"][0]["delta"]["content"]
|
||||||
|
if i == 0:
|
||||||
|
assert chunk.additional_kwargs["citations"] == [
|
||||||
|
"example.com",
|
||||||
|
"example2.com",
|
||||||
|
]
|
||||||
|
assert chunk.additional_kwargs["images"] == [
|
||||||
|
{
|
||||||
|
"image_url": "mock_image_url",
|
||||||
|
"origin_url": "mock_origin_url",
|
||||||
|
"height": 100,
|
||||||
|
"width": 100,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
assert "citations" not in chunk.additional_kwargs
|
||||||
|
assert "images" not in chunk.additional_kwargs
|
||||||
|
assert isinstance(full, AIMessageChunk)
|
||||||
|
assert full.content == "Hello Perplexity"
|
||||||
|
assert full.additional_kwargs == {
|
||||||
|
"citations": ["example.com", "example2.com"],
|
||||||
|
"images": [
|
||||||
|
{
|
||||||
|
"image_url": "mock_image_url",
|
||||||
|
"origin_url": "mock_origin_url",
|
||||||
|
"height": 100,
|
||||||
|
"width": 100,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
patcher.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_perplexity_stream_includes_citations_and_related_questions(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
) -> None:
|
||||||
|
"""Test that the stream method includes citations in the additional_kwargs."""
|
||||||
|
llm = ChatPerplexity(model="test", timeout=30, verbose=True)
|
||||||
|
mock_chunk_0 = {
|
||||||
|
"choices": [{"delta": {"content": "Hello "}, "finish_reason": None}],
|
||||||
|
"citations": ["example.com", "example2.com"],
|
||||||
|
"related_questions": ["example_question_1", "example_question_2"],
|
||||||
|
}
|
||||||
|
mock_chunk_1 = {
|
||||||
|
"choices": [{"delta": {"content": "Perplexity"}, "finish_reason": None}],
|
||||||
|
"citations": ["example.com", "example2.com"],
|
||||||
|
"related_questions": ["example_question_1", "example_question_2"],
|
||||||
|
}
|
||||||
|
mock_chunks: List[Dict[str, Any]] = [mock_chunk_0, mock_chunk_1]
|
||||||
|
mock_stream = MagicMock()
|
||||||
|
mock_stream.__iter__.return_value = mock_chunks
|
||||||
|
patcher = mocker.patch.object(
|
||||||
|
llm.client.chat.completions, "create", return_value=mock_stream
|
||||||
|
)
|
||||||
|
stream = llm.stream("Hello langchain")
|
||||||
|
full: Optional[BaseMessageChunk] = None
|
||||||
|
for i, chunk in enumerate(stream):
|
||||||
|
full = chunk if full is None else full + chunk
|
||||||
|
assert chunk.content == mock_chunks[i]["choices"][0]["delta"]["content"]
|
||||||
|
if i == 0:
|
||||||
|
assert chunk.additional_kwargs["citations"] == [
|
||||||
|
"example.com",
|
||||||
|
"example2.com",
|
||||||
|
]
|
||||||
|
assert chunk.additional_kwargs["related_questions"] == [
|
||||||
|
"example_question_1",
|
||||||
|
"example_question_2",
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
assert "citations" not in chunk.additional_kwargs
|
||||||
|
assert "related_questions" not in chunk.additional_kwargs
|
||||||
|
assert isinstance(full, AIMessageChunk)
|
||||||
|
assert full.content == "Hello Perplexity"
|
||||||
|
assert full.additional_kwargs == {
|
||||||
|
"citations": ["example.com", "example2.com"],
|
||||||
|
"related_questions": ["example_question_1", "example_question_2"],
|
||||||
|
}
|
||||||
|
|
||||||
|
patcher.assert_called_once()
|
@ -0,0 +1,18 @@
|
|||||||
|
"""Test Perplexity Chat API wrapper."""
|
||||||
|
|
||||||
|
from typing import Tuple, Type
|
||||||
|
|
||||||
|
from langchain_core.language_models import BaseChatModel
|
||||||
|
from langchain_tests.unit_tests import ChatModelUnitTests
|
||||||
|
|
||||||
|
from langchain_perplexity import ChatPerplexity
|
||||||
|
|
||||||
|
|
||||||
|
class TestPerplexityStandard(ChatModelUnitTests):
|
||||||
|
@property
|
||||||
|
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||||
|
return ChatPerplexity
|
||||||
|
|
||||||
|
@property
|
||||||
|
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
|
||||||
|
return ({"PPLX_API_KEY": "api_key"}, {}, {"pplx_api_key": "api_key"})
|
@ -0,0 +1,7 @@
|
|||||||
|
from langchain_perplexity import __all__
|
||||||
|
|
||||||
|
EXPECTED_ALL = ["ChatPerplexity"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_imports() -> None:
|
||||||
|
assert sorted(EXPECTED_ALL) == sorted(__all__)
|
@ -0,0 +1,8 @@
|
|||||||
|
from langchain_perplexity import ChatPerplexity
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_perplexity_secrets() -> None:
|
||||||
|
model = ChatPerplexity(
|
||||||
|
model="llama-3.1-sonar-small-128k-online", pplx_api_key="foo"
|
||||||
|
)
|
||||||
|
assert "foo" not in str(model)
|
1562
libs/partners/perplexity/uv.lock
Normal file
1562
libs/partners/perplexity/uv.lock
Normal file
File diff suppressed because it is too large
Load Diff
4
uv.lock
4
uv.lock
@ -2453,7 +2453,7 @@ typing = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "langchain-core"
|
name = "langchain-core"
|
||||||
version = "0.3.49"
|
version = "0.3.50"
|
||||||
source = { editable = "libs/core" }
|
source = { editable = "libs/core" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "jsonpatch" },
|
{ name = "jsonpatch" },
|
||||||
@ -2735,7 +2735,7 @@ typing = []
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "langchain-openai"
|
name = "langchain-openai"
|
||||||
version = "0.3.11"
|
version = "0.3.12"
|
||||||
source = { editable = "libs/partners/openai" }
|
source = { editable = "libs/partners/openai" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "langchain-core" },
|
{ name = "langchain-core" },
|
||||||
|
Loading…
Reference in New Issue
Block a user