Compare commits

...

54 Commits

Author SHA1 Message Date
Eugene Yurtsev
b3cbff741e x 2023-09-07 17:06:51 -04:00
Eugene Yurtsev
00c3a2d3d2 x 2023-09-07 16:24:21 -04:00
Eugene Yurtsev
ba61d64c04 x 2023-09-07 15:54:08 -04:00
Eugene Yurtsev
0396570c82 x 2023-09-07 15:47:47 -04:00
Eugene Yurtsev
cf6bca3da7 x 2023-09-07 15:45:39 -04:00
Eugene Yurtsev
95d7e87a68 x 2023-09-07 14:26:21 -04:00
Eugene Yurtsev
62b936dc7c x 2023-09-07 13:57:57 -04:00
Eugene Yurtsev
23e1ef8f4a x 2023-09-07 13:45:55 -04:00
Eugene Yurtsev
4108862154 x 2023-09-07 13:43:57 -04:00
Eugene Yurtsev
c9cbfc9d51 Merge branch 'master' into eugene/automaton_variant_3 2023-09-07 13:20:09 -04:00
Eugene Yurtsev
96f3fc2d58 x 2023-09-07 13:19:53 -04:00
Eugene Yurtsev
486d772184 x 2023-09-01 15:15:47 -04:00
Eugene Yurtsev
29ddf63abb x 2023-09-01 14:55:02 -04:00
Eugene Yurtsev
35318007d8 x 2023-09-01 14:54:47 -04:00
Eugene Yurtsev
900cf6ead3 x 2023-09-01 14:17:55 -04:00
Eugene Yurtsev
beec2d99b9 x 2023-09-01 14:16:45 -04:00
Eugene Yurtsev
a695ba5ba4 Merge branch 'master' into eugene/automaton_variant_2 2023-09-01 14:16:11 -04:00
Eugene Yurtsev
71c277c17c x 2023-09-01 14:16:01 -04:00
Eugene Yurtsev
e67f4b8d3b x 2023-09-01 14:14:29 -04:00
Eugene Yurtsev
7181211d78 x 2023-09-01 14:13:02 -04:00
Eugene Yurtsev
2c798fcc77 x 2023-09-01 14:02:35 -04:00
Eugene Yurtsev
bf12905e7f x 2023-09-01 11:23:02 -04:00
Eugene Yurtsev
e4cf87ab34 x 2023-09-01 10:56:12 -04:00
Eugene Yurtsev
9bbc5af2a8 x 2023-09-01 10:19:45 -04:00
Eugene Yurtsev
06e34e54f5 x 2023-08-31 15:22:38 -04:00
Eugene Yurtsev
5a68ba4174 x 2023-08-31 12:21:47 -04:00
Eugene Yurtsev
8e8c368019 x 2023-08-31 11:58:51 -04:00
Eugene Yurtsev
acbd19fac5 x 2023-08-31 11:24:43 -04:00
Eugene Yurtsev
aa56956d16 x 2023-08-31 10:20:04 -04:00
Eugene Yurtsev
0d2cf70d1e x 2023-08-30 09:18:04 -04:00
Eugene Yurtsev
188ac56c97 x 2023-08-30 09:17:27 -04:00
Eugene Yurtsev
383f09b271 x 2023-08-29 13:55:37 -04:00
Eugene Yurtsev
0af306fc54 Merge branch 'master' into eugene/automaton_variant_1 2023-08-29 13:53:10 -04:00
Eugene Yurtsev
15d5c49076 Merge branch 'master' into eugene/automaton_variant_1 2023-08-07 10:40:09 -04:00
Eugene Yurtsev
65660535bc x 2023-08-07 10:39:56 -04:00
Eugene Yurtsev
6c41dd82f0 x 2023-08-07 10:38:14 -04:00
Eugene Yurtsev
78d788c28c x 2023-08-07 10:37:41 -04:00
Eugene Yurtsev
e9deeab37f x 2023-08-07 10:32:53 -04:00
Eugene Yurtsev
4d595eec5b x 2023-08-06 23:13:32 -04:00
Eugene Yurtsev
047b001336 x 2023-08-06 22:50:46 -04:00
Eugene Yurtsev
840e936c7c x 2023-08-06 22:30:33 -04:00
Eugene Yurtsev
6cc6b490be x 2023-08-06 22:23:51 -04:00
Eugene Yurtsev
f45d1ed4f5 x 2023-08-06 22:20:29 -04:00
Eugene Yurtsev
4ffc417858 x 2023-08-05 23:10:15 -04:00
Eugene Yurtsev
9e74a70859 x 2023-08-05 22:39:04 -04:00
Eugene Yurtsev
0997f2c0f1 x 2023-08-05 22:37:16 -04:00
Eugene Yurtsev
437b545426 x 2023-08-05 16:01:12 -04:00
Eugene Yurtsev
9b9d07572b x 2023-08-05 14:47:17 -04:00
Eugene Yurtsev
6a90c6c2c8 Merge branch 'eugene/fix_mutation_in_place' into eugene/automaton 2023-08-04 10:56:24 -04:00
Eugene Yurtsev
8371187689 x 2023-08-04 10:46:56 -04:00
Eugene Yurtsev
4309c17ffa x 2023-08-04 10:44:57 -04:00
Eugene Yurtsev
183a9d4e66 x 2023-08-04 10:19:10 -04:00
Eugene Yurtsev
c1b444e1e7 x 2023-08-03 15:39:45 -04:00
Eugene Yurtsev
5f117384c0 x 2023-08-03 12:28:01 -04:00
18 changed files with 1856 additions and 0 deletions

View File

@@ -0,0 +1,66 @@
from __future__ import annotations
import json
from typing import List, Sequence
from langchain.automaton.chat_agent import ChatAgent
from langchain.automaton.typedefs import (
AgentFinish,
FunctionCall,
FunctionResult,
MessageLike,
)
from langchain.chat_models.openai import ChatOpenAI
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
from langchain.schema import AIMessage, BaseMessage, FunctionMessage, Generation
from langchain.schema.output_parser import BaseGenerationOutputParser
from langchain.tools import BaseTool, format_tool_to_openai_function
class OpenAIFunctionsParser(BaseGenerationOutputParser):
def parse_result(self, result: List[Generation]):
if len(result) != 1:
raise AssertionError(f"Expected exactly one result")
first_result = result[0]
message = first_result.message
if not isinstance(message, AIMessage) or not message.additional_kwargs:
return AgentFinish(result=message)
parser = JsonOutputFunctionsParser(strict=False, args_only=False)
try:
function_request = parser.parse_result(result)
except Exception as e:
raise RuntimeError(f"Error parsing result: {result} {repr(e)}") from e
return FunctionCall(
name=function_request["name"],
named_arguments=function_request["arguments"],
)
def prompt_generator(input_messages: Sequence[MessageLike]) -> List[BaseMessage]:
"""Generate a prompt from a log of message like objects."""
messages = []
for message in input_messages:
if isinstance(message, BaseMessage):
messages.append(message)
elif isinstance(message, FunctionResult):
messages.append(
FunctionMessage(name=message.name, content=json.dumps(message.result))
)
else:
pass
return messages
def create_openai_agent(llm: ChatOpenAI, tools: Sequence[BaseTool]) -> ChatAgent:
"""Create an agent that uses OpenAI's API."""
openai_funcs = [format_tool_to_openai_function(tool_) for tool_ in tools]
return ChatAgent(
llm.bind(functions=openai_funcs),
prompt_generator=prompt_generator,
tools=tools,
parser=OpenAIFunctionsParser(),
)

View File

@@ -0,0 +1,231 @@
"""Implementation of a think act agent.
Uses priming messages with text based LLM.
"""
from __future__ import annotations
import ast
import re
from typing import Iterator, List, Optional, Sequence, Union
from langchain.automaton.runnables import (
create_llm_program,
)
from langchain.automaton.tool_utils import generate_tool_info
from langchain.automaton.typedefs import (
AdHocMessage,
Agent,
AgentFinish,
FunctionCall,
FunctionResult,
MessageLike,
)
from langchain.prompts import SystemMessagePromptTemplate
from langchain.schema import (
AIMessage,
BaseMessage,
HumanMessage,
PromptValue,
SystemMessage,
)
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import BaseTool, Tool
TEMPLATE_ = """\
Respond to the human as helpfully and accurately as
possible. You have access to the following tools:
{tools_description}
Use a blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
Valid "action" values: "Final Answer" or {tool_names}
You can only use a single tool at a time.
<action>
{{
"action": $TOOL_NAME,
"action_input": $INPUT
}}
</action>
Follow this format:
Question: input question to answer
Thought: consider previous and subsequent steps
<action>
$BLOB
</action>
Observation: action result
... (repeat Thought/Action/Observation N times)
Thought: I know what to respond
<action>
{{
"action": "Final Answer",
"action_input": "Final response to human"
}}
</action>
Begin:
Reminder to ALWAYS respond with a valid blob of a single action. \
Use tools if necessary. \
Respond directly if appropriate. \
Format is <action>$BLOB</action> then Observation. \
"""
def generate_memory(tools: Sequence[Tool]) -> List[MessageLike]:
"""Set up basic memory for agent."""
tools_info = generate_tool_info(tools)
return [SystemMessagePromptTemplate.from_template(TEMPLATE_).format(**tools_info)]
class ActionParser:
"""A utility class to encode and decode action blocks."""
def __init__(self) -> None:
"""Initialize the ActionParser."""
self.pattern = re.compile(r"<action>(?P<action_blob>.*?)<\/action>", re.DOTALL)
def decode(self, text: Union[BaseMessage, str]) -> Optional[MessageLike]:
"""Decode the action."""
if isinstance(text, BaseMessage):
text = text.content
match = self.pattern.search(text)
if match:
action_blob = match.group("action_blob")
try:
data = ast.literal_eval(action_blob)
except SyntaxError:
return AdHocMessage(
type="error", data=f"Invalid action blob {action_blob}"
)
name = data["action"]
if name == "Final Answer": # Special cased "tool" for final answer
return AgentFinish(result=data["action_input"])
action_input = data["action_input"]
if isinstance(action_input, str) and not action_input:
named_arguments = {}
elif isinstance(action_input, dict):
named_arguments = action_input
else:
return AdHocMessage(
type="error",
data=f"Invalid action blob {action_blob}, action_input must be a dict",
)
return FunctionCall(
name=data["action"], named_arguments=named_arguments or {}
)
else:
return None
class ThinkActPromptGenerator(PromptValue):
"""think-act paradigm
prompt value can do additional priming of the LLM, used in to_string method right now.
"""
messages: Sequence[MessageLike]
class Config:
arbitrary_types_allowed = True
def to_string(self) -> str:
"""The string variant of the prompt."""
finalized = []
messages = self.messages
for idx, message in enumerate(messages):
if isinstance(message, AdHocMessage):
if message.type == "error":
finalized.extend(
[
f"Error: Malformed <action> blob with error: {message.data}. "
f"Please re-write the action correctly."
"\n",
]
)
continue
if message.type != "prime":
raise AssertionError()
component = message.data
finalized.append(component)
continue
if isinstance(message, FunctionResult):
component = f"Observation: {message.result}"
elif isinstance(message, HumanMessage):
component = f"Question: {message.content.strip()}"
elif isinstance(message, (AIMessage, SystemMessage)):
component = message.content.strip()
elif isinstance(message, FunctionCall):
# This is an internal message, and should not be returned to the user.
continue
elif isinstance(message, AgentFinish):
component = f"Answer: {message.result}"
else:
raise NotImplementedError()
finalized.extend([component, "\n"])
return "".join(finalized).strip()
def to_messages(self) -> List[BaseMessage]:
"""Return prompt as a list of Messages."""
messages = []
for message in self.messages:
if isinstance(message, BaseMessage):
messages.append(message)
elif isinstance(message, FunctionResult):
messages.append(
SystemMessage(content=f"Observation: `{message.result}`")
)
return messages
@classmethod
def from_message_log(cls, messages: Sequence[MessageLike]):
"""Create a ThinkActPromptGenerator from a MessageLog, allowing positional arguments."""
return cls(messages=messages)
class ThinkActAgent(Agent):
def __init__(
self,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
*,
max_iterations: int = 10,
) -> None:
"""Initialize the chat automaton."""
self.think_act = create_llm_program(
llm,
prompt_generator=ThinkActPromptGenerator.from_message_log,
stop=["Observation:", "observation:"],
parser=ActionParser().decode,
tools=tools,
invoke_tools=True,
)
self.max_iterations = max_iterations
def run(
self,
messages: Sequence[MessageLike],
*,
max_iterations: int = 100,
config,
) -> Iterator[MessageLike]:
"""Run the agent."""
all_messages = list(messages)
for _ in range(max_iterations):
if all_messages and isinstance(all_messages[-1], AgentFinish):
break
if all_messages and isinstance(
all_messages[-1], (FunctionResult, HumanMessage)
):
all_messages.append(AdHocMessage(type="prime", data="Thought:"))
new_messages = self.think_act.invoke(all_messages, config=config)
yield from new_messages
all_messages.extend(new_messages)

View File

@@ -0,0 +1,106 @@
from __future__ import annotations
import ast
import re
from typing import List, Sequence, Union
from langchain.automaton.chat_agent import ChatAgent
from langchain.automaton.tool_utils import generate_tool_info
from langchain.automaton.typedefs import (
AgentFinish,
FunctionCall,
FunctionResult,
MessageLike,
)
from langchain.prompts import SystemMessagePromptTemplate
from langchain.schema import BaseMessage, HumanMessage
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import BaseTool
TEMPLATE_ = SystemMessagePromptTemplate.from_template(
"""Respond to the human as helpfully and accurately as \
possible. You have access to the following tools:
{tools_description}
Use a blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
Valid "action" values: "Final Answer" or {tool_names}
Provide only ONE action per $BLOB, as shown.
<action>
{{
"action": $TOOL_NAME,
"action_input": $INPUT
}}
</action>
When invoking a tool do not provide any clarifying information.
The human will forward results of tool invocations as "Observations".
When you know the answer paraphrase the information in the observations properly and respond to the user. \
If you do not know the answer use more tools.
You can only take a single action at a time."""
)
def get_start_state(tools: Sequence[BaseTool]) -> List[BaseMessage]:
"""Generate a prompt for the agent."""
tool_info = generate_tool_info(tools)
msg = TEMPLATE_.format(**tool_info)
return [msg]
def _decode(text: Union[BaseMessage, str]) -> MessageLike:
"""Decode the action."""
pattern = re.compile(r"<action>(?P<action_blob>.*?)<\/action>", re.DOTALL)
if not isinstance(text, BaseMessage):
raise NotImplementedError()
_text = text.content
match = pattern.search(_text)
if match:
action_blob = match.group("action_blob")
data = ast.literal_eval(action_blob)
name = data["action"]
if name == "Final Answer": # Special cased "tool" for final answer
return AgentFinish(result=data["action_input"])
return FunctionCall(
name=data["action"], named_arguments=data["action_input"] or {}
)
else:
return AgentFinish(result=text)
def generate_prompt(current_messages: Sequence[MessageLike]) -> List[BaseMessage]:
"""Generate a prompt from a log of message like objects."""
messages = []
for message in current_messages:
if isinstance(message, BaseMessage):
messages.append(message)
elif isinstance(message, FunctionResult):
messages.append(
HumanMessage(
content=f"Observation: {message.result}",
)
)
else:
pass
return messages
# PUBLIC API
def create_xml_agent(
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
) -> ChatAgent:
"""XML based chat agent."""
return ChatAgent(
llm=llm,
tools=tools,
prompt_generator=generate_prompt,
parser=_decode,
)

View File

@@ -0,0 +1,71 @@
"""Generalized chat agent, works with any chat model."""
from __future__ import annotations
from typing import Callable, Iterator, Optional, Sequence, TypeVar, Union
from langchain.automaton.runnables import create_llm_program
from langchain.automaton.typedefs import (
Agent,
AgentFinish,
MessageLike,
)
from langchain.schema import PromptValue
from langchain.schema.language_model import (
BaseLanguageModel,
LanguageModelInput,
LanguageModelOutput,
)
from langchain.schema.messages import BaseMessage
from langchain.schema.output_parser import BaseOutputParser
from langchain.schema.runnable import (
Runnable,
)
from langchain.tools import BaseTool
T = TypeVar("T")
class ChatAgent(Agent):
"""A generalized chat agent."""
def __init__(
self,
llm: BaseLanguageModel[LanguageModelInput, LanguageModelOutput],
prompt_generator: Union[
Callable[[T], Union[str, PromptValue, Sequence[BaseMessage]]], Runnable
],
*,
tools: Optional[Sequence[BaseTool]] = None,
stop: Optional[Sequence[str]] = None,
parser: Union[
Runnable[Union[BaseMessage, str], MessageLike],
Callable[[Union[BaseMessage, str]], MessageLike],
BaseOutputParser,
None,
] = None,
) -> None:
"""Initialize the chat agent."""
invoke_tools = bool(tools)
self.llm_program = create_llm_program(
llm,
prompt_generator=prompt_generator,
tools=tools,
parser=parser,
stop=stop,
invoke_tools=invoke_tools,
)
def run(
self,
messages: Sequence[MessageLike],
*,
max_iterations: int = 100,
) -> Iterator[MessageLike]:
"""Run the agent."""
all_messages = list(messages)
for _ in range(max_iterations):
if all_messages and isinstance(all_messages[-1], AgentFinish):
break
new_messages = self.llm_program.invoke(all_messages)
yield from new_messages
all_messages.extend(new_messages)

View File

@@ -0,0 +1,325 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "a91c50ff-bcb9-484b-a4bd-f56e75c182ba",
"metadata": {},
"source": [
"# General Chat Agent w/ XML like encoding"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "80595f69-8b46-494e-bd2f-0c9538ed526a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"import sys\n",
"sys.path.insert(0, '../../../')"
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "724c11c6-c98a-47ef-8697-87cd888df7a0",
"metadata": {},
"outputs": [],
"source": [
"from typing import List, Optional, Any\n",
"from langchain.tools import tool, Tool\n",
"from langchain.schema.messages import AIMessage, HumanMessage\n",
"from langchain.chat_models.anthropic import ChatAnthropic\n",
"from langchain.automaton.agent_implementations import xml_agent"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "ed3855cd-06b0-4a45-8b76-88926433fc60",
"metadata": {},
"outputs": [],
"source": [
"def get_tools() -> List[Tool]:\n",
" @tool\n",
" def name() -> str:\n",
" \"\"\"Use to look up the user's name\"\"\"\n",
" return \"Eugene\"\n",
"\n",
" @tool\n",
" def get_weather(city: str) -> str:\n",
" \"\"\"Get weather in a specific city.\"\"\"\n",
" return \"42F and sunny\"\n",
"\n",
" @tool\n",
" def add(x: int, y: int) -> int:\n",
" \"\"\"Use to add two numbers. For example; { \"x\": 2, \"y\": 10}\"\"\"\n",
" return x + y\n",
"\n",
" @tool\n",
" def add_3(x: int, y: int, z: int) -> int:\n",
" \"\"\"Add 3 numbers together.\"\"\"\n",
" return x + y + z\n",
"\n",
" @tool\n",
" def divide(x: float, y: float) -> float:\n",
" \"\"\"Divide x by y\"\"\"\n",
" return x / y\n",
"\n",
" return list(locals().values())\n",
"\n",
"\n",
"tools = get_tools()"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "97a59c88-2f0a-4982-81e3-4ffba4f222da",
"metadata": {},
"outputs": [],
"source": [
"llm = ChatAnthropic(temperature=0)"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "00177716-422e-436c-b07d-b70b2ccb5839",
"metadata": {},
"outputs": [],
"source": [
"start_messages = xml_agent.get_start_state(tools)"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "778b2eda-5128-40ea-949f-9727610cc194",
"metadata": {},
"outputs": [],
"source": [
"agent = xml_agent.create_xml_agent(llm, tools)"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "3bfcb9a7-af38-4b7a-87cc-84a842d6c072",
"metadata": {},
"outputs": [],
"source": [
"question = HumanMessage(\n",
" content=\"Sum of first 10 numbers starting from 1? use one tool at a time\"\n",
")\n",
"messages = start_messages + [question]"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "29ec7747-5cf7-49e3-b39e-23cf0a0818f3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ai: <action>\n",
"{\n",
" \"action\": \"add\", \n",
" \"action_input\": {\n",
" \"x\": 1,\n",
" \"y\": 2\n",
" } \n",
"}\n",
"</action>\n",
"FunctionCall(name=add, named_arguments={'x': 1, 'y': 2})\n",
"FunctionResult(name=add, result=3, error=None)\n",
"ai: <action>\n",
"{\n",
" \"action\": \"add\",\n",
" \"action_input\": {\n",
" \"x\": 3, \n",
" \"y\": 3\n",
" }\n",
"} \n",
"</action>\n",
"FunctionCall(name=add, named_arguments={'x': 3, 'y': 3})\n",
"FunctionResult(name=add, result=6, error=None)\n",
"ai: <action>\n",
"{\n",
" \"action\": \"add\",\n",
" \"action_input\": {\n",
" \"x\": 6,\n",
" \"y\": 4 \n",
" }\n",
"}\n",
"</action>\n",
"FunctionCall(name=add, named_arguments={'x': 6, 'y': 4})\n",
"FunctionResult(name=add, result=10, error=None)\n",
"ai: <action>\n",
"{\n",
" \"action\": \"add\",\n",
" \"action_input\": {\n",
" \"x\": 10, \n",
" \"y\": 5\n",
" }\n",
"}\n",
"</action> \n",
"FunctionCall(name=add, named_arguments={'x': 10, 'y': 5})\n",
"FunctionResult(name=add, result=15, error=None)\n",
"ai: <action>\n",
"{\n",
" \"action\": \"add\",\n",
" \"action_input\": {\n",
" \"x\": 15,\n",
" \"y\": 6\n",
" }\n",
"} \n",
"</action>\n",
"FunctionCall(name=add, named_arguments={'x': 15, 'y': 6})\n",
"FunctionResult(name=add, result=21, error=None)\n",
"ai: <action>\n",
"{\n",
" \"action\": \"add\",\n",
" \"action_input\": {\n",
" \"x\": 21,\n",
" \"y\": 7\n",
" }\n",
"}\n",
"</action>\n",
"FunctionCall(name=add, named_arguments={'x': 21, 'y': 7})\n",
"FunctionResult(name=add, result=28, error=None)\n",
"ai: <action>\n",
"{\n",
" \"action\": \"add\",\n",
" \"action_input\": {\n",
" \"x\": 28, \n",
" \"y\": 8\n",
" }\n",
"}\n",
"</action>\n",
"FunctionCall(name=add, named_arguments={'x': 28, 'y': 8})\n",
"FunctionResult(name=add, result=36, error=None)\n",
"ai: <action>\n",
"{\n",
" \"action\": \"add\",\n",
" \"action_input\": {\n",
" \"x\": 36,\n",
" \"y\": 9 \n",
" }\n",
"} \n",
"</action>\n",
"FunctionCall(name=add, named_arguments={'x': 36, 'y': 9})\n",
"FunctionResult(name=add, result=45, error=None)\n",
"ai: <action>\n",
"{\n",
" \"action\": \"add\",\n",
" \"action_input\": {\n",
" \"x\": 45,\n",
" \"y\": 10\n",
" }\n",
"}\n",
"</action>\n",
"FunctionCall(name=add, named_arguments={'x': 45, 'y': 10})\n",
"FunctionResult(name=add, result=55, error=None)\n",
"ai: <action>\n",
"{\n",
" \"action\": \"Final Answer\",\n",
" \"action_input\": {\n",
" \"answer\": \"The sum of the first 10 numbers starting from 1 is 55.\"\n",
" }\n",
"} \n",
"</action>\n",
"AgentFinish(result={'answer': 'The sum of the first 10 numbers starting from 1 is 55.'})\n"
]
}
],
"source": [
"for msg in agent.run(messages):\n",
" print(msg)"
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "a3a4eeca-9126-4d97-902d-fa4564563130",
"metadata": {},
"outputs": [
{
"ename": "ImportError",
"evalue": "cannot import name 'prompt_generator' from 'langchain.automaton.chat_agent' (/home/eugene/src/langchain/libs/langchain/langchain/automaton/notebooks/../../../langchain/automaton/chat_agent.py)",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[40], line 4\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Visualize\u001b[39;00m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mhtml\u001b[39;00m\n\u001b[0;32m----> 4\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mlangchain\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mautomaton\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mchat_agent\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m prompt_generator\n\u001b[1;32m 6\u001b[0m messages \u001b[38;5;241m=\u001b[39m prompt_generator(message_log)\n\u001b[1;32m 8\u001b[0m \u001b[38;5;66;03m# Generate HTML divs\u001b[39;00m\n",
"\u001b[0;31mImportError\u001b[0m: cannot import name 'prompt_generator' from 'langchain.automaton.chat_agent' (/home/eugene/src/langchain/libs/langchain/langchain/automaton/notebooks/../../../langchain/automaton/chat_agent.py)"
]
}
],
"source": [
"# # Visualize\n",
"\n",
"# import html\n",
"# from langchain.automaton.chat_agent import prompt_generator\n",
"\n",
"# messages = prompt_generator(message_log)\n",
"\n",
"# # Generate HTML divs\n",
"# html_divs = []\n",
"# for idx, msg in enumerate(messages):\n",
"# background_color = \"DodgerBlue\" if idx % 2 == 0 else \"Tomato\"\n",
"# div_content = f\"\"\"\n",
"# <div style=\"border: 1px solid black; padding: 10px; background-color:{background_color}\">\n",
"# <div style=\"display: flex;\">\n",
"# <div style=\"font-weight: bold; margin-right: 10px;\">{html.escape(msg.type)}</div>\n",
"# <div>{html.escape(msg.content)}</div>\n",
"# </div>\n",
"# </div>\n",
"# \"\"\"\n",
"# html_divs.append(div_content)\n",
"\n",
"# # Display HTML divs in a Jupyter Notebook cell\n",
"# from IPython.display import HTML, display\n",
"\n",
"# html_output = \"\\n\".join(html_divs)\n",
"# display(HTML(html_output))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,167 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "a91c50ff-bcb9-484b-a4bd-f56e75c182ba",
"metadata": {},
"source": [
"# Open AI Chat"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5d77b586-fced-4a34-b778-b05141ab32bc",
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"import sys\n",
"sys.path.insert(0, '../../../')"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "f73ed8ca-7384-4afc-af8c-bfd5b794df01",
"metadata": {},
"outputs": [],
"source": [
"from typing import List, Optional, Any\n",
"from langchain.tools import tool, Tool\n",
"from langchain.automaton.agent_implementations.openai_agent import create_openai_agent\n",
"from langchain.schema.messages import HumanMessage\n",
"from langchain.chat_models.openai import ChatOpenAI"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "ed3855cd-06b0-4a45-8b76-88926433fc60",
"metadata": {},
"outputs": [],
"source": [
"def get_tools() -> List[Tool]:\n",
" @tool\n",
" def name() -> str:\n",
" \"\"\"Use to look up the user's name\"\"\"\n",
" return \"Eugene\"\n",
"\n",
" @tool\n",
" def get_weather(city: str) -> str:\n",
" \"\"\"Get weather in a specific city.\"\"\"\n",
" return \"42F and sunny\"\n",
"\n",
" @tool\n",
" def add(x: int, y: int) -> int:\n",
" \"\"\"Use to add two numbers. For example; { \"x\": 2, \"y\": 10} \"\"\"\n",
" return x + y\n",
"\n",
" @tool\n",
" def add_3(x: int, y: int, z: int) -> int:\n",
" \"\"\"Add 3 numbers together.\"\"\"\n",
" return x + y + z\n",
"\n",
" @tool\n",
" def divide(x: float, y: float) -> float:\n",
" \"\"\"Divide x by y\"\"\"\n",
" return x / y\n",
"\n",
" return list(locals().values())\n",
"\n",
"tools = get_tools()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "a0790c55-eacb-4e0c-970b-5784d62ef5c2",
"metadata": {},
"outputs": [],
"source": [
"llm = ChatOpenAI(temperature=0)\n",
"agent = create_openai_agent(llm, tools)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "e7d1ec69-de12-4c66-abee-3b2bbabeec13",
"metadata": {},
"outputs": [],
"source": [
"a = agent.run([HumanMessage(content=\"Sum of first 10 numbers starting from 1? use one tool at a time\")])"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "75c6cc2c-7604-44c0-b8cf-4513620782f9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ai: {'function_call': {'name': 'add', 'arguments': '{\\n \"x\": 1,\\n \"y\": 2\\n}'}}\n",
"FunctionCall(name=add, named_arguments={'x': 1, 'y': 2})\n",
"FunctionResult(name=add, result=3, error=None)\n",
"ai: {'function_call': {'name': 'add', 'arguments': '{\\n \"x\": 3,\\n \"y\": 3\\n}'}}\n",
"FunctionCall(name=add, named_arguments={'x': 3, 'y': 3})\n",
"FunctionResult(name=add, result=6, error=None)\n",
"ai: {'function_call': {'name': 'add', 'arguments': '{\\n \"x\": 6,\\n \"y\": 4\\n}'}}\n",
"FunctionCall(name=add, named_arguments={'x': 6, 'y': 4})\n",
"FunctionResult(name=add, result=10, error=None)\n",
"ai: {'function_call': {'name': 'add', 'arguments': '{\\n \"x\": 10,\\n \"y\": 5\\n}'}}\n",
"FunctionCall(name=add, named_arguments={'x': 10, 'y': 5})\n",
"FunctionResult(name=add, result=15, error=None)\n",
"ai: {'function_call': {'name': 'add', 'arguments': '{\\n \"x\": 15,\\n \"y\": 6\\n}'}}\n",
"FunctionCall(name=add, named_arguments={'x': 15, 'y': 6})\n",
"FunctionResult(name=add, result=21, error=None)\n",
"ai: {'function_call': {'name': 'add', 'arguments': '{\\n \"x\": 21,\\n \"y\": 7\\n}'}}\n",
"FunctionCall(name=add, named_arguments={'x': 21, 'y': 7})\n",
"FunctionResult(name=add, result=28, error=None)\n",
"ai: {'function_call': {'name': 'add', 'arguments': '{\\n \"x\": 28,\\n \"y\": 8\\n}'}}\n",
"FunctionCall(name=add, named_arguments={'x': 28, 'y': 8})\n",
"FunctionResult(name=add, result=36, error=None)\n",
"ai: {'function_call': {'name': 'add', 'arguments': '{\\n \"x\": 36,\\n \"y\": 9\\n}'}}\n",
"FunctionCall(name=add, named_arguments={'x': 36, 'y': 9})\n",
"FunctionResult(name=add, result=45, error=None)\n",
"ai: {'function_call': {'name': 'add', 'arguments': '{\\n \"x\": 45,\\n \"y\": 10\\n}'}}\n",
"FunctionCall(name=add, named_arguments={'x': 45, 'y': 10})\n",
"FunctionResult(name=add, result=55, error=None)\n",
"ai: The sum of the first 10 numbers starting from 1 is 55.\n",
"AgentFinish(result=ai: The sum of the first 10 numbers starting from 1 is 55.)\n"
]
}
],
"source": [
"for msg in a:\n",
" print(msg)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,222 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "9875f554-8f4d-409b-929f-00e8bc70185d",
"metadata": {},
"source": [
"# Think Act Agent"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "80595f69-8b46-494e-bd2f-0c9538ed526a",
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"import sys\n",
"sys.path.insert(0, '../../../')"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "724c11c6-c98a-47ef-8697-87cd888df7a0",
"metadata": {},
"outputs": [],
"source": [
"from typing import List, Optional, Any\n",
"from langchain.tools import tool, Tool\n",
"from langchain.schema.messages import AIMessage, HumanMessage\n",
"from langchain.chat_models.anthropic import ChatAnthropic\n",
"from langchain.llms import OpenAI\n",
"from langchain.automaton.agent_implementations import think_act_agent"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "ed3855cd-06b0-4a45-8b76-88926433fc60",
"metadata": {},
"outputs": [],
"source": [
"def get_tools() -> List[Tool]:\n",
" @tool\n",
" def name() -> str:\n",
" \"\"\"Use to look up the user's name\"\"\"\n",
" return \"Eugene\"\n",
"\n",
" @tool\n",
" def get_weather(city: str) -> str:\n",
" \"\"\"Get weather in a specific city.\"\"\"\n",
" return \"42F and sunny\"\n",
"\n",
" @tool\n",
" def add(x: int, y: int) -> int:\n",
" \"\"\"Use to add two numbers. For example; { \"x\": 2, \"y\": 10}\"\"\"\n",
" return x + y\n",
"\n",
" @tool\n",
" def divide(x: float, y: float) -> float:\n",
" \"\"\"Divide x by y\"\"\"\n",
" return x / y\n",
"\n",
" return list(locals().values())\n",
"\n",
"\n",
"tools = get_tools()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "f6abff39-5243-45d1-a250-a862bbd665ff",
"metadata": {},
"outputs": [],
"source": [
"llm = OpenAI(temperature=0)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "d416bb0e-28e4-4da3-8b10-aa563f87a48e",
"metadata": {},
"outputs": [],
"source": [
"agent = think_act_agent.ThinkActAgent(llm, tools)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "f9c1ed10-309b-47a1-b22e-153e53c3a86e",
"metadata": {},
"outputs": [],
"source": [
"start_messages = think_act_agent.generate_memory(tools)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "04488a11-6ef2-428b-b1ae-7800b424480a",
"metadata": {},
"outputs": [],
"source": [
"from langchain.callbacks.manager import (\n",
" trace_as_chain_group,\n",
" atrace_as_chain_group,\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "32a3ce50-e993-4f29-b2f8-02261d417e24",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ai: I need to use the divide tool and the get_weather tool\n",
"<action>\n",
"{\n",
" \"action\": \"divide\",\n",
" \"action_input\": { \"x\": 12381237, \"y\": 23 }\n",
"}\n",
"</action>\n",
"\n",
"FunctionCall(name=divide, named_arguments={'x': 12381237, 'y': 23})\n",
"FunctionResult(name=divide, result=538314.6521739131, error=None)\n",
"ai: I need to use the name tool and the get_weather tool\n",
"<action>\n",
"{\n",
" \"action\": \"name\",\n",
" \"action_input\": \"\"\n",
"}\n",
"</action>\n",
"\n",
"FunctionCall(name=name, named_arguments={})\n",
"FunctionResult(name=name, result=Eugene, error=None)\n",
"ai: I need to use the get_weather tool\n",
"<action>\n",
"{\n",
" \"action\": \"get_weather\",\n",
" \"action_input\": \"New York\"\n",
"}\n",
"</action>\n",
"\n",
"AdHocMessage(type=error, data=Invalid action blob \n",
"{\n",
" \"action\": \"get_weather\",\n",
" \"action_input\": \"New York\"\n",
"}\n",
", action_input must be a dict)\n",
"ai: \n",
"Thought:I need to use the get_weather tool\n",
"<action>\n",
"{\n",
" \"action\": \"get_weather\",\n",
" \"action_input\": {\"city\": \"New York\"}\n",
"}\n",
"</action>\n",
"\n",
"FunctionCall(name=get_weather, named_arguments={'city': 'New York'})\n",
"FunctionResult(name=get_weather, result=42F and sunny, error=None)\n",
"ai: I know what to respond\n",
"<action>\n",
"{\n",
" \"action\": \"Final Answer\",\n",
" \"action_input\": \"Hi Eugene, the answer to 12381237/23 is 538314.6521739131. The weather in New York is 42F and sunny.\"\n",
"}\n",
"</action>\n",
"AgentFinish(result=Hi Eugene, the answer to 12381237/23 is 538314.6521739131. The weather in New York is 42F and sunny.)\n"
]
}
],
"source": [
"question = HumanMessage(content=\"what is 12381237/23? What is my name? And what is the weather in new york?\")\n",
"messages = start_messages + [question]\n",
"\n",
"with trace_as_chain_group(\"think_act_agent\") as group_manager:\n",
" for msg in agent.run(messages, config={\"callbacks\": group_manager}):\n",
" print(msg)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "248d63ee-1ec4-49cb-9e6b-84e75a7644ca",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,225 @@
"""Module contains useful runnables for agents."""
from __future__ import annotations
from typing import Any, Callable, Dict, List, Optional, Sequence, TypeVar, Union
from langchain.automaton.typedefs import (
FunctionCall,
FunctionResult,
MessageLike,
)
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.schema import AIMessage, BaseMessage, PromptValue
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.output_parser import BaseOutputParser
from langchain.schema.runnable import (
Runnable,
RunnableConfig,
RunnableLambda,
RunnableMap,
RunnablePassthrough,
patch_config,
)
from langchain.tools import BaseTool
T = TypeVar("T")
def _to_message(result: Union[BaseMessage, str]) -> BaseMessage:
"""Convert to a list of messages."""
if isinstance(result, BaseMessage):
return result
elif isinstance(result, str):
return AIMessage(content=result)
else:
raise NotImplementedError(f"Unsupported type {type(result)}")
def _to_list(element: Union[None, T, List[T]]) -> List[T]:
"""Convert to a sequence."""
if element is None:
return []
elif isinstance(element, list):
return element
else:
return [element]
def _to_runnable_parser(parser: Optional[BaseOutputParser]) -> Runnable:
"""Adapt a parser to a runnable."""
if parser is None:
# Then create a runnable that returns no messages
return RunnableLambda(lambda *args, **kwargs: None)
elif isinstance(parser, Runnable):
return parser
elif isinstance(parser, Callable):
return RunnableLambda(parser)
else:
raise ValueError(f"Expected BaseOutputParser, got {parser}")
def _concatenate_head_and_tail(intermediate_input: Dict[str, Any]) -> List[BaseMessage]:
"""Concatenate head and tail into a single list."""
head = _to_list(intermediate_input["head"])
tail = _to_list(intermediate_input["tail"])
return head + tail
def _apply_and_concat(
head: Union[Runnable, Callable], tail: Union[Runnable, Callable]
) -> Runnable:
"""Apply head and tail and concatenate the results.
Note: Probably generalize to _apply(funcs) and _concatenate runnables
Args:
head: A runnable or callable
tail: A runnable or callable
Returns:
A runnable that applies head and tail and concatenates the results in order.
"""
head_ = head if isinstance(head, Runnable) else RunnableLambda(head)
tail_ = tail if isinstance(tail, Runnable) else RunnableLambda(tail)
return (
RunnableMap(
steps={
"head": head_,
"tail": tail_,
}
)
| _concatenate_head_and_tail
)
# PUBLIC API
def create_tool_invoker(
tools: Sequence[BaseTool],
) -> Runnable[MessageLike, Optional[FunctionResult]]:
"""See if possible to re-write with router
TODO:
* re-write with router
* potentially remove hack replace MessageLike with FunctionCall, requires
a branching runnable
"""
tools_by_name = {tool.name: tool for tool in tools}
def func(
function_call: MessageLike,
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
) -> Optional[FunctionResult]:
"""A function that can invoke a tool using .run"""
if not isinstance(
function_call, FunctionCall
): # TODO(Hack): Workaround lack of conditional apply
return None
try:
tool = tools_by_name[function_call.name]
except KeyError:
raise AssertionError(f"No such tool: {function_call.name}")
try:
result = tool.invoke(
function_call.named_arguments or {},
patch_config(config, callbacks=run_manager.get_child()),
)
error = None
except Exception as e:
result = None
error = repr(e) + repr(function_call.named_arguments)
return FunctionResult(name=function_call.name, result=result, error=error)
async def afunc(
function_call: MessageLike,
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
) -> Optional[FunctionResult]:
"""A function that can invoke a tool using .run"""
if not isinstance(
function_call, FunctionCall
): # TODO(Hack): Workaround lack of conditional apply
return None
try:
tool = tools_by_name[function_call.name]
except KeyError:
raise AssertionError(f"No such tool: {function_call.name}")
try:
result = await tool.ainvoke(
function_call.named_arguments or {},
patch_config(config, callbacks=run_manager.get_child()),
)
error = None
except Exception as e:
result = None
error = repr(e) + repr(function_call.named_arguments)
return FunctionResult(name=function_call.name, result=result, error=error)
return RunnableLambda(func=func, afunc=afunc)
def create_llm_program(
llm: BaseLanguageModel,
prompt_generator: Union[
Callable[
[Sequence[MessageLike]], Union[str, PromptValue, Sequence[BaseMessage]]
],
Runnable,
],
*,
tools: Optional[Sequence[BaseTool]] = None,
stop: Optional[Sequence[str]] = None,
parser: Union[
Runnable[Union[BaseMessage, str], MessageLike],
Callable[[Union[BaseMessage, str]], MessageLike],
BaseOutputParser,
None,
] = None,
invoke_tools: bool = True, # TODO(Eugene): Perhaps remove.
) -> Runnable[Sequence[MessageLike], List[MessageLike]]:
"""Create a runnable that provides a generalized interface to an LLM with actions.
Args:
llm: A language model
prompt_generator: A function that takes a list of messages and returns a prompt
tools: A list of tools to invoke
stop: optional list of stop tokens
parser: optional parser to apply to the output of the LLM
invoke_tools: Whether to invoke tools on the output of the LLM
Returns:
A runnable that returns a list of messages
"""
if not isinstance(prompt_generator, Runnable):
_prompt_generator = RunnableLambda(prompt_generator)
else:
_prompt_generator = Runnable
if stop:
llm = llm.bind(stop=stop)
# Add parser to the end of the chain and concatenate original llm output
# with the parser output.
# The parser is always created even if it is None, to make sure that
# the _to_message adapter is always applied (regardless of the parser).
_parser = _to_runnable_parser(parser)
chain = _prompt_generator | llm | _apply_and_concat(_to_message, _parser)
# Add tool invoker to the end of the chain.
if invoke_tools and tools:
tool_invoker = create_tool_invoker(tools)
invoke_on_last = RunnableLambda(lambda msgs: msgs[-1]) | tool_invoker
complete_chain = chain | _apply_and_concat(
RunnablePassthrough(), invoke_on_last
)
else:
complete_chain = chain
return complete_chain

View File

@@ -0,0 +1,42 @@
from __future__ import annotations
import langchain.automaton.agent_implementations.xml_agent
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
from langchain.automaton.mrkl_agent import ActionParser
def test_structured_output_chat() -> None:
parser = StructuredChatOutputParser()
output = parser.parse(
"""
```json
{
"action": "hello",
"action_input": {
"a": 2
}
}
```
"""
)
assert output == {}
def test_parser() -> None:
"""Tes the parser."""
sample_text = """
Some text before
<action>
{
"key": "value",
"number": 42
}
</action>
Some text after
"""
action_parser = ActionParser()
action = langchain.automaton.agent_implementations.xml_agent.decode(sample_text)
assert action == {
"key": "value",
"number": 42,
}

View File

@@ -0,0 +1,97 @@
from __future__ import annotations
from typing import List, cast
import pytest
from langchain.automaton.openai_agent import OpenAIAgent
from langchain.automaton.tests.utils import (
FakeChatModel,
construct_func_invocation_message,
)
from langchain.automaton.typedefs import (
AgentFinish,
FunctionCall,
FunctionResult,
MessageLog,
)
from langchain.schema.messages import (
AIMessage,
SystemMessage,
)
from langchain.tools import Tool, tool
from langchain.tools.base import BaseTool
@pytest.fixture()
def tools() -> List[BaseTool]:
@tool
def get_time() -> str:
"""Get time."""
return "9 PM"
@tool
def get_location() -> str:
"""Get location."""
return "the park"
return cast(List[Tool], [get_time, get_location])
def test_openai_agent(tools: List[Tool]) -> None:
get_time, get_location = tools
llm = FakeChatModel(
message_iter=iter(
[
construct_func_invocation_message(get_time, {}),
AIMessage(
content="The time is 9 PM.",
),
]
)
)
agent = OpenAIAgent(llm=llm, tools=tools, max_iterations=10)
message_log = MessageLog(
[
SystemMessage(
content="What time is it?",
)
]
)
expected_messages = [
SystemMessage(
content="What time is it?",
),
AIMessage(
content="",
additional_kwargs={
"function_call": {
"name": "get_time",
"arguments": "{}",
}
},
),
FunctionCall(
name="get_time",
arguments={},
),
FunctionResult(
name="get_time",
result="9 PM",
error=None,
),
AIMessage(
content="The time is 9 PM.",
),
AgentFinish(
AIMessage(
content="The time is 9 PM.",
),
),
]
agent.run(message_log)
assert message_log.messages == expected_messages

View File

@@ -0,0 +1,153 @@
"""Test parsers"""
from typing import Any, List, Optional, cast
import pytest
from langchain.automaton.runnables import (
RunnablePassthrough,
_apply_and_concat,
_to_list,
_to_runnable_parser,
create_llm_program,
)
from langchain.automaton.tests.utils import (
FakeChatModel,
)
from langchain.automaton.typedefs import FunctionCall, FunctionResult, MessageLike
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage
from langchain.schema.runnable import RunnableLambda
from langchain.tools import BaseTool, tool
def test_apply_and_concat() -> None:
"""Test apply and concat."""
msg = HumanMessage(content="Hello")
llm = RunnableLambda(lambda *args, **kwargs: msg)
# Test that it works with a runnable
chain = llm | _apply_and_concat(RunnablePassthrough(), RunnablePassthrough())
assert chain.invoke({}) == [msg, msg]
chain = llm | _apply_and_concat(
lambda msg: msg.content[0], lambda msg: msg.content[1]
)
assert chain.invoke({}) == ["H", "e"]
@pytest.mark.parametrize(
"input_value, expected_output",
[
(None, []), # Test when input is None
([1, 2, 3], [1, 2, 3]), # Test when input is a list of integers
(5, [5]), # Test when input is a single integer
(["a", "b", "c"], ["a", "b", "c"]), # Test when input is a list of strings
("xyz", ["xyz"]), # Test when input is a single string
([], []), # Test when input is an empty list
],
)
def test_to_list(input_value: Any, expected_output: List) -> None:
assert _to_list(input_value) == expected_output
@pytest.mark.parametrize(
"parser, expected_output",
[
(None, None),
(RunnablePassthrough(), AIMessage(content="Hello")),
],
)
def to_runnable_parser(parser: Any, expected_output: Optional[BaseMessage]) -> None:
"""To runnable parser."""
parser_ = _to_runnable_parser(parser)
assert parser_.invoke(AIMessage(content="Hello")) == [expected_output]
@pytest.fixture()
def tools() -> List[BaseTool]:
"""Make a tools fixture."""
@tool
def get_time() -> str:
"""Get time."""
return "9 PM"
@tool
def get_location() -> str:
"""Get location."""
return "the park"
return cast(List[BaseTool], [get_time, get_location])
@pytest.fixture()
def fake_llm() -> BaseLanguageModel:
"""Make a fake chat model."""
llm = FakeChatModel(
message_iter=iter(
[
AIMessage(
content="Hello",
),
]
)
)
return llm
def test_simple_llm_program(fake_llm: BaseLanguageModel) -> None:
"""Test simple llm program with no parser or tools."""
get_time, _ = tools
program = create_llm_program(
fake_llm,
prompt_generator=lambda x: x,
)
assert program.invoke("What time is it?") == [AIMessage(content="Hello")]
def test_llm_program_with_parser(fake_llm: BaseLanguageModel) -> None:
"""Test simple llm program with no parser or tools."""
parser = RunnableLambda(lambda msg: AIMessage(content=msg.content + " parsed"))
program = create_llm_program(
fake_llm,
prompt_generator=lambda x: x,
parser=parser,
)
assert program.invoke("What time is it?") == [
AIMessage(content="Hello"),
AIMessage(content="Hello parsed"),
]
@pytest.mark.parametrize(
"parser, output",
[
(
RunnableLambda(lambda msg: AIMessage(content="Goodbye")),
[AIMessage(content="Hello"), AIMessage(content="Goodbye")],
),
(
RunnableLambda(lambda msg: FunctionCall(name="get_time")),
[
AIMessage(content="Hello"),
FunctionCall(name="get_time"),
FunctionResult(result="9 PM", name="get_time"),
],
),
],
)
def test_llm_program_with_parser_and_tools(
tools: List[BaseTool],
fake_llm: BaseLanguageModel,
parser: Any,
output: List[MessageLike],
) -> None:
"""Test simple llm program with no parser or tools."""
program = create_llm_program(
fake_llm,
prompt_generator=lambda x: x,
parser=parser,
tools=tools,
invoke_tools=True,
)
assert program.invoke("What time is it?") == output

View File

@@ -0,0 +1,46 @@
from __future__ import annotations
import json
from typing import Any, Iterator, List, Mapping
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.chat_models.base import BaseChatModel
from langchain.schema import AIMessage, BaseMessage, ChatGeneration, ChatResult
from langchain.tools import BaseTool
class FakeChatModel(BaseChatModel):
"""A fake chat model that returns a pre-defined response."""
message_iter: Iterator[BaseMessage]
@property
def _llm_type(self) -> str:
"""The type of the model."""
return "fake-openai-chat-model"
def _generate(
self,
messages: List[BaseMessage],
stop: List[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
"""Generate a response to the given messages."""
message = next(self.message_iter)
return ChatResult(generations=[ChatGeneration(message=message)])
def construct_func_invocation_message(
tool: BaseTool, args: Mapping[str, Any]
) -> AIMessage:
"""Construct a function invocation message."""
return AIMessage(
content="",
additional_kwargs={
"function_call": {
"name": tool.name,
"arguments": json.dumps(args),
}
},
)

View File

@@ -0,0 +1,25 @@
from typing import Sequence, TypedDict
from langchain.tools import BaseTool
def _generate_tools_descriptions(tools: Sequence[BaseTool]) -> str:
"""Generate a description of the tools."""
return "\n".join([f"{tool_.name}: {tool_.description}" for tool_ in tools]) + "\n"
class ToolInfo(TypedDict):
"""A dictionary containing information about a tool."""
tool_names: str
tools_description: str
def generate_tool_info(tools: Sequence[BaseTool]) -> ToolInfo:
"""Generate a string containing the names of the tools and their descriptions."""
tools_description = _generate_tools_descriptions(tools)
tool_names = ", ".join([tool_.name for tool_ in tools])
return {
"tool_names": tool_names,
"tools_description": tools_description,
}

View File

@@ -0,0 +1,71 @@
from __future__ import annotations
from typing import Any, Iterator, Mapping, Optional, Sequence, Union
from langchain.load.serializable import Serializable
from langchain.schema import (
BaseMessage,
)
class InternalMessage(Serializable):
@property
def lc_serializable(self) -> bool:
"""Indicate whether the class is serializable."""
return True
class FunctionCall(InternalMessage):
"""A request for a function invocation.
This message can be used to request a function invocation
using the function name and the arguments to pass to the function.
"""
name: str
"""The name of the function to invoke."""
named_arguments: Optional[Mapping[str, Any]] = None
"""The named arguments to pass to the function."""
class Config:
extra = "forbid"
def __str__(self):
return f"FunctionCall(name={self.name}, named_arguments={self.named_arguments})"
class FunctionResult(InternalMessage):
"""A result of a function invocation."""
name: str
result: Any
error: Optional[str] = None
def __str__(self):
return f"FunctionResult(name={self.name}, result={self.result}, error={self.error})"
class AdHocMessage(InternalMessage):
"""A message that is used to prime the language model."""
type: str
data: Any # Make sure this is serializable
def __str__(self):
return f"AdHocMessage(type={self.type}, data={self.data})"
class AgentFinish(InternalMessage):
result: Any
def __str__(self):
return f"AgentFinish(result={self.result})"
MessageLike = Union[BaseMessage, InternalMessage]
class Agent: # This is just approximate still, may end up being a runnable
def run(self, messages: Sequence[MessageLike]) -> Iterator[MessageLike]:
"""Run the agent on a message."""
raise NotImplementedError

View File

@@ -79,6 +79,15 @@ class BaseMessage(Serializable):
"""Whether this class is LangChain serializable."""
return True
def __repr__(self):
if self.additional_kwargs:
return f"{self.type}: {self.additional_kwargs}"
else:
return f"{self.type}: {self.content}"
def __str__(self):
return self.__repr__()
def __add__(self, other: Any) -> ChatPromptTemplate:
from langchain.prompts.chat import ChatPromptTemplate