Compare commits

...

67 Commits

Author SHA1 Message Date
Eugene Yurtsev
f0ea4b8b25 x 2023-09-12 10:52:28 -04:00
Eugene Yurtsev
d45d2294fc x 2023-09-11 14:37:17 -04:00
Eugene Yurtsev
9933c6c44c x 2023-09-11 12:50:10 -04:00
Eugene Yurtsev
f869131743 x 2023-09-11 11:27:04 -04:00
Eugene Yurtsev
db116dd4cd x 2023-09-11 11:25:17 -04:00
Eugene Yurtsev
47cdc8873c x 2023-09-11 11:05:38 -04:00
Eugene Yurtsev
5a8c0ab5c5 x 2023-09-11 09:45:56 -04:00
Eugene Yurtsev
ff4d1d150a Merge branch 'master' into eugene/automaton_variant_3 2023-09-08 17:00:42 -04:00
Eugene Yurtsev
cd3a57602d x 2023-09-08 17:00:14 -04:00
Eugene Yurtsev
5601d60df0 x 2023-09-08 16:40:27 -04:00
Eugene Yurtsev
1cd9fb6444 x 2023-09-08 16:38:13 -04:00
Eugene Yurtsev
6eb5a1c5e6 x 2023-09-08 15:43:41 -04:00
Eugene Yurtsev
71dee4c885 x 2023-09-08 09:24:49 -04:00
Eugene Yurtsev
b3cbff741e x 2023-09-07 17:06:51 -04:00
Eugene Yurtsev
00c3a2d3d2 x 2023-09-07 16:24:21 -04:00
Eugene Yurtsev
ba61d64c04 x 2023-09-07 15:54:08 -04:00
Eugene Yurtsev
0396570c82 x 2023-09-07 15:47:47 -04:00
Eugene Yurtsev
cf6bca3da7 x 2023-09-07 15:45:39 -04:00
Eugene Yurtsev
95d7e87a68 x 2023-09-07 14:26:21 -04:00
Eugene Yurtsev
62b936dc7c x 2023-09-07 13:57:57 -04:00
Eugene Yurtsev
23e1ef8f4a x 2023-09-07 13:45:55 -04:00
Eugene Yurtsev
4108862154 x 2023-09-07 13:43:57 -04:00
Eugene Yurtsev
c9cbfc9d51 Merge branch 'master' into eugene/automaton_variant_3 2023-09-07 13:20:09 -04:00
Eugene Yurtsev
96f3fc2d58 x 2023-09-07 13:19:53 -04:00
Eugene Yurtsev
486d772184 x 2023-09-01 15:15:47 -04:00
Eugene Yurtsev
29ddf63abb x 2023-09-01 14:55:02 -04:00
Eugene Yurtsev
35318007d8 x 2023-09-01 14:54:47 -04:00
Eugene Yurtsev
900cf6ead3 x 2023-09-01 14:17:55 -04:00
Eugene Yurtsev
beec2d99b9 x 2023-09-01 14:16:45 -04:00
Eugene Yurtsev
a695ba5ba4 Merge branch 'master' into eugene/automaton_variant_2 2023-09-01 14:16:11 -04:00
Eugene Yurtsev
71c277c17c x 2023-09-01 14:16:01 -04:00
Eugene Yurtsev
e67f4b8d3b x 2023-09-01 14:14:29 -04:00
Eugene Yurtsev
7181211d78 x 2023-09-01 14:13:02 -04:00
Eugene Yurtsev
2c798fcc77 x 2023-09-01 14:02:35 -04:00
Eugene Yurtsev
bf12905e7f x 2023-09-01 11:23:02 -04:00
Eugene Yurtsev
e4cf87ab34 x 2023-09-01 10:56:12 -04:00
Eugene Yurtsev
9bbc5af2a8 x 2023-09-01 10:19:45 -04:00
Eugene Yurtsev
06e34e54f5 x 2023-08-31 15:22:38 -04:00
Eugene Yurtsev
5a68ba4174 x 2023-08-31 12:21:47 -04:00
Eugene Yurtsev
8e8c368019 x 2023-08-31 11:58:51 -04:00
Eugene Yurtsev
acbd19fac5 x 2023-08-31 11:24:43 -04:00
Eugene Yurtsev
aa56956d16 x 2023-08-31 10:20:04 -04:00
Eugene Yurtsev
0d2cf70d1e x 2023-08-30 09:18:04 -04:00
Eugene Yurtsev
188ac56c97 x 2023-08-30 09:17:27 -04:00
Eugene Yurtsev
383f09b271 x 2023-08-29 13:55:37 -04:00
Eugene Yurtsev
0af306fc54 Merge branch 'master' into eugene/automaton_variant_1 2023-08-29 13:53:10 -04:00
Eugene Yurtsev
15d5c49076 Merge branch 'master' into eugene/automaton_variant_1 2023-08-07 10:40:09 -04:00
Eugene Yurtsev
65660535bc x 2023-08-07 10:39:56 -04:00
Eugene Yurtsev
6c41dd82f0 x 2023-08-07 10:38:14 -04:00
Eugene Yurtsev
78d788c28c x 2023-08-07 10:37:41 -04:00
Eugene Yurtsev
e9deeab37f x 2023-08-07 10:32:53 -04:00
Eugene Yurtsev
4d595eec5b x 2023-08-06 23:13:32 -04:00
Eugene Yurtsev
047b001336 x 2023-08-06 22:50:46 -04:00
Eugene Yurtsev
840e936c7c x 2023-08-06 22:30:33 -04:00
Eugene Yurtsev
6cc6b490be x 2023-08-06 22:23:51 -04:00
Eugene Yurtsev
f45d1ed4f5 x 2023-08-06 22:20:29 -04:00
Eugene Yurtsev
4ffc417858 x 2023-08-05 23:10:15 -04:00
Eugene Yurtsev
9e74a70859 x 2023-08-05 22:39:04 -04:00
Eugene Yurtsev
0997f2c0f1 x 2023-08-05 22:37:16 -04:00
Eugene Yurtsev
437b545426 x 2023-08-05 16:01:12 -04:00
Eugene Yurtsev
9b9d07572b x 2023-08-05 14:47:17 -04:00
Eugene Yurtsev
6a90c6c2c8 Merge branch 'eugene/fix_mutation_in_place' into eugene/automaton 2023-08-04 10:56:24 -04:00
Eugene Yurtsev
8371187689 x 2023-08-04 10:46:56 -04:00
Eugene Yurtsev
4309c17ffa x 2023-08-04 10:44:57 -04:00
Eugene Yurtsev
183a9d4e66 x 2023-08-04 10:19:10 -04:00
Eugene Yurtsev
c1b444e1e7 x 2023-08-03 15:39:45 -04:00
Eugene Yurtsev
5f117384c0 x 2023-08-03 12:28:01 -04:00
24 changed files with 2263 additions and 0 deletions

View File

@@ -0,0 +1,75 @@
"""Generalized chat agent, works with any chat model."""
from __future__ import annotations
from typing import Callable, Optional, Sequence, TypeVar, List
from langchain.automaton.typedefs import MessageLike, AgentFinish, Agent
from langchain.schema.runnable import (
Runnable,
RunnableConfig,
)
T = TypeVar("T")
WorkingMemoryProcessor = Runnable[Sequence[MessageLike], List[MessageLike]]
LLMProgram = Runnable[Sequence[MessageLike], List[MessageLike]]
class SequentialAgent(Agent):
def __init__(
self,
llm_program: LLMProgram,
memory_processor: Optional[WorkingMemoryProcessor] = None,
) -> None:
"""Initialize the agent."""
self.llm_program = llm_program
self.memory_processor = memory_processor
def run(
self,
messages: Sequence[MessageLike],
*,
config: Optional[RunnableConfig] = None,
max_iterations: int = 10,
) -> List[MessageLike]:
"""Run the agent."""
messages = list(messages)
for iteration_num in range(max_iterations):
# Working memory / working state updates can take the form
# of appends or replacements
if self.memory_processor:
# This is a replacement
messages = self.memory_processor.invoke(messages)
if messages and isinstance(messages[-1], AgentFinish):
break
# This is an append to working memory
messages.extend(self.llm_program.invoke(messages, config=config))
return messages
class MessageAutomaton: # Just a sketch
def __init__(
self,
router: Callable[[Sequence[MessageLike]], Optional[WorkingMemoryProcessor]],
) -> None:
"""Initialize the automaton."""
self.router = router
def run(
self,
messages: Sequence[MessageLike],
*,
config: Optional[RunnableConfig] = None,
max_iterations: int = 10,
) -> List[MessageLike]:
"""Run the automaton."""
new_messages = list(messages)
for _ in range(max_iterations):
runnable = self.router(new_messages)
if not runnable:
break
new_messages.extend(runnable.invoke(new_messages, config=config))
return new_messages

View File

@@ -0,0 +1,67 @@
from __future__ import annotations
import json
from typing import List, Sequence
from langchain.automaton.agent import SequentialAgent
from langchain.automaton.prompt_generator import AdapterBasedGenerator
from langchain.automaton.runnables import create_llm_program
from langchain.automaton.typedefs import (
AgentFinish, FunctionCallRequest, FunctionCallResponse,
)
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
from langchain.schema import AIMessage, FunctionMessage, Generation
from langchain.schema.language_model import (
BaseLanguageModel, LanguageModelInput, LanguageModelOutput,
)
from langchain.schema.output_parser import BaseGenerationOutputParser
from langchain.schema.runnable import Runnable
from langchain.tools import BaseTool, format_tool_to_openai_function
class OpenAIFunctionsParser(BaseGenerationOutputParser):
def parse_result(self, result: List[Generation]):
if len(result) != 1:
raise AssertionError(f"Expected exactly one result")
first_result = result[0]
message = first_result.message
if not isinstance(message, AIMessage) or not message.additional_kwargs:
return AgentFinish(result=message)
parser = JsonOutputFunctionsParser(strict=False, args_only=False)
try:
function_request = parser.parse_result(result)
except Exception as e:
raise RuntimeError(f"Error parsing result: {result} {repr(e)}") from e
return FunctionCallRequest(
name=function_request["name"],
named_arguments=function_request["arguments"],
)
def create_openai_agent(
llm: BaseLanguageModel[LanguageModelInput, LanguageModelOutput]
| Runnable[LanguageModelInput, LanguageModelOutput],
tools: Sequence[BaseTool],
) -> SequentialAgent:
"""Create an agent that uses OpenAI's API."""
openai_funcs = [format_tool_to_openai_function(tool_) for tool_ in tools]
prompt_generator = AdapterBasedGenerator(
msg_adapters={
FunctionCallResponse: lambda message: FunctionMessage(
name=message.name, content=json.dumps(message.result)
),
# No need to translate function call requests
},
)
llm_program = create_llm_program(
llm.bind(functions=openai_funcs),
prompt_generator=prompt_generator,
tools=tools,
parser=OpenAIFunctionsParser(),
)
return SequentialAgent(llm_program, memory_processor=None)

View File

@@ -0,0 +1,113 @@
"""Generalized chat agent, works with any chat model."""
from __future__ import annotations
from typing import (
Iterator,
Optional,
Sequence,
TypeVar,
List,
)
from langchain.automaton.processors import WorkingMemoryManager
from langchain.automaton.runnables import create_retriever
from langchain.automaton.typedefs import (
Agent,
AgentFinish,
MessageLike,
RetrievalRequest,
RetrievalResponse,
)
from langchain.schema import BaseRetriever
from langchain.schema.messages import BaseMessage, HumanMessage, AIMessage
from langchain.schema.runnable import (
Runnable,
RunnableConfig,
)
T = TypeVar("T")
def prompt_generator(input_messages: Sequence[MessageLike]) -> List[BaseMessage]:
"""Generate a prompt from a log of message like objects."""
messages = []
for message in input_messages:
if isinstance(message, BaseMessage):
messages.append(message)
elif isinstance(message, RetrievalResponse):
prompt = ""
if message.results:
for idx, doc in enumerate(message.results):
prompt += f"--- Result {idx} ---\n"
prompt += "Text:\n"
prompt += f"{doc.page_content}"
for field in ["title", "description", "source"]:
if field in doc.metadata:
prompt += f"{doc.metadata[field]}\n"
prompt += f"--- End Result {idx} ---\n"
else:
prompt = "Found no results for the query."
messages.append(
HumanMessage(
content="Context: <result>\n" + prompt + "\n</result>",
)
)
else:
pass
return messages
class RagAgent(Agent):
"""A generalized chat agent."""
def __init__(
self,
llm_program: Runnable[Sequence[MessageLike], List[MessageLike]],
retriever: Optional[BaseRetriever] = None,
memory_manager: WorkingMemoryManager = None,
) -> None:
"""Initialize the chat agent."""
self.llm_program = llm_program
self.retriever = create_retriever(retriever)
self.memory_manager = memory_manager
def step(
self,
messages: Sequence[MessageLike],
*,
config: Optional[RunnableConfig] = None,
) -> List[MessageLike]:
"""Take a single step with the agent."""
last_message = messages[-1] if messages else None
if not last_message:
return []
match last_message:
case AIMessage():
return []
case AgentFinish():
return []
case HumanMessage():
return [RetrievalRequest(query=last_message.content)]
case RetrievalRequest():
return [self.retriever.invoke(last_message, config=config)]
case _:
return self.llm_program.invoke(messages, config=config)
def run(
self,
messages: Sequence[MessageLike],
*,
config: Optional[dict] = None,
max_iterations: int = 100,
) -> Iterator[MessageLike]:
"""Run the agent."""
all_messages = list(messages)
for _ in range(max_iterations):
all_messages = self.memory_manager.process(all_messages)
new_messages = self.step(all_messages, config=config)
if not new_messages:
break
yield from new_messages
all_messages.extend(new_messages)

View File

@@ -0,0 +1,230 @@
"""Implementation of a think act agent.
Uses priming messages with text based LLM.
"""
from __future__ import annotations
import ast
import re
from typing import Iterator, List, Optional, Sequence, Union
from langchain.automaton.runnables import (
create_llm_program,
)
from langchain.automaton.tool_utils import generate_tool_info
from langchain.automaton.typedefs import (
AdHocMessage,
Agent,
AgentFinish,
FunctionCallRequest,
FunctionCallResponse,
MessageLike,
)
from langchain.prompts import SystemMessagePromptTemplate
from langchain.schema import (
AIMessage,
BaseMessage,
HumanMessage,
PromptValue,
SystemMessage,
)
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import BaseTool, Tool
TEMPLATE_ = """\
Respond to the human as helpfully and accurately as possible. \
You have access to the following tools:
{tools_description}
Use a blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
Valid "action" values: "Final Answer" or {tool_names}
You can only use a single tool at a time.
<action>
{{
"action": $TOOL_NAME,
"action_input": $INPUT
}}
</action>
Follow this format:
Question: input question to answer
Thought: consider previous and subsequent steps
<action>
$BLOB
</action>
Observation: action result
... (repeat Thought/Action/Observation N times)
Thought: I know what to respond
<action>
{{
"action": "Final Answer",
"action_input": "Final response to human"
}}
</action>
Begin:
Reminder to ALWAYS respond with a valid blob of a single action. \
Use tools if necessary. Respond directly if appropriate. \
Format is <action>$BLOB</action> then Observation.\
"""
def generate_memory(tools: Sequence[Tool]) -> List[MessageLike]:
"""Set up basic memory for agent."""
tools_info = generate_tool_info(tools)
return [SystemMessagePromptTemplate.from_template(TEMPLATE_).format(**tools_info)]
class ActionParser:
"""A utility class to encode and decode action blocks."""
def __init__(self) -> None:
"""Initialize the ActionParser."""
self.pattern = re.compile(r"<action>(?P<action_blob>.*?)<\/action>", re.DOTALL)
def decode(self, text: Union[BaseMessage, str]) -> Optional[MessageLike]:
"""Decode the action."""
if isinstance(text, BaseMessage):
text = text.content
match = self.pattern.search(text)
if match:
action_blob = match.group("action_blob")
try:
data = ast.literal_eval(action_blob)
except SyntaxError:
return AdHocMessage(
type="error", data=f"Invalid action blob {action_blob}"
)
name = data["action"]
if name == "Final Answer": # Special cased "tool" for final answer
return AgentFinish(result=data["action_input"])
action_input = data["action_input"]
if isinstance(action_input, str) and not action_input:
named_arguments = {}
elif isinstance(action_input, dict):
named_arguments = action_input
else:
return AdHocMessage(
type="error",
data=f"Invalid action blob {action_blob}, action_input must be a dict",
)
return FunctionCallRequest(
name=data["action"], named_arguments=named_arguments or {}
)
else:
return None
class ThinkActPromptGenerator(PromptValue):
"""think-act paradigm
prompt value can do additional priming of the LLM, used in to_string method right now.
"""
messages: Sequence[MessageLike]
class Config:
arbitrary_types_allowed = True
def to_string(self) -> str:
"""The string variant of the prompt."""
finalized = []
messages = self.messages
for idx, message in enumerate(messages):
if isinstance(message, AdHocMessage):
if message.type == "error":
finalized.extend(
[
f"Error: Malformed <action> blob with error: {message.data}. "
f"Please re-write the action correctly."
"\n",
]
)
continue
if message.type != "prime":
raise AssertionError()
component = message.data
finalized.append(component)
continue
if isinstance(message, FunctionCallResponse):
component = f"Observation: {message.result}"
elif isinstance(message, HumanMessage):
component = f"Question: {message.content.strip()}"
elif isinstance(message, (AIMessage, SystemMessage)):
component = message.content.strip()
elif isinstance(message, FunctionCallRequest):
# This is an internal message, and should not be returned to the user.
continue
elif isinstance(message, AgentFinish):
component = f"Answer: {message.result}"
else:
raise NotImplementedError()
finalized.extend([component, "\n"])
return "".join(finalized).strip()
def to_messages(self) -> List[BaseMessage]:
"""Return prompt as a list of Messages."""
messages = []
for message in self.messages:
if isinstance(message, BaseMessage):
messages.append(message)
elif isinstance(message, FunctionCallResponse):
messages.append(
SystemMessage(content=f"Observation: `{message.result}`")
)
return messages
@classmethod
def from_message_log(cls, messages: Sequence[MessageLike]):
"""Create a ThinkActPromptGenerator from a MessageLog, allowing positional arguments."""
return cls(messages=messages)
class ThinkActAgent(Agent):
def __init__(
self,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
*,
max_iterations: int = 10,
) -> None:
"""Initialize the chat automaton."""
self.think_act = create_llm_program(
llm,
prompt_generator=ThinkActPromptGenerator.from_message_log,
stop=["Observation:", "observation:"],
parser=ActionParser().decode,
tools=tools,
invoke_tools=True,
)
self.max_iterations = max_iterations
def run(
self,
messages: Sequence[MessageLike],
*,
config: Optional[dict] = None,
max_iterations: int = 100,
) -> Iterator[MessageLike]:
"""Run the agent."""
all_messages = list(messages)
for _ in range(max_iterations):
if all_messages and isinstance(all_messages[-1], AgentFinish):
break
if all_messages and isinstance(
all_messages[-1], (FunctionCallResponse, HumanMessage)
):
all_messages.append(AdHocMessage(type="prime", data="Thought:"))
new_messages = self.think_act.invoke(all_messages, config=config)
yield from new_messages
all_messages.extend(new_messages)

View File

@@ -0,0 +1,105 @@
from __future__ import annotations
import ast
import re
from typing import List, Sequence, Union, Optional
from langchain.automaton.agent import SequentialAgent
from langchain.automaton.agent import WorkingMemoryProcessor
from langchain.automaton.prompt_generator import AdapterBasedGenerator
from langchain.automaton.runnables import create_llm_program
from langchain.automaton.tool_utils import generate_tool_info
from langchain.automaton.typedefs import (
AgentFinish,
FunctionCallRequest,
FunctionCallResponse,
MessageLike,
)
from langchain.prompts import SystemMessagePromptTemplate
from langchain.schema import BaseMessage, HumanMessage
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import BaseTool
TEMPLATE_ = SystemMessagePromptTemplate.from_template(
"""Respond to the human as helpfully and accurately as \
possible. You have access to the following tools:
{tools_description}
Use a blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
Valid "action" values: "Final Answer" or {tool_names}
Provide only ONE action per $BLOB, as shown.
<action>
{{
"action": $TOOL_NAME,
"action_input": $INPUT
}}
</action>
When invoking a tool do not provide any clarifying information.
The human will forward results of tool invocations as "Observations".
When you know the answer paraphrase the information in the observations properly and respond to the user. \
If you do not know the answer use more tools.
You can only take a single action at a time."""
)
def get_start_state(tools: Sequence[BaseTool]) -> List[BaseMessage]:
"""Generate a prompt for the agent."""
tool_info = generate_tool_info(tools)
msg = TEMPLATE_.format(**tool_info)
return [msg]
def _decode(text: Union[BaseMessage, str]) -> MessageLike:
"""Decode the action."""
pattern = re.compile(r"<action>(?P<action_blob>.*?)<\/action>", re.DOTALL)
if not isinstance(text, BaseMessage):
raise NotImplementedError()
_text = text.content
match = pattern.search(_text)
if match:
action_blob = match.group("action_blob")
data = ast.literal_eval(action_blob)
name = data["action"]
if name == "Final Answer": # Special cased "tool" for final answer
return AgentFinish(result=data["action_input"])
return FunctionCallRequest(
name=data["action"], named_arguments=data["action_input"] or {}
)
else:
return AgentFinish(result=text)
# PUBLIC API
def create_xml_agent(
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
memory_processor: Optional[WorkingMemoryProcessor] = None,
) -> SequentialAgent:
"""XML based chat agent."""
prompt_generator = AdapterBasedGenerator(
msg_adapters={
FunctionCallResponse: lambda msg: HumanMessage(
content=f"Observation: {msg.result}"
),
}
)
llm_program = create_llm_program(
llm,
prompt_generator=prompt_generator,
tools=tools,
parser=_decode,
)
return SequentialAgent(
llm_program,
memory_processor=memory_processor,
)

View File

@@ -0,0 +1,325 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "a91c50ff-bcb9-484b-a4bd-f56e75c182ba",
"metadata": {},
"source": [
"# General Chat Agent w/ XML like encoding"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "80595f69-8b46-494e-bd2f-0c9538ed526a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"import sys\n",
"sys.path.insert(0, '../../../')"
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "724c11c6-c98a-47ef-8697-87cd888df7a0",
"metadata": {},
"outputs": [],
"source": [
"from typing import List, Optional, Any\n",
"from langchain.tools import tool, Tool\n",
"from langchain.schema.messages import AIMessage, HumanMessage\n",
"from langchain.chat_models.anthropic import ChatAnthropic\n",
"from langchain.automaton.agent_implementations import xml_agent"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "ed3855cd-06b0-4a45-8b76-88926433fc60",
"metadata": {},
"outputs": [],
"source": [
"def get_tools() -> List[Tool]:\n",
" @tool\n",
" def name() -> str:\n",
" \"\"\"Use to look up the user's name\"\"\"\n",
" return \"Eugene\"\n",
"\n",
" @tool\n",
" def get_weather(city: str) -> str:\n",
" \"\"\"Get weather in a specific city.\"\"\"\n",
" return \"42F and sunny\"\n",
"\n",
" @tool\n",
" def add(x: int, y: int) -> int:\n",
" \"\"\"Use to add two numbers. For example; { \"x\": 2, \"y\": 10}\"\"\"\n",
" return x + y\n",
"\n",
" @tool\n",
" def add_3(x: int, y: int, z: int) -> int:\n",
" \"\"\"Add 3 numbers together.\"\"\"\n",
" return x + y + z\n",
"\n",
" @tool\n",
" def divide(x: float, y: float) -> float:\n",
" \"\"\"Divide x by y\"\"\"\n",
" return x / y\n",
"\n",
" return list(locals().values())\n",
"\n",
"\n",
"tools = get_tools()"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "97a59c88-2f0a-4982-81e3-4ffba4f222da",
"metadata": {},
"outputs": [],
"source": [
"llm = ChatAnthropic(temperature=0)"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "00177716-422e-436c-b07d-b70b2ccb5839",
"metadata": {},
"outputs": [],
"source": [
"start_messages = xml_agent.get_start_state(tools)"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "778b2eda-5128-40ea-949f-9727610cc194",
"metadata": {},
"outputs": [],
"source": [
"agent = xml_agent.create_xml_agent(llm, tools)"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "3bfcb9a7-af38-4b7a-87cc-84a842d6c072",
"metadata": {},
"outputs": [],
"source": [
"question = HumanMessage(\n",
" content=\"Sum of first 10 numbers starting from 1? use one tool at a time\"\n",
")\n",
"messages = start_messages + [question]"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "29ec7747-5cf7-49e3-b39e-23cf0a0818f3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ai: <action>\n",
"{\n",
" \"action\": \"add\", \n",
" \"action_input\": {\n",
" \"x\": 1,\n",
" \"y\": 2\n",
" } \n",
"}\n",
"</action>\n",
"FunctionCall(name=add, named_arguments={'x': 1, 'y': 2})\n",
"FunctionResult(name=add, result=3, error=None)\n",
"ai: <action>\n",
"{\n",
" \"action\": \"add\",\n",
" \"action_input\": {\n",
" \"x\": 3, \n",
" \"y\": 3\n",
" }\n",
"} \n",
"</action>\n",
"FunctionCall(name=add, named_arguments={'x': 3, 'y': 3})\n",
"FunctionResult(name=add, result=6, error=None)\n",
"ai: <action>\n",
"{\n",
" \"action\": \"add\",\n",
" \"action_input\": {\n",
" \"x\": 6,\n",
" \"y\": 4 \n",
" }\n",
"}\n",
"</action>\n",
"FunctionCall(name=add, named_arguments={'x': 6, 'y': 4})\n",
"FunctionResult(name=add, result=10, error=None)\n",
"ai: <action>\n",
"{\n",
" \"action\": \"add\",\n",
" \"action_input\": {\n",
" \"x\": 10, \n",
" \"y\": 5\n",
" }\n",
"}\n",
"</action> \n",
"FunctionCall(name=add, named_arguments={'x': 10, 'y': 5})\n",
"FunctionResult(name=add, result=15, error=None)\n",
"ai: <action>\n",
"{\n",
" \"action\": \"add\",\n",
" \"action_input\": {\n",
" \"x\": 15,\n",
" \"y\": 6\n",
" }\n",
"} \n",
"</action>\n",
"FunctionCall(name=add, named_arguments={'x': 15, 'y': 6})\n",
"FunctionResult(name=add, result=21, error=None)\n",
"ai: <action>\n",
"{\n",
" \"action\": \"add\",\n",
" \"action_input\": {\n",
" \"x\": 21,\n",
" \"y\": 7\n",
" }\n",
"}\n",
"</action>\n",
"FunctionCall(name=add, named_arguments={'x': 21, 'y': 7})\n",
"FunctionResult(name=add, result=28, error=None)\n",
"ai: <action>\n",
"{\n",
" \"action\": \"add\",\n",
" \"action_input\": {\n",
" \"x\": 28, \n",
" \"y\": 8\n",
" }\n",
"}\n",
"</action>\n",
"FunctionCall(name=add, named_arguments={'x': 28, 'y': 8})\n",
"FunctionResult(name=add, result=36, error=None)\n",
"ai: <action>\n",
"{\n",
" \"action\": \"add\",\n",
" \"action_input\": {\n",
" \"x\": 36,\n",
" \"y\": 9 \n",
" }\n",
"} \n",
"</action>\n",
"FunctionCall(name=add, named_arguments={'x': 36, 'y': 9})\n",
"FunctionResult(name=add, result=45, error=None)\n",
"ai: <action>\n",
"{\n",
" \"action\": \"add\",\n",
" \"action_input\": {\n",
" \"x\": 45,\n",
" \"y\": 10\n",
" }\n",
"}\n",
"</action>\n",
"FunctionCall(name=add, named_arguments={'x': 45, 'y': 10})\n",
"FunctionResult(name=add, result=55, error=None)\n",
"ai: <action>\n",
"{\n",
" \"action\": \"Final Answer\",\n",
" \"action_input\": {\n",
" \"answer\": \"The sum of the first 10 numbers starting from 1 is 55.\"\n",
" }\n",
"} \n",
"</action>\n",
"AgentFinish(result={'answer': 'The sum of the first 10 numbers starting from 1 is 55.'})\n"
]
}
],
"source": [
"for msg in agent.run(messages):\n",
" print(msg)"
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "a3a4eeca-9126-4d97-902d-fa4564563130",
"metadata": {},
"outputs": [
{
"ename": "ImportError",
"evalue": "cannot import name 'prompt_generator' from 'langchain.automaton.chat_agent' (/home/eugene/src/langchain/libs/langchain/langchain/automaton/notebooks/../../../langchain/automaton/chat_agent.py)",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[40], line 4\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Visualize\u001b[39;00m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mhtml\u001b[39;00m\n\u001b[0;32m----> 4\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mlangchain\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mautomaton\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mchat_agent\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m prompt_generator\n\u001b[1;32m 6\u001b[0m messages \u001b[38;5;241m=\u001b[39m prompt_generator(message_log)\n\u001b[1;32m 8\u001b[0m \u001b[38;5;66;03m# Generate HTML divs\u001b[39;00m\n",
"\u001b[0;31mImportError\u001b[0m: cannot import name 'prompt_generator' from 'langchain.automaton.chat_agent' (/home/eugene/src/langchain/libs/langchain/langchain/automaton/notebooks/../../../langchain/automaton/chat_agent.py)"
]
}
],
"source": [
"# # Visualize\n",
"\n",
"# import html\n",
"# from langchain.automaton.chat_agent import prompt_generator\n",
"\n",
"# messages = prompt_generator(message_log)\n",
"\n",
"# # Generate HTML divs\n",
"# html_divs = []\n",
"# for idx, msg in enumerate(messages):\n",
"# background_color = \"DodgerBlue\" if idx % 2 == 0 else \"Tomato\"\n",
"# div_content = f\"\"\"\n",
"# <div style=\"border: 1px solid black; padding: 10px; background-color:{background_color}\">\n",
"# <div style=\"display: flex;\">\n",
"# <div style=\"font-weight: bold; margin-right: 10px;\">{html.escape(msg.type)}</div>\n",
"# <div>{html.escape(msg.content)}</div>\n",
"# </div>\n",
"# </div>\n",
"# \"\"\"\n",
"# html_divs.append(div_content)\n",
"\n",
"# # Display HTML divs in a Jupyter Notebook cell\n",
"# from IPython.display import HTML, display\n",
"\n",
"# html_output = \"\\n\".join(html_divs)\n",
"# display(HTML(html_output))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,167 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "a91c50ff-bcb9-484b-a4bd-f56e75c182ba",
"metadata": {},
"source": [
"# Open AI Chat"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5d77b586-fced-4a34-b778-b05141ab32bc",
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"import sys\n",
"sys.path.insert(0, '../../../')"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "f73ed8ca-7384-4afc-af8c-bfd5b794df01",
"metadata": {},
"outputs": [],
"source": [
"from typing import List, Optional, Any\n",
"from langchain.tools import tool, Tool\n",
"from langchain.automaton.agent_implementations.openai_agent import create_openai_agent\n",
"from langchain.schema.messages import HumanMessage\n",
"from langchain.chat_models.openai import ChatOpenAI"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "ed3855cd-06b0-4a45-8b76-88926433fc60",
"metadata": {},
"outputs": [],
"source": [
"def get_tools() -> List[Tool]:\n",
" @tool\n",
" def name() -> str:\n",
" \"\"\"Use to look up the user's name\"\"\"\n",
" return \"Eugene\"\n",
"\n",
" @tool\n",
" def get_weather(city: str) -> str:\n",
" \"\"\"Get weather in a specific city.\"\"\"\n",
" return \"42F and sunny\"\n",
"\n",
" @tool\n",
" def add(x: int, y: int) -> int:\n",
" \"\"\"Use to add two numbers. For example; { \"x\": 2, \"y\": 10} \"\"\"\n",
" return x + y\n",
"\n",
" @tool\n",
" def add_3(x: int, y: int, z: int) -> int:\n",
" \"\"\"Add 3 numbers together.\"\"\"\n",
" return x + y + z\n",
"\n",
" @tool\n",
" def divide(x: float, y: float) -> float:\n",
" \"\"\"Divide x by y\"\"\"\n",
" return x / y\n",
"\n",
" return list(locals().values())\n",
"\n",
"tools = get_tools()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "a0790c55-eacb-4e0c-970b-5784d62ef5c2",
"metadata": {},
"outputs": [],
"source": [
"llm = ChatOpenAI(temperature=0)\n",
"agent = create_openai_agent(llm, tools)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "e7d1ec69-de12-4c66-abee-3b2bbabeec13",
"metadata": {},
"outputs": [],
"source": [
"a = agent.run([HumanMessage(content=\"Sum of first 10 numbers starting from 1? use one tool at a time\")])"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "75c6cc2c-7604-44c0-b8cf-4513620782f9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ai: {'function_call': {'name': 'add', 'arguments': '{\\n \"x\": 1,\\n \"y\": 2\\n}'}}\n",
"FunctionCall(name=add, named_arguments={'x': 1, 'y': 2})\n",
"FunctionResult(name=add, result=3, error=None)\n",
"ai: {'function_call': {'name': 'add', 'arguments': '{\\n \"x\": 3,\\n \"y\": 3\\n}'}}\n",
"FunctionCall(name=add, named_arguments={'x': 3, 'y': 3})\n",
"FunctionResult(name=add, result=6, error=None)\n",
"ai: {'function_call': {'name': 'add', 'arguments': '{\\n \"x\": 6,\\n \"y\": 4\\n}'}}\n",
"FunctionCall(name=add, named_arguments={'x': 6, 'y': 4})\n",
"FunctionResult(name=add, result=10, error=None)\n",
"ai: {'function_call': {'name': 'add', 'arguments': '{\\n \"x\": 10,\\n \"y\": 5\\n}'}}\n",
"FunctionCall(name=add, named_arguments={'x': 10, 'y': 5})\n",
"FunctionResult(name=add, result=15, error=None)\n",
"ai: {'function_call': {'name': 'add', 'arguments': '{\\n \"x\": 15,\\n \"y\": 6\\n}'}}\n",
"FunctionCall(name=add, named_arguments={'x': 15, 'y': 6})\n",
"FunctionResult(name=add, result=21, error=None)\n",
"ai: {'function_call': {'name': 'add', 'arguments': '{\\n \"x\": 21,\\n \"y\": 7\\n}'}}\n",
"FunctionCall(name=add, named_arguments={'x': 21, 'y': 7})\n",
"FunctionResult(name=add, result=28, error=None)\n",
"ai: {'function_call': {'name': 'add', 'arguments': '{\\n \"x\": 28,\\n \"y\": 8\\n}'}}\n",
"FunctionCall(name=add, named_arguments={'x': 28, 'y': 8})\n",
"FunctionResult(name=add, result=36, error=None)\n",
"ai: {'function_call': {'name': 'add', 'arguments': '{\\n \"x\": 36,\\n \"y\": 9\\n}'}}\n",
"FunctionCall(name=add, named_arguments={'x': 36, 'y': 9})\n",
"FunctionResult(name=add, result=45, error=None)\n",
"ai: {'function_call': {'name': 'add', 'arguments': '{\\n \"x\": 45,\\n \"y\": 10\\n}'}}\n",
"FunctionCall(name=add, named_arguments={'x': 45, 'y': 10})\n",
"FunctionResult(name=add, result=55, error=None)\n",
"ai: The sum of the first 10 numbers starting from 1 is 55.\n",
"AgentFinish(result=ai: The sum of the first 10 numbers starting from 1 is 55.)\n"
]
}
],
"source": [
"for msg in a:\n",
" print(msg)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,222 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "9875f554-8f4d-409b-929f-00e8bc70185d",
"metadata": {},
"source": [
"# Think Act Agent"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "80595f69-8b46-494e-bd2f-0c9538ed526a",
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"import sys\n",
"sys.path.insert(0, '../../../')"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "724c11c6-c98a-47ef-8697-87cd888df7a0",
"metadata": {},
"outputs": [],
"source": [
"from typing import List, Optional, Any\n",
"from langchain.tools import tool, Tool\n",
"from langchain.schema.messages import AIMessage, HumanMessage\n",
"from langchain.chat_models.anthropic import ChatAnthropic\n",
"from langchain.llms import OpenAI\n",
"from langchain.automaton.agent_implementations import think_act_agent"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "ed3855cd-06b0-4a45-8b76-88926433fc60",
"metadata": {},
"outputs": [],
"source": [
"def get_tools() -> List[Tool]:\n",
" @tool\n",
" def name() -> str:\n",
" \"\"\"Use to look up the user's name\"\"\"\n",
" return \"Eugene\"\n",
"\n",
" @tool\n",
" def get_weather(city: str) -> str:\n",
" \"\"\"Get weather in a specific city.\"\"\"\n",
" return \"42F and sunny\"\n",
"\n",
" @tool\n",
" def add(x: int, y: int) -> int:\n",
" \"\"\"Use to add two numbers. For example; { \"x\": 2, \"y\": 10}\"\"\"\n",
" return x + y\n",
"\n",
" @tool\n",
" def divide(x: float, y: float) -> float:\n",
" \"\"\"Divide x by y\"\"\"\n",
" return x / y\n",
"\n",
" return list(locals().values())\n",
"\n",
"\n",
"tools = get_tools()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "f6abff39-5243-45d1-a250-a862bbd665ff",
"metadata": {},
"outputs": [],
"source": [
"llm = OpenAI(temperature=0)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "d416bb0e-28e4-4da3-8b10-aa563f87a48e",
"metadata": {},
"outputs": [],
"source": [
"agent = think_act_agent.ThinkActAgent(llm, tools)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "f9c1ed10-309b-47a1-b22e-153e53c3a86e",
"metadata": {},
"outputs": [],
"source": [
"start_messages = think_act_agent.generate_memory(tools)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "04488a11-6ef2-428b-b1ae-7800b424480a",
"metadata": {},
"outputs": [],
"source": [
"from langchain.callbacks.manager import (\n",
" trace_as_chain_group,\n",
" atrace_as_chain_group,\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "32a3ce50-e993-4f29-b2f8-02261d417e24",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ai: I need to use the divide tool and the get_weather tool\n",
"<action>\n",
"{\n",
" \"action\": \"divide\",\n",
" \"action_input\": { \"x\": 12381237, \"y\": 23 }\n",
"}\n",
"</action>\n",
"\n",
"FunctionCall(name=divide, named_arguments={'x': 12381237, 'y': 23})\n",
"FunctionResult(name=divide, result=538314.6521739131, error=None)\n",
"ai: I need to use the name tool and the get_weather tool\n",
"<action>\n",
"{\n",
" \"action\": \"name\",\n",
" \"action_input\": \"\"\n",
"}\n",
"</action>\n",
"\n",
"FunctionCall(name=name, named_arguments={})\n",
"FunctionResult(name=name, result=Eugene, error=None)\n",
"ai: I need to use the get_weather tool\n",
"<action>\n",
"{\n",
" \"action\": \"get_weather\",\n",
" \"action_input\": \"New York\"\n",
"}\n",
"</action>\n",
"\n",
"AdHocMessage(type=error, data=Invalid action blob \n",
"{\n",
" \"action\": \"get_weather\",\n",
" \"action_input\": \"New York\"\n",
"}\n",
", action_input must be a dict)\n",
"ai: \n",
"Thought:I need to use the get_weather tool\n",
"<action>\n",
"{\n",
" \"action\": \"get_weather\",\n",
" \"action_input\": {\"city\": \"New York\"}\n",
"}\n",
"</action>\n",
"\n",
"FunctionCall(name=get_weather, named_arguments={'city': 'New York'})\n",
"FunctionResult(name=get_weather, result=42F and sunny, error=None)\n",
"ai: I know what to respond\n",
"<action>\n",
"{\n",
" \"action\": \"Final Answer\",\n",
" \"action_input\": \"Hi Eugene, the answer to 12381237/23 is 538314.6521739131. The weather in New York is 42F and sunny.\"\n",
"}\n",
"</action>\n",
"AgentFinish(result=Hi Eugene, the answer to 12381237/23 is 538314.6521739131. The weather in New York is 42F and sunny.)\n"
]
}
],
"source": [
"question = HumanMessage(content=\"what is 12381237/23? What is my name? And what is the weather in new york?\")\n",
"messages = start_messages + [question]\n",
"\n",
"with trace_as_chain_group(\"think_act_agent\") as group_manager:\n",
" for msg in agent.run(messages, config={\"callbacks\": group_manager}):\n",
" print(msg)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "248d63ee-1ec4-49cb-9e6b-84e75a7644ca",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,62 @@
"""Code for processing the working memory."""
import abc
from typing import List, Sequence, Callable
from langchain.schema.messages import get_buffer_string, SystemMessage
from langchain.automaton.typedefs import MessageLike
from langchain.automaton.prompt_generator import PromptGenerator
class WorkingMemoryManager(abc.ABC):
@abc.abstractmethod
def process(self, messages: Sequence[MessageLike]) -> List[MessageLike]:
"""Process the working memory returning a potentially new working memory."""
class TokenBufferProcessor(WorkingMemoryManager):
"""Trim the working memory to a maximum context length (aka token limit)."""
def __init__(
self,
max_token_limit: int,
token_counter: Callable[[str], int],
prompt_generator: PromptGenerator,
skip_system_messages: bool = True,
) -> None:
"""Token counter."""
self.token_counter = token_counter
self.max_token_limit = max_token_limit
self.prompt_generator = prompt_generator
self.skip_system_messages = skip_system_messages
def process(self, messages: Sequence[MessageLike]) -> List[MessageLike]:
"""Update the working memory with the given messages."""
# Work backwards from the end of the buffer dropping messages until
messages = list(messages)
curr_buffer_length = self.count_across_messages(messages)
idx = 0
while curr_buffer_length > self.max_token_limit:
if idx >= len(messages):
raise AssertionError("No messages left in buffer")
if isinstance(messages[idx], SystemMessage):
idx += 1
continue
messages.pop(idx) # Drop the first message
if not messages:
raise AssertionError("No messages left in buffer")
curr_buffer_length = self.count_across_messages(messages)
return messages
def count_across_messages(self, messages: Sequence[MessageLike]) -> int:
"""Count the number of tokens across messages."""
if isinstance(self.prompt_generator, PromptGenerator):
base_messages = self.prompt_generator.to_messages(messages)
else:
base_messages = self.prompt_generator(messages)
buffer_string = get_buffer_string(base_messages)
return self.token_counter(buffer_string)

View File

@@ -0,0 +1,92 @@
"""Prompt generation for the automaton."""
from __future__ import annotations
import abc
from typing import Mapping, Any, Callable, List, Sequence, Optional, Union
from langchain.automaton.typedefs import MessageLike
from langchain.schema import BaseMessage, PromptValue
class BoundPromptValue(PromptValue):
"""A prompt value that is bound to a specific value."""
as_string: Callable[[], str]
as_messages: Callable[[], List[BaseMessage]]
def to_string(self) -> str:
"""Return prompt value as string."""
return self.as_string()
def to_messages(self) -> List[BaseMessage]:
"""Return prompt as a list of Messages."""
return self.as_messages()
class PromptGenerator(abc.ABC):
@abc.abstractmethod
def to_messages(
self, original_messages: Sequence[MessageLike]
) -> List[BaseMessage]:
"""Generate a prompt from message like objects."""
@abc.abstractmethod
def to_string(self, original_messages: Sequence[MessageLike]) -> str:
"""Generate a prompt from message like objects."""
def to_prompt_value(self, original_messages: Sequence[MessageLike]) -> PromptValue:
"""Generate a prompt from message like objects."""
return BoundPromptValue(
as_string=lambda: self.to_string(original_messages),
as_messages=lambda: self.to_messages(original_messages),
)
class AdapterBasedGenerator(PromptGenerator):
def __init__(
self,
*,
msg_adapters: Optional[
Mapping[Any, Callable[[MessageLike], Union[BaseMessage, List[BaseMessage]]]]
] = None,
str_adapters: Optional[Mapping[Any, Callable[[MessageLike], str]]] = None,
) -> None:
"""Initialize the adapter based generator."""
self.msg_adapters = msg_adapters or {}
self.str_adapters = str_adapters or {}
def to_messages(self, messages: Sequence[MessageLike]) -> List[BaseMessage]:
"""Generate a prompt from message like objects."""
new_messages = []
for original_message in messages:
adapter = self.msg_adapters.get(type(original_message), None)
if adapter:
translated = adapter(original_message)
if isinstance(translated, BaseMessage):
new_messages.append(translated)
else:
new_messages.extend(translated)
continue
if isinstance(original_message, BaseMessage):
# Only adds BaseMessages by default,
# internal messages are ignored
new_messages.append(original_message)
return new_messages
def to_string(self, messages: Sequence[MessageLike]) -> str:
"""Generate a prompt from message like objects."""
string_prompts = []
for original_message in messages:
adapter = self.str_adapters.get(type(original_message), None)
if adapter:
string_prompts.extend(adapter(original_message))
continue
else:
raise RuntimeError(
f"String adapter not found for {type(original_message)}"
)
return "\n".join(string_prompts)

View File

@@ -0,0 +1,262 @@
"""Module contains useful runnables for agents."""
from __future__ import annotations
from typing import Any, Callable, Dict, List, Optional, Sequence, TypeVar, Union
from langchain.schema.retriever import BaseRetriever
from langchain.automaton.typedefs import (
FunctionCallRequest,
FunctionCallResponse,
MessageLike,
RetrievalResponse,
RetrievalRequest,
)
from langchain.automaton.prompt_generator import PromptGenerator
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.schema import AIMessage, BaseMessage, PromptValue, Document
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.output_parser import BaseOutputParser
from langchain.schema.runnable import (
Runnable,
RunnableConfig,
RunnableLambda,
RunnableMap,
RunnablePassthrough,
patch_config,
)
from langchain.tools import BaseTool
T = TypeVar("T")
def _to_message(result: Union[BaseMessage, str]) -> BaseMessage:
"""Convert to a list of messages."""
if isinstance(result, BaseMessage):
return result
elif isinstance(result, str):
return AIMessage(content=result)
else:
raise NotImplementedError(f"Unsupported type {type(result)}")
def _to_list(element: Union[None, T, List[T]]) -> List[T]:
"""Convert to a sequence."""
if element is None:
return []
elif isinstance(element, list):
return element
else:
return [element]
def _to_runnable_parser(parser: Optional[BaseOutputParser]) -> Runnable:
"""Adapt a parser to a runnable."""
if parser is None:
# Then create a runnable that returns no messages
return RunnableLambda(lambda *args, **kwargs: None)
elif isinstance(parser, Runnable):
return parser
elif isinstance(parser, Callable):
return RunnableLambda(parser)
else:
raise ValueError(f"Expected BaseOutputParser, got {parser}")
def _concatenate_head_and_tail(intermediate_input: Dict[str, Any]) -> List[BaseMessage]:
"""Concatenate head and tail into a single list."""
head = _to_list(intermediate_input["head"])
tail = _to_list(intermediate_input["tail"])
return head + tail
def _apply_and_concat(
head: Union[Runnable, Callable], tail: Union[Runnable, Callable]
) -> Runnable:
"""Apply head and tail and concatenate the results.
Note: Probably generalize to _apply(funcs) and _concatenate runnables
Args:
head: A runnable or callable
tail: A runnable or callable
Returns:
A runnable that applies head and tail and concatenates the results in order.
"""
head_ = head if isinstance(head, Runnable) else RunnableLambda(head)
tail_ = tail if isinstance(tail, Runnable) else RunnableLambda(tail)
return (
RunnableMap(
steps={
"head": head_,
"tail": tail_,
}
)
| _concatenate_head_and_tail
)
def _to_retriever_input(message: MessageLike) -> str:
"""Convert a message to a retriever input."""
if isinstance(message, str):
return message
elif isinstance(message, BaseMessage):
return message.content
else:
raise NotImplementedError(f"Unsupported type {type(message)}")
# PUBLIC API
def create_tool_invoker(
tools: Sequence[BaseTool],
) -> Runnable[MessageLike, Optional[FunctionCallResponse]]:
"""See if possible to re-write with router
TODO:
* re-write with router
* potentially remove hack replace MessageLike with FunctionCall, requires
a branching runnable
"""
tools_by_name = {tool.name: tool for tool in tools}
def func(
function_call: MessageLike,
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
) -> Optional[FunctionCallResponse]:
"""A function that can invoke a tool using .run"""
if not isinstance(
function_call, FunctionCallRequest
): # TODO(Hack): Workaround lack of conditional apply
return None
try:
tool = tools_by_name[function_call.name]
except KeyError:
raise AssertionError(f"No such tool: {function_call.name}")
try:
result = tool.invoke(
function_call.named_arguments or {},
patch_config(config, callbacks=run_manager.get_child()),
)
error = None
except Exception as e:
result = None
error = repr(e) + repr(function_call.named_arguments)
return FunctionCallResponse(name=function_call.name, result=result, error=error)
async def afunc(
function_call: MessageLike,
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
) -> Optional[FunctionCallResponse]:
"""A function that can invoke a tool using .run"""
if not isinstance(
function_call, FunctionCallRequest
): # TODO(Hack): Workaround lack of conditional apply
return None
try:
tool = tools_by_name[function_call.name]
except KeyError:
raise AssertionError(f"No such tool: {function_call.name}")
try:
result = await tool.ainvoke(
function_call.named_arguments or {},
patch_config(config, callbacks=run_manager.get_child()),
)
error = None
except Exception as e:
result = None
error = repr(e) + repr(function_call.named_arguments)
return FunctionCallResponse(name=function_call.name, result=result, error=error)
return RunnableLambda(func=func, afunc=afunc)
def create_llm_program(
llm: BaseLanguageModel,
prompt_generator: Union[
Callable[
[Sequence[MessageLike]], Union[str, PromptValue, Sequence[BaseMessage]]
],
PromptGenerator,
Runnable,
],
*,
tools: Optional[Sequence[BaseTool]] = None,
stop: Optional[Sequence[str]] = None,
parser: Union[
Runnable[Union[BaseMessage, str], MessageLike],
Callable[[Union[BaseMessage, str]], MessageLike],
BaseOutputParser,
None,
] = None,
invoke_tools: bool = True, # TODO(Eugene): Perhaps remove.
) -> Runnable[Sequence[MessageLike], List[MessageLike]]:
"""Create a runnable that provides a generalized interface to an LLM with actions.
Args:
llm: A language model
prompt_generator: A function that takes a list of messages and returns a prompt
tools: A list of tools to invoke
stop: optional list of stop tokens
parser: optional parser to apply to the output of the LLM
invoke_tools: Whether to invoke tools on the output of the LLM
Returns:
A runnable that returns a list of new messages
"""
if isinstance(prompt_generator, PromptGenerator):
_prompt_generator = RunnableLambda(prompt_generator.to_prompt_value)
elif isinstance(prompt_generator, Runnable):
_prompt_generator = prompt_generator
else: # Otherwise attempt to convert to a runnable lambda
_prompt_generator = RunnableLambda(prompt_generator)
if stop:
llm = llm.bind(stop=stop)
# Add parser to the end of the chain and concatenate original llm output
# with the parser output.
# The parser is always created even if it is None, to make sure that
# the _to_message adapter is always applied (regardless of the parser).
_parser = _to_runnable_parser(parser)
chain = _prompt_generator | llm | _apply_and_concat(_to_message, _parser)
# Add tool invoker to the end of the chain.
if invoke_tools and tools:
tool_invoker = create_tool_invoker(tools)
invoke_on_last = RunnableLambda(lambda msgs: msgs[-1]) | tool_invoker
complete_chain = chain | _apply_and_concat(
RunnablePassthrough(), invoke_on_last
)
else:
complete_chain = chain
return complete_chain
def create_retriever(
base_retriever: BaseRetriever,
) -> Runnable[RetrievalRequest, RetrievalResponse]:
"""Create a runnable retriever that uses messages."""
def _from_retrieval_request(request: RetrievalRequest) -> str:
"""Convert a message to a list of documents."""
return request.query
def _to_retrieval_result(docs: List[Document]) -> RetrievalResponse:
"""Convert a list of documents to a message."""
return RetrievalResponse(results=docs)
return (
RunnableLambda(_from_retrieval_request)
| base_retriever
| RunnableLambda(_to_retrieval_result)
)

View File

@@ -0,0 +1,42 @@
from __future__ import annotations
import langchain.automaton.agent_implementations.xml_agent
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
from langchain.automaton.mrkl_agent import ActionParser
def test_structured_output_chat() -> None:
parser = StructuredChatOutputParser()
output = parser.parse(
"""
```json
{
"action": "hello",
"action_input": {
"a": 2
}
}
```
"""
)
assert output == {}
def test_parser() -> None:
"""Tes the parser."""
sample_text = """
Some text before
<action>
{
"key": "value",
"number": 42
}
</action>
Some text after
"""
action_parser = ActionParser()
action = langchain.automaton.agent_implementations.xml_agent.decode(sample_text)
assert action == {
"key": "value",
"number": 42,
}

View File

@@ -0,0 +1,96 @@
from __future__ import annotations
from typing import List, cast
import pytest
from langchain.automaton.agent_implementations.openai_agent import OpenAIAgent
from langchain.automaton.tests.utils import (
FakeChatModel,
construct_func_invocation_message,
)
from langchain.automaton.typedefs import (
AgentFinish,
FunctionCallRequest,
FunctionCallResponse,
)
from langchain.schema.messages import (
AIMessage,
SystemMessage,
)
from langchain.tools import Tool, tool
from langchain.tools.base import BaseTool
@pytest.fixture()
def tools() -> List[BaseTool]:
@tool
def get_time() -> str:
"""Get time."""
return "9 PM"
@tool
def get_location() -> str:
"""Get location."""
return "the park"
return cast(List[Tool], [get_time, get_location])
def test_openai_agent(tools: List[Tool]) -> None:
get_time, get_location = tools
llm = FakeChatModel(
message_iter=iter(
[
construct_func_invocation_message(get_time, {}),
AIMessage(
content="The time is 9 PM.",
),
]
)
)
agent = OpenAIAgent(llm=llm, tools=tools, max_iterations=10)
message_log = MessageLog(
[
SystemMessage(
content="What time is it?",
)
]
)
expected_messages = [
SystemMessage(
content="What time is it?",
),
AIMessage(
content="",
additional_kwargs={
"function_call": {
"name": "get_time",
"arguments": "{}",
}
},
),
FunctionCallRequest(
name="get_time",
arguments={},
),
FunctionCallResponse(
name="get_time",
result="9 PM",
error=None,
),
AIMessage(
content="The time is 9 PM.",
),
AgentFinish(
AIMessage(
content="The time is 9 PM.",
),
),
]
agent.run(message_log)
assert message_log.messages == expected_messages

View File

@@ -0,0 +1,153 @@
"""Test parsers"""
from typing import Any, List, Optional, cast
import pytest
from langchain.automaton.runnables import (
RunnablePassthrough,
_apply_and_concat,
_to_list,
_to_runnable_parser,
create_llm_program,
)
from langchain.automaton.tests.utils import (
FakeChatModel,
)
from langchain.automaton.typedefs import FunctionCallRequest, FunctionCallResponse, MessageLike
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage
from langchain.schema.runnable import RunnableLambda
from langchain.tools import BaseTool, tool
def test_apply_and_concat() -> None:
"""Test apply and concat."""
msg = HumanMessage(content="Hello")
llm = RunnableLambda(lambda *args, **kwargs: msg)
# Test that it works with a runnable
chain = llm | _apply_and_concat(RunnablePassthrough(), RunnablePassthrough())
assert chain.invoke({}) == [msg, msg]
chain = llm | _apply_and_concat(
lambda msg: msg.content[0], lambda msg: msg.content[1]
)
assert chain.invoke({}) == ["H", "e"]
@pytest.mark.parametrize(
"input_value, expected_output",
[
(None, []), # Test when input is None
([1, 2, 3], [1, 2, 3]), # Test when input is a list of integers
(5, [5]), # Test when input is a single integer
(["a", "b", "c"], ["a", "b", "c"]), # Test when input is a list of strings
("xyz", ["xyz"]), # Test when input is a single string
([], []), # Test when input is an empty list
],
)
def test_to_list(input_value: Any, expected_output: List) -> None:
assert _to_list(input_value) == expected_output
@pytest.mark.parametrize(
"parser, expected_output",
[
(None, None),
(RunnablePassthrough(), AIMessage(content="Hello")),
],
)
def to_runnable_parser(parser: Any, expected_output: Optional[BaseMessage]) -> None:
"""To runnable parser."""
parser_ = _to_runnable_parser(parser)
assert parser_.invoke(AIMessage(content="Hello")) == [expected_output]
@pytest.fixture()
def tools() -> List[BaseTool]:
"""Make a tools fixture."""
@tool
def get_time() -> str:
"""Get time."""
return "9 PM"
@tool
def get_location() -> str:
"""Get location."""
return "the park"
return cast(List[BaseTool], [get_time, get_location])
@pytest.fixture()
def fake_llm() -> BaseLanguageModel:
"""Make a fake chat model."""
llm = FakeChatModel(
message_iter=iter(
[
AIMessage(
content="Hello",
),
]
)
)
return llm
def test_simple_llm_program(fake_llm: BaseLanguageModel) -> None:
"""Test simple llm program with no parser or tools."""
get_time, _ = tools
program = create_llm_program(
fake_llm,
prompt_generator=lambda x: x,
)
assert program.invoke("What time is it?") == [AIMessage(content="Hello")]
def test_llm_program_with_parser(fake_llm: BaseLanguageModel) -> None:
"""Test simple llm program with no parser or tools."""
parser = RunnableLambda(lambda msg: AIMessage(content=msg.content + " parsed"))
program = create_llm_program(
fake_llm,
prompt_generator=lambda x: x,
parser=parser,
)
assert program.invoke("What time is it?") == [
AIMessage(content="Hello"),
AIMessage(content="Hello parsed"),
]
@pytest.mark.parametrize(
"parser, output",
[
(
RunnableLambda(lambda msg: AIMessage(content="Goodbye")),
[AIMessage(content="Hello"), AIMessage(content="Goodbye")],
),
(
RunnableLambda(lambda msg: FunctionCallRequest(name="get_time")),
[
AIMessage(content="Hello"),
FunctionCallRequest(name="get_time"),
FunctionCallResponse(result="9 PM", name="get_time"),
],
),
],
)
def test_llm_program_with_parser_and_tools(
tools: List[BaseTool],
fake_llm: BaseLanguageModel,
parser: Any,
output: List[MessageLike],
) -> None:
"""Test simple llm program with no parser or tools."""
program = create_llm_program(
fake_llm,
prompt_generator=lambda x: x,
parser=parser,
tools=tools,
invoke_tools=True,
)
assert program.invoke("What time is it?") == output

View File

@@ -0,0 +1,42 @@
from langchain.automaton.prompt_generation import AdapterBasedTranslator
from langchain.schema.messages import HumanMessage, AIMessage, SystemMessage
from langchain.automaton.typedefs import FunctionCallResponse
def test_adapter_based_generator() -> None:
translator = AdapterBasedTranslator()
assert translator.to_messages(
[
SystemMessage(content="System"),
AIMessage(content="Hi"),
HumanMessage(content="Hello"),
]
) == [
SystemMessage(content="System"),
AIMessage(content="Hi"),
HumanMessage(content="Hello"),
]
translator = AdapterBasedTranslator(
msg_adapters={
HumanMessage: lambda m: AIMessage(content=m.content),
FunctionCallResponse: lambda m: HumanMessage(
content=f"Observation: {m.result}"
),
}
)
assert translator.to_messages(
[
SystemMessage(content="System"),
AIMessage(content="Hi"),
HumanMessage(content="Hello"),
FunctionCallResponse(name="func", result="result"),
]
) == [
SystemMessage(content="System"),
AIMessage(content="Hi"),
AIMessage(content="Hello"),
HumanMessage(content="Observation: result"),
]

View File

@@ -0,0 +1,20 @@
from typing import Sequence, List, Optional
from langchain.automaton.typedefs import MessageLike
from langchain.schema import HumanMessage, AIMessage
from langchain.schema.runnable import RunnableLambda, Runnable
def test_router() -> None:
messages = [AIMessage(content="Hello, world!")]
program_state = RunnableLambda(lambda x: x)
def route(messages: Sequence[MessageLike]) -> Optional[Runnable]:
if isinstance(messages[-1], HumanMessage):
return None
else:
return program_state
router = RunnableLambda(route)
assert router.invoke(messages) == []

View File

@@ -0,0 +1,46 @@
from __future__ import annotations
import json
from typing import Any, Iterator, List, Mapping
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.chat_models.base import BaseChatModel
from langchain.schema import AIMessage, BaseMessage, ChatGeneration, ChatResult
from langchain.tools import BaseTool
class FakeChatModel(BaseChatModel):
"""A fake chat model that returns a pre-defined response."""
message_iter: Iterator[BaseMessage]
@property
def _llm_type(self) -> str:
"""The type of the model."""
return "fake-openai-chat-model"
def _generate(
self,
messages: List[BaseMessage],
stop: List[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
"""Generate a response to the given messages."""
message = next(self.message_iter)
return ChatResult(generations=[ChatGeneration(message=message)])
def construct_func_invocation_message(
tool: BaseTool, args: Mapping[str, Any]
) -> AIMessage:
"""Construct a function invocation message."""
return AIMessage(
content="",
additional_kwargs={
"function_call": {
"name": tool.name,
"arguments": json.dumps(args),
}
},
)

View File

@@ -0,0 +1,25 @@
from typing import Sequence, TypedDict
from langchain.tools import BaseTool
def _generate_tools_descriptions(tools: Sequence[BaseTool]) -> str:
"""Generate a description of the tools."""
return "\n".join([f"{tool_.name}: {tool_.description}" for tool_ in tools]) + "\n"
class ToolInfo(TypedDict):
"""A dictionary containing information about a tool."""
tool_names: str
tools_description: str
def generate_tool_info(tools: Sequence[BaseTool]) -> ToolInfo:
"""Generate a string containing the names of the tools and their descriptions."""
tools_description = _generate_tools_descriptions(tools)
tool_names = ", ".join([tool_.name for tool_ in tools])
return {
"tool_names": tool_names,
"tools_description": tools_description,
}

View File

@@ -0,0 +1,110 @@
from __future__ import annotations
import abc
from typing import Any, Iterator, Mapping, Optional, Sequence, Union, List
from langchain.load.serializable import Serializable
from langchain.schema import BaseMessage, Document
from langchain.schema.runnable import RunnableConfig
class InternalMessage(Serializable):
@property
def lc_serializable(self) -> bool:
"""Indicate whether the class is serializable."""
return True
class FunctionCallRequest(
InternalMessage
): # TODO(Eugene): Rename as FunctionCallRequest
"""A request for a function invocation.
This message can be used to request a function invocation
using the function name and the arguments to pass to the function.
"""
name: str
"""The name of the function to invoke."""
named_arguments: Optional[Mapping[str, Any]] = None
"""The named arguments to pass to the function."""
class Config:
extra = "forbid"
def __str__(self) -> str:
"""Return a string representation of the object."""
return f"FunctionCall(name={self.name}, named_arguments={self.named_arguments})"
class FunctionCallResponse(InternalMessage): # Rename as FunctionCallResult
"""A result of a function invocation."""
name: str
result: Any
error: Optional[str] = None
def __str__(self) -> str:
"""Return a string representation of the object."""
return (
f"FunctionResult(name={self.name}, result={self.result}, "
f"error={self.error})"
)
class RetrievalRequest(InternalMessage):
"""A request for a retrieval."""
query: str
"""The query to use for the retrieval."""
def __str__(self) -> str:
"""Return a string representation of the object."""
return f"RetrievalRequest(query={self.query})"
class RetrievalResponse(InternalMessage):
"""A result of a retrieval."""
results: Sequence[Document]
def __str__(self) -> str:
"""Return a string representation of the object."""
return f"RetrievalResults(results={self.results})"
class AdHocMessage(InternalMessage):
"""A message that is used to prime the language model."""
type: str
data: Any # Make sure this is serializable
def __str__(self) -> str:
"""Return a string representation of the object."""
return f"AdHocMessage(type={self.type}, data={self.data})"
class AgentFinish(InternalMessage):
"""A message that indicates that the agent is finished."""
result: Any
def __str__(self) -> str:
"""Return a string representation of the object."""
return f"AgentFinish(result={self.result})"
MessageLike = Union[BaseMessage, InternalMessage]
class Agent(abc.ABC):
@abc.abstractmethod
def run(
self,
messages: Sequence[MessageLike],
*,
config: Optional[dict] = None,
max_iterations: int = 100,
) -> Iterator[MessageLike]:
"""Run the agent."""
raise NotImplementedError()

View File

@@ -79,6 +79,15 @@ class BaseMessage(Serializable):
"""Whether this class is LangChain serializable."""
return True
def __repr__(self):
if self.additional_kwargs:
return f"{self.type}: {self.additional_kwargs}"
else:
return f"{self.type}: {self.content}"
def __str__(self):
return self.__repr__()
def __add__(self, other: Any) -> ChatPromptTemplate:
from langchain.prompts.chat import ChatPromptTemplate