mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-12 12:11:34 +00:00
Compare commits
43 Commits
v0.0.335
...
eugene/aut
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
486d772184 | ||
|
|
29ddf63abb | ||
|
|
35318007d8 | ||
|
|
900cf6ead3 | ||
|
|
beec2d99b9 | ||
|
|
a695ba5ba4 | ||
|
|
71c277c17c | ||
|
|
e67f4b8d3b | ||
|
|
7181211d78 | ||
|
|
2c798fcc77 | ||
|
|
bf12905e7f | ||
|
|
e4cf87ab34 | ||
|
|
9bbc5af2a8 | ||
|
|
06e34e54f5 | ||
|
|
5a68ba4174 | ||
|
|
8e8c368019 | ||
|
|
acbd19fac5 | ||
|
|
aa56956d16 | ||
|
|
0d2cf70d1e | ||
|
|
188ac56c97 | ||
|
|
383f09b271 | ||
|
|
0af306fc54 | ||
|
|
15d5c49076 | ||
|
|
65660535bc | ||
|
|
6c41dd82f0 | ||
|
|
78d788c28c | ||
|
|
e9deeab37f | ||
|
|
4d595eec5b | ||
|
|
047b001336 | ||
|
|
840e936c7c | ||
|
|
6cc6b490be | ||
|
|
f45d1ed4f5 | ||
|
|
4ffc417858 | ||
|
|
9e74a70859 | ||
|
|
0997f2c0f1 | ||
|
|
437b545426 | ||
|
|
9b9d07572b | ||
|
|
6a90c6c2c8 | ||
|
|
8371187689 | ||
|
|
4309c17ffa | ||
|
|
183a9d4e66 | ||
|
|
c1b444e1e7 | ||
|
|
5f117384c0 |
0
libs/langchain/langchain/automaton/__init__.py
Normal file
0
libs/langchain/langchain/automaton/__init__.py
Normal file
126
libs/langchain/langchain/automaton/chat_agent.py
Normal file
126
libs/langchain/langchain/automaton/chat_agent.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""Generalized chat agent, works with any chat model."""
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import re
|
||||
from typing import Sequence, Union, List
|
||||
from langchain.automaton.tool_utils import generate_tool_info
|
||||
|
||||
from langchain.automaton.runnables import create_llm_program
|
||||
from langchain.automaton.typedefs import (
|
||||
MessageLog,
|
||||
AgentFinish,
|
||||
MessageLike,
|
||||
FunctionCall,
|
||||
FunctionResult,
|
||||
)
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.messages import BaseMessage, HumanMessage
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from langchain.prompts import SystemMessagePromptTemplate
|
||||
|
||||
|
||||
TEMPLATE_ = SystemMessagePromptTemplate.from_template(
|
||||
"""Respond to the human as helpfully and accurately as \
|
||||
possible. You have access to the following tools:
|
||||
{tools_description}
|
||||
|
||||
Use a blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
|
||||
Valid "action" values: "Final Answer" or {tool_names}
|
||||
|
||||
Provide only ONE action per $BLOB, as shown.
|
||||
|
||||
<action>
|
||||
{{
|
||||
"action": $TOOL_NAME,
|
||||
"action_input": $INPUT
|
||||
}}
|
||||
</action>
|
||||
|
||||
When invoking a tool do not provide any clarifying information.
|
||||
|
||||
The human will forward results of tool invocations as "Observations".
|
||||
|
||||
When you know the answer paraphrase the information in the observations properly and respond to the user. \
|
||||
If you do not know the answer use more tools.
|
||||
|
||||
You can only take a single action at a time."""
|
||||
)
|
||||
|
||||
|
||||
def generate_prompt(tools: Sequence[BaseTool]) -> MessageLog:
|
||||
"""Generate a prompt for the agent."""
|
||||
tool_info = generate_tool_info(tools)
|
||||
msg = TEMPLATE_.format(**tool_info)
|
||||
return MessageLog(messages=[msg])
|
||||
|
||||
|
||||
def decode(text: Union[BaseMessage, str]) -> MessageLike:
|
||||
"""Decode the action."""
|
||||
pattern = re.compile(r"<action>(?P<action_blob>.*?)<\/action>", re.DOTALL)
|
||||
if not isinstance(text, BaseMessage):
|
||||
raise NotImplementedError()
|
||||
_text = text.content
|
||||
match = pattern.search(_text)
|
||||
if match:
|
||||
action_blob = match.group("action_blob")
|
||||
data = ast.literal_eval(action_blob)
|
||||
name = data["action"]
|
||||
if name == "Final Answer": # Special cased "tool" for final answer
|
||||
return AgentFinish(result=data["action_input"])
|
||||
return FunctionCall(name=data["action"], arguments=data["action_input"] or {})
|
||||
else:
|
||||
return AgentFinish(result=text)
|
||||
|
||||
|
||||
def prompt_generator(log: MessageLog) -> List[BaseMessage]:
|
||||
"""Generate a prompt from a log of message like objects."""
|
||||
messages = []
|
||||
for message in log.messages:
|
||||
if isinstance(message, BaseMessage):
|
||||
messages.append(message)
|
||||
elif isinstance(message, FunctionResult):
|
||||
messages.append(
|
||||
HumanMessage(
|
||||
content=f"Observation: {message.result}",
|
||||
)
|
||||
)
|
||||
else:
|
||||
pass
|
||||
return messages
|
||||
|
||||
|
||||
class ChatAgent:
|
||||
"""An agent for chat models."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool],
|
||||
*,
|
||||
max_iterations: int = 10,
|
||||
) -> None:
|
||||
"""Initialize the chat automaton."""
|
||||
self.llm_program = create_llm_program(
|
||||
llm,
|
||||
prompt_generator=prompt_generator,
|
||||
tools=tools,
|
||||
parser=decode,
|
||||
)
|
||||
self.max_iterations = max_iterations
|
||||
|
||||
def run(self, message_log: MessageLog) -> None:
|
||||
"""Run the agent."""
|
||||
if not message_log:
|
||||
raise AssertionError(f"Expected at least one message in message_log")
|
||||
|
||||
for _ in range(self.max_iterations):
|
||||
last_message = message_log[-1]
|
||||
|
||||
if isinstance(last_message, AgentFinish):
|
||||
break
|
||||
|
||||
messages = self.llm_program.invoke(message_log)
|
||||
message_log.add_messages(messages)
|
||||
683
libs/langchain/langchain/automaton/notebooks/chat_agent.ipynb
Normal file
683
libs/langchain/langchain/automaton/notebooks/chat_agent.ipynb
Normal file
@@ -0,0 +1,683 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a91c50ff-bcb9-484b-a4bd-f56e75c182ba",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# General Chat Agent"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "80595f69-8b46-494e-bd2f-0c9538ed526a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%load_ext autoreload\n",
|
||||
"%autoreload 2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "724c11c6-c98a-47ef-8697-87cd888df7a0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from typing import List, Optional, Any\n",
|
||||
"from langchain.tools import tool, Tool\n",
|
||||
"from langchain.schema.messages import AIMessage, HumanMessage\n",
|
||||
"from langchain.chat_models.anthropic import ChatAnthropic\n",
|
||||
"from langchain.chat_models.openai import ChatOpenAI\n",
|
||||
"from langchain.automaton.typedefs import MessageLog\n",
|
||||
"from langchain.automaton.chat_agent import ChatAgent\n",
|
||||
"from langchain.automaton.tool_utils import generate_tool_info\n",
|
||||
"from langchain.prompts import ChatPromptTemplate"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "ed3855cd-06b0-4a45-8b76-88926433fc60",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_tools() -> List[Tool]:\n",
|
||||
" @tool\n",
|
||||
" def name() -> str:\n",
|
||||
" \"\"\"Use to look up the user's name\"\"\"\n",
|
||||
" return \"Eugene\"\n",
|
||||
"\n",
|
||||
" @tool\n",
|
||||
" def get_weather(city: str) -> str:\n",
|
||||
" \"\"\"Get weather in a specific city.\"\"\"\n",
|
||||
" return \"42F and sunny\"\n",
|
||||
"\n",
|
||||
" @tool\n",
|
||||
" def add(x: int, y: int) -> int:\n",
|
||||
" \"\"\"Use to add two numbers. For example; { \"x\": 2, \"y\": 10} \"\"\"\n",
|
||||
" return x + y\n",
|
||||
"\n",
|
||||
" @tool\n",
|
||||
" def add_3(x: int, y: int, z: int) -> int:\n",
|
||||
" \"\"\"Add 3 numbers together.\"\"\"\n",
|
||||
" return x + y + z\n",
|
||||
"\n",
|
||||
" @tool\n",
|
||||
" def divide(x: float, y: float) -> float:\n",
|
||||
" \"\"\"Divide x by y\"\"\"\n",
|
||||
" return x / y\n",
|
||||
"\n",
|
||||
" return list(locals().values())\n",
|
||||
"\n",
|
||||
"tools = get_tools()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "97a59c88-2f0a-4982-81e3-4ffba4f222da",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm = ChatAnthropic(temperature=0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "6d2c6418-1fcb-4f4a-9926-2cb3639e8019",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tool_info = generate_tool_info(tools)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "51d874e1-bbd8-4d47-b9b8-230b138519ce",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Generate a prompt with an example trace"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "6b8197d5-533a-4350-a4a3-cbd302733995",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sys_msg = \"\"\"Respond to the human as helpfully and accurately as \\\n",
|
||||
"possible. You have access to the following tools:\n",
|
||||
"{tools_description}\n",
|
||||
"\n",
|
||||
"Use a blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).\n",
|
||||
"\n",
|
||||
"Valid \"action\" values: \"Final Answer\" or {tool_names}\n",
|
||||
"\n",
|
||||
"Provide only ONE action per $BLOB, as shown.\n",
|
||||
"\n",
|
||||
"<action>\n",
|
||||
"{{\n",
|
||||
" \"action\": $TOOL_NAME,\n",
|
||||
" \"action_input\": $INPUT\n",
|
||||
"}}\n",
|
||||
"</action>\n",
|
||||
"\n",
|
||||
"When invoking a tool do not provide any clarifying information.\n",
|
||||
"\n",
|
||||
"The human will forward results of tool invocations as \"Observations\".\n",
|
||||
"\n",
|
||||
"When you know the answer paraphrase the information in the observations properly and respond to the user. \\\n",
|
||||
"If you do not know the answer use more tools.\n",
|
||||
"\n",
|
||||
"You can only take a single action at a time.\"\"\"\n",
|
||||
"\n",
|
||||
"messages = ChatPromptTemplate.from_messages([\n",
|
||||
" (\"system\", sys_msg),\n",
|
||||
" ('human', 'what is 5+8'),\n",
|
||||
" ('ai', '<action> {{ \"action\": \"add\", \"action_input\": {{ \"x\": 5, \"y\": 8 }} }} </action>'),\n",
|
||||
" ('human', 'Observation: 13'),\n",
|
||||
" ('ai', '5 + 8 is 13'),\n",
|
||||
"]).format_messages(**tool_info)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "2694bf0f-722f-4d1c-b10d-bfdc90eabc8e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.callbacks.manager import (\n",
|
||||
" trace_as_chain_group,\n",
|
||||
" atrace_as_chain_group,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "778b2eda-5128-40ea-949f-9727610cc194",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"agent = ChatAgent(llm, tools)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "7bec72cc-c882-415a-964e-9d53951630f3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class VerboseMessageLog(MessageLog):\n",
|
||||
" def add_messages(self, *args, **kwargs):\n",
|
||||
" for arg in args:\n",
|
||||
" print(arg)\n",
|
||||
" return super().add_messages(*args, **kwargs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "e7d1ec69-de12-4c66-abee-3b2bbabeec13",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[human: Sum of first 10 numbers starting from 1? use one tool at a time]\n",
|
||||
"[ai: <action> \n",
|
||||
"{\n",
|
||||
" \"action\": \"add\",\n",
|
||||
" \"action_input\": {\n",
|
||||
" \"x\": 1, \n",
|
||||
" \"y\": 2\n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"</action>, FunctionCall(name='add', arguments={'x': 1, 'y': 2}), FunctionResult(name='add', result=3, error=None)]\n",
|
||||
"[ai: <action>\n",
|
||||
"{\n",
|
||||
" \"action\": \"add\", \n",
|
||||
" \"action_input\": {\n",
|
||||
" \"x\": 3,\n",
|
||||
" \"y\": 3\n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"</action>, FunctionCall(name='add', arguments={'x': 3, 'y': 3}), FunctionResult(name='add', result=6, error=None)]\n",
|
||||
"[ai: <action>\n",
|
||||
"{\n",
|
||||
" \"action\": \"add\",\n",
|
||||
" \"action_input\": {\n",
|
||||
" \"x\": 6, \n",
|
||||
" \"y\": 4\n",
|
||||
" } \n",
|
||||
"}\n",
|
||||
"</action>, FunctionCall(name='add', arguments={'x': 6, 'y': 4}), FunctionResult(name='add', result=10, error=None)]\n",
|
||||
"[ai: <action>\n",
|
||||
"{\n",
|
||||
" \"action\": \"add\",\n",
|
||||
" \"action_input\": {\n",
|
||||
" \"x\": 10,\n",
|
||||
" \"y\": 5\n",
|
||||
" }\n",
|
||||
"} \n",
|
||||
"</action>, FunctionCall(name='add', arguments={'x': 10, 'y': 5}), FunctionResult(name='add', result=15, error=None)]\n",
|
||||
"[ai: Based on the observations so far, the sum of the first 10 numbers starting from 1 is 15. I will continue invoking the add tool one input at a time to calculate the full sum:\n",
|
||||
"\n",
|
||||
"<action>\n",
|
||||
"{\n",
|
||||
" \"action\": \"add\",\n",
|
||||
" \"action_input\": {\n",
|
||||
" \"x\": 15,\n",
|
||||
" \"y\": 6\n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"</action>, FunctionCall(name='add', arguments={'x': 15, 'y': 6}), FunctionResult(name='add', result=21, error=None)]\n",
|
||||
"[ai: <action>\n",
|
||||
"{\n",
|
||||
" \"action\": \"add\",\n",
|
||||
" \"action_input\": {\n",
|
||||
" \"x\": 21, \n",
|
||||
" \"y\": 7\n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"</action>, FunctionCall(name='add', arguments={'x': 21, 'y': 7}), FunctionResult(name='add', result=28, error=None)]\n",
|
||||
"[ai: <action>\n",
|
||||
"{\n",
|
||||
" \"action\": \"add\",\n",
|
||||
" \"action_input\": {\n",
|
||||
" \"x\": 28,\n",
|
||||
" \"y\": 8 \n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"</action>, FunctionCall(name='add', arguments={'x': 28, 'y': 8}), FunctionResult(name='add', result=36, error=None)]\n",
|
||||
"[ai: <action>\n",
|
||||
"{\n",
|
||||
" \"action\": \"add\",\n",
|
||||
" \"action_input\": {\n",
|
||||
" \"x\": 36,\n",
|
||||
" \"y\": 9\n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"</action>, FunctionCall(name='add', arguments={'x': 36, 'y': 9}), FunctionResult(name='add', result=45, error=None)]\n",
|
||||
"[ai: <action>\n",
|
||||
"{\n",
|
||||
" \"action\": \"add\",\n",
|
||||
" \"action_input\": {\n",
|
||||
" \"x\": 45,\n",
|
||||
" \"y\": 10\n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"</action>, FunctionCall(name='add', arguments={'x': 45, 'y': 10}), FunctionResult(name='add', result=55, error=None)]\n",
|
||||
"[ai: Based on the observations, the sum of the first 10 numbers starting from 1 is 55., AgentFinish(result=ai: Based on the observations, the sum of the first 10 numbers starting from 1 is 55.)]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"message_log = VerboseMessageLog(messages=messages)\n",
|
||||
"question = HumanMessage(content=\"Sum of first 10 numbers starting from 1? use one tool at a time\")\n",
|
||||
"message_log.add_messages([question])\n",
|
||||
"\n",
|
||||
"agent.run(message_log)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d01d4dd6-8706-4d84-9660-b0c6cc2eed1b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Visualize"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 36,
|
||||
"id": "c4ec3106-7679-4cd0-9cba-134814cfa8ca",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import html\n",
|
||||
"from langchain.automaton.chat_agent import prompt_generator"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 37,
|
||||
"id": "b9f800f1-1e90-4199-a016-015be152ae11",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"messages = prompt_generator(message_log)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 39,
|
||||
"id": "9ba80285-b28c-419d-a8a6-043133f64190",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"\n",
|
||||
" <div style=\"border: 1px solid black; padding: 10px; background-color:DodgerBlue\">\n",
|
||||
" <div style=\"display: flex;\">\n",
|
||||
" <div style=\"font-weight: bold; margin-right: 10px;\">system</div>\n",
|
||||
" <div>Respond to the human as helpfully and accurately as possible. You have access to the following tools:\n",
|
||||
"name: name() -> str - Use to look up the user's name\n",
|
||||
"get_weather: get_weather(city: str) -> str - Get weather in a specific city.\n",
|
||||
"add: add(x: int, y: int) -> int - Use to add two numbers. For example; { "x": 2, "y": 10}\n",
|
||||
"add_3: add_3(x: int, y: int, z: int) -> int - Add 3 numbers together.\n",
|
||||
"divide: divide(x: float, y: float) -> float - Divide x by y\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Use a blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).\n",
|
||||
"\n",
|
||||
"Valid "action" values: "Final Answer" or name, get_weather, add, add_3, divide\n",
|
||||
"\n",
|
||||
"Provide only ONE action per $BLOB, as shown.\n",
|
||||
"\n",
|
||||
"<action>\n",
|
||||
"{\n",
|
||||
" "action": $TOOL_NAME,\n",
|
||||
" "action_input": $INPUT\n",
|
||||
"}\n",
|
||||
"</action>\n",
|
||||
"\n",
|
||||
"When invoking a tool do not provide any clarifying information.\n",
|
||||
"\n",
|
||||
"The human will forward results of tool invocations as "Observations".\n",
|
||||
"\n",
|
||||
"When you know the answer paraphrase the information in the observations properly and respond to the user. If you do not know the answer use more tools.\n",
|
||||
"\n",
|
||||
"You can only take a single action at a time.</div>\n",
|
||||
" </div>\n",
|
||||
" </div>\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" <div style=\"border: 1px solid black; padding: 10px; background-color:Tomato\">\n",
|
||||
" <div style=\"display: flex;\">\n",
|
||||
" <div style=\"font-weight: bold; margin-right: 10px;\">human</div>\n",
|
||||
" <div>what is 5+8</div>\n",
|
||||
" </div>\n",
|
||||
" </div>\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" <div style=\"border: 1px solid black; padding: 10px; background-color:DodgerBlue\">\n",
|
||||
" <div style=\"display: flex;\">\n",
|
||||
" <div style=\"font-weight: bold; margin-right: 10px;\">ai</div>\n",
|
||||
" <div><action> { "action": "add", "action_input": { "x": 5, "y": 8 } } </action></div>\n",
|
||||
" </div>\n",
|
||||
" </div>\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" <div style=\"border: 1px solid black; padding: 10px; background-color:Tomato\">\n",
|
||||
" <div style=\"display: flex;\">\n",
|
||||
" <div style=\"font-weight: bold; margin-right: 10px;\">human</div>\n",
|
||||
" <div>Observation: 13</div>\n",
|
||||
" </div>\n",
|
||||
" </div>\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" <div style=\"border: 1px solid black; padding: 10px; background-color:DodgerBlue\">\n",
|
||||
" <div style=\"display: flex;\">\n",
|
||||
" <div style=\"font-weight: bold; margin-right: 10px;\">ai</div>\n",
|
||||
" <div>5 + 8 is 13</div>\n",
|
||||
" </div>\n",
|
||||
" </div>\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" <div style=\"border: 1px solid black; padding: 10px; background-color:Tomato\">\n",
|
||||
" <div style=\"display: flex;\">\n",
|
||||
" <div style=\"font-weight: bold; margin-right: 10px;\">human</div>\n",
|
||||
" <div>Sum of first 10 numbers starting from 1? use one tool at a time</div>\n",
|
||||
" </div>\n",
|
||||
" </div>\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" <div style=\"border: 1px solid black; padding: 10px; background-color:DodgerBlue\">\n",
|
||||
" <div style=\"display: flex;\">\n",
|
||||
" <div style=\"font-weight: bold; margin-right: 10px;\">ai</div>\n",
|
||||
" <div> <action> \n",
|
||||
"{\n",
|
||||
" "action": "add",\n",
|
||||
" "action_input": {\n",
|
||||
" "x": 1, \n",
|
||||
" "y": 2\n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"</action></div>\n",
|
||||
" </div>\n",
|
||||
" </div>\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" <div style=\"border: 1px solid black; padding: 10px; background-color:Tomato\">\n",
|
||||
" <div style=\"display: flex;\">\n",
|
||||
" <div style=\"font-weight: bold; margin-right: 10px;\">human</div>\n",
|
||||
" <div>Observation: 3</div>\n",
|
||||
" </div>\n",
|
||||
" </div>\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" <div style=\"border: 1px solid black; padding: 10px; background-color:DodgerBlue\">\n",
|
||||
" <div style=\"display: flex;\">\n",
|
||||
" <div style=\"font-weight: bold; margin-right: 10px;\">ai</div>\n",
|
||||
" <div> <action>\n",
|
||||
"{\n",
|
||||
" "action": "add", \n",
|
||||
" "action_input": {\n",
|
||||
" "x": 3,\n",
|
||||
" "y": 3\n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"</action></div>\n",
|
||||
" </div>\n",
|
||||
" </div>\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" <div style=\"border: 1px solid black; padding: 10px; background-color:Tomato\">\n",
|
||||
" <div style=\"display: flex;\">\n",
|
||||
" <div style=\"font-weight: bold; margin-right: 10px;\">human</div>\n",
|
||||
" <div>Observation: 6</div>\n",
|
||||
" </div>\n",
|
||||
" </div>\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" <div style=\"border: 1px solid black; padding: 10px; background-color:DodgerBlue\">\n",
|
||||
" <div style=\"display: flex;\">\n",
|
||||
" <div style=\"font-weight: bold; margin-right: 10px;\">ai</div>\n",
|
||||
" <div> <action>\n",
|
||||
"{\n",
|
||||
" "action": "add",\n",
|
||||
" "action_input": {\n",
|
||||
" "x": 6, \n",
|
||||
" "y": 4\n",
|
||||
" } \n",
|
||||
"}\n",
|
||||
"</action></div>\n",
|
||||
" </div>\n",
|
||||
" </div>\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" <div style=\"border: 1px solid black; padding: 10px; background-color:Tomato\">\n",
|
||||
" <div style=\"display: flex;\">\n",
|
||||
" <div style=\"font-weight: bold; margin-right: 10px;\">human</div>\n",
|
||||
" <div>Observation: 10</div>\n",
|
||||
" </div>\n",
|
||||
" </div>\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" <div style=\"border: 1px solid black; padding: 10px; background-color:DodgerBlue\">\n",
|
||||
" <div style=\"display: flex;\">\n",
|
||||
" <div style=\"font-weight: bold; margin-right: 10px;\">ai</div>\n",
|
||||
" <div> <action>\n",
|
||||
"{\n",
|
||||
" "action": "add",\n",
|
||||
" "action_input": {\n",
|
||||
" "x": 10,\n",
|
||||
" "y": 5\n",
|
||||
" }\n",
|
||||
"} \n",
|
||||
"</action></div>\n",
|
||||
" </div>\n",
|
||||
" </div>\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" <div style=\"border: 1px solid black; padding: 10px; background-color:Tomato\">\n",
|
||||
" <div style=\"display: flex;\">\n",
|
||||
" <div style=\"font-weight: bold; margin-right: 10px;\">human</div>\n",
|
||||
" <div>Observation: 15</div>\n",
|
||||
" </div>\n",
|
||||
" </div>\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" <div style=\"border: 1px solid black; padding: 10px; background-color:DodgerBlue\">\n",
|
||||
" <div style=\"display: flex;\">\n",
|
||||
" <div style=\"font-weight: bold; margin-right: 10px;\">ai</div>\n",
|
||||
" <div> Based on the observations so far, the sum of the first 10 numbers starting from 1 is 15. I will continue invoking the add tool one input at a time to calculate the full sum:\n",
|
||||
"\n",
|
||||
"<action>\n",
|
||||
"{\n",
|
||||
" "action": "add",\n",
|
||||
" "action_input": {\n",
|
||||
" "x": 15,\n",
|
||||
" "y": 6\n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"</action></div>\n",
|
||||
" </div>\n",
|
||||
" </div>\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" <div style=\"border: 1px solid black; padding: 10px; background-color:Tomato\">\n",
|
||||
" <div style=\"display: flex;\">\n",
|
||||
" <div style=\"font-weight: bold; margin-right: 10px;\">human</div>\n",
|
||||
" <div>Observation: 21</div>\n",
|
||||
" </div>\n",
|
||||
" </div>\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" <div style=\"border: 1px solid black; padding: 10px; background-color:DodgerBlue\">\n",
|
||||
" <div style=\"display: flex;\">\n",
|
||||
" <div style=\"font-weight: bold; margin-right: 10px;\">ai</div>\n",
|
||||
" <div> <action>\n",
|
||||
"{\n",
|
||||
" "action": "add",\n",
|
||||
" "action_input": {\n",
|
||||
" "x": 21, \n",
|
||||
" "y": 7\n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"</action></div>\n",
|
||||
" </div>\n",
|
||||
" </div>\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" <div style=\"border: 1px solid black; padding: 10px; background-color:Tomato\">\n",
|
||||
" <div style=\"display: flex;\">\n",
|
||||
" <div style=\"font-weight: bold; margin-right: 10px;\">human</div>\n",
|
||||
" <div>Observation: 28</div>\n",
|
||||
" </div>\n",
|
||||
" </div>\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" <div style=\"border: 1px solid black; padding: 10px; background-color:DodgerBlue\">\n",
|
||||
" <div style=\"display: flex;\">\n",
|
||||
" <div style=\"font-weight: bold; margin-right: 10px;\">ai</div>\n",
|
||||
" <div> <action>\n",
|
||||
"{\n",
|
||||
" "action": "add",\n",
|
||||
" "action_input": {\n",
|
||||
" "x": 28,\n",
|
||||
" "y": 8 \n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"</action></div>\n",
|
||||
" </div>\n",
|
||||
" </div>\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" <div style=\"border: 1px solid black; padding: 10px; background-color:Tomato\">\n",
|
||||
" <div style=\"display: flex;\">\n",
|
||||
" <div style=\"font-weight: bold; margin-right: 10px;\">human</div>\n",
|
||||
" <div>Observation: 36</div>\n",
|
||||
" </div>\n",
|
||||
" </div>\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" <div style=\"border: 1px solid black; padding: 10px; background-color:DodgerBlue\">\n",
|
||||
" <div style=\"display: flex;\">\n",
|
||||
" <div style=\"font-weight: bold; margin-right: 10px;\">ai</div>\n",
|
||||
" <div> <action>\n",
|
||||
"{\n",
|
||||
" "action": "add",\n",
|
||||
" "action_input": {\n",
|
||||
" "x": 36,\n",
|
||||
" "y": 9\n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"</action></div>\n",
|
||||
" </div>\n",
|
||||
" </div>\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" <div style=\"border: 1px solid black; padding: 10px; background-color:Tomato\">\n",
|
||||
" <div style=\"display: flex;\">\n",
|
||||
" <div style=\"font-weight: bold; margin-right: 10px;\">human</div>\n",
|
||||
" <div>Observation: 45</div>\n",
|
||||
" </div>\n",
|
||||
" </div>\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" <div style=\"border: 1px solid black; padding: 10px; background-color:DodgerBlue\">\n",
|
||||
" <div style=\"display: flex;\">\n",
|
||||
" <div style=\"font-weight: bold; margin-right: 10px;\">ai</div>\n",
|
||||
" <div> <action>\n",
|
||||
"{\n",
|
||||
" "action": "add",\n",
|
||||
" "action_input": {\n",
|
||||
" "x": 45,\n",
|
||||
" "y": 10\n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"</action></div>\n",
|
||||
" </div>\n",
|
||||
" </div>\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" <div style=\"border: 1px solid black; padding: 10px; background-color:Tomato\">\n",
|
||||
" <div style=\"display: flex;\">\n",
|
||||
" <div style=\"font-weight: bold; margin-right: 10px;\">human</div>\n",
|
||||
" <div>Observation: 55</div>\n",
|
||||
" </div>\n",
|
||||
" </div>\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" <div style=\"border: 1px solid black; padding: 10px; background-color:DodgerBlue\">\n",
|
||||
" <div style=\"display: flex;\">\n",
|
||||
" <div style=\"font-weight: bold; margin-right: 10px;\">ai</div>\n",
|
||||
" <div> Based on the observations, the sum of the first 10 numbers starting from 1 is 55.</div>\n",
|
||||
" </div>\n",
|
||||
" </div>\n"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"\n",
|
||||
"# Generate HTML divs\n",
|
||||
"html_divs = []\n",
|
||||
"for idx, msg in enumerate(messages):\n",
|
||||
" background_color = \"DodgerBlue\" if idx % 2 == 0 else 'Tomato'\n",
|
||||
" div_content = f'''\n",
|
||||
" <div style=\"border: 1px solid black; padding: 10px; background-color:{background_color}\">\n",
|
||||
" <div style=\"display: flex;\">\n",
|
||||
" <div style=\"font-weight: bold; margin-right: 10px;\">{html.escape(msg.type)}</div>\n",
|
||||
" <div>{html.escape(msg.content)}</div>\n",
|
||||
" </div>\n",
|
||||
" </div>\n",
|
||||
"'''\n",
|
||||
" html_divs.append(div_content)\n",
|
||||
"\n",
|
||||
"# Display HTML divs in a Jupyter Notebook cell\n",
|
||||
"from IPython.display import HTML, display\n",
|
||||
"\n",
|
||||
"html_output = '\\n'.join(html_divs)\n",
|
||||
"display(HTML(html_output))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
185
libs/langchain/langchain/automaton/notebooks/openai.ipynb
Normal file
185
libs/langchain/langchain/automaton/notebooks/openai.ipynb
Normal file
@@ -0,0 +1,185 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a91c50ff-bcb9-484b-a4bd-f56e75c182ba",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Open AI Chat"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "80595f69-8b46-494e-bd2f-0c9538ed526a",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The autoreload extension is already loaded. To reload it, use:\n",
|
||||
" %reload_ext autoreload\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%load_ext autoreload\n",
|
||||
"%autoreload 2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"id": "f73ed8ca-7384-4afc-af8c-bfd5b794df01",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from typing import List, Optional, Any\n",
|
||||
"from langchain.tools import tool, Tool, format_tool_to_openai_function\n",
|
||||
"from langchain.schema.messages import AIMessage, HumanMessage\n",
|
||||
"from langchain.chat_models.openai import ChatOpenAI\n",
|
||||
"from langchain.prompts import ChatPromptTemplate\n",
|
||||
"\n",
|
||||
"from langchain.automaton.typedefs import MessageLog\n",
|
||||
"from langchain.automaton.tool_utils import generate_tool_info\n",
|
||||
"from langchain.automaton.openai_agent import OpenAIAgent\n",
|
||||
"from langchain.chat_models.openai import ChatOpenAI"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 34,
|
||||
"id": "ed3855cd-06b0-4a45-8b76-88926433fc60",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_tools() -> List[Tool]:\n",
|
||||
" @tool\n",
|
||||
" def name() -> str:\n",
|
||||
" \"\"\"Use to look up the user's name\"\"\"\n",
|
||||
" return \"Eugene\"\n",
|
||||
"\n",
|
||||
" @tool\n",
|
||||
" def get_weather(city: str) -> str:\n",
|
||||
" \"\"\"Get weather in a specific city.\"\"\"\n",
|
||||
" return \"42F and sunny\"\n",
|
||||
"\n",
|
||||
" @tool\n",
|
||||
" def add(x: int, y: int) -> int:\n",
|
||||
" \"\"\"Use to add two numbers. For example; { \"x\": 2, \"y\": 10} \"\"\"\n",
|
||||
" return x + y\n",
|
||||
"\n",
|
||||
" @tool\n",
|
||||
" def add_3(x: int, y: int, z: int) -> int:\n",
|
||||
" \"\"\"Add 3 numbers together.\"\"\"\n",
|
||||
" return x + y + z\n",
|
||||
"\n",
|
||||
" @tool\n",
|
||||
" def divide(x: float, y: float) -> float:\n",
|
||||
" \"\"\"Divide x by y\"\"\"\n",
|
||||
" return x / y\n",
|
||||
"\n",
|
||||
" return list(locals().values())\n",
|
||||
"\n",
|
||||
"tools = get_tools()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 35,
|
||||
"id": "9015d76d-414f-449b-9647-310917d43cb8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"openai_funcs = [format_tool_to_openai_function(tool_) for tool_ in tools]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 36,
|
||||
"id": "97a59c88-2f0a-4982-81e3-4ffba4f222da",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm = ChatOpenAI(temperature=0).bind(functions=openai_funcs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 37,
|
||||
"id": "778b2eda-5128-40ea-949f-9727610cc194",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"agent = OpenAIAgent(llm, tools)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 42,
|
||||
"id": "7bec72cc-c882-415a-964e-9d53951630f3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class VerboseMessageLog(MessageLog):\n",
|
||||
" def add_messages(self, *args, **kwargs):\n",
|
||||
" for arg in args:\n",
|
||||
" print(arg)\n",
|
||||
" return super().add_messages(*args, **kwargs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 45,
|
||||
"id": "e7d1ec69-de12-4c66-abee-3b2bbabeec13",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[human: Sum of first 10 numbers starting from 1? use one tool at a time]\n",
|
||||
"[ai: {'function_call': {'name': 'add', 'arguments': '{\\n \"x\": 1,\\n \"y\": 2\\n}'}}, FunctionCall(name='add', arguments={'x': 1, 'y': 2}), FunctionResult(name='add', result=3, error=None)]\n",
|
||||
"[ai: {'function_call': {'name': 'add', 'arguments': '{\\n \"x\": 3,\\n \"y\": 3\\n}'}}, FunctionCall(name='add', arguments={'x': 3, 'y': 3}), FunctionResult(name='add', result=6, error=None)]\n",
|
||||
"[ai: {'function_call': {'name': 'add', 'arguments': '{\\n \"x\": 6,\\n \"y\": 4\\n}'}}, FunctionCall(name='add', arguments={'x': 6, 'y': 4}), FunctionResult(name='add', result=10, error=None)]\n",
|
||||
"[ai: {'function_call': {'name': 'add', 'arguments': '{\\n \"x\": 10,\\n \"y\": 5\\n}'}}, FunctionCall(name='add', arguments={'x': 10, 'y': 5}), FunctionResult(name='add', result=15, error=None)]\n",
|
||||
"[ai: {'function_call': {'name': 'add', 'arguments': '{\\n \"x\": 15,\\n \"y\": 6\\n}'}}, FunctionCall(name='add', arguments={'x': 15, 'y': 6}), FunctionResult(name='add', result=21, error=None)]\n",
|
||||
"[ai: {'function_call': {'name': 'add', 'arguments': '{\\n \"x\": 21,\\n \"y\": 7\\n}'}}, FunctionCall(name='add', arguments={'x': 21, 'y': 7}), FunctionResult(name='add', result=28, error=None)]\n",
|
||||
"[ai: {'function_call': {'name': 'add', 'arguments': '{\\n \"x\": 28,\\n \"y\": 8\\n}'}}, FunctionCall(name='add', arguments={'x': 28, 'y': 8}), FunctionResult(name='add', result=36, error=None)]\n",
|
||||
"[ai: {'function_call': {'name': 'add', 'arguments': '{\\n \"x\": 36,\\n \"y\": 9\\n}'}}, FunctionCall(name='add', arguments={'x': 36, 'y': 9}), FunctionResult(name='add', result=45, error=None)]\n",
|
||||
"[ai: {'function_call': {'name': 'add', 'arguments': '{\\n \"x\": 45,\\n \"y\": 10\\n}'}}, FunctionCall(name='add', arguments={'x': 45, 'y': 10}), FunctionResult(name='add', result=55, error=None)]\n",
|
||||
"[ai: The sum of the first 10 numbers starting from 1 is 55., AgentFinish(result=ai: The sum of the first 10 numbers starting from 1 is 55.)]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"message_log = VerboseMessageLog(messages=[])\n",
|
||||
"question = HumanMessage(content=\"Sum of first 10 numbers starting from 1? use one tool at a time\")\n",
|
||||
"message_log.add_messages([question])\n",
|
||||
"agent.run(message_log)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -0,0 +1,198 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9875f554-8f4d-409b-929f-00e8bc70185d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Think Act Agent"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "80595f69-8b46-494e-bd2f-0c9538ed526a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%load_ext autoreload\n",
|
||||
"%autoreload 2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "724c11c6-c98a-47ef-8697-87cd888df7a0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from typing import List, Optional, Any\n",
|
||||
"from langchain.automaton.think_act_agent import ThinkActAgent, generate_memory, ThinkActPromptGenerator\n",
|
||||
"from langchain.schema.messages import HumanMessage\n",
|
||||
"from langchain.llms import OpenAI\n",
|
||||
"from langchain.tools import tool, Tool"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "ed3855cd-06b0-4a45-8b76-88926433fc60",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_tools() -> List[Tool]:\n",
|
||||
" @tool\n",
|
||||
" def name() -> str:\n",
|
||||
" \"\"\"Use to look up the user's name\"\"\"\n",
|
||||
" return \"Eugene\"\n",
|
||||
"\n",
|
||||
" @tool\n",
|
||||
" def get_weather(city: str) -> str:\n",
|
||||
" \"\"\"Get weather in a specific city.\"\"\"\n",
|
||||
" return \"42F and sunny\"\n",
|
||||
"\n",
|
||||
" @tool\n",
|
||||
" def add(x: int, y: int) -> int:\n",
|
||||
" \"\"\"Use to add two numbers. For example; { \"x\": 2, \"y\": 10} \"\"\"\n",
|
||||
" return x + y\n",
|
||||
"\n",
|
||||
" @tool\n",
|
||||
" def divide(x: float, y: float) -> float:\n",
|
||||
" \"\"\"Divide x by y\"\"\"\n",
|
||||
" return x / y\n",
|
||||
"\n",
|
||||
" return list(locals().values())\n",
|
||||
"\n",
|
||||
"tools = get_tools()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "d416bb0e-28e4-4da3-8b10-aa563f87a48e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm = OpenAI(temperature=0)\n",
|
||||
"agent = ThinkActAgent(llm, tools)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "94535cb6-8f0d-43c8-9e23-46e4834cc72d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"message_log = generate_memory(tools)\n",
|
||||
"question = HumanMessage(content='what is the 2 + 5?')\n",
|
||||
"message_log.add_messages([question])\n",
|
||||
"agent.run(message_log)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "a3e99642-ac4c-4795-8aa0-a35037d699dd",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Respond to the human as helpfully and accurately as \n",
|
||||
"possible. You have access to the following tools:\n",
|
||||
"name: name() -> str - Use to look up the user's name\n",
|
||||
"get_weather: get_weather(city: str) -> str - Get weather in a specific city.\n",
|
||||
"add: add(x: int, y: int) -> int - Use to add two numbers. For example; { \"x\": 2, \"y\": 10}\n",
|
||||
"divide: divide(x: float, y: float) -> float - Divide x by y\n",
|
||||
"\n",
|
||||
"Use a blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).\n",
|
||||
"\n",
|
||||
"Valid \"action\" values: \"Final Answer\" or name, get_weather, add, divide\n",
|
||||
"\n",
|
||||
"You can only use a single tool at a time.\n",
|
||||
"\n",
|
||||
"<action>\n",
|
||||
"{\n",
|
||||
" \"action\": $TOOL_NAME,\n",
|
||||
" \"action_input\": $INPUT\n",
|
||||
"}\n",
|
||||
"</action>\n",
|
||||
"\n",
|
||||
"Follow this format:\n",
|
||||
"\n",
|
||||
"Question: input question to answer\n",
|
||||
"Thought: consider previous and subsequent steps\n",
|
||||
"<action>\n",
|
||||
"$BLOB\n",
|
||||
"</action>\n",
|
||||
"\n",
|
||||
"Observation: action result\n",
|
||||
"... (repeat Thought/Action/Observation N times)\n",
|
||||
"Thought: I know what to respond\n",
|
||||
"<action>\n",
|
||||
"{\n",
|
||||
" \"action\": \"Final Answer\",\n",
|
||||
" \"action_input\": \"Final response to human\"\n",
|
||||
"}\n",
|
||||
"</action>\n",
|
||||
"\n",
|
||||
"Begin:\n",
|
||||
"\n",
|
||||
"Reminder to ALWAYS respond with a valid blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is <action>$BLOB</action> then Observation.\n",
|
||||
"Question: what is the 2 + 5?\n",
|
||||
"Thought:I need to use the add tool\n",
|
||||
"<action>\n",
|
||||
"{\n",
|
||||
" \"action\": \"add\",\n",
|
||||
" \"action_input\": { \"x\": 2, \"y\": 5 }\n",
|
||||
"}\n",
|
||||
"</action>\n",
|
||||
"Observation: 7\n",
|
||||
"Thought:I know what to respond\n",
|
||||
"<action>\n",
|
||||
"{\n",
|
||||
" \"action\": \"Final Answer\",\n",
|
||||
" \"action_input\": \"The answer is 7\"\n",
|
||||
"}\n",
|
||||
"</action>\n",
|
||||
"Answer: The answer is 7\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(ThinkActPromptGenerator.from_message_log(message_log).to_string())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "aa0b8911-9117-42db-8f0d-bdfeb3472f65",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
89
libs/langchain/langchain/automaton/openai_agent.py
Normal file
89
libs/langchain/langchain/automaton/openai_agent.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""Specialized open ai functions based agent."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Sequence, List
|
||||
|
||||
from langchain.automaton.runnables import create_llm_program
|
||||
from langchain.automaton.typedefs import (
|
||||
MessageLog,
|
||||
AgentFinish,
|
||||
FunctionCall,
|
||||
FunctionResult,
|
||||
)
|
||||
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
||||
from langchain.schema import Generation
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.messages import BaseMessage, FunctionMessage, AIMessage
|
||||
from langchain.schema.output_parser import BaseGenerationOutputParser
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
|
||||
class OpenAIFunctionsParser(BaseGenerationOutputParser):
|
||||
def parse_result(self, result: List[Generation]):
|
||||
if len(result) != 1:
|
||||
raise AssertionError(f"Expected exactly one result")
|
||||
first_result = result[0]
|
||||
|
||||
message = first_result.message
|
||||
|
||||
if not isinstance(message, AIMessage) or not message.additional_kwargs:
|
||||
return AgentFinish(result=message)
|
||||
|
||||
parser = JsonOutputFunctionsParser(strict=False, args_only=False)
|
||||
try:
|
||||
function_request = parser.parse_result(result)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error parsing result: {result} {repr(e)}") from e
|
||||
|
||||
return FunctionCall(
|
||||
name=function_request["name"],
|
||||
arguments=function_request["arguments"],
|
||||
)
|
||||
|
||||
|
||||
def prompt_generator(log: MessageLog) -> List[BaseMessage]:
|
||||
"""Generate a prompt from a log of message like objects."""
|
||||
messages = []
|
||||
for message in log.messages:
|
||||
if isinstance(message, BaseMessage):
|
||||
messages.append(message)
|
||||
elif isinstance(message, FunctionResult):
|
||||
messages.append(
|
||||
FunctionMessage(name=message.name, content=json.dumps(message.result))
|
||||
)
|
||||
else:
|
||||
pass
|
||||
return messages
|
||||
|
||||
|
||||
class OpenAIAgent:
|
||||
def __init__(
|
||||
self,
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool],
|
||||
*,
|
||||
max_iterations: int = 10,
|
||||
) -> None:
|
||||
"""Initialize the chat automaton."""
|
||||
self.llm_program = create_llm_program(
|
||||
llm,
|
||||
prompt_generator=prompt_generator,
|
||||
tools=tools,
|
||||
parser=OpenAIFunctionsParser(),
|
||||
)
|
||||
self.max_iterations = max_iterations
|
||||
|
||||
def run(self, message_log: MessageLog) -> None:
|
||||
"""Run the agent."""
|
||||
if not message_log:
|
||||
raise AssertionError(f"Expected at least one message in message_log")
|
||||
|
||||
for _ in range(self.max_iterations):
|
||||
last_message = message_log[-1]
|
||||
|
||||
if isinstance(last_message, AgentFinish):
|
||||
break
|
||||
|
||||
messages = self.llm_program.invoke(message_log)
|
||||
message_log.add_messages(messages)
|
||||
108
libs/langchain/langchain/automaton/runnables.py
Normal file
108
libs/langchain/langchain/automaton/runnables.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""Module contains useful runnables for agents."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Callable, List, Optional, Union
|
||||
|
||||
from langchain.automaton.typedefs import (
|
||||
MessageLike,
|
||||
MessageLog,
|
||||
FunctionResult,
|
||||
FunctionCall,
|
||||
)
|
||||
from langchain.schema import BaseMessage, AIMessage, PromptValue
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.runnable.base import RunnableLambda, Runnable
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
|
||||
# PUBLIC API
|
||||
|
||||
|
||||
def create_tool_invoker(
|
||||
tools: Sequence[BaseTool],
|
||||
) -> Runnable[FunctionCall, FunctionResult]:
|
||||
"""Re-write with router."""
|
||||
tools_by_name = {tool.name: tool for tool in tools}
|
||||
|
||||
def func(function_call: FunctionCall) -> FunctionResult:
|
||||
"""A function that can invoke a tool using .run"""
|
||||
try:
|
||||
tool = tools_by_name[function_call.name]
|
||||
except KeyError:
|
||||
raise AssertionError(f"No such tool: {function_call.name}")
|
||||
try:
|
||||
result = tool.run(function_call.arguments or {})
|
||||
error = None
|
||||
except Exception as e:
|
||||
result = None
|
||||
error = repr(e) + repr(function_call.arguments)
|
||||
return FunctionResult(name=function_call.name, result=result, error=error)
|
||||
|
||||
return RunnableLambda(func=func)
|
||||
|
||||
|
||||
LogOrMessages = Union[MessageLog, Sequence[MessageLike]]
|
||||
|
||||
|
||||
def create_llm_program(
|
||||
llm: BaseLanguageModel,
|
||||
prompt_generator: Callable[
|
||||
[LogOrMessages], Union[str, PromptValue, Sequence[BaseMessage]]
|
||||
],
|
||||
*,
|
||||
tools: Optional[Sequence[BaseTool]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
parser: Union[
|
||||
Runnable[Union[BaseMessage, str], MessageLike],
|
||||
Callable[[Union[BaseMessage, str]], MessageLike],
|
||||
None,
|
||||
] = None,
|
||||
invoke_tools: bool = True,
|
||||
) -> Runnable[MessageLog, List[MessageLike]]:
|
||||
"""Create a runnable that can update memory."""
|
||||
|
||||
tool_invoker = create_tool_invoker(tools) if invoke_tools else None
|
||||
|
||||
def _bound(message_log: MessageLog) -> List[MessageLike]:
|
||||
"""A function that can be invoked with a message log."""
|
||||
messages = []
|
||||
prompt = prompt_generator(message_log)
|
||||
llm_chain = llm
|
||||
if stop:
|
||||
llm_chain = llm_chain.bind(stop=stop)
|
||||
|
||||
result = llm_chain.invoke(prompt)
|
||||
|
||||
if isinstance(result, BaseMessage):
|
||||
messages.append(result)
|
||||
elif isinstance(result, str):
|
||||
messages.append(AIMessage(content=result))
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported type {type(result)}")
|
||||
|
||||
if parser:
|
||||
if not isinstance(parser, Runnable):
|
||||
_parser = RunnableLambda(parser)
|
||||
else:
|
||||
_parser = parser
|
||||
parsed_result = _parser.invoke(result)
|
||||
if parsed_result:
|
||||
if not isinstance(parsed_result, MessageLike):
|
||||
raise TypeError(
|
||||
f"Expected a MessageLike type got: {type(parsed_result)}"
|
||||
)
|
||||
messages.append(parsed_result)
|
||||
|
||||
if not messages:
|
||||
raise AssertionError(f"Expected at least one message")
|
||||
last_message = messages[-1]
|
||||
|
||||
if tool_invoker and isinstance(last_message, FunctionCall):
|
||||
function_result = tool_invoker.invoke(last_message)
|
||||
messages.append(function_result)
|
||||
|
||||
return messages
|
||||
|
||||
return RunnableLambda(
|
||||
func=_bound,
|
||||
)
|
||||
41
libs/langchain/langchain/automaton/tests/test_mrkl_agent.py
Normal file
41
libs/langchain/langchain/automaton/tests/test_mrkl_agent.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
|
||||
from langchain.automaton.mrkl_agent import ActionParser
|
||||
|
||||
|
||||
def test_structured_output_chat() -> None:
|
||||
parser = StructuredChatOutputParser()
|
||||
output = parser.parse(
|
||||
"""
|
||||
```json
|
||||
{
|
||||
"action": "hello",
|
||||
"action_input": {
|
||||
"a": 2
|
||||
}
|
||||
}
|
||||
```
|
||||
"""
|
||||
)
|
||||
assert output == {}
|
||||
|
||||
|
||||
def test_parser() -> None:
|
||||
"""Tes the parser."""
|
||||
sample_text = """
|
||||
Some text before
|
||||
<action>
|
||||
{
|
||||
"key": "value",
|
||||
"number": 42
|
||||
}
|
||||
</action>
|
||||
Some text after
|
||||
"""
|
||||
action_parser = ActionParser()
|
||||
action = action_parser.decode(sample_text)
|
||||
assert action == {
|
||||
"key": "value",
|
||||
"number": 42,
|
||||
}
|
||||
@@ -0,0 +1,97 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, cast
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.automaton.openai_agent import OpenAIAgent
|
||||
from langchain.automaton.tests.utils import (
|
||||
FakeChatOpenAI,
|
||||
construct_func_invocation_message,
|
||||
)
|
||||
from langchain.automaton.typedefs import (
|
||||
FunctionCall,
|
||||
FunctionResult,
|
||||
MessageLog,
|
||||
AgentFinish,
|
||||
)
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain.tools import tool, Tool
|
||||
from langchain.tools.base import BaseTool
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def tools() -> List[BaseTool]:
|
||||
@tool
|
||||
def get_time() -> str:
|
||||
"""Get time."""
|
||||
return "9 PM"
|
||||
|
||||
@tool
|
||||
def get_location() -> str:
|
||||
"""Get location."""
|
||||
return "the park"
|
||||
|
||||
return cast(List[Tool], [get_time, get_location])
|
||||
|
||||
|
||||
def test_openai_agent(tools: List[Tool]) -> None:
|
||||
get_time, get_location = tools
|
||||
llm = FakeChatOpenAI(
|
||||
message_iter=iter(
|
||||
[
|
||||
construct_func_invocation_message(get_time, {}),
|
||||
AIMessage(
|
||||
content="The time is 9 PM.",
|
||||
),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
agent = OpenAIAgent(llm=llm, tools=tools, max_iterations=10)
|
||||
|
||||
message_log = MessageLog(
|
||||
[
|
||||
SystemMessage(
|
||||
content="What time is it?",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
expected_messages = [
|
||||
SystemMessage(
|
||||
content="What time is it?",
|
||||
),
|
||||
AIMessage(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {
|
||||
"name": "get_time",
|
||||
"arguments": "{}",
|
||||
}
|
||||
},
|
||||
),
|
||||
FunctionCall(
|
||||
name="get_time",
|
||||
arguments={},
|
||||
),
|
||||
FunctionResult(
|
||||
name="get_time",
|
||||
result="9 PM",
|
||||
error=None,
|
||||
),
|
||||
AIMessage(
|
||||
content="The time is 9 PM.",
|
||||
),
|
||||
AgentFinish(
|
||||
AIMessage(
|
||||
content="The time is 9 PM.",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
agent.run(message_log)
|
||||
assert message_log.messages == expected_messages
|
||||
46
libs/langchain/langchain/automaton/tests/utils.py
Normal file
46
libs/langchain/langchain/automaton/tests/utils.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Iterator, List, Any, Mapping
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.schema import BaseMessage, ChatResult, ChatGeneration, AIMessage
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
|
||||
class FakeChatOpenAI(BaseChatModel):
|
||||
"""A fake chat model that returns a pre-defined response."""
|
||||
|
||||
message_iter: Iterator[BaseMessage]
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""The type of the model."""
|
||||
return "fake-openai-chat-model"
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: List[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Generate a response to the given messages."""
|
||||
message = next(self.message_iter)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
|
||||
|
||||
def construct_func_invocation_message(
|
||||
tool: BaseTool, args: Mapping[str, Any]
|
||||
) -> AIMessage:
|
||||
"""Construct a function invocation message."""
|
||||
return AIMessage(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {
|
||||
"name": tool.name,
|
||||
"arguments": json.dumps(args),
|
||||
}
|
||||
},
|
||||
)
|
||||
204
libs/langchain/langchain/automaton/think_act_agent.py
Normal file
204
libs/langchain/langchain/automaton/think_act_agent.py
Normal file
@@ -0,0 +1,204 @@
|
||||
"""Implementation of a think act agent.
|
||||
|
||||
Uses priming messages with text based LLM.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import re
|
||||
from typing import Sequence, Optional, Union, List
|
||||
|
||||
from langchain.automaton.runnables import (
|
||||
create_llm_program,
|
||||
)
|
||||
from langchain.automaton.tool_utils import generate_tool_info
|
||||
from langchain.automaton.typedefs import (
|
||||
Agent,
|
||||
MessageLog,
|
||||
MessageLike,
|
||||
FunctionCall,
|
||||
FunctionResult,
|
||||
AgentFinish,
|
||||
PrimingMessage,
|
||||
)
|
||||
from langchain.prompts import SystemMessagePromptTemplate
|
||||
from langchain.schema import (
|
||||
PromptValue,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
AIMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.tools import BaseTool, Tool
|
||||
|
||||
TEMPLATE_ = """\
|
||||
Respond to the human as helpfully and accurately as
|
||||
possible. You have access to the following tools:
|
||||
{tools_description}
|
||||
Use a blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
|
||||
Valid "action" values: "Final Answer" or {tool_names}
|
||||
|
||||
You can only use a single tool at a time.
|
||||
|
||||
<action>
|
||||
{{
|
||||
"action": $TOOL_NAME,
|
||||
"action_input": $INPUT
|
||||
}}
|
||||
</action>
|
||||
|
||||
Follow this format:
|
||||
|
||||
Question: input question to answer
|
||||
Thought: consider previous and subsequent steps
|
||||
<action>
|
||||
$BLOB
|
||||
</action>
|
||||
|
||||
Observation: action result
|
||||
... (repeat Thought/Action/Observation N times)
|
||||
Thought: I know what to respond
|
||||
<action>
|
||||
{{
|
||||
"action": "Final Answer",
|
||||
"action_input": "Final response to human"
|
||||
}}
|
||||
</action>
|
||||
|
||||
Begin:
|
||||
|
||||
Reminder to ALWAYS respond with a valid blob of a single action. \
|
||||
Use tools if necessary. \
|
||||
Respond directly if appropriate. \
|
||||
Format is <action>$BLOB</action> then Observation. \
|
||||
"""
|
||||
|
||||
|
||||
def generate_memory(tools: Sequence[Tool]) -> MessageLog:
|
||||
"""Set up basic memory for agent."""
|
||||
tools_info = generate_tool_info(tools)
|
||||
|
||||
return MessageLog(
|
||||
messages=[
|
||||
SystemMessagePromptTemplate.from_template(TEMPLATE_).format(**tools_info)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class ActionParser:
|
||||
"""A utility class to encode and decode action blocks."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the ActionParser."""
|
||||
self.pattern = re.compile(r"<action>(?P<action_blob>.*?)<\/action>", re.DOTALL)
|
||||
|
||||
def decode(self, text: Union[BaseMessage, str]) -> Optional[MessageLike]:
|
||||
"""Decode the action."""
|
||||
if isinstance(text, BaseMessage):
|
||||
text = text.content
|
||||
match = self.pattern.search(text)
|
||||
if match:
|
||||
action_blob = match.group("action_blob")
|
||||
data = ast.literal_eval(action_blob)
|
||||
name = data["action"]
|
||||
if name == "Final Answer": # Special cased "tool" for final answer
|
||||
return AgentFinish(result=data["action_input"])
|
||||
return FunctionCall(
|
||||
name=data["action"], arguments=data["action_input"] or {}
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class ThinkActPromptGenerator(PromptValue):
|
||||
"""think-act paradigm
|
||||
|
||||
prompt value can do additional priming of the LLM, used in to_string method right now.
|
||||
"""
|
||||
|
||||
message_log: MessageLog
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def to_string(self) -> str:
|
||||
"""The string variant of the prompt."""
|
||||
finalized = []
|
||||
messages = self.message_log.messages
|
||||
for idx, message in enumerate(messages):
|
||||
if isinstance(message, PrimingMessage):
|
||||
component = message.content
|
||||
finalized.append(component)
|
||||
continue
|
||||
|
||||
if isinstance(message, FunctionResult):
|
||||
component = f"Observation: {message.result}"
|
||||
elif isinstance(message, HumanMessage):
|
||||
component = f"Question: {message.content.strip()}"
|
||||
elif isinstance(message, (AIMessage, SystemMessage)):
|
||||
component = message.content.strip()
|
||||
elif isinstance(message, FunctionCall):
|
||||
# This is an internal message, and should not be returned to the user.
|
||||
continue
|
||||
elif isinstance(message, AgentFinish):
|
||||
component = f"Answer: {message.result}"
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
finalized.extend([component, "\n"])
|
||||
return "".join(finalized).strip()
|
||||
|
||||
def to_messages(self) -> List[BaseMessage]:
|
||||
"""Return prompt as a list of Messages."""
|
||||
messages = []
|
||||
for message in self.message_log.messages:
|
||||
if isinstance(message, BaseMessage):
|
||||
messages.append(message)
|
||||
elif isinstance(message, FunctionResult):
|
||||
messages.append(
|
||||
SystemMessage(content=f"Observation: `{message.result}`")
|
||||
)
|
||||
return messages
|
||||
|
||||
@classmethod
|
||||
def from_message_log(cls, message_log: MessageLog):
|
||||
"""Create a ThinkActPromptGenerator from a MessageLog, allowing positional arguments."""
|
||||
return cls(message_log=message_log)
|
||||
|
||||
|
||||
class ThinkActAgent(Agent):
|
||||
def __init__(
|
||||
self,
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool],
|
||||
*,
|
||||
max_iterations: int = 10,
|
||||
) -> None:
|
||||
"""Initialize the chat automaton."""
|
||||
self.think_act = create_llm_program(
|
||||
llm,
|
||||
prompt_generator=ThinkActPromptGenerator.from_message_log,
|
||||
stop=["Observation:", "observation:"],
|
||||
parser=ActionParser().decode,
|
||||
tools=tools,
|
||||
invoke_tools=True,
|
||||
)
|
||||
self.max_iterations = max_iterations
|
||||
|
||||
def run(self, message_log: MessageLog) -> None:
|
||||
"""Run the agent."""
|
||||
if not message_log:
|
||||
raise AssertionError(f"Expected at least one message in message_log")
|
||||
|
||||
for _ in range(self.max_iterations):
|
||||
last_message = message_log[-1]
|
||||
|
||||
if isinstance(last_message, AgentFinish):
|
||||
break
|
||||
|
||||
# Prime the LLM to start with "Thought: " after an observation
|
||||
if isinstance(last_message, (FunctionResult, HumanMessage)):
|
||||
message_log.add_messages([PrimingMessage(content="Thought:")])
|
||||
|
||||
message_log.add_messages(self.think_act.invoke(message_log))
|
||||
25
libs/langchain/langchain/automaton/tool_utils.py
Normal file
25
libs/langchain/langchain/automaton/tool_utils.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from typing import Sequence, TypedDict
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
|
||||
def _generate_tools_descriptions(tools: Sequence[BaseTool]) -> str:
|
||||
"""Generate a description of the tools."""
|
||||
return "\n".join([f"{tool_.name}: {tool_.description}" for tool_ in tools]) + "\n"
|
||||
|
||||
|
||||
class ToolInfo(TypedDict):
|
||||
"""A dictionary containing information about a tool."""
|
||||
|
||||
tool_names: str
|
||||
tools_description: str
|
||||
|
||||
|
||||
def generate_tool_info(tools: Sequence[BaseTool]) -> ToolInfo:
|
||||
"""Generate a string containing the names of the tools and their descriptions."""
|
||||
tools_description = _generate_tools_descriptions(tools)
|
||||
tool_names = ", ".join([tool_.name for tool_ in tools])
|
||||
return {
|
||||
"tool_names": tool_names,
|
||||
"tools_description": tools_description,
|
||||
}
|
||||
81
libs/langchain/langchain/automaton/typedefs.py
Normal file
81
libs/langchain/langchain/automaton/typedefs.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional, Sequence, Mapping, overload, Union
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema import (
|
||||
BaseMessage,
|
||||
)
|
||||
|
||||
|
||||
class InternalMessage(Serializable):
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
"""Indicate whether the class is serializable."""
|
||||
return True
|
||||
|
||||
|
||||
class FunctionCall(InternalMessage):
|
||||
name: str
|
||||
arguments: Optional[Mapping[str, Any]]
|
||||
|
||||
|
||||
class FunctionResult(InternalMessage):
|
||||
name: str
|
||||
result: Any
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class PrimingMessage(InternalMessage):
|
||||
"""A message that is used to prime the language model."""
|
||||
|
||||
content: str
|
||||
|
||||
|
||||
class AgentFinish(InternalMessage):
|
||||
result: Any
|
||||
|
||||
|
||||
MessageLike = Union[BaseMessage, InternalMessage]
|
||||
|
||||
|
||||
class MessageLog:
|
||||
"""A generalized message log for message like items."""
|
||||
|
||||
def __init__(self, messages: Sequence[MessageLike] = ()) -> None:
|
||||
"""Initialize the message log."""
|
||||
self.messages = list(messages)
|
||||
|
||||
def add_messages(self, messages: Sequence[MessageLike]) -> None:
|
||||
"""Add messages to the message log."""
|
||||
self.messages.extend(messages)
|
||||
|
||||
@overload
|
||||
def __getitem__(self, index: int) -> MessageLike:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, index: slice) -> MessageLog:
|
||||
...
|
||||
|
||||
def __getitem__(self, index: Union[int, slice]) -> Union[MessageLike, MessageLog]:
|
||||
"""Use to index into the chat template."""
|
||||
if isinstance(index, slice):
|
||||
start, stop, step = index.indices(len(self.messages))
|
||||
messages = self.messages[start:stop:step]
|
||||
return MessageLog(messages=messages)
|
||||
else:
|
||||
return self.messages[index]
|
||||
|
||||
def __bool__(self):
|
||||
return bool(self.messages)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Get the length of the chat template."""
|
||||
return len(self.messages)
|
||||
|
||||
|
||||
class Agent: # This is just approximate still
|
||||
def run(self, message_log: MessageLog) -> None:
|
||||
"""Run the agent on a message."""
|
||||
raise NotImplementedError
|
||||
@@ -79,6 +79,15 @@ class BaseMessage(Serializable):
|
||||
"""Whether this class is LangChain serializable."""
|
||||
return True
|
||||
|
||||
def __repr__(self):
|
||||
if self.additional_kwargs:
|
||||
return f"{self.type}: {self.additional_kwargs}"
|
||||
else:
|
||||
return f"{self.type}: {self.content}"
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
def __add__(self, other: Any) -> ChatPromptTemplate:
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
|
||||
|
||||
Reference in New Issue
Block a user