docs: add ai21 tool calling example (#26199)

Add tool calling example to AI21 docs
This commit is contained in:
miri-bar 2024-09-09 19:34:54 +03:00 committed by GitHub
parent 76bce42629
commit 3e48c728d5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -50,18 +50,18 @@
},
{
"cell_type": "code",
"execution_count": null,
"id": "62e0dbc3",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import os\n",
"from getpass import getpass\n",
"\n",
"os.environ[\"AI21_API_KEY\"] = getpass()"
]
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "markdown",
@ -73,14 +73,14 @@
},
{
"cell_type": "code",
"execution_count": null,
"id": "7c2e19d3-7c58-4470-9e1a-718b27a32056",
"metadata": {},
"outputs": [],
"source": [
"# os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"\n",
"# os.environ[\"LANGCHAIN_API_KEY\"] = getpass.getpass(\"Enter your LangSmith API key: \")"
]
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "markdown",
@ -115,15 +115,15 @@
},
{
"cell_type": "code",
"execution_count": 2,
"id": "c40756fb-cbf8-4d44-a293-3989d707237e",
"metadata": {},
"outputs": [],
"source": [
"from langchain_ai21 import ChatAI21\n",
"\n",
"llm = ChatAI21(model=\"jamba-instruct\", temperature=0)"
]
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "markdown",
@ -135,21 +135,8 @@
},
{
"cell_type": "code",
"execution_count": 3,
"id": "46b982dc-5d8a-46da-a711-81c03ccd6adc",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content=\"J'adore programmer.\", id='run-2e8d16d6-a06e-45cb-8d0c-1c8208645033-0')"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"messages = [\n",
" (\n",
@ -160,7 +147,9 @@
"]\n",
"ai_msg = llm.invoke(messages)\n",
"ai_msg"
]
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "markdown",
@ -174,7 +163,6 @@
},
{
"cell_type": "code",
"execution_count": 4,
"id": "39353473fce5dd2e",
"metadata": {
"collapsed": false,
@ -182,18 +170,6 @@
"outputs_hidden": false
}
},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content='Ich liebe das Programmieren.', id='run-e1bd82dc-1a7e-4b2e-bde9-ac995929ac0f-0')"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain_core.prompts import ChatPromptTemplate\n",
"\n",
@ -215,7 +191,95 @@
" \"input\": \"I love programming.\",\n",
" }\n",
")"
]
],
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "markdown",
"source": "# Tool Calls / Function Calling",
"id": "39c0ccd229927eab"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "This example shows how to use tool calling with AI21 models:",
"id": "2bf6b40be07fe2d4"
},
{
"metadata": {},
"cell_type": "code",
"source": [
"import os\n",
"from getpass import getpass\n",
"\n",
"from langchain_ai21.chat_models import ChatAI21\n",
"from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage\n",
"from langchain_core.tools import tool\n",
"from langchain_core.utils.function_calling import convert_to_openai_tool\n",
"\n",
"os.environ[\"AI21_API_KEY\"] = getpass()\n",
"\n",
"\n",
"@tool\n",
"def get_weather(location: str, date: str) -> str:\n",
" \"\"\"“Provide the weather for the specified location on the given date.”\"\"\"\n",
" if location == \"New York\" and date == \"2024-12-05\":\n",
" return \"25 celsius\"\n",
" elif location == \"New York\" and date == \"2024-12-06\":\n",
" return \"27 celsius\"\n",
" elif location == \"London\" and date == \"2024-12-05\":\n",
" return \"22 celsius\"\n",
" return \"32 celsius\"\n",
"\n",
"\n",
"llm = ChatAI21(model=\"jamba-1.5-mini\")\n",
"\n",
"llm_with_tools = llm.bind_tools([convert_to_openai_tool(get_weather)])\n",
"\n",
"chat_messages = [\n",
" SystemMessage(\n",
" content=\"You are a helpful assistant. You can use the provided tools \"\n",
" \"to assist with various tasks and provide accurate information\"\n",
" )\n",
"]\n",
"\n",
"human_messages = [\n",
" HumanMessage(\n",
" content=\"What is the forecast for the weather in New York on December 5, 2024?\"\n",
" ),\n",
" HumanMessage(content=\"And what about the 2024-12-06?\"),\n",
" HumanMessage(content=\"OK, thank you.\"),\n",
" HumanMessage(content=\"What is the expected weather in London on December 5, 2024?\"),\n",
"]\n",
"\n",
"\n",
"for human_message in human_messages:\n",
" print(f\"User: {human_message.content}\")\n",
" chat_messages.append(human_message)\n",
" response = llm_with_tools.invoke(chat_messages)\n",
" chat_messages.append(response)\n",
" if response.tool_calls:\n",
" tool_call = response.tool_calls[0]\n",
" if tool_call[\"name\"] == \"get_weather\":\n",
" weather = get_weather.invoke(\n",
" {\n",
" \"location\": tool_call[\"args\"][\"location\"],\n",
" \"date\": tool_call[\"args\"][\"date\"],\n",
" }\n",
" )\n",
" chat_messages.append(\n",
" ToolMessage(content=weather, tool_call_id=tool_call[\"id\"])\n",
" )\n",
" llm_answer = llm_with_tools.invoke(chat_messages)\n",
" print(f\"Assistant: {llm_answer.content}\")\n",
" else:\n",
" print(f\"Assistant: {response.content}\")"
],
"id": "a181a28df77120fb",
"outputs": [],
"execution_count": null
},
{
"cell_type": "markdown",