multiple: support strict and method in with_structured_output (#30385)

This commit is contained in:
ccurme 2025-03-20 13:17:07 -04:00 committed by GitHub
parent 1103bdfaf1
commit b86cd8270c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 316 additions and 31 deletions

View File

@ -1,21 +1,27 @@
"""DeepSeek chat models."""
from json import JSONDecodeError
from typing import Any, Dict, Iterator, List, Optional, Type, Union
from typing import Any, Dict, Iterator, List, Literal, Optional, Type, TypeVar, Union
import openai
from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.messages import AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGenerationChunk, ChatResult
from langchain_core.runnables import Runnable
from langchain_core.utils import from_env, secret_from_env
from langchain_openai.chat_models.base import BaseChatOpenAI
from pydantic import ConfigDict, Field, SecretStr, model_validator
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self
DEFAULT_API_BASE = "https://api.deepseek.com/v1"
_BM = TypeVar("_BM", bound=BaseModel)
_DictOrPydanticClass = Union[Dict[str, Any], Type[_BM], Type]
_DictOrPydantic = Union[Dict, _BM]
class ChatDeepSeek(BaseChatOpenAI):
"""DeepSeek chat model integration to access models hosted in DeepSeek's API.
@ -197,14 +203,15 @@ class ChatDeepSeek(BaseChatOpenAI):
if not (self.client or None):
sync_specific: dict = {"http_client": self.http_client}
self.client = openai.OpenAI(
**client_params, **sync_specific
).chat.completions
self.root_client = openai.OpenAI(**client_params, **sync_specific)
self.client = self.root_client.chat.completions
if not (self.async_client or None):
async_specific: dict = {"http_client": self.http_async_client}
self.async_client = openai.AsyncOpenAI(
**client_params, **async_specific
).chat.completions
self.root_async_client = openai.AsyncOpenAI(
**client_params,
**async_specific,
)
self.async_client = self.root_async_client.chat.completions
return self
def _create_chat_result(
@ -281,3 +288,73 @@ class ChatDeepSeek(BaseChatOpenAI):
e.doc,
e.pos,
) from e
def with_structured_output(
self,
schema: Optional[_DictOrPydanticClass] = None,
*,
method: Literal[
"function_calling", "json_mode", "json_schema"
] = "function_calling",
include_raw: bool = False,
strict: Optional[bool] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, _DictOrPydantic]:
"""Model wrapper that returns outputs formatted to match the given schema.
Args:
schema:
The output schema. Can be passed in as:
- an OpenAI function/tool schema,
- a JSON Schema,
- a TypedDict class (support added in 0.1.20),
- or a Pydantic class.
If ``schema`` is a Pydantic class then the model output will be a
Pydantic instance of that class, and the model-generated fields will be
validated by the Pydantic class. Otherwise the model output will be a
dict and will not be validated. See :meth:`langchain_core.utils.function_calling.convert_to_openai_tool`
for more on how to properly specify types and descriptions of
schema fields when specifying a Pydantic or TypedDict class.
method: The method for steering model generation, one of:
- "function_calling":
Uses DeekSeek's `tool-calling features <https://api-docs.deepseek.com/guides/function_calling>`_.
- "json_mode":
Uses DeepSeek's `JSON mode feature <https://api-docs.deepseek.com/guides/json_mode>`_.
.. versionchanged:: 0.1.3
Added support for ``"json_mode"``.
include_raw:
If False then only the parsed structured output is returned. If
an error occurs during model output parsing it will be raised. If True
then both the raw model response (a BaseMessage) and the parsed model
response will be returned. If an error occurs during output parsing it
will be caught and returned as well. The final output is always a dict
with keys "raw", "parsed", and "parsing_error".
kwargs: Additional keyword args aren't supported.
Returns:
A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`.
| If ``include_raw`` is False and ``schema`` is a Pydantic class, Runnable outputs an instance of ``schema`` (i.e., a Pydantic object). Otherwise, if ``include_raw`` is False then Runnable outputs a dict.
| If ``include_raw`` is True, then Runnable outputs a dict with keys:
- "raw": BaseMessage
- "parsed": None if there was a parsing error, otherwise the type depends on the ``schema`` as described above.
- "parsing_error": Optional[BaseException]
""" # noqa: E501
# Some applications require that incompatible parameters (e.g., unsupported
# methods) be handled.
if method == "json_schema":
method = "function_calling"
return super().with_structured_output(
schema, method=method, include_raw=include_raw, strict=strict, **kwargs
)

View File

@ -24,6 +24,11 @@ class TestChatDeepSeek(ChatModelIntegrationTests):
"temperature": 0,
}
@property
def supports_json_mode(self) -> bool:
"""(bool) whether the chat model supports JSON mode."""
return True
@pytest.mark.xfail(reason="Not yet supported.")
def test_tool_message_histories_list_content(
self, model: BaseChatModel, my_adder_tool: BaseTool

View File

@ -75,6 +75,7 @@ from langchain_core.utils import (
get_pydantic_field_names,
)
from langchain_core.utils.function_calling import (
convert_to_json_schema,
convert_to_openai_function,
convert_to_openai_tool,
)
@ -737,7 +738,9 @@ class ChatFireworks(BaseChatModel):
self,
schema: Optional[Union[Dict, Type[BaseModel]]] = None,
*,
method: Literal["function_calling", "json_mode"] = "function_calling",
method: Literal[
"function_calling", "json_mode", "json_schema"
] = "function_calling",
include_raw: bool = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
@ -761,13 +764,19 @@ class ChatFireworks(BaseChatModel):
Added support for TypedDict class.
method:
The method for steering model generation, either "function_calling"
or "json_mode". If "function_calling" then the schema will be converted
to an OpenAI function and the returned model will make use of the
function-calling API. If "json_mode" then OpenAI's JSON mode will be
used. Note that if using "json_mode" then you must include instructions
for formatting the output into the desired schema into the model call.
method: The method for steering model generation, one of:
- "function_calling":
Uses Fireworks's `tool-calling features <https://docs.fireworks.ai/guides/function-calling>`_.
- "json_schema":
Uses Fireworks's `structured output feature <https://docs.fireworks.ai/structured-responses/structured-response-formatting>`_.
- "json_mode":
Uses Fireworks's `JSON mode feature <https://docs.fireworks.ai/structured-responses/structured-response-formatting>`_.
.. versionchanged:: 0.2.8
Added support for ``"json_schema"``.
include_raw:
If False then only the parsed structured output is returned. If
an error occurs during model output parsing it will be raised. If True
@ -928,11 +937,11 @@ class ChatFireworks(BaseChatModel):
structured_llm.invoke(
"Answer the following question. "
"Make sure to return a JSON blob with keys 'answer' and 'justification'.\n\n"
"Make sure to return a JSON blob with keys 'answer' and 'justification'. "
"What's heavier a pound of bricks or a pound of feathers?"
)
# -> {
# 'raw': AIMessage(content='{\n "answer": "They are both the same weight.",\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \n}'),
# 'raw': AIMessage(content='{"answer": "They are both the same weight.", "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight."}'),
# 'parsed': AnswerWithJustification(answer='They are both the same weight.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight.'),
# 'parsing_error': None
# }
@ -944,11 +953,11 @@ class ChatFireworks(BaseChatModel):
structured_llm.invoke(
"Answer the following question. "
"Make sure to return a JSON blob with keys 'answer' and 'justification'.\n\n"
"Make sure to return a JSON blob with keys 'answer' and 'justification'. "
"What's heavier a pound of bricks or a pound of feathers?"
)
# -> {
# 'raw': AIMessage(content='{\n "answer": "They are both the same weight.",\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \n}'),
# 'raw': AIMessage(content='{"answer": "They are both the same weight.", "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight."}'),
# 'parsed': {
# 'answer': 'They are both the same weight.',
# 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight.'
@ -956,6 +965,7 @@ class ChatFireworks(BaseChatModel):
# 'parsing_error': None
# }
""" # noqa: E501
_ = kwargs.pop("strict", None)
if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}")
is_pydantic_schema = _is_pydantic_class(schema)
@ -984,6 +994,25 @@ class ChatFireworks(BaseChatModel):
output_parser = JsonOutputKeyToolsParser(
key_name=tool_name, first_tool_only=True
)
elif method == "json_schema":
if schema is None:
raise ValueError(
"schema must be specified when method is 'json_schema'. "
"Received None."
)
formatted_schema = convert_to_json_schema(schema)
llm = self.bind(
response_format={"type": "json_object", "schema": formatted_schema},
ls_structured_output_format={
"kwargs": {"method": "json_schema"},
"schema": schema,
},
)
output_parser = (
PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type]
if is_pydantic_schema
else JsonOutputParser()
)
elif method == "json_mode":
llm = self.bind(
response_format={"type": "json_object"},

View File

@ -7,7 +7,7 @@ authors = []
license = { text = "MIT" }
requires-python = "<4.0,>=3.9"
dependencies = [
"langchain-core<1.0.0,>=0.3.33",
"langchain-core<1.0.0,>=0.3.46",
"fireworks-ai>=0.13.0",
"openai<2.0.0,>=1.10.0",
"requests<3,>=2",

View File

@ -4,10 +4,12 @@ You will need FIREWORKS_API_KEY set in your environment to run these tests.
"""
import json
from typing import Optional
from typing import Any, Literal, Optional
import pytest
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessageChunk
from pydantic import BaseModel
from pydantic import BaseModel, Field
from typing_extensions import Annotated, TypedDict
from langchain_fireworks import ChatFireworks
@ -161,3 +163,54 @@ def test_invoke() -> None:
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
assert isinstance(result.content, str)
def _get_joke_class(
schema_type: Literal["pydantic", "typeddict", "json_schema"],
) -> Any:
class Joke(BaseModel):
"""Joke to tell user."""
setup: str = Field(description="question to set up a joke")
punchline: str = Field(description="answer to resolve the joke")
def validate_joke(result: Any) -> bool:
return isinstance(result, Joke)
class JokeDict(TypedDict):
"""Joke to tell user."""
setup: Annotated[str, ..., "question to set up a joke"]
punchline: Annotated[str, ..., "answer to resolve the joke"]
def validate_joke_dict(result: Any) -> bool:
return all(key in ["setup", "punchline"] for key in result.keys())
if schema_type == "pydantic":
return Joke, validate_joke
elif schema_type == "typeddict":
return JokeDict, validate_joke_dict
elif schema_type == "json_schema":
return Joke.model_json_schema(), validate_joke_dict
else:
raise ValueError("Invalid schema type")
@pytest.mark.parametrize("schema_type", ["pydantic", "typeddict", "json_schema"])
def test_structured_output_json_schema(schema_type: str) -> None:
llm = ChatFireworks(model="accounts/fireworks/models/llama-v3p1-70b-instruct")
schema, validation_function = _get_joke_class(schema_type) # type: ignore[arg-type]
chat = llm.with_structured_output(schema, method="json_schema")
# Test invoke
result = chat.invoke("Tell me a joke about cats.")
validation_function(result)
# Test stream
chunks = []
for chunk in chat.stream("Tell me a joke about cats."):
validation_function(chunk)
chunks.append(chunk)
assert chunk

View File

@ -635,7 +635,7 @@ wheels = [
[[package]]
name = "langchain-core"
version = "0.3.35"
version = "0.3.46"
source = { editable = "../../core" }
dependencies = [
{ name = "jsonpatch" },
@ -667,7 +667,7 @@ dev = [
]
lint = [{ name = "ruff", specifier = ">=0.9.2,<1.0.0" }]
test = [
{ name = "blockbuster", specifier = "~=1.5.11" },
{ name = "blockbuster", specifier = "~=1.5.18" },
{ name = "freezegun", specifier = ">=1.2.2,<2.0.0" },
{ name = "grandalf", specifier = ">=0.8,<1.0" },
{ name = "langchain-tests", directory = "../../standard-tests" },
@ -763,7 +763,7 @@ typing = [
[[package]]
name = "langchain-tests"
version = "0.3.11"
version = "0.3.14"
source = { editable = "../../standard-tests" }
dependencies = [
{ name = "httpx" },
@ -780,8 +780,7 @@ dependencies = [
requires-dist = [
{ name = "httpx", specifier = ">=0.25.0,<1" },
{ name = "langchain-core", editable = "../../core" },
{ name = "numpy", marker = "python_full_version < '3.12'", specifier = ">=1.24.0,<2.0.0" },
{ name = "numpy", marker = "python_full_version >= '3.12'", specifier = ">=1.26.2,<3" },
{ name = "numpy", specifier = ">=1.26.2,<3" },
{ name = "pytest", specifier = ">=7,<9" },
{ name = "pytest-asyncio", specifier = ">=0.20,<1" },
{ name = "pytest-socket", specifier = ">=0.6.0,<1" },

View File

@ -460,6 +460,24 @@ class ChatGroq(BaseChatModel):
ls_params["ls_stop"] = ls_stop if isinstance(ls_stop, list) else [ls_stop]
return ls_params
def _should_stream(
self,
*,
async_api: bool,
run_manager: Optional[
Union[CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun]
] = None,
**kwargs: Any,
) -> bool:
"""Determine if a given model call should hit the streaming API."""
base_should_stream = super()._should_stream(
async_api=async_api, run_manager=run_manager, **kwargs
)
if base_should_stream and ("response_format" in kwargs):
# Streaming not supported in JSON mode.
return kwargs["response_format"] != {"type": "json_object"}
return base_should_stream
def _generate(
self,
messages: List[BaseMessage],
@ -987,9 +1005,14 @@ class ChatGroq(BaseChatModel):
# 'parsing_error': None
# }
""" # noqa: E501
_ = kwargs.pop("strict", None)
if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}")
is_pydantic_schema = _is_pydantic_class(schema)
if method == "json_schema":
# Some applications require that incompatible parameters (e.g., unsupported
# methods) be handled.
method = "function_calling"
if method == "function_calling":
if schema is None:
raise ValueError(

View File

@ -47,4 +47,4 @@ class TestGroqLlama(BaseTestGroq):
@property
def supports_json_mode(self) -> bool:
return False # Not supported in streaming mode
return True

View File

@ -945,6 +945,7 @@ class ChatMistralAI(BaseChatModel):
# }
""" # noqa: E501
_ = kwargs.pop("strict", None)
if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}")
is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema)

View File

@ -4,16 +4,28 @@ from typing import (
Any,
Dict,
List,
Literal,
Optional,
Type,
TypeVar,
Union,
)
import openai
from langchain_core.language_models.chat_models import LangSmithParams
from langchain_core.language_models.chat_models import (
LangSmithParams,
LanguageModelInput,
)
from langchain_core.runnables import Runnable
from langchain_core.utils import secret_from_env
from langchain_openai.chat_models.base import BaseChatOpenAI
from pydantic import ConfigDict, Field, SecretStr, model_validator
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self
_BM = TypeVar("_BM", bound=BaseModel)
_DictOrPydanticClass = Union[Dict[str, Any], Type[_BM], Type]
_DictOrPydantic = Union[Dict, _BM]
class ChatXAI(BaseChatOpenAI): # type: ignore[override]
r"""ChatXAI chat model.
@ -359,3 +371,83 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
**async_specific,
)
return self
def with_structured_output(
self,
schema: Optional[_DictOrPydanticClass] = None,
*,
method: Literal[
"function_calling", "json_mode", "json_schema"
] = "function_calling",
include_raw: bool = False,
strict: Optional[bool] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, _DictOrPydantic]:
"""Model wrapper that returns outputs formatted to match the given schema.
Args:
schema:
The output schema. Can be passed in as:
- an OpenAI function/tool schema,
- a JSON Schema,
- a TypedDict class (support added in 0.1.20),
- or a Pydantic class.
If ``schema`` is a Pydantic class then the model output will be a
Pydantic instance of that class, and the model-generated fields will be
validated by the Pydantic class. Otherwise the model output will be a
dict and will not be validated. See :meth:`langchain_core.utils.function_calling.convert_to_openai_tool`
for more on how to properly specify types and descriptions of
schema fields when specifying a Pydantic or TypedDict class.
method: The method for steering model generation, one of:
- "function_calling":
Uses xAI's `tool-calling features <https://docs.x.ai/docs/guides/function-calling>`_.
- "json_schema":
Uses xAI's `structured output feature <https://docs.x.ai/docs/guides/structured-outputs>`_.
- "json_mode":
Uses xAI's JSON mode feature.
include_raw:
If False then only the parsed structured output is returned. If
an error occurs during model output parsing it will be raised. If True
then both the raw model response (a BaseMessage) and the parsed model
response will be returned. If an error occurs during output parsing it
will be caught and returned as well. The final output is always a dict
with keys "raw", "parsed", and "parsing_error".
strict:
- True:
Model output is guaranteed to exactly match the schema.
The input schema will also be validated according to
https://platform.openai.com/docs/guides/structured-outputs/supported-schemas
- False:
Input schema will not be validated and model output will not be
validated.
- None:
``strict`` argument will not be passed to the model.
kwargs: Additional keyword args aren't supported.
Returns:
A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`.
| If ``include_raw`` is False and ``schema`` is a Pydantic class, Runnable outputs an instance of ``schema`` (i.e., a Pydantic object). Otherwise, if ``include_raw`` is False then Runnable outputs a dict.
| If ``include_raw`` is True, then Runnable outputs a dict with keys:
- "raw": BaseMessage
- "parsed": None if there was a parsing error, otherwise the type depends on the ``schema`` as described above.
- "parsing_error": Optional[BaseException]
""" # noqa: E501
# Some applications require that incompatible parameters (e.g., unsupported
# methods) be handled.
if method == "function_calling" and strict:
strict = None
return super().with_structured_output(
schema, method=method, include_raw=include_raw, strict=strict, **kwargs
)

View File

@ -626,6 +626,12 @@ class ChatModelUnitTests(ChatModelTests):
return
assert model.with_structured_output(schema) is not None
for method in ["json_schema", "function_calling", "json_mode"]:
strict_values = [None, False, True] if method != "json_mode" else [None]
for strict in strict_values:
assert model.with_structured_output(
schema, method=method, strict=strict
)
def test_standard_params(self, model: BaseChatModel) -> None:
"""Test that model properly generates standard parameters. These are used