mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-23 21:31:02 +00:00
Compare commits
11 Commits
erick/infr
...
jacob/curr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
080f287a2c | ||
|
|
36919b19b6 | ||
|
|
55a6347478 | ||
|
|
6ef6c9e7f1 | ||
|
|
b1ac3925f7 | ||
|
|
244cd5c141 | ||
|
|
412bc82c11 | ||
|
|
fc3353636a | ||
|
|
7130aa826f | ||
|
|
367b2d8dbe | ||
|
|
432ccd686d |
@@ -28,21 +28,17 @@
|
||||
"which shows how to create an agent that keeps track of a given user's favorite pets.\n",
|
||||
":::\n",
|
||||
"\n",
|
||||
"You may need to bind values to a tool that are only known at runtime. For example, the tool logic may require using the ID of the user who made the request.\n",
|
||||
"There are times where tools need to use runtime values that should not be populated by the LLM. For example, the tool logic may require using the ID of the user who made the request. In this case, allowing the LLM to control the parameter is a security risk.\n",
|
||||
"\n",
|
||||
"Most of the time, such values should not be controlled by the LLM. In fact, allowing the LLM to control the user ID may lead to a security risk.\n",
|
||||
"Instead, the LLM should only control the parameters of the tool that are meant to be controlled by the LLM, while other parameters (such as user ID) should be fixed by the application logic. These defined parameters should not be part of the tool's final schema.\n",
|
||||
"\n",
|
||||
"Instead, the LLM should only control the parameters of the tool that are meant to be controlled by the LLM, while other parameters (such as user ID) should be fixed by the application logic.\n",
|
||||
"\n",
|
||||
"This how-to guide shows a simple design pattern that creates the tool dynamically at run time and binds to them appropriate values."
|
||||
"This how-to guide shows some design patterns that create the tool dynamically at run time and binds appropriate values to them."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can bind them to chat models as follows:\n",
|
||||
"\n",
|
||||
"```{=mdx}\n",
|
||||
"import ChatModelTabs from \"@theme/ChatModelTabs\";\n",
|
||||
"\n",
|
||||
@@ -55,25 +51,14 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.2.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n",
|
||||
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython -m pip install --upgrade pip\u001b[0m\n",
|
||||
"Note: you may need to restart the kernel to use updated packages.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# | output: false\n",
|
||||
"# | echo: false\n",
|
||||
"\n",
|
||||
"%pip install -qU langchain langchain_openai\n",
|
||||
"%pip install -qU langchain_core langchain_openai\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"from getpass import getpass\n",
|
||||
@@ -90,10 +75,17 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Passing request time information\n",
|
||||
"## Using the `curry` utility function\n",
|
||||
"\n",
|
||||
"The idea is to create the tool dynamically at request time, and bind to it the appropriate information. For example,\n",
|
||||
"this information may be the user ID as resolved from the request itself."
|
||||
":::caution Compatibility\n",
|
||||
"\n",
|
||||
"This function is only available in `langchain_core>=0.2.17`.\n",
|
||||
"\n",
|
||||
":::\n",
|
||||
"\n",
|
||||
"We can bind arguments to the tool's inner function via a utility wrapper. This will use a technique called [currying](https://en.wikipedia.org/wiki/Currying) to bind arguments to the function while also removing it from the function signature.\n",
|
||||
"\n",
|
||||
"Below, we initialize a tool that lists a user's favorite pet. It requires a `user_id` that we'll curry ahead of time."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -102,18 +94,98 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from typing import List\n",
|
||||
"from langchain_core.tools import StructuredTool\n",
|
||||
"from langchain_core.utils.curry import curry\n",
|
||||
"\n",
|
||||
"from langchain_core.output_parsers import JsonOutputParser\n",
|
||||
"from langchain_core.tools import BaseTool, tool"
|
||||
"user_to_pets = {\"eugene\": [\"cats\"]}\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def list_favorite_pets(user_id: str) -> None:\n",
|
||||
" \"\"\"List favorite pets, if any.\"\"\"\n",
|
||||
" return user_to_pets.get(user_id, [])\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"curried_function = curry(list_favorite_pets, user_id=\"eugene\")\n",
|
||||
"\n",
|
||||
"curried_tool = StructuredTool.from_function(curried_function)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"If we examine the schema of the curried tool, we can see that it no longer has `user_id` as part of its signature:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'title': 'list_favorite_petsSchema',\n",
|
||||
" 'description': 'List favorite pets, if any.',\n",
|
||||
" 'type': 'object',\n",
|
||||
" 'properties': {}}"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"curried_tool.input_schema.schema()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"But if we invoke it, we can see that it returns Eugene's favorite pets, `cats`:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"['cats']"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"curried_tool.invoke({})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Using scope\n",
|
||||
"\n",
|
||||
"We can achieve a similar result by wrapping the tool declarations themselves in a function. This lets us take advantage of the closure created by the wrapper to pass a variable into each tool. Here's an example:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from typing import List\n",
|
||||
"\n",
|
||||
"from langchain_core.tools import BaseTool, tool\n",
|
||||
"\n",
|
||||
"user_to_pets = {}\n",
|
||||
"\n",
|
||||
"\n",
|
||||
@@ -133,7 +205,7 @@
|
||||
"\n",
|
||||
" @tool\n",
|
||||
" def list_favorite_pets() -> None:\n",
|
||||
" \"\"\"List favorite pets if any.\"\"\"\n",
|
||||
" \"\"\"List favorite pets, if any.\"\"\"\n",
|
||||
" return user_to_pets.get(user_id, [])\n",
|
||||
"\n",
|
||||
" return [update_favorite_pets, delete_favorite_pets, list_favorite_pets]"
|
||||
@@ -143,12 +215,12 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Verify that the tools work correctly"
|
||||
"Verify that the tools work correctly:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -169,21 +241,14 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_core.prompts import ChatPromptTemplate\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def handle_run_time_request(user_id: str, query: str):\n",
|
||||
" \"\"\"Handle run time request.\"\"\"\n",
|
||||
" tools = generate_tools_for_user(user_id)\n",
|
||||
" llm_with_tools = llm.bind_tools(tools)\n",
|
||||
" prompt = ChatPromptTemplate.from_messages(\n",
|
||||
" [(\"system\", \"You are a helpful assistant.\")],\n",
|
||||
" )\n",
|
||||
" chain = prompt | llm_with_tools\n",
|
||||
" return llm_with_tools.invoke(query)"
|
||||
]
|
||||
},
|
||||
@@ -196,7 +261,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -204,10 +269,10 @@
|
||||
"text/plain": [
|
||||
"[{'name': 'update_favorite_pets',\n",
|
||||
" 'args': {'pets': ['cats', 'parrots']},\n",
|
||||
" 'id': 'call_jJvjPXsNbFO5MMgW0q84iqCN'}]"
|
||||
" 'id': 'call_c8agYHY1COFSAgwZR11OGCmQ'}]"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -248,7 +313,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.4"
|
||||
"version": "3.10.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
42
libs/core/langchain_core/utils/curry.py
Normal file
42
libs/core/langchain_core/utils/curry.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
from functools import wraps
|
||||
from typing import Any, Callable
|
||||
|
||||
|
||||
def curry(func: Callable[..., Any], **curried_kwargs: Any) -> Callable[..., Any]:
|
||||
"""Util that wraps a function and partially applies kwargs to it.
|
||||
Returns a new function whose signature omits the curried variables.
|
||||
|
||||
Args:
|
||||
func: The function to curry.
|
||||
curried_kwargs: Arguments to apply to the function.
|
||||
|
||||
Returns:
|
||||
A new function with curried arguments applied.
|
||||
|
||||
.. versionadded:: 0.2.18
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
new_kwargs = {**curried_kwargs, **kwargs}
|
||||
return await func(*args, **new_kwargs)
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
new_kwargs = {**curried_kwargs, **kwargs}
|
||||
return func(*args, **new_kwargs)
|
||||
|
||||
sig = inspect.signature(func)
|
||||
# Create a new signature without the curried parameters
|
||||
new_params = [p for name, p in sig.parameters.items() if name not in curried_kwargs]
|
||||
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
async_wrapper = wraps(func)(async_wrapper)
|
||||
setattr(async_wrapper, "__signature__", sig.replace(parameters=new_params))
|
||||
return async_wrapper
|
||||
else:
|
||||
sync_wrapper = wraps(func)(sync_wrapper)
|
||||
setattr(sync_wrapper, "__signature__", sig.replace(parameters=new_params))
|
||||
return sync_wrapper
|
||||
51
libs/core/tests/unit_tests/utils/test_curry.py
Normal file
51
libs/core/tests/unit_tests/utils/test_curry.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.utils.curry import curry
|
||||
|
||||
|
||||
def test_curry() -> None:
|
||||
def test_fn(a: str, b: str) -> str:
|
||||
return a + b
|
||||
|
||||
curried = curry(test_fn, a="hey")
|
||||
assert curried(b=" you") == "hey you"
|
||||
|
||||
|
||||
def test_curry_with_kwargs_values() -> None:
|
||||
def test_fn(a: str, b: str, **kwargs: Any) -> str:
|
||||
return a + b + kwargs["c"]
|
||||
|
||||
curried = curry(test_fn, c=" you you")
|
||||
assert curried(a="hey", b=" hey") == "hey hey you you"
|
||||
|
||||
|
||||
def test_noop_curry() -> None:
|
||||
def test_fn(a: str, b: str) -> str:
|
||||
return a + b
|
||||
|
||||
curried = curry(test_fn)
|
||||
assert curried(a="bye", b=" you") == "bye you"
|
||||
|
||||
|
||||
async def test_async_curry() -> None:
|
||||
async def test_fn(a: str, b: str) -> str:
|
||||
return a + b
|
||||
|
||||
curried = curry(test_fn, a="hey")
|
||||
assert await curried(b=" you") == "hey you"
|
||||
|
||||
|
||||
async def test_async_curry_with_kwargs_values() -> None:
|
||||
async def test_fn(a: str, b: str, **kwargs: Any) -> str:
|
||||
return a + b + kwargs["c"]
|
||||
|
||||
curried = curry(test_fn, c=" you you")
|
||||
assert await curried(a="hey", b=" hey") == "hey hey you you"
|
||||
|
||||
|
||||
async def test_noop_async_curry() -> None:
|
||||
async def test_fn(a: str, b: str) -> str:
|
||||
return a + b
|
||||
|
||||
curried = curry(test_fn)
|
||||
assert await curried(a="bye", b=" you") == "bye you"
|
||||
Reference in New Issue
Block a user