mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-19 19:11:33 +00:00
multiple: structured output tracing standard metadata (#29421)
Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
parent
284c935b08
commit
8f95da4eb1
@ -365,11 +365,28 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
else:
|
else:
|
||||||
config = ensure_config(config)
|
config = ensure_config(config)
|
||||||
messages = self._convert_input(input).to_messages()
|
messages = self._convert_input(input).to_messages()
|
||||||
|
structured_output_format = kwargs.pop("structured_output_format", None)
|
||||||
|
if structured_output_format:
|
||||||
|
try:
|
||||||
|
structured_output_format_dict = {
|
||||||
|
"structured_output_format": {
|
||||||
|
"kwargs": structured_output_format.get("kwargs", {}),
|
||||||
|
"schema": convert_to_openai_tool(
|
||||||
|
structured_output_format["schema"]
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
except ValueError:
|
||||||
|
structured_output_format_dict = {}
|
||||||
|
else:
|
||||||
|
structured_output_format_dict = {}
|
||||||
|
|
||||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||||
options = {"stop": stop, **kwargs}
|
options = {"stop": stop, **kwargs}
|
||||||
inheritable_metadata = {
|
inheritable_metadata = {
|
||||||
**(config.get("metadata") or {}),
|
**(config.get("metadata") or {}),
|
||||||
**self._get_ls_params(stop=stop, **kwargs),
|
**self._get_ls_params(stop=stop, **kwargs),
|
||||||
|
**structured_output_format_dict,
|
||||||
}
|
}
|
||||||
callback_manager = CallbackManager.configure(
|
callback_manager = CallbackManager.configure(
|
||||||
config.get("callbacks"),
|
config.get("callbacks"),
|
||||||
@ -441,11 +458,29 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
|
|
||||||
config = ensure_config(config)
|
config = ensure_config(config)
|
||||||
messages = self._convert_input(input).to_messages()
|
messages = self._convert_input(input).to_messages()
|
||||||
|
|
||||||
|
structured_output_format = kwargs.pop("structured_output_format", None)
|
||||||
|
if structured_output_format:
|
||||||
|
try:
|
||||||
|
structured_output_format_dict = {
|
||||||
|
"structured_output_format": {
|
||||||
|
"kwargs": structured_output_format.get("kwargs", {}),
|
||||||
|
"schema": convert_to_openai_tool(
|
||||||
|
structured_output_format["schema"]
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
except ValueError:
|
||||||
|
structured_output_format_dict = {}
|
||||||
|
else:
|
||||||
|
structured_output_format_dict = {}
|
||||||
|
|
||||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||||
options = {"stop": stop, **kwargs}
|
options = {"stop": stop, **kwargs}
|
||||||
inheritable_metadata = {
|
inheritable_metadata = {
|
||||||
**(config.get("metadata") or {}),
|
**(config.get("metadata") or {}),
|
||||||
**self._get_ls_params(stop=stop, **kwargs),
|
**self._get_ls_params(stop=stop, **kwargs),
|
||||||
|
**structured_output_format_dict,
|
||||||
}
|
}
|
||||||
callback_manager = AsyncCallbackManager.configure(
|
callback_manager = AsyncCallbackManager.configure(
|
||||||
config.get("callbacks"),
|
config.get("callbacks"),
|
||||||
@ -606,11 +641,28 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
An LLMResult, which contains a list of candidate Generations for each input
|
An LLMResult, which contains a list of candidate Generations for each input
|
||||||
prompt and additional model provider-specific output.
|
prompt and additional model provider-specific output.
|
||||||
"""
|
"""
|
||||||
|
structured_output_format = kwargs.pop("structured_output_format", None)
|
||||||
|
if structured_output_format:
|
||||||
|
try:
|
||||||
|
structured_output_format_dict = {
|
||||||
|
"structured_output_format": {
|
||||||
|
"kwargs": structured_output_format.get("kwargs", {}),
|
||||||
|
"schema": convert_to_openai_tool(
|
||||||
|
structured_output_format["schema"]
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
except ValueError:
|
||||||
|
structured_output_format_dict = {}
|
||||||
|
else:
|
||||||
|
structured_output_format_dict = {}
|
||||||
|
|
||||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||||
options = {"stop": stop}
|
options = {"stop": stop}
|
||||||
inheritable_metadata = {
|
inheritable_metadata = {
|
||||||
**(metadata or {}),
|
**(metadata or {}),
|
||||||
**self._get_ls_params(stop=stop, **kwargs),
|
**self._get_ls_params(stop=stop, **kwargs),
|
||||||
|
**structured_output_format_dict,
|
||||||
}
|
}
|
||||||
|
|
||||||
callback_manager = CallbackManager.configure(
|
callback_manager = CallbackManager.configure(
|
||||||
@ -697,11 +749,28 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
An LLMResult, which contains a list of candidate Generations for each input
|
An LLMResult, which contains a list of candidate Generations for each input
|
||||||
prompt and additional model provider-specific output.
|
prompt and additional model provider-specific output.
|
||||||
"""
|
"""
|
||||||
|
structured_output_format = kwargs.pop("structured_output_format", None)
|
||||||
|
if structured_output_format:
|
||||||
|
try:
|
||||||
|
structured_output_format_dict = {
|
||||||
|
"structured_output_format": {
|
||||||
|
"kwargs": structured_output_format.get("kwargs", {}),
|
||||||
|
"schema": convert_to_openai_tool(
|
||||||
|
structured_output_format["schema"]
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
except ValueError:
|
||||||
|
structured_output_format_dict = {}
|
||||||
|
else:
|
||||||
|
structured_output_format_dict = {}
|
||||||
|
|
||||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||||
options = {"stop": stop}
|
options = {"stop": stop}
|
||||||
inheritable_metadata = {
|
inheritable_metadata = {
|
||||||
**(metadata or {}),
|
**(metadata or {}),
|
||||||
**self._get_ls_params(stop=stop, **kwargs),
|
**self._get_ls_params(stop=stop, **kwargs),
|
||||||
|
**structured_output_format_dict,
|
||||||
}
|
}
|
||||||
|
|
||||||
callback_manager = AsyncCallbackManager.configure(
|
callback_manager = AsyncCallbackManager.configure(
|
||||||
@ -1240,7 +1309,12 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
if self.bind_tools is BaseChatModel.bind_tools:
|
if self.bind_tools is BaseChatModel.bind_tools:
|
||||||
msg = "with_structured_output is not implemented for this model."
|
msg = "with_structured_output is not implemented for this model."
|
||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
llm = self.bind_tools([schema], tool_choice="any")
|
|
||||||
|
llm = self.bind_tools(
|
||||||
|
[schema],
|
||||||
|
tool_choice="any",
|
||||||
|
structured_output_format={"kwargs": {}, "schema": schema},
|
||||||
|
)
|
||||||
if isinstance(schema, type) and is_basemodel_subclass(schema):
|
if isinstance(schema, type) and is_basemodel_subclass(schema):
|
||||||
output_parser: OutputParserLike = PydanticToolsParser(
|
output_parser: OutputParserLike = PydanticToolsParser(
|
||||||
tools=[cast(TypeBaseModel, schema)], first_tool_only=True
|
tools=[cast(TypeBaseModel, schema)], first_tool_only=True
|
||||||
|
@ -1111,9 +1111,13 @@ class ChatAnthropic(BaseChatModel):
|
|||||||
Added support for TypedDict class as `schema`.
|
Added support for TypedDict class as `schema`.
|
||||||
|
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
formatted_tool = convert_to_anthropic_tool(schema)
|
||||||
tool_name = convert_to_anthropic_tool(schema)["name"]
|
tool_name = formatted_tool["name"]
|
||||||
llm = self.bind_tools([schema], tool_choice=tool_name)
|
llm = self.bind_tools(
|
||||||
|
[schema],
|
||||||
|
tool_choice=tool_name,
|
||||||
|
structured_output_format={"kwargs": {}, "schema": formatted_tool},
|
||||||
|
)
|
||||||
if isinstance(schema, type) and is_basemodel_subclass(schema):
|
if isinstance(schema, type) and is_basemodel_subclass(schema):
|
||||||
output_parser: OutputParserLike = PydanticToolsParser(
|
output_parser: OutputParserLike = PydanticToolsParser(
|
||||||
tools=[schema], first_tool_only=True
|
tools=[schema], first_tool_only=True
|
||||||
|
@ -965,8 +965,16 @@ class ChatFireworks(BaseChatModel):
|
|||||||
"schema must be specified when method is 'function_calling'. "
|
"schema must be specified when method is 'function_calling'. "
|
||||||
"Received None."
|
"Received None."
|
||||||
)
|
)
|
||||||
tool_name = convert_to_openai_tool(schema)["function"]["name"]
|
formatted_tool = convert_to_openai_tool(schema)
|
||||||
llm = self.bind_tools([schema], tool_choice=tool_name)
|
tool_name = formatted_tool["function"]["name"]
|
||||||
|
llm = self.bind_tools(
|
||||||
|
[schema],
|
||||||
|
tool_choice=tool_name,
|
||||||
|
structured_output_format={
|
||||||
|
"kwargs": {"method": "function_calling"},
|
||||||
|
"schema": formatted_tool,
|
||||||
|
},
|
||||||
|
)
|
||||||
if is_pydantic_schema:
|
if is_pydantic_schema:
|
||||||
output_parser: OutputParserLike = PydanticToolsParser(
|
output_parser: OutputParserLike = PydanticToolsParser(
|
||||||
tools=[schema], # type: ignore[list-item]
|
tools=[schema], # type: ignore[list-item]
|
||||||
@ -977,7 +985,13 @@ class ChatFireworks(BaseChatModel):
|
|||||||
key_name=tool_name, first_tool_only=True
|
key_name=tool_name, first_tool_only=True
|
||||||
)
|
)
|
||||||
elif method == "json_mode":
|
elif method == "json_mode":
|
||||||
llm = self.bind(response_format={"type": "json_object"})
|
llm = self.bind(
|
||||||
|
response_format={"type": "json_object"},
|
||||||
|
structured_output_format={
|
||||||
|
"kwargs": {"method": "json_mode"},
|
||||||
|
"schema": schema,
|
||||||
|
},
|
||||||
|
)
|
||||||
output_parser = (
|
output_parser = (
|
||||||
PydanticOutputParser(pydantic_object=schema) # type: ignore[type-var, arg-type]
|
PydanticOutputParser(pydantic_object=schema) # type: ignore[type-var, arg-type]
|
||||||
if is_pydantic_schema
|
if is_pydantic_schema
|
||||||
|
@ -996,8 +996,16 @@ class ChatGroq(BaseChatModel):
|
|||||||
"schema must be specified when method is 'function_calling'. "
|
"schema must be specified when method is 'function_calling'. "
|
||||||
"Received None."
|
"Received None."
|
||||||
)
|
)
|
||||||
tool_name = convert_to_openai_tool(schema)["function"]["name"]
|
formatted_tool = convert_to_openai_tool(schema)
|
||||||
llm = self.bind_tools([schema], tool_choice=tool_name)
|
tool_name = formatted_tool["function"]["name"]
|
||||||
|
llm = self.bind_tools(
|
||||||
|
[schema],
|
||||||
|
tool_choice=tool_name,
|
||||||
|
structured_output_format={
|
||||||
|
"kwargs": {"method": "function_calling"},
|
||||||
|
"schema": formatted_tool,
|
||||||
|
},
|
||||||
|
)
|
||||||
if is_pydantic_schema:
|
if is_pydantic_schema:
|
||||||
output_parser: OutputParserLike = PydanticToolsParser(
|
output_parser: OutputParserLike = PydanticToolsParser(
|
||||||
tools=[schema], # type: ignore[list-item]
|
tools=[schema], # type: ignore[list-item]
|
||||||
@ -1008,7 +1016,13 @@ class ChatGroq(BaseChatModel):
|
|||||||
key_name=tool_name, first_tool_only=True
|
key_name=tool_name, first_tool_only=True
|
||||||
)
|
)
|
||||||
elif method == "json_mode":
|
elif method == "json_mode":
|
||||||
llm = self.bind(response_format={"type": "json_object"})
|
llm = self.bind(
|
||||||
|
response_format={"type": "json_object"},
|
||||||
|
structured_output_format={
|
||||||
|
"kwargs": {"method": "json_mode"},
|
||||||
|
"schema": schema,
|
||||||
|
},
|
||||||
|
)
|
||||||
output_parser = (
|
output_parser = (
|
||||||
PydanticOutputParser(pydantic_object=schema) # type: ignore[type-var, arg-type]
|
PydanticOutputParser(pydantic_object=schema) # type: ignore[type-var, arg-type]
|
||||||
if is_pydantic_schema
|
if is_pydantic_schema
|
||||||
|
@ -931,7 +931,14 @@ class ChatMistralAI(BaseChatModel):
|
|||||||
)
|
)
|
||||||
# TODO: Update to pass in tool name as tool_choice if/when Mistral supports
|
# TODO: Update to pass in tool name as tool_choice if/when Mistral supports
|
||||||
# specifying a tool.
|
# specifying a tool.
|
||||||
llm = self.bind_tools([schema], tool_choice="any")
|
llm = self.bind_tools(
|
||||||
|
[schema],
|
||||||
|
tool_choice="any",
|
||||||
|
structured_output_format={
|
||||||
|
"kwargs": {"method": "function_calling"},
|
||||||
|
"schema": schema,
|
||||||
|
},
|
||||||
|
)
|
||||||
if is_pydantic_schema:
|
if is_pydantic_schema:
|
||||||
output_parser: OutputParserLike = PydanticToolsParser(
|
output_parser: OutputParserLike = PydanticToolsParser(
|
||||||
tools=[schema], # type: ignore[list-item]
|
tools=[schema], # type: ignore[list-item]
|
||||||
@ -943,7 +950,16 @@ class ChatMistralAI(BaseChatModel):
|
|||||||
key_name=key_name, first_tool_only=True
|
key_name=key_name, first_tool_only=True
|
||||||
)
|
)
|
||||||
elif method == "json_mode":
|
elif method == "json_mode":
|
||||||
llm = self.bind(response_format={"type": "json_object"})
|
llm = self.bind(
|
||||||
|
response_format={"type": "json_object"},
|
||||||
|
structured_output_format={
|
||||||
|
"kwargs": {
|
||||||
|
# this is correct - name difference with mistral api
|
||||||
|
"method": "json_mode"
|
||||||
|
},
|
||||||
|
"schema": schema,
|
||||||
|
},
|
||||||
|
)
|
||||||
output_parser = (
|
output_parser = (
|
||||||
PydanticOutputParser(pydantic_object=schema) # type: ignore[type-var, arg-type]
|
PydanticOutputParser(pydantic_object=schema) # type: ignore[type-var, arg-type]
|
||||||
if is_pydantic_schema
|
if is_pydantic_schema
|
||||||
@ -956,7 +972,13 @@ class ChatMistralAI(BaseChatModel):
|
|||||||
"Received None."
|
"Received None."
|
||||||
)
|
)
|
||||||
response_format = _convert_to_openai_response_format(schema, strict=True)
|
response_format = _convert_to_openai_response_format(schema, strict=True)
|
||||||
llm = self.bind(response_format=response_format)
|
llm = self.bind(
|
||||||
|
response_format=response_format,
|
||||||
|
structured_output_format={
|
||||||
|
"kwargs": {"method": "json_schema"},
|
||||||
|
"schema": schema,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
output_parser = (
|
output_parser = (
|
||||||
PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type]
|
PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type]
|
||||||
|
@ -1085,8 +1085,16 @@ class ChatOllama(BaseChatModel):
|
|||||||
"schema must be specified when method is not 'json_mode'. "
|
"schema must be specified when method is not 'json_mode'. "
|
||||||
"Received None."
|
"Received None."
|
||||||
)
|
)
|
||||||
tool_name = convert_to_openai_tool(schema)["function"]["name"]
|
formatted_tool = convert_to_openai_tool(schema)
|
||||||
llm = self.bind_tools([schema], tool_choice=tool_name)
|
tool_name = formatted_tool["function"]["name"]
|
||||||
|
llm = self.bind_tools(
|
||||||
|
[schema],
|
||||||
|
tool_choice=tool_name,
|
||||||
|
structured_output_format={
|
||||||
|
"kwargs": {"method": method},
|
||||||
|
"schema": formatted_tool,
|
||||||
|
},
|
||||||
|
)
|
||||||
if is_pydantic_schema:
|
if is_pydantic_schema:
|
||||||
output_parser: Runnable = PydanticToolsParser(
|
output_parser: Runnable = PydanticToolsParser(
|
||||||
tools=[schema], # type: ignore[list-item]
|
tools=[schema], # type: ignore[list-item]
|
||||||
@ -1097,7 +1105,13 @@ class ChatOllama(BaseChatModel):
|
|||||||
key_name=tool_name, first_tool_only=True
|
key_name=tool_name, first_tool_only=True
|
||||||
)
|
)
|
||||||
elif method == "json_mode":
|
elif method == "json_mode":
|
||||||
llm = self.bind(format="json")
|
llm = self.bind(
|
||||||
|
format="json",
|
||||||
|
structured_output_format={
|
||||||
|
"kwargs": {"method": method},
|
||||||
|
"schema": schema,
|
||||||
|
},
|
||||||
|
)
|
||||||
output_parser = (
|
output_parser = (
|
||||||
PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type]
|
PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type]
|
||||||
if is_pydantic_schema
|
if is_pydantic_schema
|
||||||
@ -1111,7 +1125,13 @@ class ChatOllama(BaseChatModel):
|
|||||||
)
|
)
|
||||||
if is_pydantic_schema:
|
if is_pydantic_schema:
|
||||||
schema = cast(TypeBaseModel, schema)
|
schema = cast(TypeBaseModel, schema)
|
||||||
llm = self.bind(format=schema.model_json_schema())
|
llm = self.bind(
|
||||||
|
format=schema.model_json_schema(),
|
||||||
|
structured_output_format={
|
||||||
|
"kwargs": {"method": method},
|
||||||
|
"schema": schema,
|
||||||
|
},
|
||||||
|
)
|
||||||
output_parser = PydanticOutputParser(pydantic_object=schema)
|
output_parser = PydanticOutputParser(pydantic_object=schema)
|
||||||
else:
|
else:
|
||||||
if is_typeddict(schema):
|
if is_typeddict(schema):
|
||||||
@ -1126,7 +1146,13 @@ class ChatOllama(BaseChatModel):
|
|||||||
else:
|
else:
|
||||||
# is JSON schema
|
# is JSON schema
|
||||||
response_format = schema
|
response_format = schema
|
||||||
llm = self.bind(format=response_format)
|
llm = self.bind(
|
||||||
|
format=response_format,
|
||||||
|
structured_output_format={
|
||||||
|
"kwargs": {"method": method},
|
||||||
|
"schema": response_format,
|
||||||
|
},
|
||||||
|
)
|
||||||
output_parser = JsonOutputParser()
|
output_parser = JsonOutputParser()
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -31,8 +31,8 @@ class TestChatOllama(ChatModelIntegrationTests):
|
|||||||
"Fails with 'AssertionError'. Ollama does not support 'tool_choice' yet."
|
"Fails with 'AssertionError'. Ollama does not support 'tool_choice' yet."
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
def test_structured_output(self, model: BaseChatModel) -> None:
|
def test_structured_output(self, model: BaseChatModel, schema_type: str) -> None:
|
||||||
super().test_structured_output(model)
|
super().test_structured_output(model, schema_type)
|
||||||
|
|
||||||
@pytest.mark.xfail(
|
@pytest.mark.xfail(
|
||||||
reason=(
|
reason=(
|
||||||
|
@ -1390,7 +1390,13 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
)
|
)
|
||||||
tool_name = convert_to_openai_tool(schema)["function"]["name"]
|
tool_name = convert_to_openai_tool(schema)["function"]["name"]
|
||||||
bind_kwargs = self._filter_disabled_params(
|
bind_kwargs = self._filter_disabled_params(
|
||||||
tool_choice=tool_name, parallel_tool_calls=False, strict=strict
|
tool_choice=tool_name,
|
||||||
|
parallel_tool_calls=False,
|
||||||
|
strict=strict,
|
||||||
|
structured_output_format={
|
||||||
|
"kwargs": {"method": method},
|
||||||
|
"schema": schema,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
llm = self.bind_tools([schema], **bind_kwargs)
|
llm = self.bind_tools([schema], **bind_kwargs)
|
||||||
@ -1404,7 +1410,13 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
key_name=tool_name, first_tool_only=True
|
key_name=tool_name, first_tool_only=True
|
||||||
)
|
)
|
||||||
elif method == "json_mode":
|
elif method == "json_mode":
|
||||||
llm = self.bind(response_format={"type": "json_object"})
|
llm = self.bind(
|
||||||
|
response_format={"type": "json_object"},
|
||||||
|
structured_output_format={
|
||||||
|
"kwargs": {"method": method},
|
||||||
|
"schema": schema,
|
||||||
|
},
|
||||||
|
)
|
||||||
output_parser = (
|
output_parser = (
|
||||||
PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type]
|
PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type]
|
||||||
if is_pydantic_schema
|
if is_pydantic_schema
|
||||||
@ -1417,7 +1429,13 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
"Received None."
|
"Received None."
|
||||||
)
|
)
|
||||||
response_format = _convert_to_openai_response_format(schema, strict=strict)
|
response_format = _convert_to_openai_response_format(schema, strict=strict)
|
||||||
llm = self.bind(response_format=response_format)
|
llm = self.bind(
|
||||||
|
response_format=response_format,
|
||||||
|
structured_output_format={
|
||||||
|
"kwargs": {"method": method},
|
||||||
|
"schema": convert_to_openai_tool(schema),
|
||||||
|
},
|
||||||
|
)
|
||||||
if is_pydantic_schema:
|
if is_pydantic_schema:
|
||||||
output_parser = _oai_structured_outputs_parser.with_types(
|
output_parser = _oai_structured_outputs_parser.with_types(
|
||||||
output_type=cast(type, schema)
|
output_type=cast(type, schema)
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
from typing import Any, List, Literal, Optional, cast
|
from typing import Any, List, Literal, Optional, cast
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import pytest
|
import pytest
|
||||||
|
from langchain_core.callbacks import BaseCallbackHandler
|
||||||
from langchain_core.language_models import BaseChatModel, GenericFakeChatModel
|
from langchain_core.language_models import BaseChatModel, GenericFakeChatModel
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
@ -17,7 +19,10 @@ from langchain_core.messages import (
|
|||||||
from langchain_core.output_parsers import StrOutputParser
|
from langchain_core.output_parsers import StrOutputParser
|
||||||
from langchain_core.prompts import ChatPromptTemplate
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
from langchain_core.tools import BaseTool, tool
|
from langchain_core.tools import BaseTool, tool
|
||||||
from langchain_core.utils.function_calling import tool_example_to_messages
|
from langchain_core.utils.function_calling import (
|
||||||
|
convert_to_openai_tool,
|
||||||
|
tool_example_to_messages,
|
||||||
|
)
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from pydantic.v1 import BaseModel as BaseModelV1
|
from pydantic.v1 import BaseModel as BaseModelV1
|
||||||
from pydantic.v1 import Field as FieldV1
|
from pydantic.v1 import Field as FieldV1
|
||||||
@ -66,6 +71,24 @@ def _get_joke_class(
|
|||||||
raise ValueError("Invalid schema type")
|
raise ValueError("Invalid schema type")
|
||||||
|
|
||||||
|
|
||||||
|
class _TestCallbackHandler(BaseCallbackHandler):
|
||||||
|
metadatas: list[Optional[dict]]
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.metadatas = []
|
||||||
|
|
||||||
|
def on_chat_model_start(
|
||||||
|
self,
|
||||||
|
serialized: Any,
|
||||||
|
messages: Any,
|
||||||
|
*,
|
||||||
|
metadata: Optional[dict[str, Any]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
self.metadatas.append(metadata)
|
||||||
|
|
||||||
|
|
||||||
class _MagicFunctionSchema(BaseModel):
|
class _MagicFunctionSchema(BaseModel):
|
||||||
input: int = Field(..., gt=-1000, lt=1000)
|
input: int = Field(..., gt=-1000, lt=1000)
|
||||||
|
|
||||||
@ -1207,13 +1230,46 @@ class ChatModelIntegrationTests(ChatModelTests):
|
|||||||
|
|
||||||
schema, validation_function = _get_joke_class(schema_type) # type: ignore[arg-type]
|
schema, validation_function = _get_joke_class(schema_type) # type: ignore[arg-type]
|
||||||
chat = model.with_structured_output(schema, **self.structured_output_kwargs)
|
chat = model.with_structured_output(schema, **self.structured_output_kwargs)
|
||||||
result = chat.invoke("Tell me a joke about cats.")
|
mock_callback = MagicMock()
|
||||||
|
mock_callback.on_chat_model_start = MagicMock()
|
||||||
|
|
||||||
|
invoke_callback = _TestCallbackHandler()
|
||||||
|
|
||||||
|
result = chat.invoke(
|
||||||
|
"Tell me a joke about cats.", config={"callbacks": [invoke_callback]}
|
||||||
|
)
|
||||||
validation_function(result)
|
validation_function(result)
|
||||||
|
|
||||||
for chunk in chat.stream("Tell me a joke about cats."):
|
assert len(invoke_callback.metadatas) == 1, (
|
||||||
|
"Expected on_chat_model_start to be called once"
|
||||||
|
)
|
||||||
|
assert isinstance(invoke_callback.metadatas[0], dict)
|
||||||
|
assert isinstance(
|
||||||
|
invoke_callback.metadatas[0]["structured_output_format"]["schema"], dict
|
||||||
|
)
|
||||||
|
assert invoke_callback.metadatas[0]["structured_output_format"][
|
||||||
|
"schema"
|
||||||
|
] == convert_to_openai_tool(schema)
|
||||||
|
|
||||||
|
stream_callback = _TestCallbackHandler()
|
||||||
|
|
||||||
|
for chunk in chat.stream(
|
||||||
|
"Tell me a joke about cats.", config={"callbacks": [stream_callback]}
|
||||||
|
):
|
||||||
validation_function(chunk)
|
validation_function(chunk)
|
||||||
assert chunk
|
assert chunk
|
||||||
|
|
||||||
|
assert len(stream_callback.metadatas) == 1, (
|
||||||
|
"Expected on_chat_model_start to be called once"
|
||||||
|
)
|
||||||
|
assert isinstance(stream_callback.metadatas[0], dict)
|
||||||
|
assert isinstance(
|
||||||
|
stream_callback.metadatas[0]["structured_output_format"]["schema"], dict
|
||||||
|
)
|
||||||
|
assert stream_callback.metadatas[0]["structured_output_format"][
|
||||||
|
"schema"
|
||||||
|
] == convert_to_openai_tool(schema)
|
||||||
|
|
||||||
@pytest.mark.parametrize("schema_type", ["pydantic", "typeddict", "json_schema"])
|
@pytest.mark.parametrize("schema_type", ["pydantic", "typeddict", "json_schema"])
|
||||||
async def test_structured_output_async(
|
async def test_structured_output_async(
|
||||||
self, model: BaseChatModel, schema_type: str
|
self, model: BaseChatModel, schema_type: str
|
||||||
@ -1248,14 +1304,46 @@ class ChatModelIntegrationTests(ChatModelTests):
|
|||||||
pytest.skip("Test requires tool calling.")
|
pytest.skip("Test requires tool calling.")
|
||||||
|
|
||||||
schema, validation_function = _get_joke_class(schema_type) # type: ignore[arg-type]
|
schema, validation_function = _get_joke_class(schema_type) # type: ignore[arg-type]
|
||||||
|
|
||||||
chat = model.with_structured_output(schema, **self.structured_output_kwargs)
|
chat = model.with_structured_output(schema, **self.structured_output_kwargs)
|
||||||
result = await chat.ainvoke("Tell me a joke about cats.")
|
ainvoke_callback = _TestCallbackHandler()
|
||||||
|
|
||||||
|
result = await chat.ainvoke(
|
||||||
|
"Tell me a joke about cats.", config={"callbacks": [ainvoke_callback]}
|
||||||
|
)
|
||||||
validation_function(result)
|
validation_function(result)
|
||||||
|
|
||||||
async for chunk in chat.astream("Tell me a joke about cats."):
|
assert len(ainvoke_callback.metadatas) == 1, (
|
||||||
|
"Expected on_chat_model_start to be called once"
|
||||||
|
)
|
||||||
|
assert isinstance(ainvoke_callback.metadatas[0], dict)
|
||||||
|
assert isinstance(
|
||||||
|
ainvoke_callback.metadatas[0]["structured_output_format"]["schema"], dict
|
||||||
|
)
|
||||||
|
assert ainvoke_callback.metadatas[0]["structured_output_format"][
|
||||||
|
"schema"
|
||||||
|
] == convert_to_openai_tool(schema)
|
||||||
|
|
||||||
|
astream_callback = _TestCallbackHandler()
|
||||||
|
|
||||||
|
async for chunk in chat.astream(
|
||||||
|
"Tell me a joke about cats.", config={"callbacks": [astream_callback]}
|
||||||
|
):
|
||||||
validation_function(chunk)
|
validation_function(chunk)
|
||||||
assert chunk
|
assert chunk
|
||||||
|
|
||||||
|
assert len(astream_callback.metadatas) == 1, (
|
||||||
|
"Expected on_chat_model_start to be called once"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(astream_callback.metadatas[0], dict)
|
||||||
|
assert isinstance(
|
||||||
|
astream_callback.metadatas[0]["structured_output_format"]["schema"], dict
|
||||||
|
)
|
||||||
|
assert astream_callback.metadatas[0]["structured_output_format"][
|
||||||
|
"schema"
|
||||||
|
] == convert_to_openai_tool(schema)
|
||||||
|
|
||||||
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Test requires pydantic 2.")
|
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Test requires pydantic 2.")
|
||||||
def test_structured_output_pydantic_2_v1(self, model: BaseChatModel) -> None:
|
def test_structured_output_pydantic_2_v1(self, model: BaseChatModel) -> None:
|
||||||
"""Test to verify we can generate structured output using
|
"""Test to verify we can generate structured output using
|
||||||
|
Loading…
Reference in New Issue
Block a user