Compare commits

...

43 Commits

Author SHA1 Message Date
Eugene Yurtsev
486d772184 x 2023-09-01 15:15:47 -04:00
Eugene Yurtsev
29ddf63abb x 2023-09-01 14:55:02 -04:00
Eugene Yurtsev
35318007d8 x 2023-09-01 14:54:47 -04:00
Eugene Yurtsev
900cf6ead3 x 2023-09-01 14:17:55 -04:00
Eugene Yurtsev
beec2d99b9 x 2023-09-01 14:16:45 -04:00
Eugene Yurtsev
a695ba5ba4 Merge branch 'master' into eugene/automaton_variant_2 2023-09-01 14:16:11 -04:00
Eugene Yurtsev
71c277c17c x 2023-09-01 14:16:01 -04:00
Eugene Yurtsev
e67f4b8d3b x 2023-09-01 14:14:29 -04:00
Eugene Yurtsev
7181211d78 x 2023-09-01 14:13:02 -04:00
Eugene Yurtsev
2c798fcc77 x 2023-09-01 14:02:35 -04:00
Eugene Yurtsev
bf12905e7f x 2023-09-01 11:23:02 -04:00
Eugene Yurtsev
e4cf87ab34 x 2023-09-01 10:56:12 -04:00
Eugene Yurtsev
9bbc5af2a8 x 2023-09-01 10:19:45 -04:00
Eugene Yurtsev
06e34e54f5 x 2023-08-31 15:22:38 -04:00
Eugene Yurtsev
5a68ba4174 x 2023-08-31 12:21:47 -04:00
Eugene Yurtsev
8e8c368019 x 2023-08-31 11:58:51 -04:00
Eugene Yurtsev
acbd19fac5 x 2023-08-31 11:24:43 -04:00
Eugene Yurtsev
aa56956d16 x 2023-08-31 10:20:04 -04:00
Eugene Yurtsev
0d2cf70d1e x 2023-08-30 09:18:04 -04:00
Eugene Yurtsev
188ac56c97 x 2023-08-30 09:17:27 -04:00
Eugene Yurtsev
383f09b271 x 2023-08-29 13:55:37 -04:00
Eugene Yurtsev
0af306fc54 Merge branch 'master' into eugene/automaton_variant_1 2023-08-29 13:53:10 -04:00
Eugene Yurtsev
15d5c49076 Merge branch 'master' into eugene/automaton_variant_1 2023-08-07 10:40:09 -04:00
Eugene Yurtsev
65660535bc x 2023-08-07 10:39:56 -04:00
Eugene Yurtsev
6c41dd82f0 x 2023-08-07 10:38:14 -04:00
Eugene Yurtsev
78d788c28c x 2023-08-07 10:37:41 -04:00
Eugene Yurtsev
e9deeab37f x 2023-08-07 10:32:53 -04:00
Eugene Yurtsev
4d595eec5b x 2023-08-06 23:13:32 -04:00
Eugene Yurtsev
047b001336 x 2023-08-06 22:50:46 -04:00
Eugene Yurtsev
840e936c7c x 2023-08-06 22:30:33 -04:00
Eugene Yurtsev
6cc6b490be x 2023-08-06 22:23:51 -04:00
Eugene Yurtsev
f45d1ed4f5 x 2023-08-06 22:20:29 -04:00
Eugene Yurtsev
4ffc417858 x 2023-08-05 23:10:15 -04:00
Eugene Yurtsev
9e74a70859 x 2023-08-05 22:39:04 -04:00
Eugene Yurtsev
0997f2c0f1 x 2023-08-05 22:37:16 -04:00
Eugene Yurtsev
437b545426 x 2023-08-05 16:01:12 -04:00
Eugene Yurtsev
9b9d07572b x 2023-08-05 14:47:17 -04:00
Eugene Yurtsev
6a90c6c2c8 Merge branch 'eugene/fix_mutation_in_place' into eugene/automaton 2023-08-04 10:56:24 -04:00
Eugene Yurtsev
8371187689 x 2023-08-04 10:46:56 -04:00
Eugene Yurtsev
4309c17ffa x 2023-08-04 10:44:57 -04:00
Eugene Yurtsev
183a9d4e66 x 2023-08-04 10:19:10 -04:00
Eugene Yurtsev
c1b444e1e7 x 2023-08-03 15:39:45 -04:00
Eugene Yurtsev
5f117384c0 x 2023-08-03 12:28:01 -04:00
15 changed files with 1892 additions and 0 deletions

View File

@@ -0,0 +1,126 @@
"""Generalized chat agent, works with any chat model."""
from __future__ import annotations
import ast
import re
from typing import Sequence, Union, List
from langchain.automaton.tool_utils import generate_tool_info
from langchain.automaton.runnables import create_llm_program
from langchain.automaton.typedefs import (
MessageLog,
AgentFinish,
MessageLike,
FunctionCall,
FunctionResult,
)
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import BaseMessage, HumanMessage
from langchain.tools import BaseTool
from langchain.prompts import SystemMessagePromptTemplate
TEMPLATE_ = SystemMessagePromptTemplate.from_template(
"""Respond to the human as helpfully and accurately as \
possible. You have access to the following tools:
{tools_description}
Use a blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
Valid "action" values: "Final Answer" or {tool_names}
Provide only ONE action per $BLOB, as shown.
<action>
{{
"action": $TOOL_NAME,
"action_input": $INPUT
}}
</action>
When invoking a tool do not provide any clarifying information.
The human will forward results of tool invocations as "Observations".
When you know the answer paraphrase the information in the observations properly and respond to the user. \
If you do not know the answer use more tools.
You can only take a single action at a time."""
)
def generate_prompt(tools: Sequence[BaseTool]) -> MessageLog:
"""Generate a prompt for the agent."""
tool_info = generate_tool_info(tools)
msg = TEMPLATE_.format(**tool_info)
return MessageLog(messages=[msg])
def decode(text: Union[BaseMessage, str]) -> MessageLike:
"""Decode the action."""
pattern = re.compile(r"<action>(?P<action_blob>.*?)<\/action>", re.DOTALL)
if not isinstance(text, BaseMessage):
raise NotImplementedError()
_text = text.content
match = pattern.search(_text)
if match:
action_blob = match.group("action_blob")
data = ast.literal_eval(action_blob)
name = data["action"]
if name == "Final Answer": # Special cased "tool" for final answer
return AgentFinish(result=data["action_input"])
return FunctionCall(name=data["action"], arguments=data["action_input"] or {})
else:
return AgentFinish(result=text)
def prompt_generator(log: MessageLog) -> List[BaseMessage]:
"""Generate a prompt from a log of message like objects."""
messages = []
for message in log.messages:
if isinstance(message, BaseMessage):
messages.append(message)
elif isinstance(message, FunctionResult):
messages.append(
HumanMessage(
content=f"Observation: {message.result}",
)
)
else:
pass
return messages
class ChatAgent:
"""An agent for chat models."""
def __init__(
self,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
*,
max_iterations: int = 10,
) -> None:
"""Initialize the chat automaton."""
self.llm_program = create_llm_program(
llm,
prompt_generator=prompt_generator,
tools=tools,
parser=decode,
)
self.max_iterations = max_iterations
def run(self, message_log: MessageLog) -> None:
"""Run the agent."""
if not message_log:
raise AssertionError(f"Expected at least one message in message_log")
for _ in range(self.max_iterations):
last_message = message_log[-1]
if isinstance(last_message, AgentFinish):
break
messages = self.llm_program.invoke(message_log)
message_log.add_messages(messages)

View File

@@ -0,0 +1,683 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "a91c50ff-bcb9-484b-a4bd-f56e75c182ba",
"metadata": {},
"source": [
"# General Chat Agent"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "80595f69-8b46-494e-bd2f-0c9538ed526a",
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "724c11c6-c98a-47ef-8697-87cd888df7a0",
"metadata": {},
"outputs": [],
"source": [
"from typing import List, Optional, Any\n",
"from langchain.tools import tool, Tool\n",
"from langchain.schema.messages import AIMessage, HumanMessage\n",
"from langchain.chat_models.anthropic import ChatAnthropic\n",
"from langchain.chat_models.openai import ChatOpenAI\n",
"from langchain.automaton.typedefs import MessageLog\n",
"from langchain.automaton.chat_agent import ChatAgent\n",
"from langchain.automaton.tool_utils import generate_tool_info\n",
"from langchain.prompts import ChatPromptTemplate"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "ed3855cd-06b0-4a45-8b76-88926433fc60",
"metadata": {},
"outputs": [],
"source": [
"def get_tools() -> List[Tool]:\n",
" @tool\n",
" def name() -> str:\n",
" \"\"\"Use to look up the user's name\"\"\"\n",
" return \"Eugene\"\n",
"\n",
" @tool\n",
" def get_weather(city: str) -> str:\n",
" \"\"\"Get weather in a specific city.\"\"\"\n",
" return \"42F and sunny\"\n",
"\n",
" @tool\n",
" def add(x: int, y: int) -> int:\n",
" \"\"\"Use to add two numbers. For example; { \"x\": 2, \"y\": 10} \"\"\"\n",
" return x + y\n",
"\n",
" @tool\n",
" def add_3(x: int, y: int, z: int) -> int:\n",
" \"\"\"Add 3 numbers together.\"\"\"\n",
" return x + y + z\n",
"\n",
" @tool\n",
" def divide(x: float, y: float) -> float:\n",
" \"\"\"Divide x by y\"\"\"\n",
" return x / y\n",
"\n",
" return list(locals().values())\n",
"\n",
"tools = get_tools()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "97a59c88-2f0a-4982-81e3-4ffba4f222da",
"metadata": {},
"outputs": [],
"source": [
"llm = ChatAnthropic(temperature=0)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "6d2c6418-1fcb-4f4a-9926-2cb3639e8019",
"metadata": {},
"outputs": [],
"source": [
"tool_info = generate_tool_info(tools)"
]
},
{
"cell_type": "markdown",
"id": "51d874e1-bbd8-4d47-b9b8-230b138519ce",
"metadata": {},
"source": [
"Generate a prompt with an example trace"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "6b8197d5-533a-4350-a4a3-cbd302733995",
"metadata": {},
"outputs": [],
"source": [
"sys_msg = \"\"\"Respond to the human as helpfully and accurately as \\\n",
"possible. You have access to the following tools:\n",
"{tools_description}\n",
"\n",
"Use a blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).\n",
"\n",
"Valid \"action\" values: \"Final Answer\" or {tool_names}\n",
"\n",
"Provide only ONE action per $BLOB, as shown.\n",
"\n",
"<action>\n",
"{{\n",
" \"action\": $TOOL_NAME,\n",
" \"action_input\": $INPUT\n",
"}}\n",
"</action>\n",
"\n",
"When invoking a tool do not provide any clarifying information.\n",
"\n",
"The human will forward results of tool invocations as \"Observations\".\n",
"\n",
"When you know the answer paraphrase the information in the observations properly and respond to the user. \\\n",
"If you do not know the answer use more tools.\n",
"\n",
"You can only take a single action at a time.\"\"\"\n",
"\n",
"messages = ChatPromptTemplate.from_messages([\n",
" (\"system\", sys_msg),\n",
" ('human', 'what is 5+8'),\n",
" ('ai', '<action> {{ \"action\": \"add\", \"action_input\": {{ \"x\": 5, \"y\": 8 }} }} </action>'),\n",
" ('human', 'Observation: 13'),\n",
" ('ai', '5 + 8 is 13'),\n",
"]).format_messages(**tool_info)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "2694bf0f-722f-4d1c-b10d-bfdc90eabc8e",
"metadata": {},
"outputs": [],
"source": [
"from langchain.callbacks.manager import (\n",
" trace_as_chain_group,\n",
" atrace_as_chain_group,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "778b2eda-5128-40ea-949f-9727610cc194",
"metadata": {},
"outputs": [],
"source": [
"agent = ChatAgent(llm, tools)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "7bec72cc-c882-415a-964e-9d53951630f3",
"metadata": {},
"outputs": [],
"source": [
"class VerboseMessageLog(MessageLog):\n",
" def add_messages(self, *args, **kwargs):\n",
" for arg in args:\n",
" print(arg)\n",
" return super().add_messages(*args, **kwargs)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "e7d1ec69-de12-4c66-abee-3b2bbabeec13",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[human: Sum of first 10 numbers starting from 1? use one tool at a time]\n",
"[ai: <action> \n",
"{\n",
" \"action\": \"add\",\n",
" \"action_input\": {\n",
" \"x\": 1, \n",
" \"y\": 2\n",
" }\n",
"}\n",
"</action>, FunctionCall(name='add', arguments={'x': 1, 'y': 2}), FunctionResult(name='add', result=3, error=None)]\n",
"[ai: <action>\n",
"{\n",
" \"action\": \"add\", \n",
" \"action_input\": {\n",
" \"x\": 3,\n",
" \"y\": 3\n",
" }\n",
"}\n",
"</action>, FunctionCall(name='add', arguments={'x': 3, 'y': 3}), FunctionResult(name='add', result=6, error=None)]\n",
"[ai: <action>\n",
"{\n",
" \"action\": \"add\",\n",
" \"action_input\": {\n",
" \"x\": 6, \n",
" \"y\": 4\n",
" } \n",
"}\n",
"</action>, FunctionCall(name='add', arguments={'x': 6, 'y': 4}), FunctionResult(name='add', result=10, error=None)]\n",
"[ai: <action>\n",
"{\n",
" \"action\": \"add\",\n",
" \"action_input\": {\n",
" \"x\": 10,\n",
" \"y\": 5\n",
" }\n",
"} \n",
"</action>, FunctionCall(name='add', arguments={'x': 10, 'y': 5}), FunctionResult(name='add', result=15, error=None)]\n",
"[ai: Based on the observations so far, the sum of the first 10 numbers starting from 1 is 15. I will continue invoking the add tool one input at a time to calculate the full sum:\n",
"\n",
"<action>\n",
"{\n",
" \"action\": \"add\",\n",
" \"action_input\": {\n",
" \"x\": 15,\n",
" \"y\": 6\n",
" }\n",
"}\n",
"</action>, FunctionCall(name='add', arguments={'x': 15, 'y': 6}), FunctionResult(name='add', result=21, error=None)]\n",
"[ai: <action>\n",
"{\n",
" \"action\": \"add\",\n",
" \"action_input\": {\n",
" \"x\": 21, \n",
" \"y\": 7\n",
" }\n",
"}\n",
"</action>, FunctionCall(name='add', arguments={'x': 21, 'y': 7}), FunctionResult(name='add', result=28, error=None)]\n",
"[ai: <action>\n",
"{\n",
" \"action\": \"add\",\n",
" \"action_input\": {\n",
" \"x\": 28,\n",
" \"y\": 8 \n",
" }\n",
"}\n",
"</action>, FunctionCall(name='add', arguments={'x': 28, 'y': 8}), FunctionResult(name='add', result=36, error=None)]\n",
"[ai: <action>\n",
"{\n",
" \"action\": \"add\",\n",
" \"action_input\": {\n",
" \"x\": 36,\n",
" \"y\": 9\n",
" }\n",
"}\n",
"</action>, FunctionCall(name='add', arguments={'x': 36, 'y': 9}), FunctionResult(name='add', result=45, error=None)]\n",
"[ai: <action>\n",
"{\n",
" \"action\": \"add\",\n",
" \"action_input\": {\n",
" \"x\": 45,\n",
" \"y\": 10\n",
" }\n",
"}\n",
"</action>, FunctionCall(name='add', arguments={'x': 45, 'y': 10}), FunctionResult(name='add', result=55, error=None)]\n",
"[ai: Based on the observations, the sum of the first 10 numbers starting from 1 is 55., AgentFinish(result=ai: Based on the observations, the sum of the first 10 numbers starting from 1 is 55.)]\n"
]
}
],
"source": [
"message_log = VerboseMessageLog(messages=messages)\n",
"question = HumanMessage(content=\"Sum of first 10 numbers starting from 1? use one tool at a time\")\n",
"message_log.add_messages([question])\n",
"\n",
"agent.run(message_log)"
]
},
{
"cell_type": "markdown",
"id": "d01d4dd6-8706-4d84-9660-b0c6cc2eed1b",
"metadata": {},
"source": [
"# Visualize"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "c4ec3106-7679-4cd0-9cba-134814cfa8ca",
"metadata": {},
"outputs": [],
"source": [
"import html\n",
"from langchain.automaton.chat_agent import prompt_generator"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "b9f800f1-1e90-4199-a016-015be152ae11",
"metadata": {},
"outputs": [],
"source": [
"messages = prompt_generator(message_log)"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "9ba80285-b28c-419d-a8a6-043133f64190",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" <div style=\"border: 1px solid black; padding: 10px; background-color:DodgerBlue\">\n",
" <div style=\"display: flex;\">\n",
" <div style=\"font-weight: bold; margin-right: 10px;\">system</div>\n",
" <div>Respond to the human as helpfully and accurately as possible. You have access to the following tools:\n",
"name: name() -&gt; str - Use to look up the user&#x27;s name\n",
"get_weather: get_weather(city: str) -&gt; str - Get weather in a specific city.\n",
"add: add(x: int, y: int) -&gt; int - Use to add two numbers. For example; { &quot;x&quot;: 2, &quot;y&quot;: 10}\n",
"add_3: add_3(x: int, y: int, z: int) -&gt; int - Add 3 numbers together.\n",
"divide: divide(x: float, y: float) -&gt; float - Divide x by y\n",
"\n",
"\n",
"Use a blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).\n",
"\n",
"Valid &quot;action&quot; values: &quot;Final Answer&quot; or name, get_weather, add, add_3, divide\n",
"\n",
"Provide only ONE action per $BLOB, as shown.\n",
"\n",
"&lt;action&gt;\n",
"{\n",
" &quot;action&quot;: $TOOL_NAME,\n",
" &quot;action_input&quot;: $INPUT\n",
"}\n",
"&lt;/action&gt;\n",
"\n",
"When invoking a tool do not provide any clarifying information.\n",
"\n",
"The human will forward results of tool invocations as &quot;Observations&quot;.\n",
"\n",
"When you know the answer paraphrase the information in the observations properly and respond to the user. If you do not know the answer use more tools.\n",
"\n",
"You can only take a single action at a time.</div>\n",
" </div>\n",
" </div>\n",
"\n",
"\n",
" <div style=\"border: 1px solid black; padding: 10px; background-color:Tomato\">\n",
" <div style=\"display: flex;\">\n",
" <div style=\"font-weight: bold; margin-right: 10px;\">human</div>\n",
" <div>what is 5+8</div>\n",
" </div>\n",
" </div>\n",
"\n",
"\n",
" <div style=\"border: 1px solid black; padding: 10px; background-color:DodgerBlue\">\n",
" <div style=\"display: flex;\">\n",
" <div style=\"font-weight: bold; margin-right: 10px;\">ai</div>\n",
" <div>&lt;action&gt; { &quot;action&quot;: &quot;add&quot;, &quot;action_input&quot;: { &quot;x&quot;: 5, &quot;y&quot;: 8 } } &lt;/action&gt;</div>\n",
" </div>\n",
" </div>\n",
"\n",
"\n",
" <div style=\"border: 1px solid black; padding: 10px; background-color:Tomato\">\n",
" <div style=\"display: flex;\">\n",
" <div style=\"font-weight: bold; margin-right: 10px;\">human</div>\n",
" <div>Observation: 13</div>\n",
" </div>\n",
" </div>\n",
"\n",
"\n",
" <div style=\"border: 1px solid black; padding: 10px; background-color:DodgerBlue\">\n",
" <div style=\"display: flex;\">\n",
" <div style=\"font-weight: bold; margin-right: 10px;\">ai</div>\n",
" <div>5 + 8 is 13</div>\n",
" </div>\n",
" </div>\n",
"\n",
"\n",
" <div style=\"border: 1px solid black; padding: 10px; background-color:Tomato\">\n",
" <div style=\"display: flex;\">\n",
" <div style=\"font-weight: bold; margin-right: 10px;\">human</div>\n",
" <div>Sum of first 10 numbers starting from 1? use one tool at a time</div>\n",
" </div>\n",
" </div>\n",
"\n",
"\n",
" <div style=\"border: 1px solid black; padding: 10px; background-color:DodgerBlue\">\n",
" <div style=\"display: flex;\">\n",
" <div style=\"font-weight: bold; margin-right: 10px;\">ai</div>\n",
" <div> &lt;action&gt; \n",
"{\n",
" &quot;action&quot;: &quot;add&quot;,\n",
" &quot;action_input&quot;: {\n",
" &quot;x&quot;: 1, \n",
" &quot;y&quot;: 2\n",
" }\n",
"}\n",
"&lt;/action&gt;</div>\n",
" </div>\n",
" </div>\n",
"\n",
"\n",
" <div style=\"border: 1px solid black; padding: 10px; background-color:Tomato\">\n",
" <div style=\"display: flex;\">\n",
" <div style=\"font-weight: bold; margin-right: 10px;\">human</div>\n",
" <div>Observation: 3</div>\n",
" </div>\n",
" </div>\n",
"\n",
"\n",
" <div style=\"border: 1px solid black; padding: 10px; background-color:DodgerBlue\">\n",
" <div style=\"display: flex;\">\n",
" <div style=\"font-weight: bold; margin-right: 10px;\">ai</div>\n",
" <div> &lt;action&gt;\n",
"{\n",
" &quot;action&quot;: &quot;add&quot;, \n",
" &quot;action_input&quot;: {\n",
" &quot;x&quot;: 3,\n",
" &quot;y&quot;: 3\n",
" }\n",
"}\n",
"&lt;/action&gt;</div>\n",
" </div>\n",
" </div>\n",
"\n",
"\n",
" <div style=\"border: 1px solid black; padding: 10px; background-color:Tomato\">\n",
" <div style=\"display: flex;\">\n",
" <div style=\"font-weight: bold; margin-right: 10px;\">human</div>\n",
" <div>Observation: 6</div>\n",
" </div>\n",
" </div>\n",
"\n",
"\n",
" <div style=\"border: 1px solid black; padding: 10px; background-color:DodgerBlue\">\n",
" <div style=\"display: flex;\">\n",
" <div style=\"font-weight: bold; margin-right: 10px;\">ai</div>\n",
" <div> &lt;action&gt;\n",
"{\n",
" &quot;action&quot;: &quot;add&quot;,\n",
" &quot;action_input&quot;: {\n",
" &quot;x&quot;: 6, \n",
" &quot;y&quot;: 4\n",
" } \n",
"}\n",
"&lt;/action&gt;</div>\n",
" </div>\n",
" </div>\n",
"\n",
"\n",
" <div style=\"border: 1px solid black; padding: 10px; background-color:Tomato\">\n",
" <div style=\"display: flex;\">\n",
" <div style=\"font-weight: bold; margin-right: 10px;\">human</div>\n",
" <div>Observation: 10</div>\n",
" </div>\n",
" </div>\n",
"\n",
"\n",
" <div style=\"border: 1px solid black; padding: 10px; background-color:DodgerBlue\">\n",
" <div style=\"display: flex;\">\n",
" <div style=\"font-weight: bold; margin-right: 10px;\">ai</div>\n",
" <div> &lt;action&gt;\n",
"{\n",
" &quot;action&quot;: &quot;add&quot;,\n",
" &quot;action_input&quot;: {\n",
" &quot;x&quot;: 10,\n",
" &quot;y&quot;: 5\n",
" }\n",
"} \n",
"&lt;/action&gt;</div>\n",
" </div>\n",
" </div>\n",
"\n",
"\n",
" <div style=\"border: 1px solid black; padding: 10px; background-color:Tomato\">\n",
" <div style=\"display: flex;\">\n",
" <div style=\"font-weight: bold; margin-right: 10px;\">human</div>\n",
" <div>Observation: 15</div>\n",
" </div>\n",
" </div>\n",
"\n",
"\n",
" <div style=\"border: 1px solid black; padding: 10px; background-color:DodgerBlue\">\n",
" <div style=\"display: flex;\">\n",
" <div style=\"font-weight: bold; margin-right: 10px;\">ai</div>\n",
" <div> Based on the observations so far, the sum of the first 10 numbers starting from 1 is 15. I will continue invoking the add tool one input at a time to calculate the full sum:\n",
"\n",
"&lt;action&gt;\n",
"{\n",
" &quot;action&quot;: &quot;add&quot;,\n",
" &quot;action_input&quot;: {\n",
" &quot;x&quot;: 15,\n",
" &quot;y&quot;: 6\n",
" }\n",
"}\n",
"&lt;/action&gt;</div>\n",
" </div>\n",
" </div>\n",
"\n",
"\n",
" <div style=\"border: 1px solid black; padding: 10px; background-color:Tomato\">\n",
" <div style=\"display: flex;\">\n",
" <div style=\"font-weight: bold; margin-right: 10px;\">human</div>\n",
" <div>Observation: 21</div>\n",
" </div>\n",
" </div>\n",
"\n",
"\n",
" <div style=\"border: 1px solid black; padding: 10px; background-color:DodgerBlue\">\n",
" <div style=\"display: flex;\">\n",
" <div style=\"font-weight: bold; margin-right: 10px;\">ai</div>\n",
" <div> &lt;action&gt;\n",
"{\n",
" &quot;action&quot;: &quot;add&quot;,\n",
" &quot;action_input&quot;: {\n",
" &quot;x&quot;: 21, \n",
" &quot;y&quot;: 7\n",
" }\n",
"}\n",
"&lt;/action&gt;</div>\n",
" </div>\n",
" </div>\n",
"\n",
"\n",
" <div style=\"border: 1px solid black; padding: 10px; background-color:Tomato\">\n",
" <div style=\"display: flex;\">\n",
" <div style=\"font-weight: bold; margin-right: 10px;\">human</div>\n",
" <div>Observation: 28</div>\n",
" </div>\n",
" </div>\n",
"\n",
"\n",
" <div style=\"border: 1px solid black; padding: 10px; background-color:DodgerBlue\">\n",
" <div style=\"display: flex;\">\n",
" <div style=\"font-weight: bold; margin-right: 10px;\">ai</div>\n",
" <div> &lt;action&gt;\n",
"{\n",
" &quot;action&quot;: &quot;add&quot;,\n",
" &quot;action_input&quot;: {\n",
" &quot;x&quot;: 28,\n",
" &quot;y&quot;: 8 \n",
" }\n",
"}\n",
"&lt;/action&gt;</div>\n",
" </div>\n",
" </div>\n",
"\n",
"\n",
" <div style=\"border: 1px solid black; padding: 10px; background-color:Tomato\">\n",
" <div style=\"display: flex;\">\n",
" <div style=\"font-weight: bold; margin-right: 10px;\">human</div>\n",
" <div>Observation: 36</div>\n",
" </div>\n",
" </div>\n",
"\n",
"\n",
" <div style=\"border: 1px solid black; padding: 10px; background-color:DodgerBlue\">\n",
" <div style=\"display: flex;\">\n",
" <div style=\"font-weight: bold; margin-right: 10px;\">ai</div>\n",
" <div> &lt;action&gt;\n",
"{\n",
" &quot;action&quot;: &quot;add&quot;,\n",
" &quot;action_input&quot;: {\n",
" &quot;x&quot;: 36,\n",
" &quot;y&quot;: 9\n",
" }\n",
"}\n",
"&lt;/action&gt;</div>\n",
" </div>\n",
" </div>\n",
"\n",
"\n",
" <div style=\"border: 1px solid black; padding: 10px; background-color:Tomato\">\n",
" <div style=\"display: flex;\">\n",
" <div style=\"font-weight: bold; margin-right: 10px;\">human</div>\n",
" <div>Observation: 45</div>\n",
" </div>\n",
" </div>\n",
"\n",
"\n",
" <div style=\"border: 1px solid black; padding: 10px; background-color:DodgerBlue\">\n",
" <div style=\"display: flex;\">\n",
" <div style=\"font-weight: bold; margin-right: 10px;\">ai</div>\n",
" <div> &lt;action&gt;\n",
"{\n",
" &quot;action&quot;: &quot;add&quot;,\n",
" &quot;action_input&quot;: {\n",
" &quot;x&quot;: 45,\n",
" &quot;y&quot;: 10\n",
" }\n",
"}\n",
"&lt;/action&gt;</div>\n",
" </div>\n",
" </div>\n",
"\n",
"\n",
" <div style=\"border: 1px solid black; padding: 10px; background-color:Tomato\">\n",
" <div style=\"display: flex;\">\n",
" <div style=\"font-weight: bold; margin-right: 10px;\">human</div>\n",
" <div>Observation: 55</div>\n",
" </div>\n",
" </div>\n",
"\n",
"\n",
" <div style=\"border: 1px solid black; padding: 10px; background-color:DodgerBlue\">\n",
" <div style=\"display: flex;\">\n",
" <div style=\"font-weight: bold; margin-right: 10px;\">ai</div>\n",
" <div> Based on the observations, the sum of the first 10 numbers starting from 1 is 55.</div>\n",
" </div>\n",
" </div>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"\n",
"# Generate HTML divs\n",
"html_divs = []\n",
"for idx, msg in enumerate(messages):\n",
" background_color = \"DodgerBlue\" if idx % 2 == 0 else 'Tomato'\n",
" div_content = f'''\n",
" <div style=\"border: 1px solid black; padding: 10px; background-color:{background_color}\">\n",
" <div style=\"display: flex;\">\n",
" <div style=\"font-weight: bold; margin-right: 10px;\">{html.escape(msg.type)}</div>\n",
" <div>{html.escape(msg.content)}</div>\n",
" </div>\n",
" </div>\n",
"'''\n",
" html_divs.append(div_content)\n",
"\n",
"# Display HTML divs in a Jupyter Notebook cell\n",
"from IPython.display import HTML, display\n",
"\n",
"html_output = '\\n'.join(html_divs)\n",
"display(HTML(html_output))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.2"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,185 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "a91c50ff-bcb9-484b-a4bd-f56e75c182ba",
"metadata": {},
"source": [
"# Open AI Chat"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "80595f69-8b46-494e-bd2f-0c9538ed526a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "f73ed8ca-7384-4afc-af8c-bfd5b794df01",
"metadata": {},
"outputs": [],
"source": [
"from typing import List, Optional, Any\n",
"from langchain.tools import tool, Tool, format_tool_to_openai_function\n",
"from langchain.schema.messages import AIMessage, HumanMessage\n",
"from langchain.chat_models.openai import ChatOpenAI\n",
"from langchain.prompts import ChatPromptTemplate\n",
"\n",
"from langchain.automaton.typedefs import MessageLog\n",
"from langchain.automaton.tool_utils import generate_tool_info\n",
"from langchain.automaton.openai_agent import OpenAIAgent\n",
"from langchain.chat_models.openai import ChatOpenAI"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "ed3855cd-06b0-4a45-8b76-88926433fc60",
"metadata": {},
"outputs": [],
"source": [
"def get_tools() -> List[Tool]:\n",
" @tool\n",
" def name() -> str:\n",
" \"\"\"Use to look up the user's name\"\"\"\n",
" return \"Eugene\"\n",
"\n",
" @tool\n",
" def get_weather(city: str) -> str:\n",
" \"\"\"Get weather in a specific city.\"\"\"\n",
" return \"42F and sunny\"\n",
"\n",
" @tool\n",
" def add(x: int, y: int) -> int:\n",
" \"\"\"Use to add two numbers. For example; { \"x\": 2, \"y\": 10} \"\"\"\n",
" return x + y\n",
"\n",
" @tool\n",
" def add_3(x: int, y: int, z: int) -> int:\n",
" \"\"\"Add 3 numbers together.\"\"\"\n",
" return x + y + z\n",
"\n",
" @tool\n",
" def divide(x: float, y: float) -> float:\n",
" \"\"\"Divide x by y\"\"\"\n",
" return x / y\n",
"\n",
" return list(locals().values())\n",
"\n",
"tools = get_tools()"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "9015d76d-414f-449b-9647-310917d43cb8",
"metadata": {},
"outputs": [],
"source": [
"openai_funcs = [format_tool_to_openai_function(tool_) for tool_ in tools]"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "97a59c88-2f0a-4982-81e3-4ffba4f222da",
"metadata": {},
"outputs": [],
"source": [
"llm = ChatOpenAI(temperature=0).bind(functions=openai_funcs)"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "778b2eda-5128-40ea-949f-9727610cc194",
"metadata": {},
"outputs": [],
"source": [
"agent = OpenAIAgent(llm, tools)"
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "7bec72cc-c882-415a-964e-9d53951630f3",
"metadata": {},
"outputs": [],
"source": [
"class VerboseMessageLog(MessageLog):\n",
" def add_messages(self, *args, **kwargs):\n",
" for arg in args:\n",
" print(arg)\n",
" return super().add_messages(*args, **kwargs)"
]
},
{
"cell_type": "code",
"execution_count": 45,
"id": "e7d1ec69-de12-4c66-abee-3b2bbabeec13",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[human: Sum of first 10 numbers starting from 1? use one tool at a time]\n",
"[ai: {'function_call': {'name': 'add', 'arguments': '{\\n \"x\": 1,\\n \"y\": 2\\n}'}}, FunctionCall(name='add', arguments={'x': 1, 'y': 2}), FunctionResult(name='add', result=3, error=None)]\n",
"[ai: {'function_call': {'name': 'add', 'arguments': '{\\n \"x\": 3,\\n \"y\": 3\\n}'}}, FunctionCall(name='add', arguments={'x': 3, 'y': 3}), FunctionResult(name='add', result=6, error=None)]\n",
"[ai: {'function_call': {'name': 'add', 'arguments': '{\\n \"x\": 6,\\n \"y\": 4\\n}'}}, FunctionCall(name='add', arguments={'x': 6, 'y': 4}), FunctionResult(name='add', result=10, error=None)]\n",
"[ai: {'function_call': {'name': 'add', 'arguments': '{\\n \"x\": 10,\\n \"y\": 5\\n}'}}, FunctionCall(name='add', arguments={'x': 10, 'y': 5}), FunctionResult(name='add', result=15, error=None)]\n",
"[ai: {'function_call': {'name': 'add', 'arguments': '{\\n \"x\": 15,\\n \"y\": 6\\n}'}}, FunctionCall(name='add', arguments={'x': 15, 'y': 6}), FunctionResult(name='add', result=21, error=None)]\n",
"[ai: {'function_call': {'name': 'add', 'arguments': '{\\n \"x\": 21,\\n \"y\": 7\\n}'}}, FunctionCall(name='add', arguments={'x': 21, 'y': 7}), FunctionResult(name='add', result=28, error=None)]\n",
"[ai: {'function_call': {'name': 'add', 'arguments': '{\\n \"x\": 28,\\n \"y\": 8\\n}'}}, FunctionCall(name='add', arguments={'x': 28, 'y': 8}), FunctionResult(name='add', result=36, error=None)]\n",
"[ai: {'function_call': {'name': 'add', 'arguments': '{\\n \"x\": 36,\\n \"y\": 9\\n}'}}, FunctionCall(name='add', arguments={'x': 36, 'y': 9}), FunctionResult(name='add', result=45, error=None)]\n",
"[ai: {'function_call': {'name': 'add', 'arguments': '{\\n \"x\": 45,\\n \"y\": 10\\n}'}}, FunctionCall(name='add', arguments={'x': 45, 'y': 10}), FunctionResult(name='add', result=55, error=None)]\n",
"[ai: The sum of the first 10 numbers starting from 1 is 55., AgentFinish(result=ai: The sum of the first 10 numbers starting from 1 is 55.)]\n"
]
}
],
"source": [
"message_log = VerboseMessageLog(messages=[])\n",
"question = HumanMessage(content=\"Sum of first 10 numbers starting from 1? use one tool at a time\")\n",
"message_log.add_messages([question])\n",
"agent.run(message_log)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.2"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,198 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "9875f554-8f4d-409b-929f-00e8bc70185d",
"metadata": {},
"source": [
"# Think Act Agent"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "80595f69-8b46-494e-bd2f-0c9538ed526a",
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "724c11c6-c98a-47ef-8697-87cd888df7a0",
"metadata": {},
"outputs": [],
"source": [
"from typing import List, Optional, Any\n",
"from langchain.automaton.think_act_agent import ThinkActAgent, generate_memory, ThinkActPromptGenerator\n",
"from langchain.schema.messages import HumanMessage\n",
"from langchain.llms import OpenAI\n",
"from langchain.tools import tool, Tool"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "ed3855cd-06b0-4a45-8b76-88926433fc60",
"metadata": {},
"outputs": [],
"source": [
"def get_tools() -> List[Tool]:\n",
" @tool\n",
" def name() -> str:\n",
" \"\"\"Use to look up the user's name\"\"\"\n",
" return \"Eugene\"\n",
"\n",
" @tool\n",
" def get_weather(city: str) -> str:\n",
" \"\"\"Get weather in a specific city.\"\"\"\n",
" return \"42F and sunny\"\n",
"\n",
" @tool\n",
" def add(x: int, y: int) -> int:\n",
" \"\"\"Use to add two numbers. For example; { \"x\": 2, \"y\": 10} \"\"\"\n",
" return x + y\n",
"\n",
" @tool\n",
" def divide(x: float, y: float) -> float:\n",
" \"\"\"Divide x by y\"\"\"\n",
" return x / y\n",
"\n",
" return list(locals().values())\n",
"\n",
"tools = get_tools()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "d416bb0e-28e4-4da3-8b10-aa563f87a48e",
"metadata": {},
"outputs": [],
"source": [
"llm = OpenAI(temperature=0)\n",
"agent = ThinkActAgent(llm, tools)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "94535cb6-8f0d-43c8-9e23-46e4834cc72d",
"metadata": {},
"outputs": [],
"source": [
"message_log = generate_memory(tools)\n",
"question = HumanMessage(content='what is the 2 + 5?')\n",
"message_log.add_messages([question])\n",
"agent.run(message_log)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "a3e99642-ac4c-4795-8aa0-a35037d699dd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Respond to the human as helpfully and accurately as \n",
"possible. You have access to the following tools:\n",
"name: name() -> str - Use to look up the user's name\n",
"get_weather: get_weather(city: str) -> str - Get weather in a specific city.\n",
"add: add(x: int, y: int) -> int - Use to add two numbers. For example; { \"x\": 2, \"y\": 10}\n",
"divide: divide(x: float, y: float) -> float - Divide x by y\n",
"\n",
"Use a blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).\n",
"\n",
"Valid \"action\" values: \"Final Answer\" or name, get_weather, add, divide\n",
"\n",
"You can only use a single tool at a time.\n",
"\n",
"<action>\n",
"{\n",
" \"action\": $TOOL_NAME,\n",
" \"action_input\": $INPUT\n",
"}\n",
"</action>\n",
"\n",
"Follow this format:\n",
"\n",
"Question: input question to answer\n",
"Thought: consider previous and subsequent steps\n",
"<action>\n",
"$BLOB\n",
"</action>\n",
"\n",
"Observation: action result\n",
"... (repeat Thought/Action/Observation N times)\n",
"Thought: I know what to respond\n",
"<action>\n",
"{\n",
" \"action\": \"Final Answer\",\n",
" \"action_input\": \"Final response to human\"\n",
"}\n",
"</action>\n",
"\n",
"Begin:\n",
"\n",
"Reminder to ALWAYS respond with a valid blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is <action>$BLOB</action> then Observation.\n",
"Question: what is the 2 + 5?\n",
"Thought:I need to use the add tool\n",
"<action>\n",
"{\n",
" \"action\": \"add\",\n",
" \"action_input\": { \"x\": 2, \"y\": 5 }\n",
"}\n",
"</action>\n",
"Observation: 7\n",
"Thought:I know what to respond\n",
"<action>\n",
"{\n",
" \"action\": \"Final Answer\",\n",
" \"action_input\": \"The answer is 7\"\n",
"}\n",
"</action>\n",
"Answer: The answer is 7\n"
]
}
],
"source": [
"print(ThinkActPromptGenerator.from_message_log(message_log).to_string())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aa0b8911-9117-42db-8f0d-bdfeb3472f65",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.2"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,89 @@
"""Specialized open ai functions based agent."""
from __future__ import annotations
import json
from typing import Sequence, List
from langchain.automaton.runnables import create_llm_program
from langchain.automaton.typedefs import (
MessageLog,
AgentFinish,
FunctionCall,
FunctionResult,
)
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
from langchain.schema import Generation
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import BaseMessage, FunctionMessage, AIMessage
from langchain.schema.output_parser import BaseGenerationOutputParser
from langchain.tools import BaseTool
class OpenAIFunctionsParser(BaseGenerationOutputParser):
def parse_result(self, result: List[Generation]):
if len(result) != 1:
raise AssertionError(f"Expected exactly one result")
first_result = result[0]
message = first_result.message
if not isinstance(message, AIMessage) or not message.additional_kwargs:
return AgentFinish(result=message)
parser = JsonOutputFunctionsParser(strict=False, args_only=False)
try:
function_request = parser.parse_result(result)
except Exception as e:
raise RuntimeError(f"Error parsing result: {result} {repr(e)}") from e
return FunctionCall(
name=function_request["name"],
arguments=function_request["arguments"],
)
def prompt_generator(log: MessageLog) -> List[BaseMessage]:
"""Generate a prompt from a log of message like objects."""
messages = []
for message in log.messages:
if isinstance(message, BaseMessage):
messages.append(message)
elif isinstance(message, FunctionResult):
messages.append(
FunctionMessage(name=message.name, content=json.dumps(message.result))
)
else:
pass
return messages
class OpenAIAgent:
def __init__(
self,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
*,
max_iterations: int = 10,
) -> None:
"""Initialize the chat automaton."""
self.llm_program = create_llm_program(
llm,
prompt_generator=prompt_generator,
tools=tools,
parser=OpenAIFunctionsParser(),
)
self.max_iterations = max_iterations
def run(self, message_log: MessageLog) -> None:
"""Run the agent."""
if not message_log:
raise AssertionError(f"Expected at least one message in message_log")
for _ in range(self.max_iterations):
last_message = message_log[-1]
if isinstance(last_message, AgentFinish):
break
messages = self.llm_program.invoke(message_log)
message_log.add_messages(messages)

View File

@@ -0,0 +1,108 @@
"""Module contains useful runnables for agents."""
from __future__ import annotations
from typing import Sequence, Callable, List, Optional, Union
from langchain.automaton.typedefs import (
MessageLike,
MessageLog,
FunctionResult,
FunctionCall,
)
from langchain.schema import BaseMessage, AIMessage, PromptValue
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.runnable.base import RunnableLambda, Runnable
from langchain.tools import BaseTool
# PUBLIC API
def create_tool_invoker(
tools: Sequence[BaseTool],
) -> Runnable[FunctionCall, FunctionResult]:
"""Re-write with router."""
tools_by_name = {tool.name: tool for tool in tools}
def func(function_call: FunctionCall) -> FunctionResult:
"""A function that can invoke a tool using .run"""
try:
tool = tools_by_name[function_call.name]
except KeyError:
raise AssertionError(f"No such tool: {function_call.name}")
try:
result = tool.run(function_call.arguments or {})
error = None
except Exception as e:
result = None
error = repr(e) + repr(function_call.arguments)
return FunctionResult(name=function_call.name, result=result, error=error)
return RunnableLambda(func=func)
LogOrMessages = Union[MessageLog, Sequence[MessageLike]]
def create_llm_program(
llm: BaseLanguageModel,
prompt_generator: Callable[
[LogOrMessages], Union[str, PromptValue, Sequence[BaseMessage]]
],
*,
tools: Optional[Sequence[BaseTool]] = None,
stop: Optional[Sequence[str]] = None,
parser: Union[
Runnable[Union[BaseMessage, str], MessageLike],
Callable[[Union[BaseMessage, str]], MessageLike],
None,
] = None,
invoke_tools: bool = True,
) -> Runnable[MessageLog, List[MessageLike]]:
"""Create a runnable that can update memory."""
tool_invoker = create_tool_invoker(tools) if invoke_tools else None
def _bound(message_log: MessageLog) -> List[MessageLike]:
"""A function that can be invoked with a message log."""
messages = []
prompt = prompt_generator(message_log)
llm_chain = llm
if stop:
llm_chain = llm_chain.bind(stop=stop)
result = llm_chain.invoke(prompt)
if isinstance(result, BaseMessage):
messages.append(result)
elif isinstance(result, str):
messages.append(AIMessage(content=result))
else:
raise NotImplementedError(f"Unsupported type {type(result)}")
if parser:
if not isinstance(parser, Runnable):
_parser = RunnableLambda(parser)
else:
_parser = parser
parsed_result = _parser.invoke(result)
if parsed_result:
if not isinstance(parsed_result, MessageLike):
raise TypeError(
f"Expected a MessageLike type got: {type(parsed_result)}"
)
messages.append(parsed_result)
if not messages:
raise AssertionError(f"Expected at least one message")
last_message = messages[-1]
if tool_invoker and isinstance(last_message, FunctionCall):
function_result = tool_invoker.invoke(last_message)
messages.append(function_result)
return messages
return RunnableLambda(
func=_bound,
)

View File

@@ -0,0 +1,41 @@
from __future__ import annotations
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
from langchain.automaton.mrkl_agent import ActionParser
def test_structured_output_chat() -> None:
parser = StructuredChatOutputParser()
output = parser.parse(
"""
```json
{
"action": "hello",
"action_input": {
"a": 2
}
}
```
"""
)
assert output == {}
def test_parser() -> None:
"""Tes the parser."""
sample_text = """
Some text before
<action>
{
"key": "value",
"number": 42
}
</action>
Some text after
"""
action_parser = ActionParser()
action = action_parser.decode(sample_text)
assert action == {
"key": "value",
"number": 42,
}

View File

@@ -0,0 +1,97 @@
from __future__ import annotations
from typing import List, cast
import pytest
from langchain.automaton.openai_agent import OpenAIAgent
from langchain.automaton.tests.utils import (
FakeChatOpenAI,
construct_func_invocation_message,
)
from langchain.automaton.typedefs import (
FunctionCall,
FunctionResult,
MessageLog,
AgentFinish,
)
from langchain.schema.messages import (
AIMessage,
SystemMessage,
)
from langchain.tools import tool, Tool
from langchain.tools.base import BaseTool
@pytest.fixture()
def tools() -> List[BaseTool]:
@tool
def get_time() -> str:
"""Get time."""
return "9 PM"
@tool
def get_location() -> str:
"""Get location."""
return "the park"
return cast(List[Tool], [get_time, get_location])
def test_openai_agent(tools: List[Tool]) -> None:
get_time, get_location = tools
llm = FakeChatOpenAI(
message_iter=iter(
[
construct_func_invocation_message(get_time, {}),
AIMessage(
content="The time is 9 PM.",
),
]
)
)
agent = OpenAIAgent(llm=llm, tools=tools, max_iterations=10)
message_log = MessageLog(
[
SystemMessage(
content="What time is it?",
)
]
)
expected_messages = [
SystemMessage(
content="What time is it?",
),
AIMessage(
content="",
additional_kwargs={
"function_call": {
"name": "get_time",
"arguments": "{}",
}
},
),
FunctionCall(
name="get_time",
arguments={},
),
FunctionResult(
name="get_time",
result="9 PM",
error=None,
),
AIMessage(
content="The time is 9 PM.",
),
AgentFinish(
AIMessage(
content="The time is 9 PM.",
),
),
]
agent.run(message_log)
assert message_log.messages == expected_messages

View File

@@ -0,0 +1,46 @@
from __future__ import annotations
import json
from typing import Iterator, List, Any, Mapping
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.chat_models.base import BaseChatModel
from langchain.schema import BaseMessage, ChatResult, ChatGeneration, AIMessage
from langchain.tools import BaseTool
class FakeChatOpenAI(BaseChatModel):
"""A fake chat model that returns a pre-defined response."""
message_iter: Iterator[BaseMessage]
@property
def _llm_type(self) -> str:
"""The type of the model."""
return "fake-openai-chat-model"
def _generate(
self,
messages: List[BaseMessage],
stop: List[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
"""Generate a response to the given messages."""
message = next(self.message_iter)
return ChatResult(generations=[ChatGeneration(message=message)])
def construct_func_invocation_message(
tool: BaseTool, args: Mapping[str, Any]
) -> AIMessage:
"""Construct a function invocation message."""
return AIMessage(
content="",
additional_kwargs={
"function_call": {
"name": tool.name,
"arguments": json.dumps(args),
}
},
)

View File

@@ -0,0 +1,204 @@
"""Implementation of a think act agent.
Uses priming messages with text based LLM.
"""
from __future__ import annotations
import ast
import re
from typing import Sequence, Optional, Union, List
from langchain.automaton.runnables import (
create_llm_program,
)
from langchain.automaton.tool_utils import generate_tool_info
from langchain.automaton.typedefs import (
Agent,
MessageLog,
MessageLike,
FunctionCall,
FunctionResult,
AgentFinish,
PrimingMessage,
)
from langchain.prompts import SystemMessagePromptTemplate
from langchain.schema import (
PromptValue,
BaseMessage,
HumanMessage,
AIMessage,
SystemMessage,
)
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import BaseTool, Tool
TEMPLATE_ = """\
Respond to the human as helpfully and accurately as
possible. You have access to the following tools:
{tools_description}
Use a blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
Valid "action" values: "Final Answer" or {tool_names}
You can only use a single tool at a time.
<action>
{{
"action": $TOOL_NAME,
"action_input": $INPUT
}}
</action>
Follow this format:
Question: input question to answer
Thought: consider previous and subsequent steps
<action>
$BLOB
</action>
Observation: action result
... (repeat Thought/Action/Observation N times)
Thought: I know what to respond
<action>
{{
"action": "Final Answer",
"action_input": "Final response to human"
}}
</action>
Begin:
Reminder to ALWAYS respond with a valid blob of a single action. \
Use tools if necessary. \
Respond directly if appropriate. \
Format is <action>$BLOB</action> then Observation. \
"""
def generate_memory(tools: Sequence[Tool]) -> MessageLog:
"""Set up basic memory for agent."""
tools_info = generate_tool_info(tools)
return MessageLog(
messages=[
SystemMessagePromptTemplate.from_template(TEMPLATE_).format(**tools_info)
]
)
class ActionParser:
"""A utility class to encode and decode action blocks."""
def __init__(self) -> None:
"""Initialize the ActionParser."""
self.pattern = re.compile(r"<action>(?P<action_blob>.*?)<\/action>", re.DOTALL)
def decode(self, text: Union[BaseMessage, str]) -> Optional[MessageLike]:
"""Decode the action."""
if isinstance(text, BaseMessage):
text = text.content
match = self.pattern.search(text)
if match:
action_blob = match.group("action_blob")
data = ast.literal_eval(action_blob)
name = data["action"]
if name == "Final Answer": # Special cased "tool" for final answer
return AgentFinish(result=data["action_input"])
return FunctionCall(
name=data["action"], arguments=data["action_input"] or {}
)
else:
return None
class ThinkActPromptGenerator(PromptValue):
"""think-act paradigm
prompt value can do additional priming of the LLM, used in to_string method right now.
"""
message_log: MessageLog
class Config:
arbitrary_types_allowed = True
def to_string(self) -> str:
"""The string variant of the prompt."""
finalized = []
messages = self.message_log.messages
for idx, message in enumerate(messages):
if isinstance(message, PrimingMessage):
component = message.content
finalized.append(component)
continue
if isinstance(message, FunctionResult):
component = f"Observation: {message.result}"
elif isinstance(message, HumanMessage):
component = f"Question: {message.content.strip()}"
elif isinstance(message, (AIMessage, SystemMessage)):
component = message.content.strip()
elif isinstance(message, FunctionCall):
# This is an internal message, and should not be returned to the user.
continue
elif isinstance(message, AgentFinish):
component = f"Answer: {message.result}"
else:
raise NotImplementedError()
finalized.extend([component, "\n"])
return "".join(finalized).strip()
def to_messages(self) -> List[BaseMessage]:
"""Return prompt as a list of Messages."""
messages = []
for message in self.message_log.messages:
if isinstance(message, BaseMessage):
messages.append(message)
elif isinstance(message, FunctionResult):
messages.append(
SystemMessage(content=f"Observation: `{message.result}`")
)
return messages
@classmethod
def from_message_log(cls, message_log: MessageLog):
"""Create a ThinkActPromptGenerator from a MessageLog, allowing positional arguments."""
return cls(message_log=message_log)
class ThinkActAgent(Agent):
def __init__(
self,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
*,
max_iterations: int = 10,
) -> None:
"""Initialize the chat automaton."""
self.think_act = create_llm_program(
llm,
prompt_generator=ThinkActPromptGenerator.from_message_log,
stop=["Observation:", "observation:"],
parser=ActionParser().decode,
tools=tools,
invoke_tools=True,
)
self.max_iterations = max_iterations
def run(self, message_log: MessageLog) -> None:
"""Run the agent."""
if not message_log:
raise AssertionError(f"Expected at least one message in message_log")
for _ in range(self.max_iterations):
last_message = message_log[-1]
if isinstance(last_message, AgentFinish):
break
# Prime the LLM to start with "Thought: " after an observation
if isinstance(last_message, (FunctionResult, HumanMessage)):
message_log.add_messages([PrimingMessage(content="Thought:")])
message_log.add_messages(self.think_act.invoke(message_log))

View File

@@ -0,0 +1,25 @@
from typing import Sequence, TypedDict
from langchain.tools import BaseTool
def _generate_tools_descriptions(tools: Sequence[BaseTool]) -> str:
"""Generate a description of the tools."""
return "\n".join([f"{tool_.name}: {tool_.description}" for tool_ in tools]) + "\n"
class ToolInfo(TypedDict):
"""A dictionary containing information about a tool."""
tool_names: str
tools_description: str
def generate_tool_info(tools: Sequence[BaseTool]) -> ToolInfo:
"""Generate a string containing the names of the tools and their descriptions."""
tools_description = _generate_tools_descriptions(tools)
tool_names = ", ".join([tool_.name for tool_ in tools])
return {
"tool_names": tool_names,
"tools_description": tools_description,
}

View File

@@ -0,0 +1,81 @@
from __future__ import annotations
from typing import Any, Optional, Sequence, Mapping, overload, Union
from langchain.load.serializable import Serializable
from langchain.schema import (
BaseMessage,
)
class InternalMessage(Serializable):
@property
def lc_serializable(self) -> bool:
"""Indicate whether the class is serializable."""
return True
class FunctionCall(InternalMessage):
name: str
arguments: Optional[Mapping[str, Any]]
class FunctionResult(InternalMessage):
name: str
result: Any
error: Optional[str] = None
class PrimingMessage(InternalMessage):
"""A message that is used to prime the language model."""
content: str
class AgentFinish(InternalMessage):
result: Any
MessageLike = Union[BaseMessage, InternalMessage]
class MessageLog:
"""A generalized message log for message like items."""
def __init__(self, messages: Sequence[MessageLike] = ()) -> None:
"""Initialize the message log."""
self.messages = list(messages)
def add_messages(self, messages: Sequence[MessageLike]) -> None:
"""Add messages to the message log."""
self.messages.extend(messages)
@overload
def __getitem__(self, index: int) -> MessageLike:
...
@overload
def __getitem__(self, index: slice) -> MessageLog:
...
def __getitem__(self, index: Union[int, slice]) -> Union[MessageLike, MessageLog]:
"""Use to index into the chat template."""
if isinstance(index, slice):
start, stop, step = index.indices(len(self.messages))
messages = self.messages[start:stop:step]
return MessageLog(messages=messages)
else:
return self.messages[index]
def __bool__(self):
return bool(self.messages)
def __len__(self) -> int:
"""Get the length of the chat template."""
return len(self.messages)
class Agent: # This is just approximate still
def run(self, message_log: MessageLog) -> None:
"""Run the agent on a message."""
raise NotImplementedError

View File

@@ -79,6 +79,15 @@ class BaseMessage(Serializable):
"""Whether this class is LangChain serializable."""
return True
def __repr__(self):
if self.additional_kwargs:
return f"{self.type}: {self.additional_kwargs}"
else:
return f"{self.type}: {self.content}"
def __str__(self):
return self.__repr__()
def __add__(self, other: Any) -> ChatPromptTemplate:
from langchain.prompts.chat import ChatPromptTemplate