anthropic[minor]: add tool calling (#18554)

This commit is contained in:
Erick Friis
2024-03-05 08:30:16 -08:00
committed by GitHub
parent 5fc67ca2c7
commit 4ac2cb4adc
8 changed files with 496 additions and 203 deletions

View File

@@ -8,19 +8,17 @@ This package contains the LangChain integration for Anthropic's generative model
## Chat Models
| API Model Name | Model Family |
| ------------------ | -------------- |
| claude-instant-1.2 | Claude Instant |
| claude-2.1 | Claude |
| claude-2.0 | Claude |
Anthropic recommends using their chat models over text completions.
You can see their recommended models [here](https://docs.anthropic.com/claude/docs/models-overview#model-recommendations).
To use, you should have an Anthropic API key configured. Initialize the model as:
```
from langchain_anthropic import ChatAnthropicMessages
from langchain_anthropic import ChatAnthropic
from langchain_core.messages import AIMessage, HumanMessage
model = ChatAnthropicMessages(model="claude-2.1", temperature=0, max_tokens=1024)
model = ChatAnthropic(model="claude-3-opus-20240229", temperature=0, max_tokens=1024)
```
### Define the input message
@@ -32,3 +30,14 @@ model = ChatAnthropicMessages(model="claude-2.1", temperature=0, max_tokens=1024
`response = model.invoke([message])`
For a more detailed walkthrough see [here](https://python.langchain.com/docs/integrations/chat/anthropic).
## LLMs (Legacy)
You can use the Claude 2 models for text completions.
```python
from langchain_anthropic import AnthropicLLM
model = AnthropicLLM(model="claude-2.1", temperature=0, max_tokens=1024)
response = model.invoke("The best restaurant in San Francisco is: ")
```

View File

@@ -256,6 +256,14 @@ class ChatAnthropic(BaseChatModel):
await run_manager.on_llm_new_token(text, chunk=chunk)
yield chunk
def _format_output(self, data: Any) -> ChatResult:
return ChatResult(
generations=[
ChatGeneration(message=AIMessage(content=data.content[0].text))
],
llm_output=data,
)
def _generate(
self,
messages: List[BaseMessage],
@@ -265,12 +273,7 @@ class ChatAnthropic(BaseChatModel):
) -> ChatResult:
params = self._format_params(messages=messages, stop=stop, **kwargs)
data = self._client.messages.create(**params)
return ChatResult(
generations=[
ChatGeneration(message=AIMessage(content=data.content[0].text))
],
llm_output=data,
)
return self._format_output(data, **kwargs)
async def _agenerate(
self,
@@ -281,12 +284,7 @@ class ChatAnthropic(BaseChatModel):
) -> ChatResult:
params = self._format_params(messages=messages, stop=stop, **kwargs)
data = await self._async_client.messages.create(**params)
return ChatResult(
generations=[
ChatGeneration(message=AIMessage(content=data.content[0].text))
],
llm_output=data,
)
return self._format_output(data, **kwargs)
@deprecated(since="0.1.0", removal="0.2.0", alternative="ChatAnthropic")

View File

@@ -0,0 +1,277 @@
import json
from typing import (
Any,
AsyncIterator,
Dict,
Iterator,
List,
Optional,
Sequence,
Type,
Union,
cast,
)
from langchain_core._api.beta_decorator import beta
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.messages import (
AIMessage,
BaseMessage,
BaseMessageChunk,
SystemMessage,
)
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
PydanticToolsParser,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_function
from langchain_anthropic.chat_models import ChatAnthropic
SYSTEM_PROMPT_FORMAT = """In this environment you have access to a set of tools you can use to answer the user's question.
You may call them like this:
<function_calls>
<invoke>
<tool_name>$TOOL_NAME</tool_name>
<parameters>
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
...
</parameters>
</invoke>
</function_calls>
Here are the tools available:
<tools>
{formatted_tools}
</tools>""" # noqa: E501
TOOL_FORMAT = """<tool_description>
<tool_name>{tool_name}</tool_name>
<description>{tool_description}</description>
<parameters>
{formatted_parameters}
</parameters>
</tool_description>"""
TOOL_PARAMETER_FORMAT = """<parameter>
<name>{parameter_name}</name>
<type>{parameter_type}</type>
<description>{parameter_description}</description>
</parameter>"""
def get_system_message(tools: List[Dict]) -> str:
tools_data: List[Dict] = [
{
"tool_name": tool["name"],
"tool_description": tool["description"],
"formatted_parameters": "\n".join(
[
TOOL_PARAMETER_FORMAT.format(
parameter_name=name,
parameter_type=parameter["type"],
parameter_description=parameter.get("description"),
)
for name, parameter in tool["parameters"]["properties"].items()
]
),
}
for tool in tools
]
tools_formatted = "\n".join(
[
TOOL_FORMAT.format(
tool_name=tool["tool_name"],
tool_description=tool["tool_description"],
formatted_parameters=tool["formatted_parameters"],
)
for tool in tools_data
]
)
return SYSTEM_PROMPT_FORMAT.format(formatted_tools=tools_formatted)
def _xml_to_dict(t: Any) -> Union[str, Dict[str, Any]]:
# Base case: If the element has no children, return its text or an empty string.
if len(t) == 0:
return t.text or ""
# Recursive case: The element has children. Convert them into a dictionary.
d: Dict[str, Any] = {}
for child in t:
if child.tag not in d:
d[child.tag] = _xml_to_dict(child)
else:
# Handle multiple children with the same tag
if not isinstance(d[child.tag], list):
d[child.tag] = [d[child.tag]] # Convert existing entry into a list
d[child.tag].append(_xml_to_dict(child))
return d
def _xml_to_tool_calls(elem: Any) -> List[Dict[str, Any]]:
"""
Convert an XML element and its children into a dictionary of dictionaries.
"""
invokes = elem.findall("invoke")
return [
{
"function": {
"name": invoke.find("tool_name").text,
"arguments": json.dumps(_xml_to_dict(invoke.find("parameters"))),
},
"type": "function",
}
for invoke in invokes
]
@beta()
class ChatAnthropicTools(ChatAnthropic):
"""Chat model for interacting with Anthropic functions."""
_xmllib: Any = Field(default=None)
@root_validator()
def check_xml_lib(cls, values: Dict[str, Any]) -> Dict[str, Any]:
try:
# do this as an optional dep for temporary nature of this feature
import defusedxml.ElementTree as DET # type: ignore
values["_xmllib"] = DET
except ImportError:
raise ImportError(
"Could not import defusedxml python package. "
"Please install it using `pip install defusedxml`"
)
return values
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tools to the chat model."""
formatted_tools = [convert_to_openai_function(tool) for tool in tools]
return super().bind(tools=formatted_tools, **kwargs)
def with_structured_output(
self, schema: Union[Dict, Type[BaseModel]], **kwargs: Any
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
if kwargs:
raise ValueError("kwargs are not supported for with_structured_output")
llm = self.bind_tools([schema])
if isinstance(schema, type) and issubclass(schema, BaseModel):
# schema is pydantic
return llm | PydanticToolsParser(tools=[schema], first_tool_only=True)
else:
# schema is dict
key_name = convert_to_openai_function(schema)["name"]
return llm | JsonOutputKeyToolsParser(
key_name=key_name, first_tool_only=True
)
def _format_params(
self,
*,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Dict:
tools: List[Dict] = kwargs.get("tools", None)
# experimental tools are sent in as part of system prompt, so if
# both are set, turn system prompt into tools + system prompt (tools first)
if tools:
tool_system = get_system_message(tools)
if messages[0].type == "system":
sys_content = messages[0].content
new_sys_content = f"{tool_system}\n\n{sys_content}"
messages = [SystemMessage(content=new_sys_content), *messages[1:]]
else:
messages = [SystemMessage(content=tool_system), *messages]
return super()._format_params(messages=messages, stop=stop, **kwargs)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
# streaming not supported for functions
result = self._generate(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
)
to_yield = result.generations[0]
chunk = ChatGenerationChunk(
message=cast(BaseMessageChunk, to_yield.message),
generation_info=to_yield.generation_info,
)
if run_manager:
run_manager.on_llm_new_token(
cast(str, to_yield.message.content), chunk=chunk
)
yield chunk
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
# streaming not supported for functions
result = await self._agenerate(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
)
to_yield = result.generations[0]
chunk = ChatGenerationChunk(
message=cast(BaseMessageChunk, to_yield.message),
generation_info=to_yield.generation_info,
)
if run_manager:
await run_manager.on_llm_new_token(
cast(str, to_yield.message.content), chunk=chunk
)
yield chunk
def _format_output(self, data: Any, **kwargs: Any) -> ChatResult:
"""Format the output of the model, parsing xml as a tool call."""
text = data.content[0].text
tools = kwargs.get("tools", None)
additional_kwargs: Dict[str, Any] = {}
if tools:
# parse out the xml from the text
try:
# get everything between <function_calls> and </function_calls>
start = text.find("<function_calls>")
end = text.find("</function_calls>") + len("</function_calls>")
xml_text = text[start:end]
xml = self._xmllib.fromstring(xml_text)
additional_kwargs["tool_calls"] = _xml_to_tool_calls(xml)
text = ""
except Exception:
pass
return ChatResult(
generations=[
ChatGeneration(
message=AIMessage(content=text, additional_kwargs=additional_kwargs)
)
],
llm_output=data,
)

View File

@@ -198,6 +198,17 @@ files = [
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
]
[[package]]
name = "defusedxml"
version = "0.7.1"
description = "XML bomb protection for Python stdlib modules"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
files = [
{file = "defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61"},
{file = "defusedxml-0.7.1.tar.gz", hash = "sha256:1bb3032db185915b62d7c6209c5a8792be6a32ab2fedacc84e01b52c51aa3e69"},
]
[[package]]
name = "distro"
version = "1.9.0"
@@ -1195,4 +1206,4 @@ watchmedo = ["PyYAML (>=3.10)"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "87eac6e38dbdf3658a937aa5a67b5660ff50c4d1d20271e841461020e8aa1ea1"
content-hash = "9894a8470203b5687f296626c352d47843fcb312029313f81ac582b867373bcd"

View File

@@ -1,6 +1,6 @@
[tool.poetry]
name = "langchain-anthropic"
version = "0.1.1"
version = "0.1.2"
description = "An integration package connecting AnthropicMessages and LangChain"
authors = []
readme = "README.md"
@@ -14,6 +14,7 @@ license = "MIT"
python = ">=3.8.1,<4.0"
langchain-core = "^0.1"
anthropic = ">=0.17.0,<1"
defusedxml = {version = "^0.7.1", optional = true}
[tool.poetry.group.test]
optional = true
@@ -26,6 +27,7 @@ syrupy = "^4.0.2"
pytest-watcher = "^0.3.4"
pytest-asyncio = "^0.21.1"
langchain-core = { path = "../../core", develop = true }
defusedxml = "^0.7.1"
[tool.poetry.group.codespell]
optional = true

View File

@@ -0,0 +1,129 @@
"""Test ChatAnthropic chat model."""
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel
from langchain_anthropic.experimental import ChatAnthropicTools
MODEL_NAME = "claude-3-sonnet-20240229"
#####################################
### Test Basic features, no tools ###
#####################################
def test_stream() -> None:
"""Test streaming tokens from Anthropic."""
llm = ChatAnthropicTools(model_name=MODEL_NAME)
for token in llm.stream("I'm Pickle Rick"):
assert isinstance(token.content, str)
async def test_astream() -> None:
"""Test streaming tokens from Anthropic."""
llm = ChatAnthropicTools(model_name=MODEL_NAME)
async for token in llm.astream("I'm Pickle Rick"):
assert isinstance(token.content, str)
async def test_abatch() -> None:
"""Test streaming tokens from ChatAnthropicTools."""
llm = ChatAnthropicTools(model_name=MODEL_NAME)
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
assert isinstance(token.content, str)
async def test_abatch_tags() -> None:
"""Test batch tokens from ChatAnthropicTools."""
llm = ChatAnthropicTools(model_name=MODEL_NAME)
result = await llm.abatch(
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
)
for token in result:
assert isinstance(token.content, str)
def test_batch() -> None:
"""Test batch tokens from ChatAnthropicTools."""
llm = ChatAnthropicTools(model_name=MODEL_NAME)
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
assert isinstance(token.content, str)
async def test_ainvoke() -> None:
"""Test invoke tokens from ChatAnthropicTools."""
llm = ChatAnthropicTools(model_name=MODEL_NAME)
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
assert isinstance(result.content, str)
def test_invoke() -> None:
"""Test invoke tokens from ChatAnthropicTools."""
llm = ChatAnthropicTools(model_name=MODEL_NAME)
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
assert isinstance(result.content, str)
def test_system_invoke() -> None:
"""Test invoke tokens with a system message"""
llm = ChatAnthropicTools(model_name=MODEL_NAME)
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are an expert cartographer. If asked, you are a cartographer. "
"STAY IN CHARACTER",
),
("human", "Are you a mathematician?"),
]
)
chain = prompt | llm
result = chain.invoke({})
assert isinstance(result.content, str)
##################
### Test Tools ###
##################
def test_tools() -> None:
class Person(BaseModel):
name: str
age: int
llm = ChatAnthropicTools(model_name=MODEL_NAME).bind_tools([Person])
result = llm.invoke("Erick is 27 years old")
assert result.content == "", f"content should be empty, not {result.content}"
assert "tool_calls" in result.additional_kwargs
tool_calls = result.additional_kwargs["tool_calls"]
assert len(tool_calls) == 1
tool_call = tool_calls[0]
assert tool_call["type"] == "function"
function = tool_call["function"]
assert function["name"] == "Person"
assert function["arguments"] == {"name": "Erick", "age": "27"}
def test_with_structured_output() -> None:
class Person(BaseModel):
name: str
age: int
chain = ChatAnthropicTools(model_name=MODEL_NAME).with_structured_output(Person)
result = chain.invoke("Erick is 27 years old")
assert isinstance(result, Person)
assert result.name == "Erick"
assert result.age == 27