"""Wrapper around Perplexity APIs.""" from __future__ import annotations import logging from collections.abc import Iterator, Mapping from operator import itemgetter from typing import Any, Literal, TypeAlias import openai from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models import LanguageModelInput from langchain_core.language_models.chat_models import ( BaseChatModel, generate_from_stream, ) from langchain_core.messages import ( AIMessage, AIMessageChunk, BaseMessage, BaseMessageChunk, ChatMessage, ChatMessageChunk, FunctionMessageChunk, HumanMessage, HumanMessageChunk, SystemMessage, SystemMessageChunk, ToolMessageChunk, ) from langchain_core.messages.ai import UsageMetadata, subtract_usage from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough from langchain_core.utils import get_pydantic_field_names, secret_from_env from langchain_core.utils.function_calling import convert_to_json_schema from langchain_core.utils.pydantic import is_basemodel_subclass from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator from typing_extensions import Self _DictOrPydanticClass: TypeAlias = dict[str, Any] | type[BaseModel] _DictOrPydantic: TypeAlias = dict | BaseModel logger = logging.getLogger(__name__) def _is_pydantic_class(obj: Any) -> bool: return isinstance(obj, type) and is_basemodel_subclass(obj) def _create_usage_metadata(token_usage: dict) -> UsageMetadata: input_tokens = token_usage.get("prompt_tokens", 0) output_tokens = token_usage.get("completion_tokens", 0) total_tokens = token_usage.get("total_tokens", input_tokens + output_tokens) return UsageMetadata( input_tokens=input_tokens, output_tokens=output_tokens, total_tokens=total_tokens, ) class ChatPerplexity(BaseChatModel): """`Perplexity AI` Chat models API. Setup: To use, you should have the environment variable `PPLX_API_KEY` set to your API key. Any parameters that are valid to be passed to the openai.create call can be passed in, even if not explicitly saved on this class. ```bash export PPLX_API_KEY=your_api_key ``` Key init args - completion params: model: Name of the model to use. e.g. "sonar" temperature: Sampling temperature to use. max_tokens: Maximum number of tokens to generate. streaming: Whether to stream the results or not. Key init args - client params: pplx_api_key: API key for PerplexityChat API. request_timeout: Timeout for requests to PerplexityChat completion API. max_retries: Maximum number of retries to make when generating. See full list of supported init args and their descriptions in the params section. Instantiate: ```python from langchain_perplexity import ChatPerplexity model = ChatPerplexity(model="sonar", temperature=0.7) ``` Invoke: ```python messages = [("system", "You are a chatbot."), ("user", "Hello!")] model.invoke(messages) ``` Invoke with structured output: ```python from pydantic import BaseModel class StructuredOutput(BaseModel): role: str content: str model.with_structured_output(StructuredOutput) model.invoke(messages) ``` Invoke with perplexity-specific params: ```python model.invoke(messages, extra_body={"search_recency_filter": "week"}) ``` Stream: ```python for chunk in model.stream(messages): print(chunk.content) ``` Token usage: ```python response = model.invoke(messages) response.usage_metadata ``` Response metadata: ```python response = model.invoke(messages) response.response_metadata ``` """ # noqa: E501 client: Any = None #: :meta private: model: str = "sonar" """Model name.""" temperature: float = 0.7 """What sampling temperature to use.""" model_kwargs: dict[str, Any] = Field(default_factory=dict) """Holds any model parameters valid for `create` call not explicitly specified.""" pplx_api_key: SecretStr | None = Field( default_factory=secret_from_env("PPLX_API_KEY", default=None), alias="api_key" ) """Base URL path for API requests, leave blank if not using a proxy or service emulator.""" request_timeout: float | tuple[float, float] | None = Field(None, alias="timeout") """Timeout for requests to PerplexityChat completion API.""" max_retries: int = 6 """Maximum number of retries to make when generating.""" streaming: bool = False """Whether to stream the results or not.""" max_tokens: int | None = None """Maximum number of tokens to generate.""" model_config = ConfigDict(populate_by_name=True) @property def lc_secrets(self) -> dict[str, str]: return {"pplx_api_key": "PPLX_API_KEY"} @model_validator(mode="before") @classmethod def build_extra(cls, values: dict[str, Any]) -> Any: """Build extra kwargs from additional params that were passed in.""" all_required_field_names = get_pydantic_field_names(cls) extra = values.get("model_kwargs", {}) for field_name in list(values): if field_name in extra: raise ValueError(f"Found {field_name} supplied twice.") if field_name not in all_required_field_names: logger.warning( f"""WARNING! {field_name} is not a default parameter. {field_name} was transferred to model_kwargs. Please confirm that {field_name} is what you intended.""" ) extra[field_name] = values.pop(field_name) invalid_model_kwargs = all_required_field_names.intersection(extra.keys()) if invalid_model_kwargs: raise ValueError( f"Parameters {invalid_model_kwargs} should be specified explicitly. " f"Instead they were passed in as part of `model_kwargs` parameter." ) values["model_kwargs"] = extra return values @model_validator(mode="after") def validate_environment(self) -> Self: """Validate that api key and python package exists in environment.""" try: self.client = openai.OpenAI( api_key=self.pplx_api_key.get_secret_value() if self.pplx_api_key else None, base_url="https://api.perplexity.ai", ) except AttributeError: raise ValueError( "`openai` has no `ChatCompletion` attribute, this is likely " "due to an old version of the openai package. Try upgrading it " "with `pip install --upgrade openai`." ) return self @property def _default_params(self) -> dict[str, Any]: """Get the default parameters for calling PerplexityChat API.""" return { "max_tokens": self.max_tokens, "stream": self.streaming, "temperature": self.temperature, **self.model_kwargs, } def _convert_message_to_dict(self, message: BaseMessage) -> dict[str, Any]: if isinstance(message, ChatMessage): message_dict = {"role": message.role, "content": message.content} elif isinstance(message, SystemMessage): message_dict = {"role": "system", "content": message.content} elif isinstance(message, HumanMessage): message_dict = {"role": "user", "content": message.content} elif isinstance(message, AIMessage): message_dict = {"role": "assistant", "content": message.content} else: raise TypeError(f"Got unknown type {message}") return message_dict def _create_message_dicts( self, messages: list[BaseMessage], stop: list[str] | None ) -> tuple[list[dict[str, Any]], dict[str, Any]]: params = dict(self._invocation_params) if stop is not None: if "stop" in params: raise ValueError("`stop` found in both the input and default params.") params["stop"] = stop message_dicts = [self._convert_message_to_dict(m) for m in messages] return message_dicts, params def _convert_delta_to_message_chunk( self, _dict: Mapping[str, Any], default_class: type[BaseMessageChunk] ) -> BaseMessageChunk: role = _dict.get("role") content = _dict.get("content") or "" additional_kwargs: dict = {} if _dict.get("function_call"): function_call = dict(_dict["function_call"]) if "name" in function_call and function_call["name"] is None: function_call["name"] = "" additional_kwargs["function_call"] = function_call if _dict.get("tool_calls"): additional_kwargs["tool_calls"] = _dict["tool_calls"] if role == "user" or default_class == HumanMessageChunk: return HumanMessageChunk(content=content) elif role == "assistant" or default_class == AIMessageChunk: return AIMessageChunk(content=content, additional_kwargs=additional_kwargs) elif role == "system" or default_class == SystemMessageChunk: return SystemMessageChunk(content=content) elif role == "function" or default_class == FunctionMessageChunk: return FunctionMessageChunk(content=content, name=_dict["name"]) elif role == "tool" or default_class == ToolMessageChunk: return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"]) elif role or default_class == ChatMessageChunk: return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type] else: return default_class(content=content) # type: ignore[call-arg] def _stream( self, messages: list[BaseMessage], stop: list[str] | None = None, run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: message_dicts, params = self._create_message_dicts(messages, stop) params = {**params, **kwargs} default_chunk_class = AIMessageChunk params.pop("stream", None) if stop: params["stop_sequences"] = stop stream_resp = self.client.chat.completions.create( messages=message_dicts, stream=True, **params ) first_chunk = True prev_total_usage: UsageMetadata | None = None added_model_name: bool = False for chunk in stream_resp: if not isinstance(chunk, dict): chunk = chunk.model_dump() # Collect standard usage metadata (transform from aggregate to delta) if total_usage := chunk.get("usage"): lc_total_usage = _create_usage_metadata(total_usage) if prev_total_usage: usage_metadata: UsageMetadata | None = subtract_usage( lc_total_usage, prev_total_usage ) else: usage_metadata = lc_total_usage prev_total_usage = lc_total_usage else: usage_metadata = None if len(chunk["choices"]) == 0: continue choice = chunk["choices"][0] additional_kwargs = {} if first_chunk: additional_kwargs["citations"] = chunk.get("citations", []) for attr in ["images", "related_questions", "search_results"]: if attr in chunk: additional_kwargs[attr] = chunk[attr] generation_info = {} if (model_name := chunk.get("model")) and not added_model_name: generation_info["model_name"] = model_name added_model_name = True chunk = self._convert_delta_to_message_chunk( choice["delta"], default_chunk_class ) if isinstance(chunk, AIMessageChunk) and usage_metadata: chunk.usage_metadata = usage_metadata if first_chunk: chunk.additional_kwargs |= additional_kwargs first_chunk = False if finish_reason := choice.get("finish_reason"): generation_info["finish_reason"] = finish_reason default_chunk_class = chunk.__class__ chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info) if run_manager: run_manager.on_llm_new_token(chunk.text, chunk=chunk) yield chunk def _generate( self, messages: list[BaseMessage], stop: list[str] | None = None, run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> ChatResult: if self.streaming: stream_iter = self._stream( messages, stop=stop, run_manager=run_manager, **kwargs ) if stream_iter: return generate_from_stream(stream_iter) message_dicts, params = self._create_message_dicts(messages, stop) params = {**params, **kwargs} response = self.client.chat.completions.create(messages=message_dicts, **params) if usage := getattr(response, "usage", None): usage_metadata = _create_usage_metadata(usage.model_dump()) else: usage_metadata = None additional_kwargs = {} for attr in ["citations", "images", "related_questions", "search_results"]: if hasattr(response, attr): additional_kwargs[attr] = getattr(response, attr) message = AIMessage( content=response.choices[0].message.content, additional_kwargs=additional_kwargs, usage_metadata=usage_metadata, response_metadata={"model_name": getattr(response, "model", self.model)}, ) return ChatResult(generations=[ChatGeneration(message=message)]) @property def _invocation_params(self) -> Mapping[str, Any]: """Get the parameters used to invoke the model.""" pplx_creds: dict[str, Any] = {"model": self.model} return {**pplx_creds, **self._default_params} @property def _llm_type(self) -> str: """Return type of chat model.""" return "perplexitychat" def with_structured_output( self, schema: _DictOrPydanticClass | None = None, *, method: Literal["json_schema"] = "json_schema", include_raw: bool = False, strict: bool | None = None, **kwargs: Any, ) -> Runnable[LanguageModelInput, _DictOrPydantic]: """Model wrapper that returns outputs formatted to match the given schema for Preplexity. Currently, Perplexity only supports "json_schema" method for structured output as per their [official documentation](https://docs.perplexity.ai/guides/structured-outputs). Args: schema: The output schema. Can be passed in as: - a JSON Schema, - a `TypedDict` class, - or a Pydantic class method: The method for steering model generation, currently only support: - `'json_schema'`: Use the JSON Schema to parse the model output include_raw: If `False` then only the parsed structured output is returned. If an error occurs during model output parsing it will be raised. If `True` then both the raw model response (a `BaseMessage`) and the parsed model response will be returned. If an error occurs during output parsing it will be caught and returned as well. The final output is always a `dict` with keys `'raw'`, `'parsed'`, and `'parsing_error'`. strict: Unsupported: whether to enable strict schema adherence when generating the output. This parameter is included for compatibility with other chat models, but is currently ignored. kwargs: Additional keyword args aren't supported. Returns: A `Runnable` that takes same inputs as a `langchain_core.language_models.chat.BaseChatModel`. If `include_raw` is `False` and `schema` is a Pydantic class, `Runnable` outputs an instance of `schema` (i.e., a Pydantic object). Otherwise, if `include_raw` is `False` then `Runnable` outputs a `dict`. If `include_raw` is `True`, then `Runnable` outputs a `dict` with keys: - `'raw'`: `BaseMessage` - `'parsed'`: `None` if there was a parsing error, otherwise the type depends on the `schema` as described above. - `'parsing_error'`: `BaseException | None` """ # noqa: E501 if method in ("function_calling", "json_mode"): method = "json_schema" if method == "json_schema": if schema is None: raise ValueError( "schema must be specified when method is not 'json_schema'. " "Received None." ) is_pydantic_schema = _is_pydantic_class(schema) response_format = convert_to_json_schema(schema) llm = self.bind( response_format={ "type": "json_schema", "json_schema": {"schema": response_format}, }, ls_structured_output_format={ "kwargs": {"method": method}, "schema": response_format, }, ) output_parser = ( PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type] if is_pydantic_schema else JsonOutputParser() ) else: raise ValueError( f"Unrecognized method argument. Expected 'json_schema' Received:\ '{method}'" ) if include_raw: parser_assign = RunnablePassthrough.assign( parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None ) parser_none = RunnablePassthrough.assign(parsed=lambda _: None) parser_with_fallback = parser_assign.with_fallbacks( [parser_none], exception_key="parsing_error" ) return RunnableMap(raw=llm) | parser_with_fallback else: return llm | output_parser