diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index 516485654ce..32b83e91b21 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -1238,17 +1238,19 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): llm = self.bind_tools([schema], tool_choice="any") if isinstance(schema, type) and is_basemodel_subclass(schema): output_parser: OutputParserLike = PydanticToolsParser( - tools=[cast(TypeBaseModel, schema)], first_tool_only=True + tools=[cast(TypeBaseModel, schema)], + first_tool_only=True, + return_message=True, ) else: key_name = convert_to_openai_tool(schema)["function"]["name"] output_parser = JsonOutputKeyToolsParser( - key_name=key_name, first_tool_only=True + key_name=key_name, first_tool_only=True, return_message=True ) if include_raw: parser_assign = RunnablePassthrough.assign( - parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None - ) + raw=itemgetter("raw") | output_parser + ).assign(parsed=(lambda x: x["raw"].parsed), parsing_error=lambda _: None) parser_none = RunnablePassthrough.assign(parsed=lambda _: None) parser_with_fallback = parser_assign.with_fallbacks( [parser_none], exception_key="parsing_error" diff --git a/libs/partners/anthropic/langchain_anthropic/chat_models.py b/libs/partners/anthropic/langchain_anthropic/chat_models.py index fd64b824a8d..c8fbf8475f0 100644 --- a/libs/partners/anthropic/langchain_anthropic/chat_models.py +++ b/libs/partners/anthropic/langchain_anthropic/chat_models.py @@ -48,8 +48,16 @@ from langchain_core.output_parsers import ( JsonOutputKeyToolsParser, PydanticToolsParser, ) -from langchain_core.output_parsers.base import OutputParserLike -from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult +from langchain_core.output_parsers.base import ( + BaseGenerationOutputParser, + OutputParserLike, +) +from langchain_core.outputs import ( + ChatGeneration, + ChatGenerationChunk, + ChatResult, + Generation, +) from langchain_core.runnables import ( Runnable, RunnableMap, @@ -819,6 +827,7 @@ class ChatAnthropic(BaseChatModel): tool_choice: Optional[ Union[Dict[str, str], Literal["any", "auto"], str] ] = None, + response_format: Optional[Union[dict, type]] = None, **kwargs: Any, ) -> Runnable[LanguageModelInput, BaseMessage]: r"""Bind tool-like objects to this chat model. @@ -954,8 +963,13 @@ class ChatAnthropic(BaseChatModel): AIMessage(content=[{'text': 'To get the current weather in San Francisco, I can use the GetWeather function. Let me check that for you.', 'type': 'text'}, {'id': 'toolu_01HtVtY1qhMFdPprx42qU2eA', 'input': {'location': 'San Francisco, CA'}, 'name': 'GetWeather', 'type': 'tool_use'}], response_metadata={'id': 'msg_016RfWHrRvW6DAGCdwB6Ac64', 'model': 'claude-3-5-sonnet-20240620', 'stop_reason': 'tool_use', 'stop_sequence': None, 'usage': {'input_tokens': 171, 'output_tokens': 82, 'cache_creation_input_tokens': 0, 'cache_read_input_tokens': 1470}}, id='run-88b1f825-dcb7-4277-ac27-53df55d22001-0', tool_calls=[{'name': 'GetWeather', 'args': {'location': 'San Francisco, CA'}, 'id': 'toolu_01HtVtY1qhMFdPprx42qU2eA', 'type': 'tool_call'}], usage_metadata={'input_tokens': 171, 'output_tokens': 82, 'total_tokens': 253}) """ # noqa: E501 + if response_format: + tools.append(response_format) formatted_tools = [convert_to_anthropic_tool(tool) for tool in tools] - if not tool_choice: + # If we have a response format, enforce that a tool is called. + if response_format and not tool_choice: + kwargs["tool_choice"] = {"type": "any"} + elif not tool_choice: pass elif isinstance(tool_choice, dict): kwargs["tool_choice"] = tool_choice @@ -968,7 +982,11 @@ class ChatAnthropic(BaseChatModel): f"Unrecognized 'tool_choice' type {tool_choice=}. Expected dict, " f"str, or None." ) - return self.bind(tools=formatted_tools, **kwargs) + llm = self.bind(tools=formatted_tools, **kwargs) + if response_format: + return llm | _ToolsToParsedMessage(response_format=response_format) + else: + return llm def with_structured_output( self, @@ -1355,3 +1373,46 @@ def _create_usage_metadata(anthropic_usage: BaseModel) -> UsageMetadata: **{k: v for k, v in input_token_details.items() if v is not None} ), ) + + +class _ToolsToParsedMessage(BaseGenerationOutputParser): + """...""" + + response_format: Union[dict, type[BaseModel]] + """...""" + model_config = ConfigDict(extra="forbid") + + def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: + """Parse a list of candidate model Generations into a specific format. + + Args: + result: A list of Generations to be parsed. The Generations are assumed + to be different candidate outputs for a single model input. + + Returns: + Structured output. + """ + if not result or not isinstance(result[0], ChatGeneration): + msg = "..." + raise ValueError(msg) + message = cast(AIMessage, result[0].message) + drop = None + for tool_call in message.tool_calls: + if tool_call["name"] == self._response_format_name: + message.parsed = ( + tool_call["args"] + if isinstance(self.response_format, dict) + else self.response_format(**tool_call["args"]) + ) + drop = tool_call["id"] + break + message.tool_calls = [tc for tc in message.tool_calls if tc["id"] != drop] + if isinstance(message, AIMessageChunk): + message.tool_call_chunks = [ + tc for tc in message.tool_call_chunks if tc["id"] != drop + ] + return message + + @property + def _response_format_name(self) -> str: + return convert_to_anthropic_tool(self.response_format)["name"] diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 119ec24e32e..dc935a43353 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -10,6 +10,7 @@ import sys import warnings from io import BytesIO from math import ceil +from operator import itemgetter from typing import ( Any, AsyncIterator, @@ -1092,6 +1093,7 @@ class BaseChatOpenAI(BaseChatModel): Union[dict, str, Literal["auto", "none", "required", "any"], bool] ] = None, strict: Optional[bool] = None, + response_format: Optional[_DictOrPydanticClass] = None, **kwargs: Any, ) -> Runnable[LanguageModelInput, BaseMessage]: """Bind tool-like objects to this chat model. @@ -1162,6 +1164,11 @@ class BaseChatOpenAI(BaseChatModel): f"Received: {tool_choice}" ) kwargs["tool_choice"] = tool_choice + if response_format: + response_format = _convert_to_openai_response_format( + response_format, strict=strict + ) + kwargs["response_format"] = response_format return super().bind(tools=formatted_tools, **kwargs) def with_structured_output( @@ -1503,9 +1510,9 @@ class BaseChatOpenAI(BaseChatModel): ) if include_raw: - parser_assign = RunnablePassthrough.assign(raw=output_parser).assign( - parsed=lambda x: x["raw"].parsed, parsing_error=lambda _: None - ) + parser_assign = RunnablePassthrough.assign( + raw=itemgetter("raw") | output_parser + ).assign(parsed=lambda x: x["raw"].parsed, parsing_error=lambda _: None) parser_none = RunnablePassthrough.assign(parsed=lambda _: None) parser_with_fallback = parser_assign.with_fallbacks( [parser_none], exception_key="parsing_error"