multiple: structured output tracing standard metadata (#29421)

Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
Erick Friis 2025-01-29 14:00:26 -08:00 committed by GitHub
parent 284c935b08
commit 8f95da4eb1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 288 additions and 28 deletions

View File

@ -365,11 +365,28 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
else: else:
config = ensure_config(config) config = ensure_config(config)
messages = self._convert_input(input).to_messages() messages = self._convert_input(input).to_messages()
structured_output_format = kwargs.pop("structured_output_format", None)
if structured_output_format:
try:
structured_output_format_dict = {
"structured_output_format": {
"kwargs": structured_output_format.get("kwargs", {}),
"schema": convert_to_openai_tool(
structured_output_format["schema"]
),
}
}
except ValueError:
structured_output_format_dict = {}
else:
structured_output_format_dict = {}
params = self._get_invocation_params(stop=stop, **kwargs) params = self._get_invocation_params(stop=stop, **kwargs)
options = {"stop": stop, **kwargs} options = {"stop": stop, **kwargs}
inheritable_metadata = { inheritable_metadata = {
**(config.get("metadata") or {}), **(config.get("metadata") or {}),
**self._get_ls_params(stop=stop, **kwargs), **self._get_ls_params(stop=stop, **kwargs),
**structured_output_format_dict,
} }
callback_manager = CallbackManager.configure( callback_manager = CallbackManager.configure(
config.get("callbacks"), config.get("callbacks"),
@ -441,11 +458,29 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
config = ensure_config(config) config = ensure_config(config)
messages = self._convert_input(input).to_messages() messages = self._convert_input(input).to_messages()
structured_output_format = kwargs.pop("structured_output_format", None)
if structured_output_format:
try:
structured_output_format_dict = {
"structured_output_format": {
"kwargs": structured_output_format.get("kwargs", {}),
"schema": convert_to_openai_tool(
structured_output_format["schema"]
),
}
}
except ValueError:
structured_output_format_dict = {}
else:
structured_output_format_dict = {}
params = self._get_invocation_params(stop=stop, **kwargs) params = self._get_invocation_params(stop=stop, **kwargs)
options = {"stop": stop, **kwargs} options = {"stop": stop, **kwargs}
inheritable_metadata = { inheritable_metadata = {
**(config.get("metadata") or {}), **(config.get("metadata") or {}),
**self._get_ls_params(stop=stop, **kwargs), **self._get_ls_params(stop=stop, **kwargs),
**structured_output_format_dict,
} }
callback_manager = AsyncCallbackManager.configure( callback_manager = AsyncCallbackManager.configure(
config.get("callbacks"), config.get("callbacks"),
@ -606,11 +641,28 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
An LLMResult, which contains a list of candidate Generations for each input An LLMResult, which contains a list of candidate Generations for each input
prompt and additional model provider-specific output. prompt and additional model provider-specific output.
""" """
structured_output_format = kwargs.pop("structured_output_format", None)
if structured_output_format:
try:
structured_output_format_dict = {
"structured_output_format": {
"kwargs": structured_output_format.get("kwargs", {}),
"schema": convert_to_openai_tool(
structured_output_format["schema"]
),
}
}
except ValueError:
structured_output_format_dict = {}
else:
structured_output_format_dict = {}
params = self._get_invocation_params(stop=stop, **kwargs) params = self._get_invocation_params(stop=stop, **kwargs)
options = {"stop": stop} options = {"stop": stop}
inheritable_metadata = { inheritable_metadata = {
**(metadata or {}), **(metadata or {}),
**self._get_ls_params(stop=stop, **kwargs), **self._get_ls_params(stop=stop, **kwargs),
**structured_output_format_dict,
} }
callback_manager = CallbackManager.configure( callback_manager = CallbackManager.configure(
@ -697,11 +749,28 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
An LLMResult, which contains a list of candidate Generations for each input An LLMResult, which contains a list of candidate Generations for each input
prompt and additional model provider-specific output. prompt and additional model provider-specific output.
""" """
structured_output_format = kwargs.pop("structured_output_format", None)
if structured_output_format:
try:
structured_output_format_dict = {
"structured_output_format": {
"kwargs": structured_output_format.get("kwargs", {}),
"schema": convert_to_openai_tool(
structured_output_format["schema"]
),
}
}
except ValueError:
structured_output_format_dict = {}
else:
structured_output_format_dict = {}
params = self._get_invocation_params(stop=stop, **kwargs) params = self._get_invocation_params(stop=stop, **kwargs)
options = {"stop": stop} options = {"stop": stop}
inheritable_metadata = { inheritable_metadata = {
**(metadata or {}), **(metadata or {}),
**self._get_ls_params(stop=stop, **kwargs), **self._get_ls_params(stop=stop, **kwargs),
**structured_output_format_dict,
} }
callback_manager = AsyncCallbackManager.configure( callback_manager = AsyncCallbackManager.configure(
@ -1240,7 +1309,12 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
if self.bind_tools is BaseChatModel.bind_tools: if self.bind_tools is BaseChatModel.bind_tools:
msg = "with_structured_output is not implemented for this model." msg = "with_structured_output is not implemented for this model."
raise NotImplementedError(msg) raise NotImplementedError(msg)
llm = self.bind_tools([schema], tool_choice="any")
llm = self.bind_tools(
[schema],
tool_choice="any",
structured_output_format={"kwargs": {}, "schema": schema},
)
if isinstance(schema, type) and is_basemodel_subclass(schema): if isinstance(schema, type) and is_basemodel_subclass(schema):
output_parser: OutputParserLike = PydanticToolsParser( output_parser: OutputParserLike = PydanticToolsParser(
tools=[cast(TypeBaseModel, schema)], first_tool_only=True tools=[cast(TypeBaseModel, schema)], first_tool_only=True

View File

@ -1111,9 +1111,13 @@ class ChatAnthropic(BaseChatModel):
Added support for TypedDict class as `schema`. Added support for TypedDict class as `schema`.
""" # noqa: E501 """ # noqa: E501
formatted_tool = convert_to_anthropic_tool(schema)
tool_name = convert_to_anthropic_tool(schema)["name"] tool_name = formatted_tool["name"]
llm = self.bind_tools([schema], tool_choice=tool_name) llm = self.bind_tools(
[schema],
tool_choice=tool_name,
structured_output_format={"kwargs": {}, "schema": formatted_tool},
)
if isinstance(schema, type) and is_basemodel_subclass(schema): if isinstance(schema, type) and is_basemodel_subclass(schema):
output_parser: OutputParserLike = PydanticToolsParser( output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], first_tool_only=True tools=[schema], first_tool_only=True

View File

@ -965,8 +965,16 @@ class ChatFireworks(BaseChatModel):
"schema must be specified when method is 'function_calling'. " "schema must be specified when method is 'function_calling'. "
"Received None." "Received None."
) )
tool_name = convert_to_openai_tool(schema)["function"]["name"] formatted_tool = convert_to_openai_tool(schema)
llm = self.bind_tools([schema], tool_choice=tool_name) tool_name = formatted_tool["function"]["name"]
llm = self.bind_tools(
[schema],
tool_choice=tool_name,
structured_output_format={
"kwargs": {"method": "function_calling"},
"schema": formatted_tool,
},
)
if is_pydantic_schema: if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser( output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], # type: ignore[list-item] tools=[schema], # type: ignore[list-item]
@ -977,7 +985,13 @@ class ChatFireworks(BaseChatModel):
key_name=tool_name, first_tool_only=True key_name=tool_name, first_tool_only=True
) )
elif method == "json_mode": elif method == "json_mode":
llm = self.bind(response_format={"type": "json_object"}) llm = self.bind(
response_format={"type": "json_object"},
structured_output_format={
"kwargs": {"method": "json_mode"},
"schema": schema,
},
)
output_parser = ( output_parser = (
PydanticOutputParser(pydantic_object=schema) # type: ignore[type-var, arg-type] PydanticOutputParser(pydantic_object=schema) # type: ignore[type-var, arg-type]
if is_pydantic_schema if is_pydantic_schema

View File

@ -996,8 +996,16 @@ class ChatGroq(BaseChatModel):
"schema must be specified when method is 'function_calling'. " "schema must be specified when method is 'function_calling'. "
"Received None." "Received None."
) )
tool_name = convert_to_openai_tool(schema)["function"]["name"] formatted_tool = convert_to_openai_tool(schema)
llm = self.bind_tools([schema], tool_choice=tool_name) tool_name = formatted_tool["function"]["name"]
llm = self.bind_tools(
[schema],
tool_choice=tool_name,
structured_output_format={
"kwargs": {"method": "function_calling"},
"schema": formatted_tool,
},
)
if is_pydantic_schema: if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser( output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], # type: ignore[list-item] tools=[schema], # type: ignore[list-item]
@ -1008,7 +1016,13 @@ class ChatGroq(BaseChatModel):
key_name=tool_name, first_tool_only=True key_name=tool_name, first_tool_only=True
) )
elif method == "json_mode": elif method == "json_mode":
llm = self.bind(response_format={"type": "json_object"}) llm = self.bind(
response_format={"type": "json_object"},
structured_output_format={
"kwargs": {"method": "json_mode"},
"schema": schema,
},
)
output_parser = ( output_parser = (
PydanticOutputParser(pydantic_object=schema) # type: ignore[type-var, arg-type] PydanticOutputParser(pydantic_object=schema) # type: ignore[type-var, arg-type]
if is_pydantic_schema if is_pydantic_schema

View File

@ -931,7 +931,14 @@ class ChatMistralAI(BaseChatModel):
) )
# TODO: Update to pass in tool name as tool_choice if/when Mistral supports # TODO: Update to pass in tool name as tool_choice if/when Mistral supports
# specifying a tool. # specifying a tool.
llm = self.bind_tools([schema], tool_choice="any") llm = self.bind_tools(
[schema],
tool_choice="any",
structured_output_format={
"kwargs": {"method": "function_calling"},
"schema": schema,
},
)
if is_pydantic_schema: if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser( output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], # type: ignore[list-item] tools=[schema], # type: ignore[list-item]
@ -943,7 +950,16 @@ class ChatMistralAI(BaseChatModel):
key_name=key_name, first_tool_only=True key_name=key_name, first_tool_only=True
) )
elif method == "json_mode": elif method == "json_mode":
llm = self.bind(response_format={"type": "json_object"}) llm = self.bind(
response_format={"type": "json_object"},
structured_output_format={
"kwargs": {
# this is correct - name difference with mistral api
"method": "json_mode"
},
"schema": schema,
},
)
output_parser = ( output_parser = (
PydanticOutputParser(pydantic_object=schema) # type: ignore[type-var, arg-type] PydanticOutputParser(pydantic_object=schema) # type: ignore[type-var, arg-type]
if is_pydantic_schema if is_pydantic_schema
@ -956,7 +972,13 @@ class ChatMistralAI(BaseChatModel):
"Received None." "Received None."
) )
response_format = _convert_to_openai_response_format(schema, strict=True) response_format = _convert_to_openai_response_format(schema, strict=True)
llm = self.bind(response_format=response_format) llm = self.bind(
response_format=response_format,
structured_output_format={
"kwargs": {"method": "json_schema"},
"schema": schema,
},
)
output_parser = ( output_parser = (
PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type] PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type]

View File

@ -1085,8 +1085,16 @@ class ChatOllama(BaseChatModel):
"schema must be specified when method is not 'json_mode'. " "schema must be specified when method is not 'json_mode'. "
"Received None." "Received None."
) )
tool_name = convert_to_openai_tool(schema)["function"]["name"] formatted_tool = convert_to_openai_tool(schema)
llm = self.bind_tools([schema], tool_choice=tool_name) tool_name = formatted_tool["function"]["name"]
llm = self.bind_tools(
[schema],
tool_choice=tool_name,
structured_output_format={
"kwargs": {"method": method},
"schema": formatted_tool,
},
)
if is_pydantic_schema: if is_pydantic_schema:
output_parser: Runnable = PydanticToolsParser( output_parser: Runnable = PydanticToolsParser(
tools=[schema], # type: ignore[list-item] tools=[schema], # type: ignore[list-item]
@ -1097,7 +1105,13 @@ class ChatOllama(BaseChatModel):
key_name=tool_name, first_tool_only=True key_name=tool_name, first_tool_only=True
) )
elif method == "json_mode": elif method == "json_mode":
llm = self.bind(format="json") llm = self.bind(
format="json",
structured_output_format={
"kwargs": {"method": method},
"schema": schema,
},
)
output_parser = ( output_parser = (
PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type] PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type]
if is_pydantic_schema if is_pydantic_schema
@ -1111,7 +1125,13 @@ class ChatOllama(BaseChatModel):
) )
if is_pydantic_schema: if is_pydantic_schema:
schema = cast(TypeBaseModel, schema) schema = cast(TypeBaseModel, schema)
llm = self.bind(format=schema.model_json_schema()) llm = self.bind(
format=schema.model_json_schema(),
structured_output_format={
"kwargs": {"method": method},
"schema": schema,
},
)
output_parser = PydanticOutputParser(pydantic_object=schema) output_parser = PydanticOutputParser(pydantic_object=schema)
else: else:
if is_typeddict(schema): if is_typeddict(schema):
@ -1126,7 +1146,13 @@ class ChatOllama(BaseChatModel):
else: else:
# is JSON schema # is JSON schema
response_format = schema response_format = schema
llm = self.bind(format=response_format) llm = self.bind(
format=response_format,
structured_output_format={
"kwargs": {"method": method},
"schema": response_format,
},
)
output_parser = JsonOutputParser() output_parser = JsonOutputParser()
else: else:
raise ValueError( raise ValueError(

View File

@ -31,8 +31,8 @@ class TestChatOllama(ChatModelIntegrationTests):
"Fails with 'AssertionError'. Ollama does not support 'tool_choice' yet." "Fails with 'AssertionError'. Ollama does not support 'tool_choice' yet."
) )
) )
def test_structured_output(self, model: BaseChatModel) -> None: def test_structured_output(self, model: BaseChatModel, schema_type: str) -> None:
super().test_structured_output(model) super().test_structured_output(model, schema_type)
@pytest.mark.xfail( @pytest.mark.xfail(
reason=( reason=(

View File

@ -1390,7 +1390,13 @@ class BaseChatOpenAI(BaseChatModel):
) )
tool_name = convert_to_openai_tool(schema)["function"]["name"] tool_name = convert_to_openai_tool(schema)["function"]["name"]
bind_kwargs = self._filter_disabled_params( bind_kwargs = self._filter_disabled_params(
tool_choice=tool_name, parallel_tool_calls=False, strict=strict tool_choice=tool_name,
parallel_tool_calls=False,
strict=strict,
structured_output_format={
"kwargs": {"method": method},
"schema": schema,
},
) )
llm = self.bind_tools([schema], **bind_kwargs) llm = self.bind_tools([schema], **bind_kwargs)
@ -1404,7 +1410,13 @@ class BaseChatOpenAI(BaseChatModel):
key_name=tool_name, first_tool_only=True key_name=tool_name, first_tool_only=True
) )
elif method == "json_mode": elif method == "json_mode":
llm = self.bind(response_format={"type": "json_object"}) llm = self.bind(
response_format={"type": "json_object"},
structured_output_format={
"kwargs": {"method": method},
"schema": schema,
},
)
output_parser = ( output_parser = (
PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type] PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type]
if is_pydantic_schema if is_pydantic_schema
@ -1417,7 +1429,13 @@ class BaseChatOpenAI(BaseChatModel):
"Received None." "Received None."
) )
response_format = _convert_to_openai_response_format(schema, strict=strict) response_format = _convert_to_openai_response_format(schema, strict=strict)
llm = self.bind(response_format=response_format) llm = self.bind(
response_format=response_format,
structured_output_format={
"kwargs": {"method": method},
"schema": convert_to_openai_tool(schema),
},
)
if is_pydantic_schema: if is_pydantic_schema:
output_parser = _oai_structured_outputs_parser.with_types( output_parser = _oai_structured_outputs_parser.with_types(
output_type=cast(type, schema) output_type=cast(type, schema)

View File

@ -1,9 +1,11 @@
import base64 import base64
import json import json
from typing import Any, List, Literal, Optional, cast from typing import Any, List, Literal, Optional, cast
from unittest.mock import MagicMock
import httpx import httpx
import pytest import pytest
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.language_models import BaseChatModel, GenericFakeChatModel from langchain_core.language_models import BaseChatModel, GenericFakeChatModel
from langchain_core.messages import ( from langchain_core.messages import (
AIMessage, AIMessage,
@ -17,7 +19,10 @@ from langchain_core.messages import (
from langchain_core.output_parsers import StrOutputParser from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate from langchain_core.prompts import ChatPromptTemplate
from langchain_core.tools import BaseTool, tool from langchain_core.tools import BaseTool, tool
from langchain_core.utils.function_calling import tool_example_to_messages from langchain_core.utils.function_calling import (
convert_to_openai_tool,
tool_example_to_messages,
)
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from pydantic.v1 import BaseModel as BaseModelV1 from pydantic.v1 import BaseModel as BaseModelV1
from pydantic.v1 import Field as FieldV1 from pydantic.v1 import Field as FieldV1
@ -66,6 +71,24 @@ def _get_joke_class(
raise ValueError("Invalid schema type") raise ValueError("Invalid schema type")
class _TestCallbackHandler(BaseCallbackHandler):
metadatas: list[Optional[dict]]
def __init__(self) -> None:
super().__init__()
self.metadatas = []
def on_chat_model_start(
self,
serialized: Any,
messages: Any,
*,
metadata: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> None:
self.metadatas.append(metadata)
class _MagicFunctionSchema(BaseModel): class _MagicFunctionSchema(BaseModel):
input: int = Field(..., gt=-1000, lt=1000) input: int = Field(..., gt=-1000, lt=1000)
@ -1207,13 +1230,46 @@ class ChatModelIntegrationTests(ChatModelTests):
schema, validation_function = _get_joke_class(schema_type) # type: ignore[arg-type] schema, validation_function = _get_joke_class(schema_type) # type: ignore[arg-type]
chat = model.with_structured_output(schema, **self.structured_output_kwargs) chat = model.with_structured_output(schema, **self.structured_output_kwargs)
result = chat.invoke("Tell me a joke about cats.") mock_callback = MagicMock()
mock_callback.on_chat_model_start = MagicMock()
invoke_callback = _TestCallbackHandler()
result = chat.invoke(
"Tell me a joke about cats.", config={"callbacks": [invoke_callback]}
)
validation_function(result) validation_function(result)
for chunk in chat.stream("Tell me a joke about cats."): assert len(invoke_callback.metadatas) == 1, (
"Expected on_chat_model_start to be called once"
)
assert isinstance(invoke_callback.metadatas[0], dict)
assert isinstance(
invoke_callback.metadatas[0]["structured_output_format"]["schema"], dict
)
assert invoke_callback.metadatas[0]["structured_output_format"][
"schema"
] == convert_to_openai_tool(schema)
stream_callback = _TestCallbackHandler()
for chunk in chat.stream(
"Tell me a joke about cats.", config={"callbacks": [stream_callback]}
):
validation_function(chunk) validation_function(chunk)
assert chunk assert chunk
assert len(stream_callback.metadatas) == 1, (
"Expected on_chat_model_start to be called once"
)
assert isinstance(stream_callback.metadatas[0], dict)
assert isinstance(
stream_callback.metadatas[0]["structured_output_format"]["schema"], dict
)
assert stream_callback.metadatas[0]["structured_output_format"][
"schema"
] == convert_to_openai_tool(schema)
@pytest.mark.parametrize("schema_type", ["pydantic", "typeddict", "json_schema"]) @pytest.mark.parametrize("schema_type", ["pydantic", "typeddict", "json_schema"])
async def test_structured_output_async( async def test_structured_output_async(
self, model: BaseChatModel, schema_type: str self, model: BaseChatModel, schema_type: str
@ -1248,14 +1304,46 @@ class ChatModelIntegrationTests(ChatModelTests):
pytest.skip("Test requires tool calling.") pytest.skip("Test requires tool calling.")
schema, validation_function = _get_joke_class(schema_type) # type: ignore[arg-type] schema, validation_function = _get_joke_class(schema_type) # type: ignore[arg-type]
chat = model.with_structured_output(schema, **self.structured_output_kwargs) chat = model.with_structured_output(schema, **self.structured_output_kwargs)
result = await chat.ainvoke("Tell me a joke about cats.") ainvoke_callback = _TestCallbackHandler()
result = await chat.ainvoke(
"Tell me a joke about cats.", config={"callbacks": [ainvoke_callback]}
)
validation_function(result) validation_function(result)
async for chunk in chat.astream("Tell me a joke about cats."): assert len(ainvoke_callback.metadatas) == 1, (
"Expected on_chat_model_start to be called once"
)
assert isinstance(ainvoke_callback.metadatas[0], dict)
assert isinstance(
ainvoke_callback.metadatas[0]["structured_output_format"]["schema"], dict
)
assert ainvoke_callback.metadatas[0]["structured_output_format"][
"schema"
] == convert_to_openai_tool(schema)
astream_callback = _TestCallbackHandler()
async for chunk in chat.astream(
"Tell me a joke about cats.", config={"callbacks": [astream_callback]}
):
validation_function(chunk) validation_function(chunk)
assert chunk assert chunk
assert len(astream_callback.metadatas) == 1, (
"Expected on_chat_model_start to be called once"
)
assert isinstance(astream_callback.metadatas[0], dict)
assert isinstance(
astream_callback.metadatas[0]["structured_output_format"]["schema"], dict
)
assert astream_callback.metadatas[0]["structured_output_format"][
"schema"
] == convert_to_openai_tool(schema)
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Test requires pydantic 2.") @pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Test requires pydantic 2.")
def test_structured_output_pydantic_2_v1(self, model: BaseChatModel) -> None: def test_structured_output_pydantic_2_v1(self, model: BaseChatModel) -> None:
"""Test to verify we can generate structured output using """Test to verify we can generate structured output using