mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 13:54:48 +00:00
anthropic[minor]: tool use (#20016)
This commit is contained in:
parent
3aacd11846
commit
209de0a561
@ -61,7 +61,10 @@
|
||||
"cell_type": "markdown",
|
||||
"id": "d1f9df276476f0bc",
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
"collapsed": false,
|
||||
"jupyter": {
|
||||
"outputs_hidden": false
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"The code provided assumes that your ANTHROPIC_API_KEY is set in your environment variables. If you would like to manually specify your API key and also choose a different model, you can use the following code:\n",
|
||||
@ -75,6 +78,17 @@
|
||||
"You can check the model comparison doc [here](https://docs.anthropic.com/claude/docs/models-overview#model-comparison)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "238bdbaa-526a-4130-89e9-523aa44bb196",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_anthropic import ChatAnthropic\n",
|
||||
"from langchain_core.prompts import ChatPromptTemplate"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
@ -99,9 +113,6 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain_anthropic import ChatAnthropic\n",
|
||||
"from langchain_core.prompts import ChatPromptTemplate\n",
|
||||
"\n",
|
||||
"chat = ChatAnthropic(temperature=0, model_name=\"claude-3-opus-20240229\")\n",
|
||||
"\n",
|
||||
"system = (\n",
|
||||
@ -288,13 +299,400 @@
|
||||
"chat.invoke(messages)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ab0174d8-7140-413c-80a9-7cf3a8b81bb4",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## [Beta] Tool-calling\n",
|
||||
"\n",
|
||||
"With Anthropic's [tool-calling, or tool-use, API](https://docs.anthropic.com/claude/docs/functions-external-tools), you can define tools for the model to invoke. This is extremely useful for building tool-using chains and agents, as well as for getting structured outputs from a model.\n",
|
||||
"\n",
|
||||
"::: {.callout-note}\n",
|
||||
"\n",
|
||||
"Anthropic's tool-calling functionality is still in beta.\n",
|
||||
"\n",
|
||||
":::\n",
|
||||
"\n",
|
||||
"### bind_tools()\n",
|
||||
"\n",
|
||||
"With `ChatAnthropic.bind_tools`, we can easily pass in Pydantic classes, dict schemas, LangChain tools, or even functions as tools to the model. Under the hood these are converted to an Anthropic tool schemas, which looks like:\n",
|
||||
"```\n",
|
||||
"{\n",
|
||||
" \"name\": \"...\",\n",
|
||||
" \"description\": \"...\",\n",
|
||||
" \"input_schema\": {...} # JSONSchema\n",
|
||||
"}\n",
|
||||
"```\n",
|
||||
"and passed in every model invocation."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3a381f8e",
|
||||
"execution_count": 4,
|
||||
"id": "42f87466-cb8e-490d-a9f8-aa0f8e9b4217",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
"source": [
|
||||
"from langchain_core.pydantic_v1 import BaseModel, Field\n",
|
||||
"\n",
|
||||
"llm = ChatAnthropic(\n",
|
||||
" model=\"claude-3-opus-20240229\",\n",
|
||||
" default_headers={\"anthropic-beta\": \"tools-2024-04-04\"},\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class GetWeather(BaseModel):\n",
|
||||
" \"\"\"Get the current weather in a given location\"\"\"\n",
|
||||
"\n",
|
||||
" location: str = Field(..., description=\"The city and state, e.g. San Francisco, CA\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"llm_with_tools = llm.bind_tools([GetWeather])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "997be6ff-3fd3-4b1c-b7e3-2e5fed4ac964",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content=[{'text': '<thinking>\\nBased on the user\\'s question, the relevant function to call is GetWeather, which requires the \"location\" parameter.\\n\\nThe user has directly specified the location as \"San Francisco\". Since San Francisco is a well known city, I can reasonably infer they mean San Francisco, CA without needing the state specified.\\n\\nAll the required parameters are provided, so I can proceed with the API call.\\n</thinking>', 'type': 'text'}, {'text': None, 'type': 'tool_use', 'id': 'toolu_01SCgExKzQ7eqSkMHfygvYuu', 'name': 'GetWeather', 'input': {'location': 'San Francisco, CA'}}], response_metadata={'id': 'msg_01GM3zQtoFv8jGQMW7abLnhi', 'model': 'claude-3-opus-20240229', 'stop_reason': 'tool_use', 'stop_sequence': None, 'usage': {'input_tokens': 487, 'output_tokens': 145}}, id='run-87b1331e-9251-4a68-acef-f0a018b639cc-0')"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"ai_msg = llm_with_tools.invoke(\n",
|
||||
" \"what is the weather like in San Francisco\",\n",
|
||||
")\n",
|
||||
"ai_msg"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1e63ac67-8c42-4468-8178-e54f13c3c5c3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Notice that the output message content is a list that contains a text block and then a tool_use block:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "7c4cd4c4-1c78-4d6c-8607-759e32a8903b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'text': '<thinking>\\nBased on the user\\'s question, the relevant function to call is GetWeather, which requires the \"location\" parameter.\\n\\nThe user has directly specified the location as \"San Francisco\". Since San Francisco is a well known city, I can reasonably infer they mean San Francisco, CA without needing the state specified.\\n\\nAll the required parameters are provided, so I can proceed with the API call.\\n</thinking>',\n",
|
||||
" 'type': 'text'}"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"ai_msg.content[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "5b92d91d-37cb-4843-8b2e-e337d2eec53e",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'text': None,\n",
|
||||
" 'type': 'tool_use',\n",
|
||||
" 'id': 'toolu_01SCgExKzQ7eqSkMHfygvYuu',\n",
|
||||
" 'name': 'GetWeather',\n",
|
||||
" 'input': {'location': 'San Francisco, CA'}}"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"ai_msg.content[1]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "90e015e0-c6e5-4ff5-8fb9-be0cd3c86395",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"::: {.callout-tip}\n",
|
||||
"\n",
|
||||
"ChatAnthropic model outputs are always a single AI message that can have either a single string or a list of content blocks. The content blocks can be text blocks or tool-duse blocks. There can be multiple of each and they can be interspersed.\n",
|
||||
"\n",
|
||||
":::"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8652ee98-814c-4ed6-9def-275eeaa9651e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Parsing tool calls\n",
|
||||
"\n",
|
||||
"The `langchain_anthropic.output_parsers.ToolsOutputParser` makes it easy to extract just the tool calls from an Anthropic AI message:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "59c175b1-0929-4ed4-a608-f0006031a3c2",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[{'name': 'GetWeather',\n",
|
||||
" 'args': {'location': 'New York City, NY'},\n",
|
||||
" 'id': 'toolu_01UK2AEWa75PUGA3DpiaHfBN',\n",
|
||||
" 'index': 1},\n",
|
||||
" {'name': 'GetWeather',\n",
|
||||
" 'args': {'location': 'Los Angeles, CA'},\n",
|
||||
" 'id': 'toolu_01M84DY7xWg9bLoX6JCArczx',\n",
|
||||
" 'index': 2},\n",
|
||||
" {'name': 'GetWeather',\n",
|
||||
" 'args': {'location': 'San Francisco, CA'},\n",
|
||||
" 'id': 'toolu_01FEasmxGpxFPwf9SF3nCTeb',\n",
|
||||
" 'index': 3},\n",
|
||||
" {'name': 'GetWeather',\n",
|
||||
" 'args': {'location': 'Cleveland, OH'},\n",
|
||||
" 'id': 'toolu_01B8fZdiyPbzWyj5cDCzGSTe',\n",
|
||||
" 'index': 4}]"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain_anthropic.output_parsers import ToolsOutputParser\n",
|
||||
"\n",
|
||||
"parser = ToolsOutputParser()\n",
|
||||
"chain = llm_with_tools | parser\n",
|
||||
"chain.invoke(\"What is the weather like in nyc, la, sf and cleveland\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c4394c23-8d79-4f2c-b0fe-7b877eaac7c7",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The `index` tells us where in the original list of content blocks each tool call was.\n",
|
||||
"\n",
|
||||
"We can pass in Pydantic classes to parse our tool calls into pydantic objects:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"id": "08f6c62c-923b-400e-9bc8-8aff417466b2",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[GetWeather(location='New York City, NY'),\n",
|
||||
" GetWeather(location='Los Angeles, CA'),\n",
|
||||
" GetWeather(location='San Francisco, CA'),\n",
|
||||
" GetWeather(location='Cleveland, OH')]"
|
||||
]
|
||||
},
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"parser = ToolsOutputParser(pydantic_schemas=[GetWeather])\n",
|
||||
"chain = llm_with_tools | parser\n",
|
||||
"chain.invoke(\"What is the weather like in nyc, la, sf and cleveland\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8ccdc039-d8ce-4460-bb2f-543753aac016",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"If we want we can return only the first tool call:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "7746c643-851f-4908-ac34-8ddbb949454d",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'name': 'GetWeather',\n",
|
||||
" 'args': {'location': 'New York City, NY'},\n",
|
||||
" 'id': 'toolu_01EjFAADbpdrML1uaSMr9tN3',\n",
|
||||
" 'index': 1}"
|
||||
]
|
||||
},
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"parser = ToolsOutputParser(first_tool_only=True)\n",
|
||||
"chain = llm_with_tools | parser\n",
|
||||
"chain.invoke(\"What is the weather like in nyc\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ab05dd51-0a9e-4b7b-b182-65cec44941ac",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### with_structured_output()\n",
|
||||
"\n",
|
||||
"The [BaseChatModel.with_structured_output interface](/docs/modules/model_io/chat/structured_output) makes it easy to get structured output from chat models. You can use `ChatAnthropic.with_structured_output`, which uses tool-calling under the hood), to get the model to more reliably return an output in a specific format:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"id": "e047b831-2338-4c2d-9ee4-0763f74e80e1",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"GetWeather(location='San Francisco, CA')"
|
||||
]
|
||||
},
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"structured_llm = llm.with_structured_output(GetWeather)\n",
|
||||
"structured_llm.invoke(\n",
|
||||
" \"what is the weather like in San Francisco\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5b61884e-3e4e-4145-b10d-188987ae1eb6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Passing tool results to model\n",
|
||||
"\n",
|
||||
"We can use `ToolMessage`s with the appropriate `tool_call_id`s to pass tool results back to the model:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "9d07a1c1-4542-440e-a1fb-392542267fb8",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content='Based on calling the GetWeather function, the weather in San Francisco, CA is:\\nRain with a high temperature of 54°F and winds from the southwest at 15-25 mph. There is a 100% chance of rain.', response_metadata={'id': 'msg_01J7nWVRPPTgae4eDpf9yR3M', 'model': 'claude-3-opus-20240229', 'stop_reason': 'end_turn', 'stop_sequence': None, 'usage': {'input_tokens': 670, 'output_tokens': 56}}, id='run-44fcd34f-9c24-464f-94dd-63bd0d22870d-0')"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain_core.messages import AIMessage, HumanMessage, ToolMessage\n",
|
||||
"\n",
|
||||
"messages = [\n",
|
||||
" HumanMessage(\"What is the weather like in San Francisco\"),\n",
|
||||
" AIMessage(\n",
|
||||
" content=[\n",
|
||||
" {\n",
|
||||
" \"text\": '<thinking>\\nBased on the user\\'s question, the relevant function to call is GetWeather, which requires the \"location\" parameter.\\n\\nThe user has directly specified the location as \"San Francisco\". Since San Francisco is a well known city, I can reasonably infer they mean San Francisco, CA without needing the state specified.\\n\\nAll the required parameters are provided, so I can proceed with the API call.\\n</thinking>',\n",
|
||||
" \"type\": \"text\",\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"type\": \"tool_use\",\n",
|
||||
" \"id\": \"toolu_01SCgExKzQ7eqSkMHfygvYuu\",\n",
|
||||
" \"name\": \"GetWeather\",\n",
|
||||
" \"input\": {\"location\": \"San Francisco, CA\"},\n",
|
||||
" \"text\": None,\n",
|
||||
" },\n",
|
||||
" ],\n",
|
||||
" ),\n",
|
||||
" ToolMessage(\n",
|
||||
" \"Rain. High 54F. Winds SW at 15 to 25 mph. Chance of rain 100%.\",\n",
|
||||
" tool_call_id=\"toolu_01SCgExKzQ7eqSkMHfygvYuu\",\n",
|
||||
" ),\n",
|
||||
"]\n",
|
||||
"llm_with_tools.invoke(messages)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1c82d198-77ce-4d5a-a65b-a98fd3c10740",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Streaming\n",
|
||||
"\n",
|
||||
"::: {.callout-warning}\n",
|
||||
"\n",
|
||||
"Anthropic does not currently support streaming tool calls. Attempting to stream will yield a single final message.\n",
|
||||
"\n",
|
||||
":::"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "d1284ddc-eb82-44be-b034-5046809536de",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/bagatur/langchain/libs/partners/anthropic/langchain_anthropic/chat_models.py:328: UserWarning: stream: Tool use is not yet supported in streaming mode.\n",
|
||||
" warnings.warn(\"stream: Tool use is not yet supported in streaming mode.\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[AIMessage(content=[{'text': '<thinking>\\nThe user is asking for the current weather in a specific location, San Francisco. The GetWeather function is the relevant tool to answer this request, as it returns the current weather for a given location.\\n\\nThe GetWeather function has one required parameter:\\nlocation: The city and state, e.g. San Francisco, CA\\n\\nThe user provided the city San Francisco in their request. They did not specify the state, but it can be reasonably inferred that they are referring to San Francisco, California since that is the most well known city with that name.\\n\\nSince the required location parameter has been provided by the user, we can proceed with calling the GetWeather function.\\n</thinking>', 'type': 'text'}, {'text': None, 'type': 'tool_use', 'id': 'toolu_01V9ZripoQzuY8HubspJy6fP', 'name': 'GetWeather', 'input': {'location': 'San Francisco, CA'}}], id='run-b825206b-5b6b-48bc-ad8d-802dee310c7f')]"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"list(llm_with_tools.stream(\"What's the weather in san francisco\"))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
@ -313,7 +711,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.4"
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -31,7 +31,7 @@ model = ChatAnthropic(model='claude-3-opus-20240229')
|
||||
|
||||
Read more in the [ChatAnthropic documentation](/docs/integrations/chat/anthropic).
|
||||
|
||||
## `AnthropicLLM`
|
||||
## [Legacy] `AnthropicLLM`
|
||||
|
||||
`AnthropicLLM` is a subclass of LangChain's `LLM`. It is a wrapper around Anthropic's
|
||||
text-based completion endpoints.
|
||||
|
@ -72,6 +72,7 @@ import ChatModelTabs from "@theme/ChatModelTabs";
|
||||
customVarName="llm"
|
||||
fireworksParams={`model="accounts/fireworks/models/firefunction-v1", temperature=0`}
|
||||
hideGoogle={true}
|
||||
hideAnthropic={true}
|
||||
/>
|
||||
|
||||
We can use the `bind_tools()` method to handle converting
|
||||
|
@ -783,7 +783,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.1"
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -39,7 +39,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": 2,
|
||||
"id": "08029f4e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -49,7 +49,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": 3,
|
||||
"id": "070bf702",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -481,6 +481,44 @@
|
||||
" \"Tell me a joke about cats, respond in JSON with `setup` and `punchline` keys\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "f94e9c7a-bfbd-409c-b3a6-59e485e4ea5b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Anthropic\n",
|
||||
"\n",
|
||||
"Anthropic's tool-calling API can be used for structuring outputs. Note that there is currently no way to force a tool-call via the API, so prompting the model correctly is still important."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "12682237-6689-4408-88b1-3595feac447f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Joke(setup='What do you call a cat that loves to bowl?', punchline='An alley cat!')"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain_anthropic import ChatAnthropic\n",
|
||||
"\n",
|
||||
"model = ChatAnthropic(\n",
|
||||
" model=\"claude-3-opus-20240229\",\n",
|
||||
" default_headers={\"anthropic-beta\": \"tools-2024-04-04\"},\n",
|
||||
")\n",
|
||||
"model_with_structure = model.with_structured_output(Joke)\n",
|
||||
"model_with_structure.invoke(\"Tell me a joke about cats\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
@ -499,7 +537,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.1"
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -13,6 +13,9 @@ class ToolMessage(BaseMessage):
|
||||
|
||||
tool_call_id: str
|
||||
"""Tool call that this message is responding to."""
|
||||
# TODO: Add is_error param?
|
||||
# is_error: bool = False
|
||||
# """Whether the tool errored."""
|
||||
|
||||
type: Literal["tool"] = "tool"
|
||||
|
||||
|
@ -1,13 +1,31 @@
|
||||
import os
|
||||
import re
|
||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union
|
||||
import warnings
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypedDict,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import anthropic
|
||||
from langchain_core._api.deprecation import deprecated
|
||||
from langchain_core._api import beta, deprecated
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
agenerate_from_stream,
|
||||
@ -17,14 +35,26 @@ from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
||||
from langchain_core.runnables import (
|
||||
Runnable,
|
||||
RunnableMap,
|
||||
RunnablePassthrough,
|
||||
)
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import (
|
||||
build_extra_kwargs,
|
||||
convert_to_secret_str,
|
||||
get_pydantic_field_names,
|
||||
)
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
|
||||
from langchain_anthropic.output_parsers import ToolsOutputParser
|
||||
|
||||
_message_type_lookups = {"human": "user", "ai": "assistant"}
|
||||
|
||||
@ -56,6 +86,41 @@ def _format_image(image_url: str) -> Dict:
|
||||
}
|
||||
|
||||
|
||||
def _merge_messages(
|
||||
messages: List[BaseMessage],
|
||||
) -> List[Union[SystemMessage, AIMessage, HumanMessage]]:
|
||||
"""Merge runs of human/tool messages into single human messages with content blocks.""" # noqa: E501
|
||||
merged: list = []
|
||||
for curr in messages:
|
||||
if isinstance(curr, ToolMessage):
|
||||
if isinstance(curr.content, str):
|
||||
curr = HumanMessage(
|
||||
[
|
||||
{
|
||||
"type": "tool_result",
|
||||
"content": curr.content,
|
||||
"tool_use_id": curr.tool_call_id,
|
||||
}
|
||||
]
|
||||
)
|
||||
else:
|
||||
curr = HumanMessage(curr.content)
|
||||
last = merged[-1] if merged else None
|
||||
if isinstance(last, HumanMessage) and isinstance(curr, HumanMessage):
|
||||
if isinstance(last.content, str):
|
||||
new_content: List = [{"type": "text", "text": last.content}]
|
||||
else:
|
||||
new_content = last.content
|
||||
if isinstance(curr.content, str):
|
||||
new_content.append({"type": "text", "text": curr.content})
|
||||
else:
|
||||
new_content.extend(curr.content)
|
||||
last.content = new_content
|
||||
else:
|
||||
merged.append(curr)
|
||||
return merged
|
||||
|
||||
|
||||
def _format_messages(messages: List[BaseMessage]) -> Tuple[Optional[str], List[Dict]]:
|
||||
"""Format messages for anthropic."""
|
||||
|
||||
@ -70,7 +135,9 @@ def _format_messages(messages: List[BaseMessage]) -> Tuple[Optional[str], List[D
|
||||
"""
|
||||
system: Optional[str] = None
|
||||
formatted_messages: List[Dict] = []
|
||||
for i, message in enumerate(messages):
|
||||
|
||||
merged_messages = _merge_messages(messages)
|
||||
for i, message in enumerate(merged_messages):
|
||||
if message.type == "system":
|
||||
if i != 0:
|
||||
raise ValueError("System message must be at beginning of message list.")
|
||||
@ -104,7 +171,7 @@ def _format_messages(messages: List[BaseMessage]) -> Tuple[Optional[str], List[D
|
||||
elif isinstance(item, dict):
|
||||
if "type" not in item:
|
||||
raise ValueError("Dict content item must have a type key")
|
||||
if item["type"] == "image_url":
|
||||
elif item["type"] == "image_url":
|
||||
# convert format
|
||||
source = _format_image(item["image_url"]["url"])
|
||||
content.append(
|
||||
@ -113,6 +180,9 @@ def _format_messages(messages: List[BaseMessage]) -> Tuple[Optional[str], List[D
|
||||
"source": source,
|
||||
}
|
||||
)
|
||||
elif item["type"] == "tool_use":
|
||||
item.pop("text", None)
|
||||
content.append(item)
|
||||
else:
|
||||
content.append(item)
|
||||
else:
|
||||
@ -175,6 +245,9 @@ class ChatAnthropic(BaseChatModel):
|
||||
|
||||
anthropic_api_key: Optional[SecretStr] = None
|
||||
|
||||
default_headers: Optional[Mapping[str, str]] = None
|
||||
"""Headers to pass to the Anthropic clients, will be used for every API call."""
|
||||
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
streaming: bool = False
|
||||
@ -207,9 +280,15 @@ class ChatAnthropic(BaseChatModel):
|
||||
or "https://api.anthropic.com"
|
||||
)
|
||||
values["anthropic_api_url"] = api_url
|
||||
values["_client"] = anthropic.Client(api_key=api_key, base_url=api_url)
|
||||
values["_client"] = anthropic.Client(
|
||||
api_key=api_key,
|
||||
base_url=api_url,
|
||||
default_headers=values.get("default_headers"),
|
||||
)
|
||||
values["_async_client"] = anthropic.AsyncClient(
|
||||
api_key=api_key, base_url=api_url
|
||||
api_key=api_key,
|
||||
base_url=api_url,
|
||||
default_headers=values.get("default_headers"),
|
||||
)
|
||||
return values
|
||||
|
||||
@ -232,6 +311,7 @@ class ChatAnthropic(BaseChatModel):
|
||||
"stop_sequences": stop,
|
||||
"system": system,
|
||||
**self.model_kwargs,
|
||||
**kwargs,
|
||||
}
|
||||
rtn = {k: v for k, v in rtn.items() if v is not None}
|
||||
|
||||
@ -245,6 +325,13 @@ class ChatAnthropic(BaseChatModel):
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
params = self._format_params(messages=messages, stop=stop, **kwargs)
|
||||
if "extra_body" in params and params["extra_body"].get("tools"):
|
||||
warnings.warn("stream: Tool use is not yet supported in streaming mode.")
|
||||
result = self._generate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
yield cast(ChatGenerationChunk, result.generations[0])
|
||||
return
|
||||
with self._client.messages.stream(**params) as stream:
|
||||
for text in stream.text_stream:
|
||||
chunk = ChatGenerationChunk(message=AIMessageChunk(content=text))
|
||||
@ -260,6 +347,13 @@ class ChatAnthropic(BaseChatModel):
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
params = self._format_params(messages=messages, stop=stop, **kwargs)
|
||||
if "extra_body" in params and params["extra_body"].get("tools"):
|
||||
warnings.warn("stream: Tool use is not yet supported in streaming mode.")
|
||||
result = await self._agenerate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
yield cast(ChatGenerationChunk, result.generations[0])
|
||||
return
|
||||
async with self._async_client.messages.stream(**params) as stream:
|
||||
async for text in stream.text_stream:
|
||||
chunk = ChatGenerationChunk(message=AIMessageChunk(content=text))
|
||||
@ -273,8 +367,12 @@ class ChatAnthropic(BaseChatModel):
|
||||
llm_output = {
|
||||
k: v for k, v in data_dict.items() if k not in ("content", "role", "type")
|
||||
}
|
||||
if len(content) == 1 and content[0]["type"] == "text":
|
||||
msg = AIMessage(content=content[0]["text"])
|
||||
else:
|
||||
msg = AIMessage(content=content)
|
||||
return ChatResult(
|
||||
generations=[ChatGeneration(message=AIMessage(content=content[0]["text"]))],
|
||||
generations=[ChatGeneration(message=msg)],
|
||||
llm_output=llm_output,
|
||||
)
|
||||
|
||||
@ -285,12 +383,17 @@ class ChatAnthropic(BaseChatModel):
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return generate_from_stream(stream_iter)
|
||||
params = self._format_params(messages=messages, stop=stop, **kwargs)
|
||||
if self.streaming:
|
||||
if "extra_body" in params and params["extra_body"].get("tools"):
|
||||
warnings.warn(
|
||||
"stream: Tool use is not yet supported in streaming mode."
|
||||
)
|
||||
else:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return generate_from_stream(stream_iter)
|
||||
data = self._client.messages.create(**params)
|
||||
return self._format_output(data, **kwargs)
|
||||
|
||||
@ -301,15 +404,91 @@ class ChatAnthropic(BaseChatModel):
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
stream_iter = self._astream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return await agenerate_from_stream(stream_iter)
|
||||
params = self._format_params(messages=messages, stop=stop, **kwargs)
|
||||
if self.streaming:
|
||||
if "extra_body" in params and params["extra_body"].get("tools"):
|
||||
warnings.warn(
|
||||
"stream: Tool use is not yet supported in streaming mode."
|
||||
)
|
||||
else:
|
||||
stream_iter = self._astream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return await agenerate_from_stream(stream_iter)
|
||||
data = await self._async_client.messages.create(**params)
|
||||
return self._format_output(data, **kwargs)
|
||||
|
||||
@beta()
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
"""Bind tool-like objects to this chat model.
|
||||
|
||||
Args:
|
||||
tools: A list of tool definitions to bind to this chat model.
|
||||
Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic
|
||||
models, callables, and BaseTools will be automatically converted to
|
||||
their schema dictionary representation.
|
||||
**kwargs: Any additional parameters to bind.
|
||||
"""
|
||||
formatted_tools = [convert_to_anthropic_tool(tool) for tool in tools]
|
||||
extra_body = kwargs.pop("extra_body", {})
|
||||
extra_body["tools"] = formatted_tools
|
||||
return self.bind(extra_body=extra_body, **kwargs)
|
||||
|
||||
@beta()
|
||||
def with_structured_output(
|
||||
self,
|
||||
schema: Union[Dict, Type[BaseModel]],
|
||||
*,
|
||||
include_raw: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
|
||||
llm = self.bind_tools([schema])
|
||||
if isinstance(schema, type) and issubclass(schema, BaseModel):
|
||||
output_parser = ToolsOutputParser(
|
||||
first_tool_only=True, pydantic_schemas=[schema]
|
||||
)
|
||||
else:
|
||||
output_parser = ToolsOutputParser(first_tool_only=True, args_only=True)
|
||||
|
||||
if include_raw:
|
||||
parser_assign = RunnablePassthrough.assign(
|
||||
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
|
||||
)
|
||||
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
|
||||
parser_with_fallback = parser_assign.with_fallbacks(
|
||||
[parser_none], exception_key="parsing_error"
|
||||
)
|
||||
return RunnableMap(raw=llm) | parser_with_fallback
|
||||
else:
|
||||
return llm | output_parser
|
||||
|
||||
|
||||
class AnthropicTool(TypedDict):
|
||||
name: str
|
||||
description: str
|
||||
input_schema: Dict[str, Any]
|
||||
|
||||
|
||||
def convert_to_anthropic_tool(
|
||||
tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool],
|
||||
) -> AnthropicTool:
|
||||
# already in Anthropic tool format
|
||||
if isinstance(tool, dict) and all(
|
||||
k in tool for k in ("name", "description", "input_schema")
|
||||
):
|
||||
return AnthropicTool(tool) # type: ignore
|
||||
else:
|
||||
formatted = convert_to_openai_tool(tool)["function"]
|
||||
return AnthropicTool(
|
||||
name=formatted["name"],
|
||||
description=formatted["description"],
|
||||
input_schema=formatted["parameters"],
|
||||
)
|
||||
|
||||
|
||||
@deprecated(since="0.1.0", removal="0.2.0", alternative="ChatAnthropic")
|
||||
class ChatAnthropicMessages(ChatAnthropic):
|
||||
|
@ -1,38 +1,13 @@
|
||||
import json
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain_core._api.beta_decorator import beta
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.output_parsers.openai_tools import (
|
||||
JsonOutputKeyToolsParser,
|
||||
PydanticToolsParser,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils.function_calling import convert_to_openai_function
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
|
||||
from langchain_anthropic.chat_models import ChatAnthropic
|
||||
|
||||
@ -168,143 +143,16 @@ def _xml_to_tool_calls(elem: Any, tools: List[Dict]) -> List[Dict[str, Any]]:
|
||||
return [_xml_to_function_call(invoke, tools) for invoke in invokes]
|
||||
|
||||
|
||||
@beta()
|
||||
@deprecated(
|
||||
"0.1.5",
|
||||
removal="0.2.0",
|
||||
alternative="ChatAnthropic",
|
||||
message=(
|
||||
"Tool-calling is now officially supported by the Anthropic API so this "
|
||||
"workaround is no longer needed."
|
||||
),
|
||||
)
|
||||
class ChatAnthropicTools(ChatAnthropic):
|
||||
"""Chat model for interacting with Anthropic functions."""
|
||||
|
||||
_xmllib: Any = Field(default=None)
|
||||
|
||||
@root_validator()
|
||||
def check_xml_lib(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
try:
|
||||
# do this as an optional dep for temporary nature of this feature
|
||||
import defusedxml.ElementTree as DET # type: ignore
|
||||
|
||||
values["_xmllib"] = DET
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import defusedxml python package. "
|
||||
"Please install it using `pip install defusedxml`"
|
||||
)
|
||||
return values
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], BaseTool]],
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
"""Bind tools to the chat model."""
|
||||
formatted_tools = [convert_to_openai_function(tool) for tool in tools]
|
||||
return super().bind(tools=formatted_tools, **kwargs)
|
||||
|
||||
def with_structured_output(
|
||||
self, schema: Union[Dict, Type[BaseModel]], **kwargs: Any
|
||||
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
|
||||
if kwargs:
|
||||
raise ValueError("kwargs are not supported for with_structured_output")
|
||||
llm = self.bind_tools([schema])
|
||||
if isinstance(schema, type) and issubclass(schema, BaseModel):
|
||||
# schema is pydantic
|
||||
return llm | PydanticToolsParser(tools=[schema], first_tool_only=True)
|
||||
else:
|
||||
# schema is dict
|
||||
key_name = convert_to_openai_function(schema)["name"]
|
||||
return llm | JsonOutputKeyToolsParser(
|
||||
key_name=key_name, first_tool_only=True
|
||||
)
|
||||
|
||||
def _format_params(
|
||||
self,
|
||||
*,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict:
|
||||
tools: List[Dict] = kwargs.get("tools", None)
|
||||
# experimental tools are sent in as part of system prompt, so if
|
||||
# both are set, turn system prompt into tools + system prompt (tools first)
|
||||
if tools:
|
||||
tool_system = get_system_message(tools)
|
||||
|
||||
if messages[0].type == "system":
|
||||
sys_content = messages[0].content
|
||||
new_sys_content = f"{tool_system}\n\n{sys_content}"
|
||||
messages = [SystemMessage(content=new_sys_content), *messages[1:]]
|
||||
else:
|
||||
messages = [SystemMessage(content=tool_system), *messages]
|
||||
|
||||
return super()._format_params(messages=messages, stop=stop, **kwargs)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
# streaming not supported for functions
|
||||
result = self._generate(
|
||||
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
to_yield = result.generations[0]
|
||||
chunk = ChatGenerationChunk(
|
||||
message=cast(BaseMessageChunk, to_yield.message),
|
||||
generation_info=to_yield.generation_info,
|
||||
)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
cast(str, to_yield.message.content), chunk=chunk
|
||||
)
|
||||
yield chunk
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
# streaming not supported for functions
|
||||
result = await self._agenerate(
|
||||
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
to_yield = result.generations[0]
|
||||
chunk = ChatGenerationChunk(
|
||||
message=cast(BaseMessageChunk, to_yield.message),
|
||||
generation_info=to_yield.generation_info,
|
||||
)
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(
|
||||
cast(str, to_yield.message.content), chunk=chunk
|
||||
)
|
||||
yield chunk
|
||||
|
||||
def _format_output(self, data: Any, **kwargs: Any) -> ChatResult:
|
||||
"""Format the output of the model, parsing xml as a tool call."""
|
||||
text = data.content[0].text
|
||||
tools = kwargs.get("tools", None)
|
||||
|
||||
additional_kwargs: Dict[str, Any] = {}
|
||||
|
||||
if tools:
|
||||
# parse out the xml from the text
|
||||
try:
|
||||
# get everything between <function_calls> and </function_calls>
|
||||
start = text.find("<function_calls>")
|
||||
end = text.find("</function_calls>") + len("</function_calls>")
|
||||
xml_text = text[start:end]
|
||||
|
||||
xml = self._xmllib.fromstring(xml_text)
|
||||
additional_kwargs["tool_calls"] = _xml_to_tool_calls(xml, tools)
|
||||
text = ""
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return ChatResult(
|
||||
generations=[
|
||||
ChatGeneration(
|
||||
message=AIMessage(content=text, additional_kwargs=additional_kwargs)
|
||||
)
|
||||
],
|
||||
llm_output=data,
|
||||
)
|
||||
|
@ -0,0 +1,66 @@
|
||||
from typing import Any, List, Optional, Type, TypedDict, cast
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.output_parsers import BaseGenerationOutputParser
|
||||
from langchain_core.outputs import ChatGeneration, Generation
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
|
||||
class _ToolCall(TypedDict):
|
||||
name: str
|
||||
args: dict
|
||||
id: str
|
||||
index: int
|
||||
|
||||
|
||||
class ToolsOutputParser(BaseGenerationOutputParser):
|
||||
first_tool_only: bool = False
|
||||
args_only: bool = False
|
||||
pydantic_schemas: Optional[List[Type[BaseModel]]] = None
|
||||
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
|
||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
||||
"""Parse a list of candidate model Generations into a specific format.
|
||||
|
||||
Args:
|
||||
result: A list of Generations to be parsed. The Generations are assumed
|
||||
to be different candidate outputs for a single model input.
|
||||
|
||||
Returns:
|
||||
Structured output.
|
||||
"""
|
||||
if not result or not isinstance(result[0], ChatGeneration):
|
||||
return None if self.first_tool_only else []
|
||||
tool_calls: List = _extract_tool_calls(result[0].message)
|
||||
if self.pydantic_schemas:
|
||||
tool_calls = [self._pydantic_parse(tc) for tc in tool_calls]
|
||||
elif self.args_only:
|
||||
tool_calls = [tc["args"] for tc in tool_calls]
|
||||
else:
|
||||
pass
|
||||
|
||||
if self.first_tool_only:
|
||||
return tool_calls[0] if tool_calls else None
|
||||
else:
|
||||
return tool_calls
|
||||
|
||||
def _pydantic_parse(self, tool_call: _ToolCall) -> BaseModel:
|
||||
cls_ = {schema.__name__: schema for schema in self.pydantic_schemas or []}[
|
||||
tool_call["name"]
|
||||
]
|
||||
return cls_(**tool_call["args"])
|
||||
|
||||
|
||||
def _extract_tool_calls(msg: BaseMessage) -> List[_ToolCall]:
|
||||
if isinstance(msg.content, str):
|
||||
return []
|
||||
tool_calls = []
|
||||
for i, block in enumerate(cast(List[dict], msg.content)):
|
||||
if block["type"] != "tool_use":
|
||||
continue
|
||||
tool_calls.append(
|
||||
_ToolCall(name=block["name"], args=block["input"], id=block["id"], index=i)
|
||||
)
|
||||
return tool_calls
|
@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "langchain-anthropic"
|
||||
version = "0.1.4"
|
||||
version = "0.1.5"
|
||||
description = "An integration package connecting AnthropicMessages and LangChain"
|
||||
authors = []
|
||||
readme = "README.md"
|
||||
|
@ -212,3 +212,47 @@ async def test_astreaming() -> None:
|
||||
response = await llm.agenerate([[HumanMessage(content="I'm Pickle Rick")]])
|
||||
assert callback_handler.llm_streams > 0
|
||||
assert isinstance(response, LLMResult)
|
||||
|
||||
|
||||
def test_tool_use() -> None:
|
||||
llm = ChatAnthropic(
|
||||
model="claude-3-opus-20240229",
|
||||
default_headers={"anthropic-beta": "tools-2024-04-04"},
|
||||
)
|
||||
|
||||
llm_with_tools = llm.bind_tools(
|
||||
[
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Get weather report for a city",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {"location": {"type": "string"}},
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
response = llm_with_tools.invoke("what's the weather in san francisco, ca")
|
||||
assert isinstance(response, AIMessage)
|
||||
assert isinstance(response.content, list)
|
||||
|
||||
|
||||
def test_with_structured_output() -> None:
|
||||
llm = ChatAnthropic(
|
||||
model="claude-3-opus-20240229",
|
||||
default_headers={"anthropic-beta": "tools-2024-04-04"},
|
||||
)
|
||||
|
||||
structured_llm = llm.with_structured_output(
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Get weather report for a city",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {"location": {"type": "string"}},
|
||||
},
|
||||
}
|
||||
)
|
||||
response = structured_llm.invoke("what's the weather in san francisco, ca")
|
||||
assert isinstance(response, dict)
|
||||
assert response["location"]
|
||||
|
@ -1,13 +1,17 @@
|
||||
"""Test chat model integration."""
|
||||
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Literal, Type
|
||||
|
||||
import pytest
|
||||
from anthropic.types import ContentBlock, Message, Usage
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from langchain_anthropic import ChatAnthropic, ChatAnthropicMessages
|
||||
from langchain_anthropic.chat_models import _merge_messages, convert_to_anthropic_tool
|
||||
|
||||
os.environ["ANTHROPIC_API_KEY"] = "foo"
|
||||
|
||||
@ -83,3 +87,175 @@ def test__format_output() -> None:
|
||||
llm = ChatAnthropic(model="test", anthropic_api_key="test")
|
||||
actual = llm._format_output(anthropic_msg)
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test__merge_messages() -> None:
|
||||
messages = [
|
||||
SystemMessage("foo"),
|
||||
HumanMessage("bar"),
|
||||
AIMessage(
|
||||
[
|
||||
{"text": "baz", "type": "text"},
|
||||
{
|
||||
"tool_input": {"a": "b"},
|
||||
"type": "tool_use",
|
||||
"id": "1",
|
||||
"text": None,
|
||||
"name": "buz",
|
||||
},
|
||||
{"text": "baz", "type": "text"},
|
||||
{
|
||||
"tool_input": {"a": "c"},
|
||||
"type": "tool_use",
|
||||
"id": "2",
|
||||
"text": None,
|
||||
"name": "blah",
|
||||
},
|
||||
]
|
||||
),
|
||||
ToolMessage("buz output", tool_call_id="1"),
|
||||
ToolMessage("blah output", tool_call_id="2"),
|
||||
HumanMessage("next thing"),
|
||||
]
|
||||
expected = [
|
||||
SystemMessage("foo"),
|
||||
HumanMessage("bar"),
|
||||
AIMessage(
|
||||
[
|
||||
{"text": "baz", "type": "text"},
|
||||
{
|
||||
"tool_input": {"a": "b"},
|
||||
"type": "tool_use",
|
||||
"id": "1",
|
||||
"text": None,
|
||||
"name": "buz",
|
||||
},
|
||||
{"text": "baz", "type": "text"},
|
||||
{
|
||||
"tool_input": {"a": "c"},
|
||||
"type": "tool_use",
|
||||
"id": "2",
|
||||
"text": None,
|
||||
"name": "blah",
|
||||
},
|
||||
]
|
||||
),
|
||||
HumanMessage(
|
||||
[
|
||||
{"type": "tool_result", "content": "buz output", "tool_use_id": "1"},
|
||||
{"type": "tool_result", "content": "blah output", "tool_use_id": "2"},
|
||||
{"type": "text", "text": "next thing"},
|
||||
]
|
||||
),
|
||||
]
|
||||
actual = _merge_messages(messages)
|
||||
assert expected == actual
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def pydantic() -> Type[BaseModel]:
|
||||
class dummy_function(BaseModel):
|
||||
"""dummy function"""
|
||||
|
||||
arg1: int = Field(..., description="foo")
|
||||
arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'")
|
||||
|
||||
return dummy_function
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def function() -> Callable:
|
||||
def dummy_function(arg1: int, arg2: Literal["bar", "baz"]) -> None:
|
||||
"""dummy function
|
||||
|
||||
Args:
|
||||
arg1: foo
|
||||
arg2: one of 'bar', 'baz'
|
||||
"""
|
||||
pass
|
||||
|
||||
return dummy_function
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def dummy_tool() -> BaseTool:
|
||||
class Schema(BaseModel):
|
||||
arg1: int = Field(..., description="foo")
|
||||
arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'")
|
||||
|
||||
class DummyFunction(BaseTool):
|
||||
args_schema: Type[BaseModel] = Schema
|
||||
name: str = "dummy_function"
|
||||
description: str = "dummy function"
|
||||
|
||||
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
pass
|
||||
|
||||
return DummyFunction()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def json_schema() -> Dict:
|
||||
return {
|
||||
"title": "dummy_function",
|
||||
"description": "dummy function",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"arg1": {"description": "foo", "type": "integer"},
|
||||
"arg2": {
|
||||
"description": "one of 'bar', 'baz'",
|
||||
"enum": ["bar", "baz"],
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
"required": ["arg1", "arg2"],
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def openai_function() -> Dict:
|
||||
return {
|
||||
"name": "dummy_function",
|
||||
"description": "dummy function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"arg1": {"description": "foo", "type": "integer"},
|
||||
"arg2": {
|
||||
"description": "one of 'bar', 'baz'",
|
||||
"enum": ["bar", "baz"],
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
"required": ["arg1", "arg2"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_convert_to_anthropic_tool(
|
||||
pydantic: Type[BaseModel],
|
||||
function: Callable,
|
||||
dummy_tool: BaseTool,
|
||||
json_schema: Dict,
|
||||
openai_function: Dict,
|
||||
) -> None:
|
||||
expected = {
|
||||
"name": "dummy_function",
|
||||
"description": "dummy function",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"arg1": {"description": "foo", "type": "integer"},
|
||||
"arg2": {
|
||||
"description": "one of 'bar', 'baz'",
|
||||
"enum": ["bar", "baz"],
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
"required": ["arg1", "arg2"],
|
||||
},
|
||||
}
|
||||
|
||||
for fn in (pydantic, function, dummy_tool, json_schema, expected, openai_function):
|
||||
actual = convert_to_anthropic_tool(fn) # type: ignore
|
||||
assert actual == expected
|
||||
|
@ -0,0 +1,72 @@
|
||||
from typing import Any, List, Literal
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.outputs import ChatGeneration
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
from langchain_anthropic.output_parsers import ToolsOutputParser
|
||||
|
||||
_CONTENT: List = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "thought",
|
||||
},
|
||||
{"type": "tool_use", "input": {"bar": 0}, "id": "1", "name": "_Foo1"},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "thought",
|
||||
},
|
||||
{"type": "tool_use", "input": {"baz": "a"}, "id": "2", "name": "_Foo2"},
|
||||
]
|
||||
|
||||
_RESULT: List = [ChatGeneration(message=AIMessage(_CONTENT))]
|
||||
|
||||
|
||||
class _Foo1(BaseModel):
|
||||
bar: int
|
||||
|
||||
|
||||
class _Foo2(BaseModel):
|
||||
baz: Literal["a", "b"]
|
||||
|
||||
|
||||
def test_tools_output_parser() -> None:
|
||||
output_parser = ToolsOutputParser()
|
||||
expected = [
|
||||
{"name": "_Foo1", "args": {"bar": 0}, "id": "1", "index": 1},
|
||||
{"name": "_Foo2", "args": {"baz": "a"}, "id": "2", "index": 3},
|
||||
]
|
||||
actual = output_parser.parse_result(_RESULT)
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_tools_output_parser_args_only() -> None:
|
||||
output_parser = ToolsOutputParser(args_only=True)
|
||||
expected = [
|
||||
{"bar": 0},
|
||||
{"baz": "a"},
|
||||
]
|
||||
actual = output_parser.parse_result(_RESULT)
|
||||
assert expected == actual
|
||||
|
||||
expected = []
|
||||
actual = output_parser.parse_result([ChatGeneration(message=AIMessage(""))])
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_tools_output_parser_first_tool_only() -> None:
|
||||
output_parser = ToolsOutputParser(first_tool_only=True)
|
||||
expected: Any = {"name": "_Foo1", "args": {"bar": 0}, "id": "1", "index": 1}
|
||||
actual = output_parser.parse_result(_RESULT)
|
||||
assert expected == actual
|
||||
|
||||
expected = None
|
||||
actual = output_parser.parse_result([ChatGeneration(message=AIMessage(""))])
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_tools_output_parser_pydantic() -> None:
|
||||
output_parser = ToolsOutputParser(pydantic_schemas=[_Foo1, _Foo2])
|
||||
expected = [_Foo1(bar=0), _Foo2(baz="a")]
|
||||
actual = output_parser.parse_result(_RESULT)
|
||||
assert expected == actual
|
Loading…
Reference in New Issue
Block a user