langchain/docs/docs/how_to/tools_human.ipynb

322 lines
8.7 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "b09b745d-f006-4ecc-8772-afa266c43605",
"metadata": {},
"source": [
"# How to add a human-in-the-loop for tools\n",
"\n",
"There are certain tools that we don't trust a model to execute on its own. One thing we can do in such situations is require human approval before the tool is invoked.\n",
"\n",
":::info\n",
"\n",
"This how-to guide shows a simple way to add human-in-the-loop for code running in a jupyter notebook or in a terminal.\n",
"\n",
"To build a production application, you will need to do more work to keep track of application state appropriately.\n",
"\n",
"We recommend using `langgraph` for powering such a capability. For more details, please see this [guide](https://langchain-ai.github.io/langgraph/how-tos/human-in-the-loop/).\n",
":::\n"
]
},
{
"cell_type": "markdown",
"id": "09178c30-a633-4d7b-88ea-092316f14b6f",
"metadata": {},
"source": [
"## Setup\n",
"\n",
"We'll need to install the following packages:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e44bec05-9aa4-47b1-a660-c0a183533598",
"metadata": {},
"outputs": [],
"source": [
"%pip install --upgrade --quiet langchain"
]
},
{
"cell_type": "markdown",
"id": "f09629b6-7f62-4879-a791-464739ca6b6b",
"metadata": {},
"source": [
"And set these environment variables:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "2bed0ccf-20cc-4fd3-9947-55471dd8c4da",
"metadata": {},
"outputs": [],
"source": [
"import getpass\n",
"import os\n",
"\n",
"# If you'd like to use LangSmith, uncomment the below:\n",
"# os.environ[\"LANGSMITH_TRACING\"] = \"true\"\n",
"# os.environ[\"LANGSMITH_API_KEY\"] = getpass.getpass()"
]
},
{
"cell_type": "markdown",
"id": "7ecd5d7e-7c3c-4180-8958-7db2c1e43564",
"metadata": {},
"source": [
"## Chain\n",
"\n",
"Let's create a few simple (dummy) tools and a tool-calling chain:"
]
},
{
"cell_type": "markdown",
"id": "43721981-4595-4721-bea0-5c67696426d3",
"metadata": {},
"source": [
"import ChatModelTabs from \"@theme/ChatModelTabs\";\n",
"\n",
"<ChatModelTabs customVarName=\"llm\"/>\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "e0ff02ac-e750-493b-9b09-4578711a6726",
"metadata": {},
"outputs": [],
"source": [
"# | output: false\n",
"# | echo: false\n",
"\n",
"from langchain_anthropic import ChatAnthropic\n",
"\n",
"llm = ChatAnthropic(model=\"claude-3-sonnet-20240229\", temperature=0)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "0221fdfd-2a18-4449-a123-e6b0b15bb3d9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[{'name': 'count_emails',\n",
" 'args': {'last_n_days': 5},\n",
" 'id': 'toolu_01QYZdJ4yPiqsdeENWHqioFW',\n",
" 'output': 10}]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from typing import Dict, List\n",
"\n",
"from langchain_core.messages import AIMessage\n",
"from langchain_core.runnables import Runnable, RunnablePassthrough\n",
"from langchain_core.tools import tool\n",
"\n",
"\n",
"@tool\n",
"def count_emails(last_n_days: int) -> int:\n",
" \"\"\"Dummy function to count number of e-mails. Returns 2 * last_n_days.\"\"\"\n",
" return last_n_days * 2\n",
"\n",
"\n",
"@tool\n",
"def send_email(message: str, recipient: str) -> str:\n",
" \"\"\"Dummy function for sending an e-mail.\"\"\"\n",
" return f\"Successfully sent email to {recipient}.\"\n",
"\n",
"\n",
"tools = [count_emails, send_email]\n",
"llm_with_tools = llm.bind_tools(tools)\n",
"\n",
"\n",
"def call_tools(msg: AIMessage) -> List[Dict]:\n",
" \"\"\"Simple sequential tool calling helper.\"\"\"\n",
" tool_map = {tool.name: tool for tool in tools}\n",
" tool_calls = msg.tool_calls.copy()\n",
" for tool_call in tool_calls:\n",
" tool_call[\"output\"] = tool_map[tool_call[\"name\"]].invoke(tool_call[\"args\"])\n",
" return tool_calls\n",
"\n",
"\n",
"chain = llm_with_tools | call_tools\n",
"chain.invoke(\"how many emails did i get in the last 5 days?\")"
]
},
{
"cell_type": "markdown",
"id": "258c1c7b-a765-4558-93fe-d0defbc29223",
"metadata": {},
"source": [
"## Adding human approval\n",
"\n",
"Let's add a step in the chain that will ask a person to approve or reject the tall call request.\n",
"\n",
"On rejection, the step will raise an exception which will stop execution of the rest of the chain."
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "341fb055-0315-47bc-8f72-ed6103d2981f",
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"\n",
"\n",
"class NotApproved(Exception):\n",
" \"\"\"Custom exception.\"\"\"\n",
"\n",
"\n",
"def human_approval(msg: AIMessage) -> AIMessage:\n",
" \"\"\"Responsible for passing through its input or raising an exception.\n",
"\n",
" Args:\n",
" msg: output from the chat model\n",
"\n",
" Returns:\n",
" msg: original output from the msg\n",
" \"\"\"\n",
" tool_strs = \"\\n\\n\".join(\n",
" json.dumps(tool_call, indent=2) for tool_call in msg.tool_calls\n",
" )\n",
" input_msg = (\n",
" f\"Do you approve of the following tool invocations\\n\\n{tool_strs}\\n\\n\"\n",
" \"Anything except 'Y'/'Yes' (case-insensitive) will be treated as a no.\\n >>>\"\n",
" )\n",
" resp = input(input_msg)\n",
" if resp.lower() not in (\"yes\", \"y\"):\n",
" raise NotApproved(f\"Tool invocations not approved:\\n\\n{tool_strs}\")\n",
" return msg"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "25dca07b-56ca-4b94-9955-d4f3e9895e03",
"metadata": {},
"outputs": [
{
"name": "stdin",
"output_type": "stream",
"text": [
"Do you approve of the following tool invocations\n",
"\n",
"{\n",
" \"name\": \"count_emails\",\n",
" \"args\": {\n",
" \"last_n_days\": 5\n",
" },\n",
" \"id\": \"toolu_01WbD8XeMoQaRFtsZezfsHor\"\n",
"}\n",
"\n",
"Anything except 'Y'/'Yes' (case-insensitive) will be treated as a no.\n",
" >>> yes\n"
]
},
{
"data": {
"text/plain": [
"[{'name': 'count_emails',\n",
" 'args': {'last_n_days': 5},\n",
" 'id': 'toolu_01WbD8XeMoQaRFtsZezfsHor',\n",
" 'output': 10}]"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain = llm_with_tools | human_approval | call_tools\n",
"chain.invoke(\"how many emails did i get in the last 5 days?\")"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "f558f2cd-847b-4ef9-a770-3961082b540c",
"metadata": {},
"outputs": [
{
"name": "stdin",
"output_type": "stream",
"text": [
"Do you approve of the following tool invocations\n",
"\n",
"{\n",
" \"name\": \"send_email\",\n",
" \"args\": {\n",
" \"recipient\": \"sally@gmail.com\",\n",
" \"message\": \"What's up homie\"\n",
" },\n",
" \"id\": \"toolu_014XccHFzBiVcc9GV1harV9U\"\n",
"}\n",
"\n",
"Anything except 'Y'/'Yes' (case-insensitive) will be treated as a no.\n",
" >>> no\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Tool invocations not approved:\n",
"\n",
"{\n",
" \"name\": \"send_email\",\n",
" \"args\": {\n",
" \"recipient\": \"sally@gmail.com\",\n",
" \"message\": \"What's up homie\"\n",
" },\n",
" \"id\": \"toolu_014XccHFzBiVcc9GV1harV9U\"\n",
"}\n"
]
}
],
"source": [
"try:\n",
" chain.invoke(\"Send sally@gmail.com an email saying 'What's up homie'\")\n",
"except NotApproved as e:\n",
" print()\n",
" print(e)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}