mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-28 06:48:50 +00:00
anthropic[minor]: tool use (#20016)
This commit is contained in:
@@ -1,13 +1,31 @@
|
||||
import os
|
||||
import re
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union
|
||||
import warnings
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypedDict,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import anthropic
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
from langchain_core._api import beta, deprecated
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
agenerate_from_stream,
|
||||
@@ -17,14 +35,26 @@ from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
||||
from langchain_core.runnables import (
|
||||
Runnable,
|
||||
RunnableMap,
|
||||
RunnablePassthrough,
|
||||
)
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import (
|
||||
build_extra_kwargs,
|
||||
convert_to_secret_str,
|
||||
get_pydantic_field_names,
|
||||
)
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
|
||||
from langchain_anthropic.output_parsers import ToolsOutputParser
|
||||
|
||||
_message_type_lookups = {"human": "user", "ai": "assistant"}
|
||||
|
||||
@@ -56,6 +86,41 @@ def _format_image(image_url: str) -> Dict:
|
||||
}
|
||||
|
||||
|
||||
def _merge_messages(
|
||||
messages: List[BaseMessage],
|
||||
) -> List[Union[SystemMessage, AIMessage, HumanMessage]]:
|
||||
"""Merge runs of human/tool messages into single human messages with content blocks.""" # noqa: E501
|
||||
merged: list = []
|
||||
for curr in messages:
|
||||
if isinstance(curr, ToolMessage):
|
||||
if isinstance(curr.content, str):
|
||||
curr = HumanMessage(
|
||||
[
|
||||
{
|
||||
"type": "tool_result",
|
||||
"content": curr.content,
|
||||
"tool_use_id": curr.tool_call_id,
|
||||
}
|
||||
]
|
||||
)
|
||||
else:
|
||||
curr = HumanMessage(curr.content)
|
||||
last = merged[-1] if merged else None
|
||||
if isinstance(last, HumanMessage) and isinstance(curr, HumanMessage):
|
||||
if isinstance(last.content, str):
|
||||
new_content: List = [{"type": "text", "text": last.content}]
|
||||
else:
|
||||
new_content = last.content
|
||||
if isinstance(curr.content, str):
|
||||
new_content.append({"type": "text", "text": curr.content})
|
||||
else:
|
||||
new_content.extend(curr.content)
|
||||
last.content = new_content
|
||||
else:
|
||||
merged.append(curr)
|
||||
return merged
|
||||
|
||||
|
||||
def _format_messages(messages: List[BaseMessage]) -> Tuple[Optional[str], List[Dict]]:
|
||||
"""Format messages for anthropic."""
|
||||
|
||||
@@ -70,7 +135,9 @@ def _format_messages(messages: List[BaseMessage]) -> Tuple[Optional[str], List[D
|
||||
"""
|
||||
system: Optional[str] = None
|
||||
formatted_messages: List[Dict] = []
|
||||
for i, message in enumerate(messages):
|
||||
|
||||
merged_messages = _merge_messages(messages)
|
||||
for i, message in enumerate(merged_messages):
|
||||
if message.type == "system":
|
||||
if i != 0:
|
||||
raise ValueError("System message must be at beginning of message list.")
|
||||
@@ -104,7 +171,7 @@ def _format_messages(messages: List[BaseMessage]) -> Tuple[Optional[str], List[D
|
||||
elif isinstance(item, dict):
|
||||
if "type" not in item:
|
||||
raise ValueError("Dict content item must have a type key")
|
||||
if item["type"] == "image_url":
|
||||
elif item["type"] == "image_url":
|
||||
# convert format
|
||||
source = _format_image(item["image_url"]["url"])
|
||||
content.append(
|
||||
@@ -113,6 +180,9 @@ def _format_messages(messages: List[BaseMessage]) -> Tuple[Optional[str], List[D
|
||||
"source": source,
|
||||
}
|
||||
)
|
||||
elif item["type"] == "tool_use":
|
||||
item.pop("text", None)
|
||||
content.append(item)
|
||||
else:
|
||||
content.append(item)
|
||||
else:
|
||||
@@ -175,6 +245,9 @@ class ChatAnthropic(BaseChatModel):
|
||||
|
||||
anthropic_api_key: Optional[SecretStr] = None
|
||||
|
||||
default_headers: Optional[Mapping[str, str]] = None
|
||||
"""Headers to pass to the Anthropic clients, will be used for every API call."""
|
||||
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
streaming: bool = False
|
||||
@@ -207,9 +280,15 @@ class ChatAnthropic(BaseChatModel):
|
||||
or "https://api.anthropic.com"
|
||||
)
|
||||
values["anthropic_api_url"] = api_url
|
||||
values["_client"] = anthropic.Client(api_key=api_key, base_url=api_url)
|
||||
values["_client"] = anthropic.Client(
|
||||
api_key=api_key,
|
||||
base_url=api_url,
|
||||
default_headers=values.get("default_headers"),
|
||||
)
|
||||
values["_async_client"] = anthropic.AsyncClient(
|
||||
api_key=api_key, base_url=api_url
|
||||
api_key=api_key,
|
||||
base_url=api_url,
|
||||
default_headers=values.get("default_headers"),
|
||||
)
|
||||
return values
|
||||
|
||||
@@ -232,6 +311,7 @@ class ChatAnthropic(BaseChatModel):
|
||||
"stop_sequences": stop,
|
||||
"system": system,
|
||||
**self.model_kwargs,
|
||||
**kwargs,
|
||||
}
|
||||
rtn = {k: v for k, v in rtn.items() if v is not None}
|
||||
|
||||
@@ -245,6 +325,13 @@ class ChatAnthropic(BaseChatModel):
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
params = self._format_params(messages=messages, stop=stop, **kwargs)
|
||||
if "extra_body" in params and params["extra_body"].get("tools"):
|
||||
warnings.warn("stream: Tool use is not yet supported in streaming mode.")
|
||||
result = self._generate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
yield cast(ChatGenerationChunk, result.generations[0])
|
||||
return
|
||||
with self._client.messages.stream(**params) as stream:
|
||||
for text in stream.text_stream:
|
||||
chunk = ChatGenerationChunk(message=AIMessageChunk(content=text))
|
||||
@@ -260,6 +347,13 @@ class ChatAnthropic(BaseChatModel):
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
params = self._format_params(messages=messages, stop=stop, **kwargs)
|
||||
if "extra_body" in params and params["extra_body"].get("tools"):
|
||||
warnings.warn("stream: Tool use is not yet supported in streaming mode.")
|
||||
result = await self._agenerate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
yield cast(ChatGenerationChunk, result.generations[0])
|
||||
return
|
||||
async with self._async_client.messages.stream(**params) as stream:
|
||||
async for text in stream.text_stream:
|
||||
chunk = ChatGenerationChunk(message=AIMessageChunk(content=text))
|
||||
@@ -273,8 +367,12 @@ class ChatAnthropic(BaseChatModel):
|
||||
llm_output = {
|
||||
k: v for k, v in data_dict.items() if k not in ("content", "role", "type")
|
||||
}
|
||||
if len(content) == 1 and content[0]["type"] == "text":
|
||||
msg = AIMessage(content=content[0]["text"])
|
||||
else:
|
||||
msg = AIMessage(content=content)
|
||||
return ChatResult(
|
||||
generations=[ChatGeneration(message=AIMessage(content=content[0]["text"]))],
|
||||
generations=[ChatGeneration(message=msg)],
|
||||
llm_output=llm_output,
|
||||
)
|
||||
|
||||
@@ -285,12 +383,17 @@ class ChatAnthropic(BaseChatModel):
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return generate_from_stream(stream_iter)
|
||||
params = self._format_params(messages=messages, stop=stop, **kwargs)
|
||||
if self.streaming:
|
||||
if "extra_body" in params and params["extra_body"].get("tools"):
|
||||
warnings.warn(
|
||||
"stream: Tool use is not yet supported in streaming mode."
|
||||
)
|
||||
else:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return generate_from_stream(stream_iter)
|
||||
data = self._client.messages.create(**params)
|
||||
return self._format_output(data, **kwargs)
|
||||
|
||||
@@ -301,15 +404,91 @@ class ChatAnthropic(BaseChatModel):
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
stream_iter = self._astream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return await agenerate_from_stream(stream_iter)
|
||||
params = self._format_params(messages=messages, stop=stop, **kwargs)
|
||||
if self.streaming:
|
||||
if "extra_body" in params and params["extra_body"].get("tools"):
|
||||
warnings.warn(
|
||||
"stream: Tool use is not yet supported in streaming mode."
|
||||
)
|
||||
else:
|
||||
stream_iter = self._astream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return await agenerate_from_stream(stream_iter)
|
||||
data = await self._async_client.messages.create(**params)
|
||||
return self._format_output(data, **kwargs)
|
||||
|
||||
@beta()
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
"""Bind tool-like objects to this chat model.
|
||||
|
||||
Args:
|
||||
tools: A list of tool definitions to bind to this chat model.
|
||||
Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic
|
||||
models, callables, and BaseTools will be automatically converted to
|
||||
their schema dictionary representation.
|
||||
**kwargs: Any additional parameters to bind.
|
||||
"""
|
||||
formatted_tools = [convert_to_anthropic_tool(tool) for tool in tools]
|
||||
extra_body = kwargs.pop("extra_body", {})
|
||||
extra_body["tools"] = formatted_tools
|
||||
return self.bind(extra_body=extra_body, **kwargs)
|
||||
|
||||
@beta()
|
||||
def with_structured_output(
|
||||
self,
|
||||
schema: Union[Dict, Type[BaseModel]],
|
||||
*,
|
||||
include_raw: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
|
||||
llm = self.bind_tools([schema])
|
||||
if isinstance(schema, type) and issubclass(schema, BaseModel):
|
||||
output_parser = ToolsOutputParser(
|
||||
first_tool_only=True, pydantic_schemas=[schema]
|
||||
)
|
||||
else:
|
||||
output_parser = ToolsOutputParser(first_tool_only=True, args_only=True)
|
||||
|
||||
if include_raw:
|
||||
parser_assign = RunnablePassthrough.assign(
|
||||
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
|
||||
)
|
||||
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
|
||||
parser_with_fallback = parser_assign.with_fallbacks(
|
||||
[parser_none], exception_key="parsing_error"
|
||||
)
|
||||
return RunnableMap(raw=llm) | parser_with_fallback
|
||||
else:
|
||||
return llm | output_parser
|
||||
|
||||
|
||||
class AnthropicTool(TypedDict):
|
||||
name: str
|
||||
description: str
|
||||
input_schema: Dict[str, Any]
|
||||
|
||||
|
||||
def convert_to_anthropic_tool(
|
||||
tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool],
|
||||
) -> AnthropicTool:
|
||||
# already in Anthropic tool format
|
||||
if isinstance(tool, dict) and all(
|
||||
k in tool for k in ("name", "description", "input_schema")
|
||||
):
|
||||
return AnthropicTool(tool) # type: ignore
|
||||
else:
|
||||
formatted = convert_to_openai_tool(tool)["function"]
|
||||
return AnthropicTool(
|
||||
name=formatted["name"],
|
||||
description=formatted["description"],
|
||||
input_schema=formatted["parameters"],
|
||||
)
|
||||
|
||||
|
||||
@deprecated(since="0.1.0", removal="0.2.0", alternative="ChatAnthropic")
|
||||
class ChatAnthropicMessages(ChatAnthropic):
|
||||
|
Reference in New Issue
Block a user