From 8f95da4eb1203f8a6a4a3ca85b2c6004b340ed37 Mon Sep 17 00:00:00 2001 From: Erick Friis Date: Wed, 29 Jan 2025 14:00:26 -0800 Subject: [PATCH] multiple: structured output tracing standard metadata (#29421) Co-authored-by: Chester Curme --- .../language_models/chat_models.py | 76 +++++++++++++- .../langchain_anthropic/chat_models.py | 10 +- .../langchain_fireworks/chat_models.py | 20 +++- .../groq/langchain_groq/chat_models.py | 20 +++- .../langchain_mistralai/chat_models.py | 28 +++++- .../ollama/langchain_ollama/chat_models.py | 36 ++++++- .../chat_models/test_chat_models_standard.py | 4 +- .../langchain_openai/chat_models/base.py | 24 ++++- .../integration_tests/chat_models.py | 98 ++++++++++++++++++- 9 files changed, 288 insertions(+), 28 deletions(-) diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index 6aaaf7d4ca8..dca8e9edaea 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -365,11 +365,28 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): else: config = ensure_config(config) messages = self._convert_input(input).to_messages() + structured_output_format = kwargs.pop("structured_output_format", None) + if structured_output_format: + try: + structured_output_format_dict = { + "structured_output_format": { + "kwargs": structured_output_format.get("kwargs", {}), + "schema": convert_to_openai_tool( + structured_output_format["schema"] + ), + } + } + except ValueError: + structured_output_format_dict = {} + else: + structured_output_format_dict = {} + params = self._get_invocation_params(stop=stop, **kwargs) options = {"stop": stop, **kwargs} inheritable_metadata = { **(config.get("metadata") or {}), **self._get_ls_params(stop=stop, **kwargs), + **structured_output_format_dict, } callback_manager = CallbackManager.configure( config.get("callbacks"), @@ -441,11 +458,29 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): config = ensure_config(config) messages = self._convert_input(input).to_messages() + + structured_output_format = kwargs.pop("structured_output_format", None) + if structured_output_format: + try: + structured_output_format_dict = { + "structured_output_format": { + "kwargs": structured_output_format.get("kwargs", {}), + "schema": convert_to_openai_tool( + structured_output_format["schema"] + ), + } + } + except ValueError: + structured_output_format_dict = {} + else: + structured_output_format_dict = {} + params = self._get_invocation_params(stop=stop, **kwargs) options = {"stop": stop, **kwargs} inheritable_metadata = { **(config.get("metadata") or {}), **self._get_ls_params(stop=stop, **kwargs), + **structured_output_format_dict, } callback_manager = AsyncCallbackManager.configure( config.get("callbacks"), @@ -606,11 +641,28 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): An LLMResult, which contains a list of candidate Generations for each input prompt and additional model provider-specific output. """ + structured_output_format = kwargs.pop("structured_output_format", None) + if structured_output_format: + try: + structured_output_format_dict = { + "structured_output_format": { + "kwargs": structured_output_format.get("kwargs", {}), + "schema": convert_to_openai_tool( + structured_output_format["schema"] + ), + } + } + except ValueError: + structured_output_format_dict = {} + else: + structured_output_format_dict = {} + params = self._get_invocation_params(stop=stop, **kwargs) options = {"stop": stop} inheritable_metadata = { **(metadata or {}), **self._get_ls_params(stop=stop, **kwargs), + **structured_output_format_dict, } callback_manager = CallbackManager.configure( @@ -697,11 +749,28 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): An LLMResult, which contains a list of candidate Generations for each input prompt and additional model provider-specific output. """ + structured_output_format = kwargs.pop("structured_output_format", None) + if structured_output_format: + try: + structured_output_format_dict = { + "structured_output_format": { + "kwargs": structured_output_format.get("kwargs", {}), + "schema": convert_to_openai_tool( + structured_output_format["schema"] + ), + } + } + except ValueError: + structured_output_format_dict = {} + else: + structured_output_format_dict = {} + params = self._get_invocation_params(stop=stop, **kwargs) options = {"stop": stop} inheritable_metadata = { **(metadata or {}), **self._get_ls_params(stop=stop, **kwargs), + **structured_output_format_dict, } callback_manager = AsyncCallbackManager.configure( @@ -1240,7 +1309,12 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): if self.bind_tools is BaseChatModel.bind_tools: msg = "with_structured_output is not implemented for this model." raise NotImplementedError(msg) - llm = self.bind_tools([schema], tool_choice="any") + + llm = self.bind_tools( + [schema], + tool_choice="any", + structured_output_format={"kwargs": {}, "schema": schema}, + ) if isinstance(schema, type) and is_basemodel_subclass(schema): output_parser: OutputParserLike = PydanticToolsParser( tools=[cast(TypeBaseModel, schema)], first_tool_only=True diff --git a/libs/partners/anthropic/langchain_anthropic/chat_models.py b/libs/partners/anthropic/langchain_anthropic/chat_models.py index d3d800ed28f..e58675728df 100644 --- a/libs/partners/anthropic/langchain_anthropic/chat_models.py +++ b/libs/partners/anthropic/langchain_anthropic/chat_models.py @@ -1111,9 +1111,13 @@ class ChatAnthropic(BaseChatModel): Added support for TypedDict class as `schema`. """ # noqa: E501 - - tool_name = convert_to_anthropic_tool(schema)["name"] - llm = self.bind_tools([schema], tool_choice=tool_name) + formatted_tool = convert_to_anthropic_tool(schema) + tool_name = formatted_tool["name"] + llm = self.bind_tools( + [schema], + tool_choice=tool_name, + structured_output_format={"kwargs": {}, "schema": formatted_tool}, + ) if isinstance(schema, type) and is_basemodel_subclass(schema): output_parser: OutputParserLike = PydanticToolsParser( tools=[schema], first_tool_only=True diff --git a/libs/partners/fireworks/langchain_fireworks/chat_models.py b/libs/partners/fireworks/langchain_fireworks/chat_models.py index 92977104706..42bed993623 100644 --- a/libs/partners/fireworks/langchain_fireworks/chat_models.py +++ b/libs/partners/fireworks/langchain_fireworks/chat_models.py @@ -965,8 +965,16 @@ class ChatFireworks(BaseChatModel): "schema must be specified when method is 'function_calling'. " "Received None." ) - tool_name = convert_to_openai_tool(schema)["function"]["name"] - llm = self.bind_tools([schema], tool_choice=tool_name) + formatted_tool = convert_to_openai_tool(schema) + tool_name = formatted_tool["function"]["name"] + llm = self.bind_tools( + [schema], + tool_choice=tool_name, + structured_output_format={ + "kwargs": {"method": "function_calling"}, + "schema": formatted_tool, + }, + ) if is_pydantic_schema: output_parser: OutputParserLike = PydanticToolsParser( tools=[schema], # type: ignore[list-item] @@ -977,7 +985,13 @@ class ChatFireworks(BaseChatModel): key_name=tool_name, first_tool_only=True ) elif method == "json_mode": - llm = self.bind(response_format={"type": "json_object"}) + llm = self.bind( + response_format={"type": "json_object"}, + structured_output_format={ + "kwargs": {"method": "json_mode"}, + "schema": schema, + }, + ) output_parser = ( PydanticOutputParser(pydantic_object=schema) # type: ignore[type-var, arg-type] if is_pydantic_schema diff --git a/libs/partners/groq/langchain_groq/chat_models.py b/libs/partners/groq/langchain_groq/chat_models.py index 45e3cf792b3..962f72979b4 100644 --- a/libs/partners/groq/langchain_groq/chat_models.py +++ b/libs/partners/groq/langchain_groq/chat_models.py @@ -996,8 +996,16 @@ class ChatGroq(BaseChatModel): "schema must be specified when method is 'function_calling'. " "Received None." ) - tool_name = convert_to_openai_tool(schema)["function"]["name"] - llm = self.bind_tools([schema], tool_choice=tool_name) + formatted_tool = convert_to_openai_tool(schema) + tool_name = formatted_tool["function"]["name"] + llm = self.bind_tools( + [schema], + tool_choice=tool_name, + structured_output_format={ + "kwargs": {"method": "function_calling"}, + "schema": formatted_tool, + }, + ) if is_pydantic_schema: output_parser: OutputParserLike = PydanticToolsParser( tools=[schema], # type: ignore[list-item] @@ -1008,7 +1016,13 @@ class ChatGroq(BaseChatModel): key_name=tool_name, first_tool_only=True ) elif method == "json_mode": - llm = self.bind(response_format={"type": "json_object"}) + llm = self.bind( + response_format={"type": "json_object"}, + structured_output_format={ + "kwargs": {"method": "json_mode"}, + "schema": schema, + }, + ) output_parser = ( PydanticOutputParser(pydantic_object=schema) # type: ignore[type-var, arg-type] if is_pydantic_schema diff --git a/libs/partners/mistralai/langchain_mistralai/chat_models.py b/libs/partners/mistralai/langchain_mistralai/chat_models.py index 78739a5fea1..4e304e29e0d 100644 --- a/libs/partners/mistralai/langchain_mistralai/chat_models.py +++ b/libs/partners/mistralai/langchain_mistralai/chat_models.py @@ -931,7 +931,14 @@ class ChatMistralAI(BaseChatModel): ) # TODO: Update to pass in tool name as tool_choice if/when Mistral supports # specifying a tool. - llm = self.bind_tools([schema], tool_choice="any") + llm = self.bind_tools( + [schema], + tool_choice="any", + structured_output_format={ + "kwargs": {"method": "function_calling"}, + "schema": schema, + }, + ) if is_pydantic_schema: output_parser: OutputParserLike = PydanticToolsParser( tools=[schema], # type: ignore[list-item] @@ -943,7 +950,16 @@ class ChatMistralAI(BaseChatModel): key_name=key_name, first_tool_only=True ) elif method == "json_mode": - llm = self.bind(response_format={"type": "json_object"}) + llm = self.bind( + response_format={"type": "json_object"}, + structured_output_format={ + "kwargs": { + # this is correct - name difference with mistral api + "method": "json_mode" + }, + "schema": schema, + }, + ) output_parser = ( PydanticOutputParser(pydantic_object=schema) # type: ignore[type-var, arg-type] if is_pydantic_schema @@ -956,7 +972,13 @@ class ChatMistralAI(BaseChatModel): "Received None." ) response_format = _convert_to_openai_response_format(schema, strict=True) - llm = self.bind(response_format=response_format) + llm = self.bind( + response_format=response_format, + structured_output_format={ + "kwargs": {"method": "json_schema"}, + "schema": schema, + }, + ) output_parser = ( PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type] diff --git a/libs/partners/ollama/langchain_ollama/chat_models.py b/libs/partners/ollama/langchain_ollama/chat_models.py index b4f8c0f2d9a..7a179b2fbed 100644 --- a/libs/partners/ollama/langchain_ollama/chat_models.py +++ b/libs/partners/ollama/langchain_ollama/chat_models.py @@ -1085,8 +1085,16 @@ class ChatOllama(BaseChatModel): "schema must be specified when method is not 'json_mode'. " "Received None." ) - tool_name = convert_to_openai_tool(schema)["function"]["name"] - llm = self.bind_tools([schema], tool_choice=tool_name) + formatted_tool = convert_to_openai_tool(schema) + tool_name = formatted_tool["function"]["name"] + llm = self.bind_tools( + [schema], + tool_choice=tool_name, + structured_output_format={ + "kwargs": {"method": method}, + "schema": formatted_tool, + }, + ) if is_pydantic_schema: output_parser: Runnable = PydanticToolsParser( tools=[schema], # type: ignore[list-item] @@ -1097,7 +1105,13 @@ class ChatOllama(BaseChatModel): key_name=tool_name, first_tool_only=True ) elif method == "json_mode": - llm = self.bind(format="json") + llm = self.bind( + format="json", + structured_output_format={ + "kwargs": {"method": method}, + "schema": schema, + }, + ) output_parser = ( PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type] if is_pydantic_schema @@ -1111,7 +1125,13 @@ class ChatOllama(BaseChatModel): ) if is_pydantic_schema: schema = cast(TypeBaseModel, schema) - llm = self.bind(format=schema.model_json_schema()) + llm = self.bind( + format=schema.model_json_schema(), + structured_output_format={ + "kwargs": {"method": method}, + "schema": schema, + }, + ) output_parser = PydanticOutputParser(pydantic_object=schema) else: if is_typeddict(schema): @@ -1126,7 +1146,13 @@ class ChatOllama(BaseChatModel): else: # is JSON schema response_format = schema - llm = self.bind(format=response_format) + llm = self.bind( + format=response_format, + structured_output_format={ + "kwargs": {"method": method}, + "schema": response_format, + }, + ) output_parser = JsonOutputParser() else: raise ValueError( diff --git a/libs/partners/ollama/tests/integration_tests/chat_models/test_chat_models_standard.py b/libs/partners/ollama/tests/integration_tests/chat_models/test_chat_models_standard.py index 4b8feccc8af..5f990f2251b 100644 --- a/libs/partners/ollama/tests/integration_tests/chat_models/test_chat_models_standard.py +++ b/libs/partners/ollama/tests/integration_tests/chat_models/test_chat_models_standard.py @@ -31,8 +31,8 @@ class TestChatOllama(ChatModelIntegrationTests): "Fails with 'AssertionError'. Ollama does not support 'tool_choice' yet." ) ) - def test_structured_output(self, model: BaseChatModel) -> None: - super().test_structured_output(model) + def test_structured_output(self, model: BaseChatModel, schema_type: str) -> None: + super().test_structured_output(model, schema_type) @pytest.mark.xfail( reason=( diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 982979e8949..5ac28a717ac 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -1390,7 +1390,13 @@ class BaseChatOpenAI(BaseChatModel): ) tool_name = convert_to_openai_tool(schema)["function"]["name"] bind_kwargs = self._filter_disabled_params( - tool_choice=tool_name, parallel_tool_calls=False, strict=strict + tool_choice=tool_name, + parallel_tool_calls=False, + strict=strict, + structured_output_format={ + "kwargs": {"method": method}, + "schema": schema, + }, ) llm = self.bind_tools([schema], **bind_kwargs) @@ -1404,7 +1410,13 @@ class BaseChatOpenAI(BaseChatModel): key_name=tool_name, first_tool_only=True ) elif method == "json_mode": - llm = self.bind(response_format={"type": "json_object"}) + llm = self.bind( + response_format={"type": "json_object"}, + structured_output_format={ + "kwargs": {"method": method}, + "schema": schema, + }, + ) output_parser = ( PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type] if is_pydantic_schema @@ -1417,7 +1429,13 @@ class BaseChatOpenAI(BaseChatModel): "Received None." ) response_format = _convert_to_openai_response_format(schema, strict=strict) - llm = self.bind(response_format=response_format) + llm = self.bind( + response_format=response_format, + structured_output_format={ + "kwargs": {"method": method}, + "schema": convert_to_openai_tool(schema), + }, + ) if is_pydantic_schema: output_parser = _oai_structured_outputs_parser.with_types( output_type=cast(type, schema) diff --git a/libs/standard-tests/langchain_tests/integration_tests/chat_models.py b/libs/standard-tests/langchain_tests/integration_tests/chat_models.py index 6fef3158930..f7371c6bcf7 100644 --- a/libs/standard-tests/langchain_tests/integration_tests/chat_models.py +++ b/libs/standard-tests/langchain_tests/integration_tests/chat_models.py @@ -1,9 +1,11 @@ import base64 import json from typing import Any, List, Literal, Optional, cast +from unittest.mock import MagicMock import httpx import pytest +from langchain_core.callbacks import BaseCallbackHandler from langchain_core.language_models import BaseChatModel, GenericFakeChatModel from langchain_core.messages import ( AIMessage, @@ -17,7 +19,10 @@ from langchain_core.messages import ( from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate from langchain_core.tools import BaseTool, tool -from langchain_core.utils.function_calling import tool_example_to_messages +from langchain_core.utils.function_calling import ( + convert_to_openai_tool, + tool_example_to_messages, +) from pydantic import BaseModel, Field from pydantic.v1 import BaseModel as BaseModelV1 from pydantic.v1 import Field as FieldV1 @@ -66,6 +71,24 @@ def _get_joke_class( raise ValueError("Invalid schema type") +class _TestCallbackHandler(BaseCallbackHandler): + metadatas: list[Optional[dict]] + + def __init__(self) -> None: + super().__init__() + self.metadatas = [] + + def on_chat_model_start( + self, + serialized: Any, + messages: Any, + *, + metadata: Optional[dict[str, Any]] = None, + **kwargs: Any, + ) -> None: + self.metadatas.append(metadata) + + class _MagicFunctionSchema(BaseModel): input: int = Field(..., gt=-1000, lt=1000) @@ -1207,13 +1230,46 @@ class ChatModelIntegrationTests(ChatModelTests): schema, validation_function = _get_joke_class(schema_type) # type: ignore[arg-type] chat = model.with_structured_output(schema, **self.structured_output_kwargs) - result = chat.invoke("Tell me a joke about cats.") + mock_callback = MagicMock() + mock_callback.on_chat_model_start = MagicMock() + + invoke_callback = _TestCallbackHandler() + + result = chat.invoke( + "Tell me a joke about cats.", config={"callbacks": [invoke_callback]} + ) validation_function(result) - for chunk in chat.stream("Tell me a joke about cats."): + assert len(invoke_callback.metadatas) == 1, ( + "Expected on_chat_model_start to be called once" + ) + assert isinstance(invoke_callback.metadatas[0], dict) + assert isinstance( + invoke_callback.metadatas[0]["structured_output_format"]["schema"], dict + ) + assert invoke_callback.metadatas[0]["structured_output_format"][ + "schema" + ] == convert_to_openai_tool(schema) + + stream_callback = _TestCallbackHandler() + + for chunk in chat.stream( + "Tell me a joke about cats.", config={"callbacks": [stream_callback]} + ): validation_function(chunk) assert chunk + assert len(stream_callback.metadatas) == 1, ( + "Expected on_chat_model_start to be called once" + ) + assert isinstance(stream_callback.metadatas[0], dict) + assert isinstance( + stream_callback.metadatas[0]["structured_output_format"]["schema"], dict + ) + assert stream_callback.metadatas[0]["structured_output_format"][ + "schema" + ] == convert_to_openai_tool(schema) + @pytest.mark.parametrize("schema_type", ["pydantic", "typeddict", "json_schema"]) async def test_structured_output_async( self, model: BaseChatModel, schema_type: str @@ -1248,14 +1304,46 @@ class ChatModelIntegrationTests(ChatModelTests): pytest.skip("Test requires tool calling.") schema, validation_function = _get_joke_class(schema_type) # type: ignore[arg-type] + chat = model.with_structured_output(schema, **self.structured_output_kwargs) - result = await chat.ainvoke("Tell me a joke about cats.") + ainvoke_callback = _TestCallbackHandler() + + result = await chat.ainvoke( + "Tell me a joke about cats.", config={"callbacks": [ainvoke_callback]} + ) validation_function(result) - async for chunk in chat.astream("Tell me a joke about cats."): + assert len(ainvoke_callback.metadatas) == 1, ( + "Expected on_chat_model_start to be called once" + ) + assert isinstance(ainvoke_callback.metadatas[0], dict) + assert isinstance( + ainvoke_callback.metadatas[0]["structured_output_format"]["schema"], dict + ) + assert ainvoke_callback.metadatas[0]["structured_output_format"][ + "schema" + ] == convert_to_openai_tool(schema) + + astream_callback = _TestCallbackHandler() + + async for chunk in chat.astream( + "Tell me a joke about cats.", config={"callbacks": [astream_callback]} + ): validation_function(chunk) assert chunk + assert len(astream_callback.metadatas) == 1, ( + "Expected on_chat_model_start to be called once" + ) + + assert isinstance(astream_callback.metadatas[0], dict) + assert isinstance( + astream_callback.metadatas[0]["structured_output_format"]["schema"], dict + ) + assert astream_callback.metadatas[0]["structured_output_format"][ + "schema" + ] == convert_to_openai_tool(schema) + @pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Test requires pydantic 2.") def test_structured_output_pydantic_2_v1(self, model: BaseChatModel) -> None: """Test to verify we can generate structured output using