multiple: support strict and method in with_structured_output (#30385)

This commit is contained in:
ccurme
2025-03-20 13:17:07 -04:00
committed by GitHub
parent 1103bdfaf1
commit b86cd8270c
11 changed files with 316 additions and 31 deletions

View File

@@ -4,16 +4,28 @@ from typing import (
Any,
Dict,
List,
Literal,
Optional,
Type,
TypeVar,
Union,
)
import openai
from langchain_core.language_models.chat_models import LangSmithParams
from langchain_core.language_models.chat_models import (
LangSmithParams,
LanguageModelInput,
)
from langchain_core.runnables import Runnable
from langchain_core.utils import secret_from_env
from langchain_openai.chat_models.base import BaseChatOpenAI
from pydantic import ConfigDict, Field, SecretStr, model_validator
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self
_BM = TypeVar("_BM", bound=BaseModel)
_DictOrPydanticClass = Union[Dict[str, Any], Type[_BM], Type]
_DictOrPydantic = Union[Dict, _BM]
class ChatXAI(BaseChatOpenAI): # type: ignore[override]
r"""ChatXAI chat model.
@@ -359,3 +371,83 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
**async_specific,
)
return self
def with_structured_output(
self,
schema: Optional[_DictOrPydanticClass] = None,
*,
method: Literal[
"function_calling", "json_mode", "json_schema"
] = "function_calling",
include_raw: bool = False,
strict: Optional[bool] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, _DictOrPydantic]:
"""Model wrapper that returns outputs formatted to match the given schema.
Args:
schema:
The output schema. Can be passed in as:
- an OpenAI function/tool schema,
- a JSON Schema,
- a TypedDict class (support added in 0.1.20),
- or a Pydantic class.
If ``schema`` is a Pydantic class then the model output will be a
Pydantic instance of that class, and the model-generated fields will be
validated by the Pydantic class. Otherwise the model output will be a
dict and will not be validated. See :meth:`langchain_core.utils.function_calling.convert_to_openai_tool`
for more on how to properly specify types and descriptions of
schema fields when specifying a Pydantic or TypedDict class.
method: The method for steering model generation, one of:
- "function_calling":
Uses xAI's `tool-calling features <https://docs.x.ai/docs/guides/function-calling>`_.
- "json_schema":
Uses xAI's `structured output feature <https://docs.x.ai/docs/guides/structured-outputs>`_.
- "json_mode":
Uses xAI's JSON mode feature.
include_raw:
If False then only the parsed structured output is returned. If
an error occurs during model output parsing it will be raised. If True
then both the raw model response (a BaseMessage) and the parsed model
response will be returned. If an error occurs during output parsing it
will be caught and returned as well. The final output is always a dict
with keys "raw", "parsed", and "parsing_error".
strict:
- True:
Model output is guaranteed to exactly match the schema.
The input schema will also be validated according to
https://platform.openai.com/docs/guides/structured-outputs/supported-schemas
- False:
Input schema will not be validated and model output will not be
validated.
- None:
``strict`` argument will not be passed to the model.
kwargs: Additional keyword args aren't supported.
Returns:
A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`.
| If ``include_raw`` is False and ``schema`` is a Pydantic class, Runnable outputs an instance of ``schema`` (i.e., a Pydantic object). Otherwise, if ``include_raw`` is False then Runnable outputs a dict.
| If ``include_raw`` is True, then Runnable outputs a dict with keys:
- "raw": BaseMessage
- "parsed": None if there was a parsing error, otherwise the type depends on the ``schema`` as described above.
- "parsing_error": Optional[BaseException]
""" # noqa: E501
# Some applications require that incompatible parameters (e.g., unsupported
# methods) be handled.
if method == "function_calling" and strict:
strict = None
return super().with_structured_output(
schema, method=method, include_raw=include_raw, strict=strict, **kwargs
)