mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 05:43:55 +00:00
(Community): Adding Structured Support for ChatPerplexity (#29361)
- **Description:** Adding Structured Support for ChatPerplexity - **Issue:** #29357 - This is implemented as per the Perplexity official docs: https://docs.perplexity.ai/guides/structured-outputs --------- Co-authored-by: ccurme <chester.curme@gmail.com>
This commit is contained in:
parent
994c5465e0
commit
9f3bcee30a
@ -3,19 +3,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
generate_from_stream,
|
||||
@ -34,17 +38,27 @@ from langchain_core.messages import (
|
||||
SystemMessageChunk,
|
||||
ToolMessageChunk,
|
||||
)
|
||||
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.utils import (
|
||||
from_env,
|
||||
get_pydantic_field_names,
|
||||
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
||||
from langchain_core.utils import from_env, get_pydantic_field_names
|
||||
from langchain_core.utils.pydantic import (
|
||||
is_basemodel_subclass,
|
||||
)
|
||||
from pydantic import ConfigDict, Field, model_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
_BM = TypeVar("_BM", bound=BaseModel)
|
||||
_DictOrPydanticClass = Union[Dict[str, Any], Type[_BM], Type]
|
||||
_DictOrPydantic = Union[Dict, _BM]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _is_pydantic_class(obj: Any) -> bool:
|
||||
return isinstance(obj, type) and is_basemodel_subclass(obj)
|
||||
|
||||
|
||||
class ChatPerplexity(BaseChatModel):
|
||||
"""`Perplexity AI` Chat models API.
|
||||
|
||||
@ -282,3 +296,99 @@ class ChatPerplexity(BaseChatModel):
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
return "perplexitychat"
|
||||
|
||||
def with_structured_output(
|
||||
self,
|
||||
schema: Optional[_DictOrPydanticClass] = None,
|
||||
*,
|
||||
method: Literal["json_schema"] = "json_schema",
|
||||
include_raw: bool = False,
|
||||
strict: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, _DictOrPydantic]:
|
||||
"""Model wrapper that returns outputs formatted to match the given schema for Preplexity.
|
||||
Currently, Preplexity only supports "json_schema" method for structured output
|
||||
as per their official documentation: https://docs.perplexity.ai/guides/structured-outputs
|
||||
|
||||
Args:
|
||||
schema:
|
||||
The output schema. Can be passed in as:
|
||||
|
||||
- a JSON Schema,
|
||||
- a TypedDict class,
|
||||
- or a Pydantic class
|
||||
|
||||
method: The method for steering model generation, currently only support:
|
||||
|
||||
- "json_schema": Use the JSON Schema to parse the model output
|
||||
|
||||
|
||||
include_raw:
|
||||
If False then only the parsed structured output is returned. If
|
||||
an error occurs during model output parsing it will be raised. If True
|
||||
then both the raw model response (a BaseMessage) and the parsed model
|
||||
response will be returned. If an error occurs during output parsing it
|
||||
will be caught and returned as well. The final output is always a dict
|
||||
with keys "raw", "parsed", and "parsing_error".
|
||||
|
||||
kwargs: Additional keyword args aren't supported.
|
||||
|
||||
Returns:
|
||||
A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`.
|
||||
|
||||
| If ``include_raw`` is False and ``schema`` is a Pydantic class, Runnable outputs an instance of ``schema`` (i.e., a Pydantic object). Otherwise, if ``include_raw`` is False then Runnable outputs a dict.
|
||||
|
||||
| If ``include_raw`` is True, then Runnable outputs a dict with keys:
|
||||
|
||||
- "raw": BaseMessage
|
||||
- "parsed": None if there was a parsing error, otherwise the type depends on the ``schema`` as described above.
|
||||
- "parsing_error": Optional[BaseException]
|
||||
|
||||
""" # noqa: E501
|
||||
if method == "json_schema":
|
||||
if schema is None:
|
||||
raise ValueError(
|
||||
"schema must be specified when method is not 'json_schema'. "
|
||||
"Received None."
|
||||
)
|
||||
is_pydantic_schema = _is_pydantic_class(schema)
|
||||
if is_pydantic_schema and hasattr(
|
||||
schema, "model_json_schema"
|
||||
): # accounting for pydantic v1 and v2
|
||||
response_format = schema.model_json_schema() # type: ignore[union-attr]
|
||||
elif is_pydantic_schema:
|
||||
response_format = schema.schema() # type: ignore[union-attr]
|
||||
elif isinstance(schema, dict):
|
||||
response_format = schema
|
||||
elif type(schema).__name__ == "_TypedDictMeta":
|
||||
adapter = TypeAdapter(schema) # if use passes typeddict
|
||||
response_format = adapter.json_schema()
|
||||
|
||||
llm = self.bind(
|
||||
response_format={
|
||||
"type": "json_schema",
|
||||
"json_schema": {"schema": response_format},
|
||||
}
|
||||
)
|
||||
output_parser = (
|
||||
PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type]
|
||||
if is_pydantic_schema
|
||||
else JsonOutputParser()
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unrecognized method argument. Expected 'json_schema' Received:\
|
||||
'{method}'"
|
||||
)
|
||||
|
||||
if include_raw:
|
||||
parser_assign = RunnablePassthrough.assign(
|
||||
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
|
||||
)
|
||||
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
|
||||
parser_with_fallback = parser_assign.with_fallbacks(
|
||||
[parser_none], exception_key="parsing_error"
|
||||
)
|
||||
return RunnableMap(raw=llm) | parser_with_fallback
|
||||
else:
|
||||
return llm | output_parser
|
||||
|
Loading…
Reference in New Issue
Block a user