Compare commits

...

11 Commits

Author SHA1 Message Date
Jacob Lee
080f287a2c Update tool_runtime.ipynb 2024-07-12 10:57:15 -07:00
jacoblee93
36919b19b6 Update imports 2024-07-12 10:55:10 -07:00
jacoblee93
55a6347478 Merge branch 'jacob/curry_tools' of https://github.com/langchain-ai/langchain into jacob/currying_docs 2024-07-12 10:53:01 -07:00
Jacob Lee
6ef6c9e7f1 Revert __init__.py 2024-07-12 10:52:28 -07:00
jacoblee93
b1ac3925f7 Adds additional runtime arg passing docs via currying 2024-07-12 10:50:43 -07:00
Jacob Lee
244cd5c141 Update __init__.py 2024-07-12 10:48:28 -07:00
Jacob Lee
412bc82c11 Update curry.py 2024-07-12 10:42:08 -07:00
jacoblee93
fc3353636a Fix lint 2024-07-11 12:55:29 -07:00
jacoblee93
7130aa826f Lint 2024-07-11 12:46:16 -07:00
Jacob Lee
367b2d8dbe Update libs/core/langchain_core/utils/curry.py
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
2024-07-11 12:44:55 -07:00
jacoblee93
432ccd686d Adds curry utils function 2024-07-11 11:19:51 -07:00
3 changed files with 200 additions and 42 deletions

View File

@@ -28,21 +28,17 @@
"which shows how to create an agent that keeps track of a given user's favorite pets.\n",
":::\n",
"\n",
"You may need to bind values to a tool that are only known at runtime. For example, the tool logic may require using the ID of the user who made the request.\n",
"There are times where tools need to use runtime values that should not be populated by the LLM. For example, the tool logic may require using the ID of the user who made the request. In this case, allowing the LLM to control the parameter is a security risk.\n",
"\n",
"Most of the time, such values should not be controlled by the LLM. In fact, allowing the LLM to control the user ID may lead to a security risk.\n",
"Instead, the LLM should only control the parameters of the tool that are meant to be controlled by the LLM, while other parameters (such as user ID) should be fixed by the application logic. These defined parameters should not be part of the tool's final schema.\n",
"\n",
"Instead, the LLM should only control the parameters of the tool that are meant to be controlled by the LLM, while other parameters (such as user ID) should be fixed by the application logic.\n",
"\n",
"This how-to guide shows a simple design pattern that creates the tool dynamically at run time and binds to them appropriate values."
"This how-to guide shows some design patterns that create the tool dynamically at run time and binds appropriate values to them."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can bind them to chat models as follows:\n",
"\n",
"```{=mdx}\n",
"import ChatModelTabs from \"@theme/ChatModelTabs\";\n",
"\n",
@@ -55,25 +51,14 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.2.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython -m pip install --upgrade pip\u001b[0m\n",
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"outputs": [],
"source": [
"# | output: false\n",
"# | echo: false\n",
"\n",
"%pip install -qU langchain langchain_openai\n",
"%pip install -qU langchain_core langchain_openai\n",
"\n",
"import os\n",
"from getpass import getpass\n",
@@ -90,10 +75,17 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Passing request time information\n",
"## Using the `curry` utility function\n",
"\n",
"The idea is to create the tool dynamically at request time, and bind to it the appropriate information. For example,\n",
"this information may be the user ID as resolved from the request itself."
":::caution Compatibility\n",
"\n",
"This function is only available in `langchain_core>=0.2.17`.\n",
"\n",
":::\n",
"\n",
"We can bind arguments to the tool's inner function via a utility wrapper. This will use a technique called [currying](https://en.wikipedia.org/wiki/Currying) to bind arguments to the function while also removing it from the function signature.\n",
"\n",
"Below, we initialize a tool that lists a user's favorite pet. It requires a `user_id` that we'll curry ahead of time."
]
},
{
@@ -102,18 +94,98 @@
"metadata": {},
"outputs": [],
"source": [
"from typing import List\n",
"from langchain_core.tools import StructuredTool\n",
"from langchain_core.utils.curry import curry\n",
"\n",
"from langchain_core.output_parsers import JsonOutputParser\n",
"from langchain_core.tools import BaseTool, tool"
"user_to_pets = {\"eugene\": [\"cats\"]}\n",
"\n",
"\n",
"def list_favorite_pets(user_id: str) -> None:\n",
" \"\"\"List favorite pets, if any.\"\"\"\n",
" return user_to_pets.get(user_id, [])\n",
"\n",
"\n",
"curried_function = curry(list_favorite_pets, user_id=\"eugene\")\n",
"\n",
"curried_tool = StructuredTool.from_function(curried_function)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If we examine the schema of the curried tool, we can see that it no longer has `user_id` as part of its signature:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'title': 'list_favorite_petsSchema',\n",
" 'description': 'List favorite pets, if any.',\n",
" 'type': 'object',\n",
" 'properties': {}}"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"curried_tool.input_schema.schema()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"But if we invoke it, we can see that it returns Eugene's favorite pets, `cats`:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['cats']"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"curried_tool.invoke({})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Using scope\n",
"\n",
"We can achieve a similar result by wrapping the tool declarations themselves in a function. This lets us take advantage of the closure created by the wrapper to pass a variable into each tool. Here's an example:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"from typing import List\n",
"\n",
"from langchain_core.tools import BaseTool, tool\n",
"\n",
"user_to_pets = {}\n",
"\n",
"\n",
@@ -133,7 +205,7 @@
"\n",
" @tool\n",
" def list_favorite_pets() -> None:\n",
" \"\"\"List favorite pets if any.\"\"\"\n",
" \"\"\"List favorite pets, if any.\"\"\"\n",
" return user_to_pets.get(user_id, [])\n",
"\n",
" return [update_favorite_pets, delete_favorite_pets, list_favorite_pets]"
@@ -143,12 +215,12 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Verify that the tools work correctly"
"Verify that the tools work correctly:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 6,
"metadata": {},
"outputs": [
{
@@ -169,21 +241,14 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"from langchain_core.prompts import ChatPromptTemplate\n",
"\n",
"\n",
"def handle_run_time_request(user_id: str, query: str):\n",
" \"\"\"Handle run time request.\"\"\"\n",
" tools = generate_tools_for_user(user_id)\n",
" llm_with_tools = llm.bind_tools(tools)\n",
" prompt = ChatPromptTemplate.from_messages(\n",
" [(\"system\", \"You are a helpful assistant.\")],\n",
" )\n",
" chain = prompt | llm_with_tools\n",
" return llm_with_tools.invoke(query)"
]
},
@@ -196,7 +261,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 8,
"metadata": {},
"outputs": [
{
@@ -204,10 +269,10 @@
"text/plain": [
"[{'name': 'update_favorite_pets',\n",
" 'args': {'pets': ['cats', 'parrots']},\n",
" 'id': 'call_jJvjPXsNbFO5MMgW0q84iqCN'}]"
" 'id': 'call_c8agYHY1COFSAgwZR11OGCmQ'}]"
]
},
"execution_count": 6,
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
@@ -248,7 +313,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.10.5"
}
},
"nbformat": 4,

View File

@@ -0,0 +1,42 @@
import asyncio
import inspect
from functools import wraps
from typing import Any, Callable
def curry(func: Callable[..., Any], **curried_kwargs: Any) -> Callable[..., Any]:
"""Util that wraps a function and partially applies kwargs to it.
Returns a new function whose signature omits the curried variables.
Args:
func: The function to curry.
curried_kwargs: Arguments to apply to the function.
Returns:
A new function with curried arguments applied.
.. versionadded:: 0.2.18
"""
@wraps(func)
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
new_kwargs = {**curried_kwargs, **kwargs}
return await func(*args, **new_kwargs)
@wraps(func)
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
new_kwargs = {**curried_kwargs, **kwargs}
return func(*args, **new_kwargs)
sig = inspect.signature(func)
# Create a new signature without the curried parameters
new_params = [p for name, p in sig.parameters.items() if name not in curried_kwargs]
if asyncio.iscoroutinefunction(func):
async_wrapper = wraps(func)(async_wrapper)
setattr(async_wrapper, "__signature__", sig.replace(parameters=new_params))
return async_wrapper
else:
sync_wrapper = wraps(func)(sync_wrapper)
setattr(sync_wrapper, "__signature__", sig.replace(parameters=new_params))
return sync_wrapper

View File

@@ -0,0 +1,51 @@
from typing import Any
from langchain_core.utils.curry import curry
def test_curry() -> None:
def test_fn(a: str, b: str) -> str:
return a + b
curried = curry(test_fn, a="hey")
assert curried(b=" you") == "hey you"
def test_curry_with_kwargs_values() -> None:
def test_fn(a: str, b: str, **kwargs: Any) -> str:
return a + b + kwargs["c"]
curried = curry(test_fn, c=" you you")
assert curried(a="hey", b=" hey") == "hey hey you you"
def test_noop_curry() -> None:
def test_fn(a: str, b: str) -> str:
return a + b
curried = curry(test_fn)
assert curried(a="bye", b=" you") == "bye you"
async def test_async_curry() -> None:
async def test_fn(a: str, b: str) -> str:
return a + b
curried = curry(test_fn, a="hey")
assert await curried(b=" you") == "hey you"
async def test_async_curry_with_kwargs_values() -> None:
async def test_fn(a: str, b: str, **kwargs: Any) -> str:
return a + b + kwargs["c"]
curried = curry(test_fn, c=" you you")
assert await curried(a="hey", b=" hey") == "hey hey you you"
async def test_noop_async_curry() -> None:
async def test_fn(a: str, b: str) -> str:
return a + b
curried = curry(test_fn)
assert await curried(a="bye", b=" you") == "bye you"