mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 22:03:52 +00:00
(Community): Adding Structured Support for ChatPerplexity (#29361)
- **Description:** Adding Structured Support for ChatPerplexity - **Issue:** #29357 - This is implemented as per the Perplexity official docs: https://docs.perplexity.ai/guides/structured-outputs --------- Co-authored-by: ccurme <chester.curme@gmail.com>
This commit is contained in:
parent
994c5465e0
commit
9f3bcee30a
@ -3,19 +3,23 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from operator import itemgetter
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Dict,
|
Dict,
|
||||||
Iterator,
|
Iterator,
|
||||||
List,
|
List,
|
||||||
|
Literal,
|
||||||
Mapping,
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
Tuple,
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||||
|
from langchain_core.language_models import LanguageModelInput
|
||||||
from langchain_core.language_models.chat_models import (
|
from langchain_core.language_models.chat_models import (
|
||||||
BaseChatModel,
|
BaseChatModel,
|
||||||
generate_from_stream,
|
generate_from_stream,
|
||||||
@ -34,17 +38,27 @@ from langchain_core.messages import (
|
|||||||
SystemMessageChunk,
|
SystemMessageChunk,
|
||||||
ToolMessageChunk,
|
ToolMessageChunk,
|
||||||
)
|
)
|
||||||
|
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
|
||||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
from langchain_core.utils import (
|
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
||||||
from_env,
|
from langchain_core.utils import from_env, get_pydantic_field_names
|
||||||
get_pydantic_field_names,
|
from langchain_core.utils.pydantic import (
|
||||||
|
is_basemodel_subclass,
|
||||||
)
|
)
|
||||||
from pydantic import ConfigDict, Field, model_validator
|
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, model_validator
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
_BM = TypeVar("_BM", bound=BaseModel)
|
||||||
|
_DictOrPydanticClass = Union[Dict[str, Any], Type[_BM], Type]
|
||||||
|
_DictOrPydantic = Union[Dict, _BM]
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_pydantic_class(obj: Any) -> bool:
|
||||||
|
return isinstance(obj, type) and is_basemodel_subclass(obj)
|
||||||
|
|
||||||
|
|
||||||
class ChatPerplexity(BaseChatModel):
|
class ChatPerplexity(BaseChatModel):
|
||||||
"""`Perplexity AI` Chat models API.
|
"""`Perplexity AI` Chat models API.
|
||||||
|
|
||||||
@ -282,3 +296,99 @@ class ChatPerplexity(BaseChatModel):
|
|||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
"""Return type of chat model."""
|
"""Return type of chat model."""
|
||||||
return "perplexitychat"
|
return "perplexitychat"
|
||||||
|
|
||||||
|
def with_structured_output(
|
||||||
|
self,
|
||||||
|
schema: Optional[_DictOrPydanticClass] = None,
|
||||||
|
*,
|
||||||
|
method: Literal["json_schema"] = "json_schema",
|
||||||
|
include_raw: bool = False,
|
||||||
|
strict: Optional[bool] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Runnable[LanguageModelInput, _DictOrPydantic]:
|
||||||
|
"""Model wrapper that returns outputs formatted to match the given schema for Preplexity.
|
||||||
|
Currently, Preplexity only supports "json_schema" method for structured output
|
||||||
|
as per their official documentation: https://docs.perplexity.ai/guides/structured-outputs
|
||||||
|
|
||||||
|
Args:
|
||||||
|
schema:
|
||||||
|
The output schema. Can be passed in as:
|
||||||
|
|
||||||
|
- a JSON Schema,
|
||||||
|
- a TypedDict class,
|
||||||
|
- or a Pydantic class
|
||||||
|
|
||||||
|
method: The method for steering model generation, currently only support:
|
||||||
|
|
||||||
|
- "json_schema": Use the JSON Schema to parse the model output
|
||||||
|
|
||||||
|
|
||||||
|
include_raw:
|
||||||
|
If False then only the parsed structured output is returned. If
|
||||||
|
an error occurs during model output parsing it will be raised. If True
|
||||||
|
then both the raw model response (a BaseMessage) and the parsed model
|
||||||
|
response will be returned. If an error occurs during output parsing it
|
||||||
|
will be caught and returned as well. The final output is always a dict
|
||||||
|
with keys "raw", "parsed", and "parsing_error".
|
||||||
|
|
||||||
|
kwargs: Additional keyword args aren't supported.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`.
|
||||||
|
|
||||||
|
| If ``include_raw`` is False and ``schema`` is a Pydantic class, Runnable outputs an instance of ``schema`` (i.e., a Pydantic object). Otherwise, if ``include_raw`` is False then Runnable outputs a dict.
|
||||||
|
|
||||||
|
| If ``include_raw`` is True, then Runnable outputs a dict with keys:
|
||||||
|
|
||||||
|
- "raw": BaseMessage
|
||||||
|
- "parsed": None if there was a parsing error, otherwise the type depends on the ``schema`` as described above.
|
||||||
|
- "parsing_error": Optional[BaseException]
|
||||||
|
|
||||||
|
""" # noqa: E501
|
||||||
|
if method == "json_schema":
|
||||||
|
if schema is None:
|
||||||
|
raise ValueError(
|
||||||
|
"schema must be specified when method is not 'json_schema'. "
|
||||||
|
"Received None."
|
||||||
|
)
|
||||||
|
is_pydantic_schema = _is_pydantic_class(schema)
|
||||||
|
if is_pydantic_schema and hasattr(
|
||||||
|
schema, "model_json_schema"
|
||||||
|
): # accounting for pydantic v1 and v2
|
||||||
|
response_format = schema.model_json_schema() # type: ignore[union-attr]
|
||||||
|
elif is_pydantic_schema:
|
||||||
|
response_format = schema.schema() # type: ignore[union-attr]
|
||||||
|
elif isinstance(schema, dict):
|
||||||
|
response_format = schema
|
||||||
|
elif type(schema).__name__ == "_TypedDictMeta":
|
||||||
|
adapter = TypeAdapter(schema) # if use passes typeddict
|
||||||
|
response_format = adapter.json_schema()
|
||||||
|
|
||||||
|
llm = self.bind(
|
||||||
|
response_format={
|
||||||
|
"type": "json_schema",
|
||||||
|
"json_schema": {"schema": response_format},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
output_parser = (
|
||||||
|
PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type]
|
||||||
|
if is_pydantic_schema
|
||||||
|
else JsonOutputParser()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unrecognized method argument. Expected 'json_schema' Received:\
|
||||||
|
'{method}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
if include_raw:
|
||||||
|
parser_assign = RunnablePassthrough.assign(
|
||||||
|
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
|
||||||
|
)
|
||||||
|
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
|
||||||
|
parser_with_fallback = parser_assign.with_fallbacks(
|
||||||
|
[parser_none], exception_key="parsing_error"
|
||||||
|
)
|
||||||
|
return RunnableMap(raw=llm) | parser_with_fallback
|
||||||
|
else:
|
||||||
|
return llm | output_parser
|
||||||
|
Loading…
Reference in New Issue
Block a user