mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-04 20:28:10 +00:00
rfc: bind_tools(response_format)
This commit is contained in:
parent
87d8012ef6
commit
e7c2b41cab
@ -1238,17 +1238,19 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
llm = self.bind_tools([schema], tool_choice="any")
|
llm = self.bind_tools([schema], tool_choice="any")
|
||||||
if isinstance(schema, type) and is_basemodel_subclass(schema):
|
if isinstance(schema, type) and is_basemodel_subclass(schema):
|
||||||
output_parser: OutputParserLike = PydanticToolsParser(
|
output_parser: OutputParserLike = PydanticToolsParser(
|
||||||
tools=[cast(TypeBaseModel, schema)], first_tool_only=True
|
tools=[cast(TypeBaseModel, schema)],
|
||||||
|
first_tool_only=True,
|
||||||
|
return_message=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
key_name = convert_to_openai_tool(schema)["function"]["name"]
|
key_name = convert_to_openai_tool(schema)["function"]["name"]
|
||||||
output_parser = JsonOutputKeyToolsParser(
|
output_parser = JsonOutputKeyToolsParser(
|
||||||
key_name=key_name, first_tool_only=True
|
key_name=key_name, first_tool_only=True, return_message=True
|
||||||
)
|
)
|
||||||
if include_raw:
|
if include_raw:
|
||||||
parser_assign = RunnablePassthrough.assign(
|
parser_assign = RunnablePassthrough.assign(
|
||||||
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
|
raw=itemgetter("raw") | output_parser
|
||||||
)
|
).assign(parsed=(lambda x: x["raw"].parsed), parsing_error=lambda _: None)
|
||||||
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
|
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
|
||||||
parser_with_fallback = parser_assign.with_fallbacks(
|
parser_with_fallback = parser_assign.with_fallbacks(
|
||||||
[parser_none], exception_key="parsing_error"
|
[parser_none], exception_key="parsing_error"
|
||||||
|
@ -48,8 +48,16 @@ from langchain_core.output_parsers import (
|
|||||||
JsonOutputKeyToolsParser,
|
JsonOutputKeyToolsParser,
|
||||||
PydanticToolsParser,
|
PydanticToolsParser,
|
||||||
)
|
)
|
||||||
from langchain_core.output_parsers.base import OutputParserLike
|
from langchain_core.output_parsers.base import (
|
||||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
BaseGenerationOutputParser,
|
||||||
|
OutputParserLike,
|
||||||
|
)
|
||||||
|
from langchain_core.outputs import (
|
||||||
|
ChatGeneration,
|
||||||
|
ChatGenerationChunk,
|
||||||
|
ChatResult,
|
||||||
|
Generation,
|
||||||
|
)
|
||||||
from langchain_core.runnables import (
|
from langchain_core.runnables import (
|
||||||
Runnable,
|
Runnable,
|
||||||
RunnableMap,
|
RunnableMap,
|
||||||
@ -819,6 +827,7 @@ class ChatAnthropic(BaseChatModel):
|
|||||||
tool_choice: Optional[
|
tool_choice: Optional[
|
||||||
Union[Dict[str, str], Literal["any", "auto"], str]
|
Union[Dict[str, str], Literal["any", "auto"], str]
|
||||||
] = None,
|
] = None,
|
||||||
|
response_format: Optional[Union[dict, type]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||||
r"""Bind tool-like objects to this chat model.
|
r"""Bind tool-like objects to this chat model.
|
||||||
@ -954,8 +963,13 @@ class ChatAnthropic(BaseChatModel):
|
|||||||
AIMessage(content=[{'text': 'To get the current weather in San Francisco, I can use the GetWeather function. Let me check that for you.', 'type': 'text'}, {'id': 'toolu_01HtVtY1qhMFdPprx42qU2eA', 'input': {'location': 'San Francisco, CA'}, 'name': 'GetWeather', 'type': 'tool_use'}], response_metadata={'id': 'msg_016RfWHrRvW6DAGCdwB6Ac64', 'model': 'claude-3-5-sonnet-20240620', 'stop_reason': 'tool_use', 'stop_sequence': None, 'usage': {'input_tokens': 171, 'output_tokens': 82, 'cache_creation_input_tokens': 0, 'cache_read_input_tokens': 1470}}, id='run-88b1f825-dcb7-4277-ac27-53df55d22001-0', tool_calls=[{'name': 'GetWeather', 'args': {'location': 'San Francisco, CA'}, 'id': 'toolu_01HtVtY1qhMFdPprx42qU2eA', 'type': 'tool_call'}], usage_metadata={'input_tokens': 171, 'output_tokens': 82, 'total_tokens': 253})
|
AIMessage(content=[{'text': 'To get the current weather in San Francisco, I can use the GetWeather function. Let me check that for you.', 'type': 'text'}, {'id': 'toolu_01HtVtY1qhMFdPprx42qU2eA', 'input': {'location': 'San Francisco, CA'}, 'name': 'GetWeather', 'type': 'tool_use'}], response_metadata={'id': 'msg_016RfWHrRvW6DAGCdwB6Ac64', 'model': 'claude-3-5-sonnet-20240620', 'stop_reason': 'tool_use', 'stop_sequence': None, 'usage': {'input_tokens': 171, 'output_tokens': 82, 'cache_creation_input_tokens': 0, 'cache_read_input_tokens': 1470}}, id='run-88b1f825-dcb7-4277-ac27-53df55d22001-0', tool_calls=[{'name': 'GetWeather', 'args': {'location': 'San Francisco, CA'}, 'id': 'toolu_01HtVtY1qhMFdPprx42qU2eA', 'type': 'tool_call'}], usage_metadata={'input_tokens': 171, 'output_tokens': 82, 'total_tokens': 253})
|
||||||
|
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
if response_format:
|
||||||
|
tools.append(response_format)
|
||||||
formatted_tools = [convert_to_anthropic_tool(tool) for tool in tools]
|
formatted_tools = [convert_to_anthropic_tool(tool) for tool in tools]
|
||||||
if not tool_choice:
|
# If we have a response format, enforce that a tool is called.
|
||||||
|
if response_format and not tool_choice:
|
||||||
|
kwargs["tool_choice"] = {"type": "any"}
|
||||||
|
elif not tool_choice:
|
||||||
pass
|
pass
|
||||||
elif isinstance(tool_choice, dict):
|
elif isinstance(tool_choice, dict):
|
||||||
kwargs["tool_choice"] = tool_choice
|
kwargs["tool_choice"] = tool_choice
|
||||||
@ -968,7 +982,11 @@ class ChatAnthropic(BaseChatModel):
|
|||||||
f"Unrecognized 'tool_choice' type {tool_choice=}. Expected dict, "
|
f"Unrecognized 'tool_choice' type {tool_choice=}. Expected dict, "
|
||||||
f"str, or None."
|
f"str, or None."
|
||||||
)
|
)
|
||||||
return self.bind(tools=formatted_tools, **kwargs)
|
llm = self.bind(tools=formatted_tools, **kwargs)
|
||||||
|
if response_format:
|
||||||
|
return llm | _ToolsToParsedMessage(response_format=response_format)
|
||||||
|
else:
|
||||||
|
return llm
|
||||||
|
|
||||||
def with_structured_output(
|
def with_structured_output(
|
||||||
self,
|
self,
|
||||||
@ -1355,3 +1373,46 @@ def _create_usage_metadata(anthropic_usage: BaseModel) -> UsageMetadata:
|
|||||||
**{k: v for k, v in input_token_details.items() if v is not None}
|
**{k: v for k, v in input_token_details.items() if v is not None}
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _ToolsToParsedMessage(BaseGenerationOutputParser):
|
||||||
|
"""..."""
|
||||||
|
|
||||||
|
response_format: Union[dict, type[BaseModel]]
|
||||||
|
"""..."""
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
||||||
|
"""Parse a list of candidate model Generations into a specific format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
result: A list of Generations to be parsed. The Generations are assumed
|
||||||
|
to be different candidate outputs for a single model input.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Structured output.
|
||||||
|
"""
|
||||||
|
if not result or not isinstance(result[0], ChatGeneration):
|
||||||
|
msg = "..."
|
||||||
|
raise ValueError(msg)
|
||||||
|
message = cast(AIMessage, result[0].message)
|
||||||
|
drop = None
|
||||||
|
for tool_call in message.tool_calls:
|
||||||
|
if tool_call["name"] == self._response_format_name:
|
||||||
|
message.parsed = (
|
||||||
|
tool_call["args"]
|
||||||
|
if isinstance(self.response_format, dict)
|
||||||
|
else self.response_format(**tool_call["args"])
|
||||||
|
)
|
||||||
|
drop = tool_call["id"]
|
||||||
|
break
|
||||||
|
message.tool_calls = [tc for tc in message.tool_calls if tc["id"] != drop]
|
||||||
|
if isinstance(message, AIMessageChunk):
|
||||||
|
message.tool_call_chunks = [
|
||||||
|
tc for tc in message.tool_call_chunks if tc["id"] != drop
|
||||||
|
]
|
||||||
|
return message
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _response_format_name(self) -> str:
|
||||||
|
return convert_to_anthropic_tool(self.response_format)["name"]
|
||||||
|
@ -10,6 +10,7 @@ import sys
|
|||||||
import warnings
|
import warnings
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from math import ceil
|
from math import ceil
|
||||||
|
from operator import itemgetter
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
AsyncIterator,
|
AsyncIterator,
|
||||||
@ -1092,6 +1093,7 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
Union[dict, str, Literal["auto", "none", "required", "any"], bool]
|
Union[dict, str, Literal["auto", "none", "required", "any"], bool]
|
||||||
] = None,
|
] = None,
|
||||||
strict: Optional[bool] = None,
|
strict: Optional[bool] = None,
|
||||||
|
response_format: Optional[_DictOrPydanticClass] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||||
"""Bind tool-like objects to this chat model.
|
"""Bind tool-like objects to this chat model.
|
||||||
@ -1162,6 +1164,11 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
f"Received: {tool_choice}"
|
f"Received: {tool_choice}"
|
||||||
)
|
)
|
||||||
kwargs["tool_choice"] = tool_choice
|
kwargs["tool_choice"] = tool_choice
|
||||||
|
if response_format:
|
||||||
|
response_format = _convert_to_openai_response_format(
|
||||||
|
response_format, strict=strict
|
||||||
|
)
|
||||||
|
kwargs["response_format"] = response_format
|
||||||
return super().bind(tools=formatted_tools, **kwargs)
|
return super().bind(tools=formatted_tools, **kwargs)
|
||||||
|
|
||||||
def with_structured_output(
|
def with_structured_output(
|
||||||
@ -1503,9 +1510,9 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if include_raw:
|
if include_raw:
|
||||||
parser_assign = RunnablePassthrough.assign(raw=output_parser).assign(
|
parser_assign = RunnablePassthrough.assign(
|
||||||
parsed=lambda x: x["raw"].parsed, parsing_error=lambda _: None
|
raw=itemgetter("raw") | output_parser
|
||||||
)
|
).assign(parsed=lambda x: x["raw"].parsed, parsing_error=lambda _: None)
|
||||||
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
|
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
|
||||||
parser_with_fallback = parser_assign.with_fallbacks(
|
parser_with_fallback = parser_assign.with_fallbacks(
|
||||||
[parser_none], exception_key="parsing_error"
|
[parser_none], exception_key="parsing_error"
|
||||||
|
Loading…
Reference in New Issue
Block a user