add return_direct flag to tool (#537)

adds a return_direct flag to tools, which just returns the tool output
as the final output
This commit is contained in:
Harrison Chase
2023-01-06 06:40:32 -08:00
committed by GitHub
parent 1f248c47f3
commit 4974f49bb7
4 changed files with 126 additions and 24 deletions

View File

@@ -44,7 +44,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 2,
"id": "36ed392e",
"metadata": {},
"outputs": [],
@@ -63,7 +63,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 3,
"id": "56ff7670",
"metadata": {},
"outputs": [],
@@ -94,7 +94,6 @@
"source": [
"# Construct the agent. We will use the default agent type here.\n",
"# See documentation for a full list of options.\n",
"llm = OpenAI(temperature=0)\n",
"agent = initialize_agent(tools, llm, agent=\"zero-shot-react-description\", verbose=True)"
]
},
@@ -248,10 +247,86 @@
"agent.run(\"Who is Olivia Wilde's boyfriend? What is his current age raised to the 0.23 power?\")"
]
},
{
"cell_type": "markdown",
"id": "a8d95999",
"metadata": {},
"source": [
"## Using tools to return directly\n",
"\n",
"Often, it can be desirable to have a tool output returned directly to the user, if it's called. You can do this easily with LangChain by setting the `return_direct` flag for a tool to be True."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "c8f65640",
"metadata": {},
"outputs": [],
"source": [
"llm_math_chain = LLMMathChain(llm=llm)\n",
"tools = [\n",
" Tool(\n",
" name=\"Calculator\",\n",
" func=llm_math_chain.run,\n",
" description=\"useful for when you need to answer questions about math\",\n",
" return_direct=True\n",
" )\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "4bd1c4bb",
"metadata": {},
"outputs": [],
"source": [
"llm = OpenAI(temperature=0)\n",
"agent = initialize_agent(tools, llm, agent=\"zero-shot-react-description\", verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "52fe0594",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m I need to calculate this\n",
"Action: Calculator\n",
"Action Input: 2**.12\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mAnswer: 1.2599210498948732\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'Answer: 1.2599210498948732'"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent.run(\"whats 2**.12\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3450512e",
"id": "7cc1b875",
"metadata": {},
"outputs": [],
"source": []
@@ -259,7 +334,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.0 64-bit ('llm-env')",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
@@ -273,7 +348,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.0"
"version": "3.10.9"
},
"vscode": {
"interpreter": {