mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 22:03:52 +00:00
multiple: support strict
and method
in with_structured_output (#30385)
This commit is contained in:
parent
1103bdfaf1
commit
b86cd8270c
@ -1,21 +1,27 @@
|
||||
"""DeepSeek chat models."""
|
||||
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, Dict, Iterator, List, Optional, Type, Union
|
||||
from typing import Any, Dict, Iterator, List, Literal, Optional, Type, TypeVar, Union
|
||||
|
||||
import openai
|
||||
from langchain_core.callbacks import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.messages import AIMessageChunk, BaseMessage
|
||||
from langchain_core.outputs import ChatGenerationChunk, ChatResult
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.utils import from_env, secret_from_env
|
||||
from langchain_openai.chat_models.base import BaseChatOpenAI
|
||||
from pydantic import ConfigDict, Field, SecretStr, model_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
DEFAULT_API_BASE = "https://api.deepseek.com/v1"
|
||||
|
||||
_BM = TypeVar("_BM", bound=BaseModel)
|
||||
_DictOrPydanticClass = Union[Dict[str, Any], Type[_BM], Type]
|
||||
_DictOrPydantic = Union[Dict, _BM]
|
||||
|
||||
|
||||
class ChatDeepSeek(BaseChatOpenAI):
|
||||
"""DeepSeek chat model integration to access models hosted in DeepSeek's API.
|
||||
@ -197,14 +203,15 @@ class ChatDeepSeek(BaseChatOpenAI):
|
||||
|
||||
if not (self.client or None):
|
||||
sync_specific: dict = {"http_client": self.http_client}
|
||||
self.client = openai.OpenAI(
|
||||
**client_params, **sync_specific
|
||||
).chat.completions
|
||||
self.root_client = openai.OpenAI(**client_params, **sync_specific)
|
||||
self.client = self.root_client.chat.completions
|
||||
if not (self.async_client or None):
|
||||
async_specific: dict = {"http_client": self.http_async_client}
|
||||
self.async_client = openai.AsyncOpenAI(
|
||||
**client_params, **async_specific
|
||||
).chat.completions
|
||||
self.root_async_client = openai.AsyncOpenAI(
|
||||
**client_params,
|
||||
**async_specific,
|
||||
)
|
||||
self.async_client = self.root_async_client.chat.completions
|
||||
return self
|
||||
|
||||
def _create_chat_result(
|
||||
@ -281,3 +288,73 @@ class ChatDeepSeek(BaseChatOpenAI):
|
||||
e.doc,
|
||||
e.pos,
|
||||
) from e
|
||||
|
||||
def with_structured_output(
|
||||
self,
|
||||
schema: Optional[_DictOrPydanticClass] = None,
|
||||
*,
|
||||
method: Literal[
|
||||
"function_calling", "json_mode", "json_schema"
|
||||
] = "function_calling",
|
||||
include_raw: bool = False,
|
||||
strict: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, _DictOrPydantic]:
|
||||
"""Model wrapper that returns outputs formatted to match the given schema.
|
||||
|
||||
Args:
|
||||
schema:
|
||||
The output schema. Can be passed in as:
|
||||
|
||||
- an OpenAI function/tool schema,
|
||||
- a JSON Schema,
|
||||
- a TypedDict class (support added in 0.1.20),
|
||||
- or a Pydantic class.
|
||||
|
||||
If ``schema`` is a Pydantic class then the model output will be a
|
||||
Pydantic instance of that class, and the model-generated fields will be
|
||||
validated by the Pydantic class. Otherwise the model output will be a
|
||||
dict and will not be validated. See :meth:`langchain_core.utils.function_calling.convert_to_openai_tool`
|
||||
for more on how to properly specify types and descriptions of
|
||||
schema fields when specifying a Pydantic or TypedDict class.
|
||||
|
||||
method: The method for steering model generation, one of:
|
||||
|
||||
- "function_calling":
|
||||
Uses DeekSeek's `tool-calling features <https://api-docs.deepseek.com/guides/function_calling>`_.
|
||||
- "json_mode":
|
||||
Uses DeepSeek's `JSON mode feature <https://api-docs.deepseek.com/guides/json_mode>`_.
|
||||
|
||||
.. versionchanged:: 0.1.3
|
||||
|
||||
Added support for ``"json_mode"``.
|
||||
|
||||
include_raw:
|
||||
If False then only the parsed structured output is returned. If
|
||||
an error occurs during model output parsing it will be raised. If True
|
||||
then both the raw model response (a BaseMessage) and the parsed model
|
||||
response will be returned. If an error occurs during output parsing it
|
||||
will be caught and returned as well. The final output is always a dict
|
||||
with keys "raw", "parsed", and "parsing_error".
|
||||
|
||||
kwargs: Additional keyword args aren't supported.
|
||||
|
||||
Returns:
|
||||
A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`.
|
||||
|
||||
| If ``include_raw`` is False and ``schema`` is a Pydantic class, Runnable outputs an instance of ``schema`` (i.e., a Pydantic object). Otherwise, if ``include_raw`` is False then Runnable outputs a dict.
|
||||
|
||||
| If ``include_raw`` is True, then Runnable outputs a dict with keys:
|
||||
|
||||
- "raw": BaseMessage
|
||||
- "parsed": None if there was a parsing error, otherwise the type depends on the ``schema`` as described above.
|
||||
- "parsing_error": Optional[BaseException]
|
||||
|
||||
""" # noqa: E501
|
||||
# Some applications require that incompatible parameters (e.g., unsupported
|
||||
# methods) be handled.
|
||||
if method == "json_schema":
|
||||
method = "function_calling"
|
||||
return super().with_structured_output(
|
||||
schema, method=method, include_raw=include_raw, strict=strict, **kwargs
|
||||
)
|
||||
|
@ -24,6 +24,11 @@ class TestChatDeepSeek(ChatModelIntegrationTests):
|
||||
"temperature": 0,
|
||||
}
|
||||
|
||||
@property
|
||||
def supports_json_mode(self) -> bool:
|
||||
"""(bool) whether the chat model supports JSON mode."""
|
||||
return True
|
||||
|
||||
@pytest.mark.xfail(reason="Not yet supported.")
|
||||
def test_tool_message_histories_list_content(
|
||||
self, model: BaseChatModel, my_adder_tool: BaseTool
|
||||
|
@ -75,6 +75,7 @@ from langchain_core.utils import (
|
||||
get_pydantic_field_names,
|
||||
)
|
||||
from langchain_core.utils.function_calling import (
|
||||
convert_to_json_schema,
|
||||
convert_to_openai_function,
|
||||
convert_to_openai_tool,
|
||||
)
|
||||
@ -737,7 +738,9 @@ class ChatFireworks(BaseChatModel):
|
||||
self,
|
||||
schema: Optional[Union[Dict, Type[BaseModel]]] = None,
|
||||
*,
|
||||
method: Literal["function_calling", "json_mode"] = "function_calling",
|
||||
method: Literal[
|
||||
"function_calling", "json_mode", "json_schema"
|
||||
] = "function_calling",
|
||||
include_raw: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
|
||||
@ -761,13 +764,19 @@ class ChatFireworks(BaseChatModel):
|
||||
|
||||
Added support for TypedDict class.
|
||||
|
||||
method:
|
||||
The method for steering model generation, either "function_calling"
|
||||
or "json_mode". If "function_calling" then the schema will be converted
|
||||
to an OpenAI function and the returned model will make use of the
|
||||
function-calling API. If "json_mode" then OpenAI's JSON mode will be
|
||||
used. Note that if using "json_mode" then you must include instructions
|
||||
for formatting the output into the desired schema into the model call.
|
||||
method: The method for steering model generation, one of:
|
||||
|
||||
- "function_calling":
|
||||
Uses Fireworks's `tool-calling features <https://docs.fireworks.ai/guides/function-calling>`_.
|
||||
- "json_schema":
|
||||
Uses Fireworks's `structured output feature <https://docs.fireworks.ai/structured-responses/structured-response-formatting>`_.
|
||||
- "json_mode":
|
||||
Uses Fireworks's `JSON mode feature <https://docs.fireworks.ai/structured-responses/structured-response-formatting>`_.
|
||||
|
||||
.. versionchanged:: 0.2.8
|
||||
|
||||
Added support for ``"json_schema"``.
|
||||
|
||||
include_raw:
|
||||
If False then only the parsed structured output is returned. If
|
||||
an error occurs during model output parsing it will be raised. If True
|
||||
@ -928,11 +937,11 @@ class ChatFireworks(BaseChatModel):
|
||||
|
||||
structured_llm.invoke(
|
||||
"Answer the following question. "
|
||||
"Make sure to return a JSON blob with keys 'answer' and 'justification'.\n\n"
|
||||
"Make sure to return a JSON blob with keys 'answer' and 'justification'. "
|
||||
"What's heavier a pound of bricks or a pound of feathers?"
|
||||
)
|
||||
# -> {
|
||||
# 'raw': AIMessage(content='{\n "answer": "They are both the same weight.",\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \n}'),
|
||||
# 'raw': AIMessage(content='{"answer": "They are both the same weight.", "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight."}'),
|
||||
# 'parsed': AnswerWithJustification(answer='They are both the same weight.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight.'),
|
||||
# 'parsing_error': None
|
||||
# }
|
||||
@ -944,11 +953,11 @@ class ChatFireworks(BaseChatModel):
|
||||
|
||||
structured_llm.invoke(
|
||||
"Answer the following question. "
|
||||
"Make sure to return a JSON blob with keys 'answer' and 'justification'.\n\n"
|
||||
"Make sure to return a JSON blob with keys 'answer' and 'justification'. "
|
||||
"What's heavier a pound of bricks or a pound of feathers?"
|
||||
)
|
||||
# -> {
|
||||
# 'raw': AIMessage(content='{\n "answer": "They are both the same weight.",\n "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight." \n}'),
|
||||
# 'raw': AIMessage(content='{"answer": "They are both the same weight.", "justification": "Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight."}'),
|
||||
# 'parsed': {
|
||||
# 'answer': 'They are both the same weight.',
|
||||
# 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The difference lies in the volume and density of the materials, not the weight.'
|
||||
@ -956,6 +965,7 @@ class ChatFireworks(BaseChatModel):
|
||||
# 'parsing_error': None
|
||||
# }
|
||||
""" # noqa: E501
|
||||
_ = kwargs.pop("strict", None)
|
||||
if kwargs:
|
||||
raise ValueError(f"Received unsupported arguments {kwargs}")
|
||||
is_pydantic_schema = _is_pydantic_class(schema)
|
||||
@ -984,6 +994,25 @@ class ChatFireworks(BaseChatModel):
|
||||
output_parser = JsonOutputKeyToolsParser(
|
||||
key_name=tool_name, first_tool_only=True
|
||||
)
|
||||
elif method == "json_schema":
|
||||
if schema is None:
|
||||
raise ValueError(
|
||||
"schema must be specified when method is 'json_schema'. "
|
||||
"Received None."
|
||||
)
|
||||
formatted_schema = convert_to_json_schema(schema)
|
||||
llm = self.bind(
|
||||
response_format={"type": "json_object", "schema": formatted_schema},
|
||||
ls_structured_output_format={
|
||||
"kwargs": {"method": "json_schema"},
|
||||
"schema": schema,
|
||||
},
|
||||
)
|
||||
output_parser = (
|
||||
PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type]
|
||||
if is_pydantic_schema
|
||||
else JsonOutputParser()
|
||||
)
|
||||
elif method == "json_mode":
|
||||
llm = self.bind(
|
||||
response_format={"type": "json_object"},
|
||||
|
@ -7,7 +7,7 @@ authors = []
|
||||
license = { text = "MIT" }
|
||||
requires-python = "<4.0,>=3.9"
|
||||
dependencies = [
|
||||
"langchain-core<1.0.0,>=0.3.33",
|
||||
"langchain-core<1.0.0,>=0.3.46",
|
||||
"fireworks-ai>=0.13.0",
|
||||
"openai<2.0.0,>=1.10.0",
|
||||
"requests<3,>=2",
|
||||
|
@ -4,10 +4,12 @@ You will need FIREWORKS_API_KEY set in your environment to run these tests.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessageChunk
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated, TypedDict
|
||||
|
||||
from langchain_fireworks import ChatFireworks
|
||||
|
||||
@ -161,3 +163,54 @@ def test_invoke() -> None:
|
||||
|
||||
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
||||
assert isinstance(result.content, str)
|
||||
|
||||
|
||||
def _get_joke_class(
|
||||
schema_type: Literal["pydantic", "typeddict", "json_schema"],
|
||||
) -> Any:
|
||||
class Joke(BaseModel):
|
||||
"""Joke to tell user."""
|
||||
|
||||
setup: str = Field(description="question to set up a joke")
|
||||
punchline: str = Field(description="answer to resolve the joke")
|
||||
|
||||
def validate_joke(result: Any) -> bool:
|
||||
return isinstance(result, Joke)
|
||||
|
||||
class JokeDict(TypedDict):
|
||||
"""Joke to tell user."""
|
||||
|
||||
setup: Annotated[str, ..., "question to set up a joke"]
|
||||
punchline: Annotated[str, ..., "answer to resolve the joke"]
|
||||
|
||||
def validate_joke_dict(result: Any) -> bool:
|
||||
return all(key in ["setup", "punchline"] for key in result.keys())
|
||||
|
||||
if schema_type == "pydantic":
|
||||
return Joke, validate_joke
|
||||
|
||||
elif schema_type == "typeddict":
|
||||
return JokeDict, validate_joke_dict
|
||||
|
||||
elif schema_type == "json_schema":
|
||||
return Joke.model_json_schema(), validate_joke_dict
|
||||
else:
|
||||
raise ValueError("Invalid schema type")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("schema_type", ["pydantic", "typeddict", "json_schema"])
|
||||
def test_structured_output_json_schema(schema_type: str) -> None:
|
||||
llm = ChatFireworks(model="accounts/fireworks/models/llama-v3p1-70b-instruct")
|
||||
schema, validation_function = _get_joke_class(schema_type) # type: ignore[arg-type]
|
||||
chat = llm.with_structured_output(schema, method="json_schema")
|
||||
|
||||
# Test invoke
|
||||
result = chat.invoke("Tell me a joke about cats.")
|
||||
validation_function(result)
|
||||
|
||||
# Test stream
|
||||
chunks = []
|
||||
for chunk in chat.stream("Tell me a joke about cats."):
|
||||
validation_function(chunk)
|
||||
chunks.append(chunk)
|
||||
assert chunk
|
||||
|
@ -635,7 +635,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-core"
|
||||
version = "0.3.35"
|
||||
version = "0.3.46"
|
||||
source = { editable = "../../core" }
|
||||
dependencies = [
|
||||
{ name = "jsonpatch" },
|
||||
@ -667,7 +667,7 @@ dev = [
|
||||
]
|
||||
lint = [{ name = "ruff", specifier = ">=0.9.2,<1.0.0" }]
|
||||
test = [
|
||||
{ name = "blockbuster", specifier = "~=1.5.11" },
|
||||
{ name = "blockbuster", specifier = "~=1.5.18" },
|
||||
{ name = "freezegun", specifier = ">=1.2.2,<2.0.0" },
|
||||
{ name = "grandalf", specifier = ">=0.8,<1.0" },
|
||||
{ name = "langchain-tests", directory = "../../standard-tests" },
|
||||
@ -763,7 +763,7 @@ typing = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-tests"
|
||||
version = "0.3.11"
|
||||
version = "0.3.14"
|
||||
source = { editable = "../../standard-tests" }
|
||||
dependencies = [
|
||||
{ name = "httpx" },
|
||||
@ -780,8 +780,7 @@ dependencies = [
|
||||
requires-dist = [
|
||||
{ name = "httpx", specifier = ">=0.25.0,<1" },
|
||||
{ name = "langchain-core", editable = "../../core" },
|
||||
{ name = "numpy", marker = "python_full_version < '3.12'", specifier = ">=1.24.0,<2.0.0" },
|
||||
{ name = "numpy", marker = "python_full_version >= '3.12'", specifier = ">=1.26.2,<3" },
|
||||
{ name = "numpy", specifier = ">=1.26.2,<3" },
|
||||
{ name = "pytest", specifier = ">=7,<9" },
|
||||
{ name = "pytest-asyncio", specifier = ">=0.20,<1" },
|
||||
{ name = "pytest-socket", specifier = ">=0.6.0,<1" },
|
||||
|
@ -460,6 +460,24 @@ class ChatGroq(BaseChatModel):
|
||||
ls_params["ls_stop"] = ls_stop if isinstance(ls_stop, list) else [ls_stop]
|
||||
return ls_params
|
||||
|
||||
def _should_stream(
|
||||
self,
|
||||
*,
|
||||
async_api: bool,
|
||||
run_manager: Optional[
|
||||
Union[CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun]
|
||||
] = None,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
"""Determine if a given model call should hit the streaming API."""
|
||||
base_should_stream = super()._should_stream(
|
||||
async_api=async_api, run_manager=run_manager, **kwargs
|
||||
)
|
||||
if base_should_stream and ("response_format" in kwargs):
|
||||
# Streaming not supported in JSON mode.
|
||||
return kwargs["response_format"] != {"type": "json_object"}
|
||||
return base_should_stream
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
@ -987,9 +1005,14 @@ class ChatGroq(BaseChatModel):
|
||||
# 'parsing_error': None
|
||||
# }
|
||||
""" # noqa: E501
|
||||
_ = kwargs.pop("strict", None)
|
||||
if kwargs:
|
||||
raise ValueError(f"Received unsupported arguments {kwargs}")
|
||||
is_pydantic_schema = _is_pydantic_class(schema)
|
||||
if method == "json_schema":
|
||||
# Some applications require that incompatible parameters (e.g., unsupported
|
||||
# methods) be handled.
|
||||
method = "function_calling"
|
||||
if method == "function_calling":
|
||||
if schema is None:
|
||||
raise ValueError(
|
||||
|
@ -47,4 +47,4 @@ class TestGroqLlama(BaseTestGroq):
|
||||
|
||||
@property
|
||||
def supports_json_mode(self) -> bool:
|
||||
return False # Not supported in streaming mode
|
||||
return True
|
||||
|
@ -945,6 +945,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
# }
|
||||
|
||||
""" # noqa: E501
|
||||
_ = kwargs.pop("strict", None)
|
||||
if kwargs:
|
||||
raise ValueError(f"Received unsupported arguments {kwargs}")
|
||||
is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema)
|
||||
|
@ -4,16 +4,28 @@ from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
import openai
|
||||
from langchain_core.language_models.chat_models import LangSmithParams
|
||||
from langchain_core.language_models.chat_models import (
|
||||
LangSmithParams,
|
||||
LanguageModelInput,
|
||||
)
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.utils import secret_from_env
|
||||
from langchain_openai.chat_models.base import BaseChatOpenAI
|
||||
from pydantic import ConfigDict, Field, SecretStr, model_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
_BM = TypeVar("_BM", bound=BaseModel)
|
||||
_DictOrPydanticClass = Union[Dict[str, Any], Type[_BM], Type]
|
||||
_DictOrPydantic = Union[Dict, _BM]
|
||||
|
||||
|
||||
class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
||||
r"""ChatXAI chat model.
|
||||
@ -359,3 +371,83 @@ class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
||||
**async_specific,
|
||||
)
|
||||
return self
|
||||
|
||||
def with_structured_output(
|
||||
self,
|
||||
schema: Optional[_DictOrPydanticClass] = None,
|
||||
*,
|
||||
method: Literal[
|
||||
"function_calling", "json_mode", "json_schema"
|
||||
] = "function_calling",
|
||||
include_raw: bool = False,
|
||||
strict: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, _DictOrPydantic]:
|
||||
"""Model wrapper that returns outputs formatted to match the given schema.
|
||||
|
||||
Args:
|
||||
schema:
|
||||
The output schema. Can be passed in as:
|
||||
|
||||
- an OpenAI function/tool schema,
|
||||
- a JSON Schema,
|
||||
- a TypedDict class (support added in 0.1.20),
|
||||
- or a Pydantic class.
|
||||
|
||||
If ``schema`` is a Pydantic class then the model output will be a
|
||||
Pydantic instance of that class, and the model-generated fields will be
|
||||
validated by the Pydantic class. Otherwise the model output will be a
|
||||
dict and will not be validated. See :meth:`langchain_core.utils.function_calling.convert_to_openai_tool`
|
||||
for more on how to properly specify types and descriptions of
|
||||
schema fields when specifying a Pydantic or TypedDict class.
|
||||
|
||||
method: The method for steering model generation, one of:
|
||||
|
||||
- "function_calling":
|
||||
Uses xAI's `tool-calling features <https://docs.x.ai/docs/guides/function-calling>`_.
|
||||
- "json_schema":
|
||||
Uses xAI's `structured output feature <https://docs.x.ai/docs/guides/structured-outputs>`_.
|
||||
- "json_mode":
|
||||
Uses xAI's JSON mode feature.
|
||||
|
||||
include_raw:
|
||||
If False then only the parsed structured output is returned. If
|
||||
an error occurs during model output parsing it will be raised. If True
|
||||
then both the raw model response (a BaseMessage) and the parsed model
|
||||
response will be returned. If an error occurs during output parsing it
|
||||
will be caught and returned as well. The final output is always a dict
|
||||
with keys "raw", "parsed", and "parsing_error".
|
||||
|
||||
strict:
|
||||
|
||||
- True:
|
||||
Model output is guaranteed to exactly match the schema.
|
||||
The input schema will also be validated according to
|
||||
https://platform.openai.com/docs/guides/structured-outputs/supported-schemas
|
||||
- False:
|
||||
Input schema will not be validated and model output will not be
|
||||
validated.
|
||||
- None:
|
||||
``strict`` argument will not be passed to the model.
|
||||
|
||||
kwargs: Additional keyword args aren't supported.
|
||||
|
||||
Returns:
|
||||
A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`.
|
||||
|
||||
| If ``include_raw`` is False and ``schema`` is a Pydantic class, Runnable outputs an instance of ``schema`` (i.e., a Pydantic object). Otherwise, if ``include_raw`` is False then Runnable outputs a dict.
|
||||
|
||||
| If ``include_raw`` is True, then Runnable outputs a dict with keys:
|
||||
|
||||
- "raw": BaseMessage
|
||||
- "parsed": None if there was a parsing error, otherwise the type depends on the ``schema`` as described above.
|
||||
- "parsing_error": Optional[BaseException]
|
||||
|
||||
""" # noqa: E501
|
||||
# Some applications require that incompatible parameters (e.g., unsupported
|
||||
# methods) be handled.
|
||||
if method == "function_calling" and strict:
|
||||
strict = None
|
||||
return super().with_structured_output(
|
||||
schema, method=method, include_raw=include_raw, strict=strict, **kwargs
|
||||
)
|
||||
|
@ -626,6 +626,12 @@ class ChatModelUnitTests(ChatModelTests):
|
||||
return
|
||||
|
||||
assert model.with_structured_output(schema) is not None
|
||||
for method in ["json_schema", "function_calling", "json_mode"]:
|
||||
strict_values = [None, False, True] if method != "json_mode" else [None]
|
||||
for strict in strict_values:
|
||||
assert model.with_structured_output(
|
||||
schema, method=method, strict=strict
|
||||
)
|
||||
|
||||
def test_standard_params(self, model: BaseChatModel) -> None:
|
||||
"""Test that model properly generates standard parameters. These are used
|
||||
|
Loading…
Reference in New Issue
Block a user