anthropic[minor]: tool use (#20016)

This commit is contained in:
Bagatur
2024-04-04 13:22:48 -07:00
committed by GitHub
parent 3aacd11846
commit 209de0a561
13 changed files with 1021 additions and 196 deletions

View File

@@ -1,13 +1,31 @@
import os
import re
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union
import warnings
from operator import itemgetter
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Mapping,
Optional,
Sequence,
Tuple,
Type,
TypedDict,
Union,
cast,
)
import anthropic
from langchain_core._api.deprecation import deprecated
from langchain_core._api import beta, deprecated
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import (
BaseChatModel,
agenerate_from_stream,
@@ -17,14 +35,26 @@ from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.runnables import (
Runnable,
RunnableMap,
RunnablePassthrough,
)
from langchain_core.tools import BaseTool
from langchain_core.utils import (
build_extra_kwargs,
convert_to_secret_str,
get_pydantic_field_names,
)
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_anthropic.output_parsers import ToolsOutputParser
_message_type_lookups = {"human": "user", "ai": "assistant"}
@@ -56,6 +86,41 @@ def _format_image(image_url: str) -> Dict:
}
def _merge_messages(
messages: List[BaseMessage],
) -> List[Union[SystemMessage, AIMessage, HumanMessage]]:
"""Merge runs of human/tool messages into single human messages with content blocks.""" # noqa: E501
merged: list = []
for curr in messages:
if isinstance(curr, ToolMessage):
if isinstance(curr.content, str):
curr = HumanMessage(
[
{
"type": "tool_result",
"content": curr.content,
"tool_use_id": curr.tool_call_id,
}
]
)
else:
curr = HumanMessage(curr.content)
last = merged[-1] if merged else None
if isinstance(last, HumanMessage) and isinstance(curr, HumanMessage):
if isinstance(last.content, str):
new_content: List = [{"type": "text", "text": last.content}]
else:
new_content = last.content
if isinstance(curr.content, str):
new_content.append({"type": "text", "text": curr.content})
else:
new_content.extend(curr.content)
last.content = new_content
else:
merged.append(curr)
return merged
def _format_messages(messages: List[BaseMessage]) -> Tuple[Optional[str], List[Dict]]:
"""Format messages for anthropic."""
@@ -70,7 +135,9 @@ def _format_messages(messages: List[BaseMessage]) -> Tuple[Optional[str], List[D
"""
system: Optional[str] = None
formatted_messages: List[Dict] = []
for i, message in enumerate(messages):
merged_messages = _merge_messages(messages)
for i, message in enumerate(merged_messages):
if message.type == "system":
if i != 0:
raise ValueError("System message must be at beginning of message list.")
@@ -104,7 +171,7 @@ def _format_messages(messages: List[BaseMessage]) -> Tuple[Optional[str], List[D
elif isinstance(item, dict):
if "type" not in item:
raise ValueError("Dict content item must have a type key")
if item["type"] == "image_url":
elif item["type"] == "image_url":
# convert format
source = _format_image(item["image_url"]["url"])
content.append(
@@ -113,6 +180,9 @@ def _format_messages(messages: List[BaseMessage]) -> Tuple[Optional[str], List[D
"source": source,
}
)
elif item["type"] == "tool_use":
item.pop("text", None)
content.append(item)
else:
content.append(item)
else:
@@ -175,6 +245,9 @@ class ChatAnthropic(BaseChatModel):
anthropic_api_key: Optional[SecretStr] = None
default_headers: Optional[Mapping[str, str]] = None
"""Headers to pass to the Anthropic clients, will be used for every API call."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
streaming: bool = False
@@ -207,9 +280,15 @@ class ChatAnthropic(BaseChatModel):
or "https://api.anthropic.com"
)
values["anthropic_api_url"] = api_url
values["_client"] = anthropic.Client(api_key=api_key, base_url=api_url)
values["_client"] = anthropic.Client(
api_key=api_key,
base_url=api_url,
default_headers=values.get("default_headers"),
)
values["_async_client"] = anthropic.AsyncClient(
api_key=api_key, base_url=api_url
api_key=api_key,
base_url=api_url,
default_headers=values.get("default_headers"),
)
return values
@@ -232,6 +311,7 @@ class ChatAnthropic(BaseChatModel):
"stop_sequences": stop,
"system": system,
**self.model_kwargs,
**kwargs,
}
rtn = {k: v for k, v in rtn.items() if v is not None}
@@ -245,6 +325,13 @@ class ChatAnthropic(BaseChatModel):
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
params = self._format_params(messages=messages, stop=stop, **kwargs)
if "extra_body" in params and params["extra_body"].get("tools"):
warnings.warn("stream: Tool use is not yet supported in streaming mode.")
result = self._generate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
yield cast(ChatGenerationChunk, result.generations[0])
return
with self._client.messages.stream(**params) as stream:
for text in stream.text_stream:
chunk = ChatGenerationChunk(message=AIMessageChunk(content=text))
@@ -260,6 +347,13 @@ class ChatAnthropic(BaseChatModel):
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
params = self._format_params(messages=messages, stop=stop, **kwargs)
if "extra_body" in params and params["extra_body"].get("tools"):
warnings.warn("stream: Tool use is not yet supported in streaming mode.")
result = await self._agenerate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
yield cast(ChatGenerationChunk, result.generations[0])
return
async with self._async_client.messages.stream(**params) as stream:
async for text in stream.text_stream:
chunk = ChatGenerationChunk(message=AIMessageChunk(content=text))
@@ -273,8 +367,12 @@ class ChatAnthropic(BaseChatModel):
llm_output = {
k: v for k, v in data_dict.items() if k not in ("content", "role", "type")
}
if len(content) == 1 and content[0]["type"] == "text":
msg = AIMessage(content=content[0]["text"])
else:
msg = AIMessage(content=content)
return ChatResult(
generations=[ChatGeneration(message=AIMessage(content=content[0]["text"]))],
generations=[ChatGeneration(message=msg)],
llm_output=llm_output,
)
@@ -285,12 +383,17 @@ class ChatAnthropic(BaseChatModel):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
params = self._format_params(messages=messages, stop=stop, **kwargs)
if self.streaming:
if "extra_body" in params and params["extra_body"].get("tools"):
warnings.warn(
"stream: Tool use is not yet supported in streaming mode."
)
else:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
data = self._client.messages.create(**params)
return self._format_output(data, **kwargs)
@@ -301,15 +404,91 @@ class ChatAnthropic(BaseChatModel):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
params = self._format_params(messages=messages, stop=stop, **kwargs)
if self.streaming:
if "extra_body" in params and params["extra_body"].get("tools"):
warnings.warn(
"stream: Tool use is not yet supported in streaming mode."
)
else:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
data = await self._async_client.messages.create(**params)
return self._format_output(data, **kwargs)
@beta()
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tool-like objects to this chat model.
Args:
tools: A list of tool definitions to bind to this chat model.
Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic
models, callables, and BaseTools will be automatically converted to
their schema dictionary representation.
**kwargs: Any additional parameters to bind.
"""
formatted_tools = [convert_to_anthropic_tool(tool) for tool in tools]
extra_body = kwargs.pop("extra_body", {})
extra_body["tools"] = formatted_tools
return self.bind(extra_body=extra_body, **kwargs)
@beta()
def with_structured_output(
self,
schema: Union[Dict, Type[BaseModel]],
*,
include_raw: bool = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
llm = self.bind_tools([schema])
if isinstance(schema, type) and issubclass(schema, BaseModel):
output_parser = ToolsOutputParser(
first_tool_only=True, pydantic_schemas=[schema]
)
else:
output_parser = ToolsOutputParser(first_tool_only=True, args_only=True)
if include_raw:
parser_assign = RunnablePassthrough.assign(
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
)
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
parser_with_fallback = parser_assign.with_fallbacks(
[parser_none], exception_key="parsing_error"
)
return RunnableMap(raw=llm) | parser_with_fallback
else:
return llm | output_parser
class AnthropicTool(TypedDict):
name: str
description: str
input_schema: Dict[str, Any]
def convert_to_anthropic_tool(
tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool],
) -> AnthropicTool:
# already in Anthropic tool format
if isinstance(tool, dict) and all(
k in tool for k in ("name", "description", "input_schema")
):
return AnthropicTool(tool) # type: ignore
else:
formatted = convert_to_openai_tool(tool)["function"]
return AnthropicTool(
name=formatted["name"],
description=formatted["description"],
input_schema=formatted["parameters"],
)
@deprecated(since="0.1.0", removal="0.2.0", alternative="ChatAnthropic")
class ChatAnthropicMessages(ChatAnthropic):