multiple: support strict and method in with_structured_output (#30385)

This commit is contained in:
ccurme
2025-03-20 13:17:07 -04:00
committed by GitHub
parent 1103bdfaf1
commit b86cd8270c
11 changed files with 316 additions and 31 deletions

View File

@@ -1,21 +1,27 @@
"""DeepSeek chat models."""
from json import JSONDecodeError
from typing import Any, Dict, Iterator, List, Optional, Type, Union
from typing import Any, Dict, Iterator, List, Literal, Optional, Type, TypeVar, Union
import openai
from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.messages import AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGenerationChunk, ChatResult
from langchain_core.runnables import Runnable
from langchain_core.utils import from_env, secret_from_env
from langchain_openai.chat_models.base import BaseChatOpenAI
from pydantic import ConfigDict, Field, SecretStr, model_validator
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self
DEFAULT_API_BASE = "https://api.deepseek.com/v1"
_BM = TypeVar("_BM", bound=BaseModel)
_DictOrPydanticClass = Union[Dict[str, Any], Type[_BM], Type]
_DictOrPydantic = Union[Dict, _BM]
class ChatDeepSeek(BaseChatOpenAI):
"""DeepSeek chat model integration to access models hosted in DeepSeek's API.
@@ -197,14 +203,15 @@ class ChatDeepSeek(BaseChatOpenAI):
if not (self.client or None):
sync_specific: dict = {"http_client": self.http_client}
self.client = openai.OpenAI(
**client_params, **sync_specific
).chat.completions
self.root_client = openai.OpenAI(**client_params, **sync_specific)
self.client = self.root_client.chat.completions
if not (self.async_client or None):
async_specific: dict = {"http_client": self.http_async_client}
self.async_client = openai.AsyncOpenAI(
**client_params, **async_specific
).chat.completions
self.root_async_client = openai.AsyncOpenAI(
**client_params,
**async_specific,
)
self.async_client = self.root_async_client.chat.completions
return self
def _create_chat_result(
@@ -281,3 +288,73 @@ class ChatDeepSeek(BaseChatOpenAI):
e.doc,
e.pos,
) from e
def with_structured_output(
self,
schema: Optional[_DictOrPydanticClass] = None,
*,
method: Literal[
"function_calling", "json_mode", "json_schema"
] = "function_calling",
include_raw: bool = False,
strict: Optional[bool] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, _DictOrPydantic]:
"""Model wrapper that returns outputs formatted to match the given schema.
Args:
schema:
The output schema. Can be passed in as:
- an OpenAI function/tool schema,
- a JSON Schema,
- a TypedDict class (support added in 0.1.20),
- or a Pydantic class.
If ``schema`` is a Pydantic class then the model output will be a
Pydantic instance of that class, and the model-generated fields will be
validated by the Pydantic class. Otherwise the model output will be a
dict and will not be validated. See :meth:`langchain_core.utils.function_calling.convert_to_openai_tool`
for more on how to properly specify types and descriptions of
schema fields when specifying a Pydantic or TypedDict class.
method: The method for steering model generation, one of:
- "function_calling":
Uses DeekSeek's `tool-calling features <https://api-docs.deepseek.com/guides/function_calling>`_.
- "json_mode":
Uses DeepSeek's `JSON mode feature <https://api-docs.deepseek.com/guides/json_mode>`_.
.. versionchanged:: 0.1.3
Added support for ``"json_mode"``.
include_raw:
If False then only the parsed structured output is returned. If
an error occurs during model output parsing it will be raised. If True
then both the raw model response (a BaseMessage) and the parsed model
response will be returned. If an error occurs during output parsing it
will be caught and returned as well. The final output is always a dict
with keys "raw", "parsed", and "parsing_error".
kwargs: Additional keyword args aren't supported.
Returns:
A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`.
| If ``include_raw`` is False and ``schema`` is a Pydantic class, Runnable outputs an instance of ``schema`` (i.e., a Pydantic object). Otherwise, if ``include_raw`` is False then Runnable outputs a dict.
| If ``include_raw`` is True, then Runnable outputs a dict with keys:
- "raw": BaseMessage
- "parsed": None if there was a parsing error, otherwise the type depends on the ``schema`` as described above.
- "parsing_error": Optional[BaseException]
""" # noqa: E501
# Some applications require that incompatible parameters (e.g., unsupported
# methods) be handled.
if method == "json_schema":
method = "function_calling"
return super().with_structured_output(
schema, method=method, include_raw=include_raw, strict=strict, **kwargs
)