mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-14 14:05:37 +00:00
anthropic[minor]: tool use (#20016)
This commit is contained in:
@@ -13,6 +13,9 @@ class ToolMessage(BaseMessage):
|
||||
|
||||
tool_call_id: str
|
||||
"""Tool call that this message is responding to."""
|
||||
# TODO: Add is_error param?
|
||||
# is_error: bool = False
|
||||
# """Whether the tool errored."""
|
||||
|
||||
type: Literal["tool"] = "tool"
|
||||
|
||||
|
@@ -1,13 +1,31 @@
|
||||
import os
|
||||
import re
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union
|
||||
import warnings
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypedDict,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import anthropic
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
from langchain_core._api import beta, deprecated
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
agenerate_from_stream,
|
||||
@@ -17,14 +35,26 @@ from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
||||
from langchain_core.runnables import (
|
||||
Runnable,
|
||||
RunnableMap,
|
||||
RunnablePassthrough,
|
||||
)
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import (
|
||||
build_extra_kwargs,
|
||||
convert_to_secret_str,
|
||||
get_pydantic_field_names,
|
||||
)
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
|
||||
from langchain_anthropic.output_parsers import ToolsOutputParser
|
||||
|
||||
_message_type_lookups = {"human": "user", "ai": "assistant"}
|
||||
|
||||
@@ -56,6 +86,41 @@ def _format_image(image_url: str) -> Dict:
|
||||
}
|
||||
|
||||
|
||||
def _merge_messages(
|
||||
messages: List[BaseMessage],
|
||||
) -> List[Union[SystemMessage, AIMessage, HumanMessage]]:
|
||||
"""Merge runs of human/tool messages into single human messages with content blocks.""" # noqa: E501
|
||||
merged: list = []
|
||||
for curr in messages:
|
||||
if isinstance(curr, ToolMessage):
|
||||
if isinstance(curr.content, str):
|
||||
curr = HumanMessage(
|
||||
[
|
||||
{
|
||||
"type": "tool_result",
|
||||
"content": curr.content,
|
||||
"tool_use_id": curr.tool_call_id,
|
||||
}
|
||||
]
|
||||
)
|
||||
else:
|
||||
curr = HumanMessage(curr.content)
|
||||
last = merged[-1] if merged else None
|
||||
if isinstance(last, HumanMessage) and isinstance(curr, HumanMessage):
|
||||
if isinstance(last.content, str):
|
||||
new_content: List = [{"type": "text", "text": last.content}]
|
||||
else:
|
||||
new_content = last.content
|
||||
if isinstance(curr.content, str):
|
||||
new_content.append({"type": "text", "text": curr.content})
|
||||
else:
|
||||
new_content.extend(curr.content)
|
||||
last.content = new_content
|
||||
else:
|
||||
merged.append(curr)
|
||||
return merged
|
||||
|
||||
|
||||
def _format_messages(messages: List[BaseMessage]) -> Tuple[Optional[str], List[Dict]]:
|
||||
"""Format messages for anthropic."""
|
||||
|
||||
@@ -70,7 +135,9 @@ def _format_messages(messages: List[BaseMessage]) -> Tuple[Optional[str], List[D
|
||||
"""
|
||||
system: Optional[str] = None
|
||||
formatted_messages: List[Dict] = []
|
||||
for i, message in enumerate(messages):
|
||||
|
||||
merged_messages = _merge_messages(messages)
|
||||
for i, message in enumerate(merged_messages):
|
||||
if message.type == "system":
|
||||
if i != 0:
|
||||
raise ValueError("System message must be at beginning of message list.")
|
||||
@@ -104,7 +171,7 @@ def _format_messages(messages: List[BaseMessage]) -> Tuple[Optional[str], List[D
|
||||
elif isinstance(item, dict):
|
||||
if "type" not in item:
|
||||
raise ValueError("Dict content item must have a type key")
|
||||
if item["type"] == "image_url":
|
||||
elif item["type"] == "image_url":
|
||||
# convert format
|
||||
source = _format_image(item["image_url"]["url"])
|
||||
content.append(
|
||||
@@ -113,6 +180,9 @@ def _format_messages(messages: List[BaseMessage]) -> Tuple[Optional[str], List[D
|
||||
"source": source,
|
||||
}
|
||||
)
|
||||
elif item["type"] == "tool_use":
|
||||
item.pop("text", None)
|
||||
content.append(item)
|
||||
else:
|
||||
content.append(item)
|
||||
else:
|
||||
@@ -175,6 +245,9 @@ class ChatAnthropic(BaseChatModel):
|
||||
|
||||
anthropic_api_key: Optional[SecretStr] = None
|
||||
|
||||
default_headers: Optional[Mapping[str, str]] = None
|
||||
"""Headers to pass to the Anthropic clients, will be used for every API call."""
|
||||
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
streaming: bool = False
|
||||
@@ -207,9 +280,15 @@ class ChatAnthropic(BaseChatModel):
|
||||
or "https://api.anthropic.com"
|
||||
)
|
||||
values["anthropic_api_url"] = api_url
|
||||
values["_client"] = anthropic.Client(api_key=api_key, base_url=api_url)
|
||||
values["_client"] = anthropic.Client(
|
||||
api_key=api_key,
|
||||
base_url=api_url,
|
||||
default_headers=values.get("default_headers"),
|
||||
)
|
||||
values["_async_client"] = anthropic.AsyncClient(
|
||||
api_key=api_key, base_url=api_url
|
||||
api_key=api_key,
|
||||
base_url=api_url,
|
||||
default_headers=values.get("default_headers"),
|
||||
)
|
||||
return values
|
||||
|
||||
@@ -232,6 +311,7 @@ class ChatAnthropic(BaseChatModel):
|
||||
"stop_sequences": stop,
|
||||
"system": system,
|
||||
**self.model_kwargs,
|
||||
**kwargs,
|
||||
}
|
||||
rtn = {k: v for k, v in rtn.items() if v is not None}
|
||||
|
||||
@@ -245,6 +325,13 @@ class ChatAnthropic(BaseChatModel):
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
params = self._format_params(messages=messages, stop=stop, **kwargs)
|
||||
if "extra_body" in params and params["extra_body"].get("tools"):
|
||||
warnings.warn("stream: Tool use is not yet supported in streaming mode.")
|
||||
result = self._generate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
yield cast(ChatGenerationChunk, result.generations[0])
|
||||
return
|
||||
with self._client.messages.stream(**params) as stream:
|
||||
for text in stream.text_stream:
|
||||
chunk = ChatGenerationChunk(message=AIMessageChunk(content=text))
|
||||
@@ -260,6 +347,13 @@ class ChatAnthropic(BaseChatModel):
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
params = self._format_params(messages=messages, stop=stop, **kwargs)
|
||||
if "extra_body" in params and params["extra_body"].get("tools"):
|
||||
warnings.warn("stream: Tool use is not yet supported in streaming mode.")
|
||||
result = await self._agenerate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
yield cast(ChatGenerationChunk, result.generations[0])
|
||||
return
|
||||
async with self._async_client.messages.stream(**params) as stream:
|
||||
async for text in stream.text_stream:
|
||||
chunk = ChatGenerationChunk(message=AIMessageChunk(content=text))
|
||||
@@ -273,8 +367,12 @@ class ChatAnthropic(BaseChatModel):
|
||||
llm_output = {
|
||||
k: v for k, v in data_dict.items() if k not in ("content", "role", "type")
|
||||
}
|
||||
if len(content) == 1 and content[0]["type"] == "text":
|
||||
msg = AIMessage(content=content[0]["text"])
|
||||
else:
|
||||
msg = AIMessage(content=content)
|
||||
return ChatResult(
|
||||
generations=[ChatGeneration(message=AIMessage(content=content[0]["text"]))],
|
||||
generations=[ChatGeneration(message=msg)],
|
||||
llm_output=llm_output,
|
||||
)
|
||||
|
||||
@@ -285,12 +383,17 @@ class ChatAnthropic(BaseChatModel):
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return generate_from_stream(stream_iter)
|
||||
params = self._format_params(messages=messages, stop=stop, **kwargs)
|
||||
if self.streaming:
|
||||
if "extra_body" in params and params["extra_body"].get("tools"):
|
||||
warnings.warn(
|
||||
"stream: Tool use is not yet supported in streaming mode."
|
||||
)
|
||||
else:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return generate_from_stream(stream_iter)
|
||||
data = self._client.messages.create(**params)
|
||||
return self._format_output(data, **kwargs)
|
||||
|
||||
@@ -301,15 +404,91 @@ class ChatAnthropic(BaseChatModel):
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
stream_iter = self._astream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return await agenerate_from_stream(stream_iter)
|
||||
params = self._format_params(messages=messages, stop=stop, **kwargs)
|
||||
if self.streaming:
|
||||
if "extra_body" in params and params["extra_body"].get("tools"):
|
||||
warnings.warn(
|
||||
"stream: Tool use is not yet supported in streaming mode."
|
||||
)
|
||||
else:
|
||||
stream_iter = self._astream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return await agenerate_from_stream(stream_iter)
|
||||
data = await self._async_client.messages.create(**params)
|
||||
return self._format_output(data, **kwargs)
|
||||
|
||||
@beta()
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
"""Bind tool-like objects to this chat model.
|
||||
|
||||
Args:
|
||||
tools: A list of tool definitions to bind to this chat model.
|
||||
Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic
|
||||
models, callables, and BaseTools will be automatically converted to
|
||||
their schema dictionary representation.
|
||||
**kwargs: Any additional parameters to bind.
|
||||
"""
|
||||
formatted_tools = [convert_to_anthropic_tool(tool) for tool in tools]
|
||||
extra_body = kwargs.pop("extra_body", {})
|
||||
extra_body["tools"] = formatted_tools
|
||||
return self.bind(extra_body=extra_body, **kwargs)
|
||||
|
||||
@beta()
|
||||
def with_structured_output(
|
||||
self,
|
||||
schema: Union[Dict, Type[BaseModel]],
|
||||
*,
|
||||
include_raw: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
|
||||
llm = self.bind_tools([schema])
|
||||
if isinstance(schema, type) and issubclass(schema, BaseModel):
|
||||
output_parser = ToolsOutputParser(
|
||||
first_tool_only=True, pydantic_schemas=[schema]
|
||||
)
|
||||
else:
|
||||
output_parser = ToolsOutputParser(first_tool_only=True, args_only=True)
|
||||
|
||||
if include_raw:
|
||||
parser_assign = RunnablePassthrough.assign(
|
||||
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
|
||||
)
|
||||
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
|
||||
parser_with_fallback = parser_assign.with_fallbacks(
|
||||
[parser_none], exception_key="parsing_error"
|
||||
)
|
||||
return RunnableMap(raw=llm) | parser_with_fallback
|
||||
else:
|
||||
return llm | output_parser
|
||||
|
||||
|
||||
class AnthropicTool(TypedDict):
|
||||
name: str
|
||||
description: str
|
||||
input_schema: Dict[str, Any]
|
||||
|
||||
|
||||
def convert_to_anthropic_tool(
|
||||
tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool],
|
||||
) -> AnthropicTool:
|
||||
# already in Anthropic tool format
|
||||
if isinstance(tool, dict) and all(
|
||||
k in tool for k in ("name", "description", "input_schema")
|
||||
):
|
||||
return AnthropicTool(tool) # type: ignore
|
||||
else:
|
||||
formatted = convert_to_openai_tool(tool)["function"]
|
||||
return AnthropicTool(
|
||||
name=formatted["name"],
|
||||
description=formatted["description"],
|
||||
input_schema=formatted["parameters"],
|
||||
)
|
||||
|
||||
|
||||
@deprecated(since="0.1.0", removal="0.2.0", alternative="ChatAnthropic")
|
||||
class ChatAnthropicMessages(ChatAnthropic):
|
||||
|
@@ -1,38 +1,13 @@
|
||||
import json
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain_core._api.beta_decorator import beta
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.output_parsers.openai_tools import (
|
||||
JsonOutputKeyToolsParser,
|
||||
PydanticToolsParser,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils.function_calling import convert_to_openai_function
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
|
||||
from langchain_anthropic.chat_models import ChatAnthropic
|
||||
|
||||
@@ -168,143 +143,16 @@ def _xml_to_tool_calls(elem: Any, tools: List[Dict]) -> List[Dict[str, Any]]:
|
||||
return [_xml_to_function_call(invoke, tools) for invoke in invokes]
|
||||
|
||||
|
||||
@beta()
|
||||
@deprecated(
|
||||
"0.1.5",
|
||||
removal="0.2.0",
|
||||
alternative="ChatAnthropic",
|
||||
message=(
|
||||
"Tool-calling is now officially supported by the Anthropic API so this "
|
||||
"workaround is no longer needed."
|
||||
),
|
||||
)
|
||||
class ChatAnthropicTools(ChatAnthropic):
|
||||
"""Chat model for interacting with Anthropic functions."""
|
||||
|
||||
_xmllib: Any = Field(default=None)
|
||||
|
||||
@root_validator()
|
||||
def check_xml_lib(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
try:
|
||||
# do this as an optional dep for temporary nature of this feature
|
||||
import defusedxml.ElementTree as DET # type: ignore
|
||||
|
||||
values["_xmllib"] = DET
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import defusedxml python package. "
|
||||
"Please install it using `pip install defusedxml`"
|
||||
)
|
||||
return values
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], BaseTool]],
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
"""Bind tools to the chat model."""
|
||||
formatted_tools = [convert_to_openai_function(tool) for tool in tools]
|
||||
return super().bind(tools=formatted_tools, **kwargs)
|
||||
|
||||
def with_structured_output(
|
||||
self, schema: Union[Dict, Type[BaseModel]], **kwargs: Any
|
||||
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
|
||||
if kwargs:
|
||||
raise ValueError("kwargs are not supported for with_structured_output")
|
||||
llm = self.bind_tools([schema])
|
||||
if isinstance(schema, type) and issubclass(schema, BaseModel):
|
||||
# schema is pydantic
|
||||
return llm | PydanticToolsParser(tools=[schema], first_tool_only=True)
|
||||
else:
|
||||
# schema is dict
|
||||
key_name = convert_to_openai_function(schema)["name"]
|
||||
return llm | JsonOutputKeyToolsParser(
|
||||
key_name=key_name, first_tool_only=True
|
||||
)
|
||||
|
||||
def _format_params(
|
||||
self,
|
||||
*,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict:
|
||||
tools: List[Dict] = kwargs.get("tools", None)
|
||||
# experimental tools are sent in as part of system prompt, so if
|
||||
# both are set, turn system prompt into tools + system prompt (tools first)
|
||||
if tools:
|
||||
tool_system = get_system_message(tools)
|
||||
|
||||
if messages[0].type == "system":
|
||||
sys_content = messages[0].content
|
||||
new_sys_content = f"{tool_system}\n\n{sys_content}"
|
||||
messages = [SystemMessage(content=new_sys_content), *messages[1:]]
|
||||
else:
|
||||
messages = [SystemMessage(content=tool_system), *messages]
|
||||
|
||||
return super()._format_params(messages=messages, stop=stop, **kwargs)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
# streaming not supported for functions
|
||||
result = self._generate(
|
||||
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
to_yield = result.generations[0]
|
||||
chunk = ChatGenerationChunk(
|
||||
message=cast(BaseMessageChunk, to_yield.message),
|
||||
generation_info=to_yield.generation_info,
|
||||
)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
cast(str, to_yield.message.content), chunk=chunk
|
||||
)
|
||||
yield chunk
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
# streaming not supported for functions
|
||||
result = await self._agenerate(
|
||||
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
to_yield = result.generations[0]
|
||||
chunk = ChatGenerationChunk(
|
||||
message=cast(BaseMessageChunk, to_yield.message),
|
||||
generation_info=to_yield.generation_info,
|
||||
)
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(
|
||||
cast(str, to_yield.message.content), chunk=chunk
|
||||
)
|
||||
yield chunk
|
||||
|
||||
def _format_output(self, data: Any, **kwargs: Any) -> ChatResult:
|
||||
"""Format the output of the model, parsing xml as a tool call."""
|
||||
text = data.content[0].text
|
||||
tools = kwargs.get("tools", None)
|
||||
|
||||
additional_kwargs: Dict[str, Any] = {}
|
||||
|
||||
if tools:
|
||||
# parse out the xml from the text
|
||||
try:
|
||||
# get everything between <function_calls> and </function_calls>
|
||||
start = text.find("<function_calls>")
|
||||
end = text.find("</function_calls>") + len("</function_calls>")
|
||||
xml_text = text[start:end]
|
||||
|
||||
xml = self._xmllib.fromstring(xml_text)
|
||||
additional_kwargs["tool_calls"] = _xml_to_tool_calls(xml, tools)
|
||||
text = ""
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return ChatResult(
|
||||
generations=[
|
||||
ChatGeneration(
|
||||
message=AIMessage(content=text, additional_kwargs=additional_kwargs)
|
||||
)
|
||||
],
|
||||
llm_output=data,
|
||||
)
|
||||
|
@@ -0,0 +1,66 @@
|
||||
from typing import Any, List, Optional, Type, TypedDict, cast
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.output_parsers import BaseGenerationOutputParser
|
||||
from langchain_core.outputs import ChatGeneration, Generation
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
|
||||
class _ToolCall(TypedDict):
|
||||
name: str
|
||||
args: dict
|
||||
id: str
|
||||
index: int
|
||||
|
||||
|
||||
class ToolsOutputParser(BaseGenerationOutputParser):
|
||||
first_tool_only: bool = False
|
||||
args_only: bool = False
|
||||
pydantic_schemas: Optional[List[Type[BaseModel]]] = None
|
||||
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
|
||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
||||
"""Parse a list of candidate model Generations into a specific format.
|
||||
|
||||
Args:
|
||||
result: A list of Generations to be parsed. The Generations are assumed
|
||||
to be different candidate outputs for a single model input.
|
||||
|
||||
Returns:
|
||||
Structured output.
|
||||
"""
|
||||
if not result or not isinstance(result[0], ChatGeneration):
|
||||
return None if self.first_tool_only else []
|
||||
tool_calls: List = _extract_tool_calls(result[0].message)
|
||||
if self.pydantic_schemas:
|
||||
tool_calls = [self._pydantic_parse(tc) for tc in tool_calls]
|
||||
elif self.args_only:
|
||||
tool_calls = [tc["args"] for tc in tool_calls]
|
||||
else:
|
||||
pass
|
||||
|
||||
if self.first_tool_only:
|
||||
return tool_calls[0] if tool_calls else None
|
||||
else:
|
||||
return tool_calls
|
||||
|
||||
def _pydantic_parse(self, tool_call: _ToolCall) -> BaseModel:
|
||||
cls_ = {schema.__name__: schema for schema in self.pydantic_schemas or []}[
|
||||
tool_call["name"]
|
||||
]
|
||||
return cls_(**tool_call["args"])
|
||||
|
||||
|
||||
def _extract_tool_calls(msg: BaseMessage) -> List[_ToolCall]:
|
||||
if isinstance(msg.content, str):
|
||||
return []
|
||||
tool_calls = []
|
||||
for i, block in enumerate(cast(List[dict], msg.content)):
|
||||
if block["type"] != "tool_use":
|
||||
continue
|
||||
tool_calls.append(
|
||||
_ToolCall(name=block["name"], args=block["input"], id=block["id"], index=i)
|
||||
)
|
||||
return tool_calls
|
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "langchain-anthropic"
|
||||
version = "0.1.4"
|
||||
version = "0.1.5"
|
||||
description = "An integration package connecting AnthropicMessages and LangChain"
|
||||
authors = []
|
||||
readme = "README.md"
|
||||
|
@@ -212,3 +212,47 @@ async def test_astreaming() -> None:
|
||||
response = await llm.agenerate([[HumanMessage(content="I'm Pickle Rick")]])
|
||||
assert callback_handler.llm_streams > 0
|
||||
assert isinstance(response, LLMResult)
|
||||
|
||||
|
||||
def test_tool_use() -> None:
|
||||
llm = ChatAnthropic(
|
||||
model="claude-3-opus-20240229",
|
||||
default_headers={"anthropic-beta": "tools-2024-04-04"},
|
||||
)
|
||||
|
||||
llm_with_tools = llm.bind_tools(
|
||||
[
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Get weather report for a city",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {"location": {"type": "string"}},
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
response = llm_with_tools.invoke("what's the weather in san francisco, ca")
|
||||
assert isinstance(response, AIMessage)
|
||||
assert isinstance(response.content, list)
|
||||
|
||||
|
||||
def test_with_structured_output() -> None:
|
||||
llm = ChatAnthropic(
|
||||
model="claude-3-opus-20240229",
|
||||
default_headers={"anthropic-beta": "tools-2024-04-04"},
|
||||
)
|
||||
|
||||
structured_llm = llm.with_structured_output(
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Get weather report for a city",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {"location": {"type": "string"}},
|
||||
},
|
||||
}
|
||||
)
|
||||
response = structured_llm.invoke("what's the weather in san francisco, ca")
|
||||
assert isinstance(response, dict)
|
||||
assert response["location"]
|
||||
|
@@ -1,13 +1,17 @@
|
||||
"""Test chat model integration."""
|
||||
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Literal, Type
|
||||
|
||||
import pytest
|
||||
from anthropic.types import ContentBlock, Message, Usage
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from langchain_anthropic import ChatAnthropic, ChatAnthropicMessages
|
||||
from langchain_anthropic.chat_models import _merge_messages, convert_to_anthropic_tool
|
||||
|
||||
os.environ["ANTHROPIC_API_KEY"] = "foo"
|
||||
|
||||
@@ -83,3 +87,175 @@ def test__format_output() -> None:
|
||||
llm = ChatAnthropic(model="test", anthropic_api_key="test")
|
||||
actual = llm._format_output(anthropic_msg)
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test__merge_messages() -> None:
|
||||
messages = [
|
||||
SystemMessage("foo"),
|
||||
HumanMessage("bar"),
|
||||
AIMessage(
|
||||
[
|
||||
{"text": "baz", "type": "text"},
|
||||
{
|
||||
"tool_input": {"a": "b"},
|
||||
"type": "tool_use",
|
||||
"id": "1",
|
||||
"text": None,
|
||||
"name": "buz",
|
||||
},
|
||||
{"text": "baz", "type": "text"},
|
||||
{
|
||||
"tool_input": {"a": "c"},
|
||||
"type": "tool_use",
|
||||
"id": "2",
|
||||
"text": None,
|
||||
"name": "blah",
|
||||
},
|
||||
]
|
||||
),
|
||||
ToolMessage("buz output", tool_call_id="1"),
|
||||
ToolMessage("blah output", tool_call_id="2"),
|
||||
HumanMessage("next thing"),
|
||||
]
|
||||
expected = [
|
||||
SystemMessage("foo"),
|
||||
HumanMessage("bar"),
|
||||
AIMessage(
|
||||
[
|
||||
{"text": "baz", "type": "text"},
|
||||
{
|
||||
"tool_input": {"a": "b"},
|
||||
"type": "tool_use",
|
||||
"id": "1",
|
||||
"text": None,
|
||||
"name": "buz",
|
||||
},
|
||||
{"text": "baz", "type": "text"},
|
||||
{
|
||||
"tool_input": {"a": "c"},
|
||||
"type": "tool_use",
|
||||
"id": "2",
|
||||
"text": None,
|
||||
"name": "blah",
|
||||
},
|
||||
]
|
||||
),
|
||||
HumanMessage(
|
||||
[
|
||||
{"type": "tool_result", "content": "buz output", "tool_use_id": "1"},
|
||||
{"type": "tool_result", "content": "blah output", "tool_use_id": "2"},
|
||||
{"type": "text", "text": "next thing"},
|
||||
]
|
||||
),
|
||||
]
|
||||
actual = _merge_messages(messages)
|
||||
assert expected == actual
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def pydantic() -> Type[BaseModel]:
|
||||
class dummy_function(BaseModel):
|
||||
"""dummy function"""
|
||||
|
||||
arg1: int = Field(..., description="foo")
|
||||
arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'")
|
||||
|
||||
return dummy_function
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def function() -> Callable:
|
||||
def dummy_function(arg1: int, arg2: Literal["bar", "baz"]) -> None:
|
||||
"""dummy function
|
||||
|
||||
Args:
|
||||
arg1: foo
|
||||
arg2: one of 'bar', 'baz'
|
||||
"""
|
||||
pass
|
||||
|
||||
return dummy_function
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def dummy_tool() -> BaseTool:
|
||||
class Schema(BaseModel):
|
||||
arg1: int = Field(..., description="foo")
|
||||
arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'")
|
||||
|
||||
class DummyFunction(BaseTool):
|
||||
args_schema: Type[BaseModel] = Schema
|
||||
name: str = "dummy_function"
|
||||
description: str = "dummy function"
|
||||
|
||||
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
pass
|
||||
|
||||
return DummyFunction()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def json_schema() -> Dict:
|
||||
return {
|
||||
"title": "dummy_function",
|
||||
"description": "dummy function",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"arg1": {"description": "foo", "type": "integer"},
|
||||
"arg2": {
|
||||
"description": "one of 'bar', 'baz'",
|
||||
"enum": ["bar", "baz"],
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
"required": ["arg1", "arg2"],
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def openai_function() -> Dict:
|
||||
return {
|
||||
"name": "dummy_function",
|
||||
"description": "dummy function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"arg1": {"description": "foo", "type": "integer"},
|
||||
"arg2": {
|
||||
"description": "one of 'bar', 'baz'",
|
||||
"enum": ["bar", "baz"],
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
"required": ["arg1", "arg2"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_convert_to_anthropic_tool(
|
||||
pydantic: Type[BaseModel],
|
||||
function: Callable,
|
||||
dummy_tool: BaseTool,
|
||||
json_schema: Dict,
|
||||
openai_function: Dict,
|
||||
) -> None:
|
||||
expected = {
|
||||
"name": "dummy_function",
|
||||
"description": "dummy function",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"arg1": {"description": "foo", "type": "integer"},
|
||||
"arg2": {
|
||||
"description": "one of 'bar', 'baz'",
|
||||
"enum": ["bar", "baz"],
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
"required": ["arg1", "arg2"],
|
||||
},
|
||||
}
|
||||
|
||||
for fn in (pydantic, function, dummy_tool, json_schema, expected, openai_function):
|
||||
actual = convert_to_anthropic_tool(fn) # type: ignore
|
||||
assert actual == expected
|
||||
|
@@ -0,0 +1,72 @@
|
||||
from typing import Any, List, Literal
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.outputs import ChatGeneration
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
from langchain_anthropic.output_parsers import ToolsOutputParser
|
||||
|
||||
_CONTENT: List = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "thought",
|
||||
},
|
||||
{"type": "tool_use", "input": {"bar": 0}, "id": "1", "name": "_Foo1"},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "thought",
|
||||
},
|
||||
{"type": "tool_use", "input": {"baz": "a"}, "id": "2", "name": "_Foo2"},
|
||||
]
|
||||
|
||||
_RESULT: List = [ChatGeneration(message=AIMessage(_CONTENT))]
|
||||
|
||||
|
||||
class _Foo1(BaseModel):
|
||||
bar: int
|
||||
|
||||
|
||||
class _Foo2(BaseModel):
|
||||
baz: Literal["a", "b"]
|
||||
|
||||
|
||||
def test_tools_output_parser() -> None:
|
||||
output_parser = ToolsOutputParser()
|
||||
expected = [
|
||||
{"name": "_Foo1", "args": {"bar": 0}, "id": "1", "index": 1},
|
||||
{"name": "_Foo2", "args": {"baz": "a"}, "id": "2", "index": 3},
|
||||
]
|
||||
actual = output_parser.parse_result(_RESULT)
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_tools_output_parser_args_only() -> None:
|
||||
output_parser = ToolsOutputParser(args_only=True)
|
||||
expected = [
|
||||
{"bar": 0},
|
||||
{"baz": "a"},
|
||||
]
|
||||
actual = output_parser.parse_result(_RESULT)
|
||||
assert expected == actual
|
||||
|
||||
expected = []
|
||||
actual = output_parser.parse_result([ChatGeneration(message=AIMessage(""))])
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_tools_output_parser_first_tool_only() -> None:
|
||||
output_parser = ToolsOutputParser(first_tool_only=True)
|
||||
expected: Any = {"name": "_Foo1", "args": {"bar": 0}, "id": "1", "index": 1}
|
||||
actual = output_parser.parse_result(_RESULT)
|
||||
assert expected == actual
|
||||
|
||||
expected = None
|
||||
actual = output_parser.parse_result([ChatGeneration(message=AIMessage(""))])
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_tools_output_parser_pydantic() -> None:
|
||||
output_parser = ToolsOutputParser(pydantic_schemas=[_Foo1, _Foo2])
|
||||
expected = [_Foo1(bar=0), _Foo2(baz="a")]
|
||||
actual = output_parser.parse_result(_RESULT)
|
||||
assert expected == actual
|
Reference in New Issue
Block a user