This commit is contained in:
Eugene Yurtsev
2023-09-01 10:19:45 -04:00
parent 06e34e54f5
commit 9bbc5af2a8
11 changed files with 342 additions and 455 deletions

View File

@@ -1,71 +0,0 @@
# from __future__ import annotations
#
# from typing import Any, Sequence
#
# from langchain.automaton.automaton import ExecutedState, State, Automaton
# from langchain.automaton.typedefs import (
# MessageType,
# infer_message_type,
# PromptGenerator,
# )
# from langchain.automaton.well_known_states import LLMProgram, UserInputState
# from langchain.schema.language_model import BaseLanguageModel
# from langchain.tools import BaseTool
#
#
# class ChatAutomaton(Automaton):
# def __init__(
# self,
# llm: BaseLanguageModel,
# tools: Sequence[BaseTool],
# prompt_generator: PromptGenerator,
# ) -> None:
# """Initialize the chat automaton."""
# super().__init__()
# self.llm = llm
# self.tools = tools
# # TODO: Fix mutability of chat template, potentially add factory method
# self.prompt_generator = prompt_generator
# self.llm_program_state = LLMProgram(
# llm=self.llm,
# tools=self.tools,
# prompt_generator=self.prompt_generator,
# )
#
# def get_start_state(self, *args: Any, **kwargs: Any) -> State:
# """Get the start state."""
# return self.llm_program_state
#
# def get_next_state(
# self, executed_state: ExecutedState # Add memory for transition functions?
# ) -> State:
# """Get the next state."""
# previous_state_id = executed_state["id"]
# data = executed_state["data"]
#
# if previous_state_id == "user_input":
# return self.llm_program_state
# elif previous_state_id == "llm_program":
# message_type = infer_message_type(data["message"])
# if message_type == MessageType.AI:
# return UserInputState()
# elif message_type == MessageType.FUNCTION:
# return self.llm_program_state
# else:
# raise AssertionError(f"Unknown message type: {message_type}")
# else:
# raise ValueError(f"Unknown state ID: {previous_state_id}")
#
#
# # This is transition matrix syntax
# # transition_matrix = {
# # ("user_input", "*"): LLMProgram,
# # ("llm_program", MessageType.AI): UserInputState,
# # ("llm_program", MessageType.AI_SELF): LLMProgram,
# # (
# # "llm_program",
# # MessageType.AI_INVOKE,
# # ): FuncInvocationState, # But must add function message
# # ("func_invocation", MessageType.FUNCTION): LLMProgram,
# # }
# #

View File

@@ -6,7 +6,6 @@ from typing import Sequence, Optional, Union, List
from langchain.automaton.runnables import (
create_llm_program,
create_tool_invoker,
)
from langchain.automaton.typedefs import (
MessageLog,
@@ -14,6 +13,7 @@ from langchain.automaton.typedefs import (
FunctionCall,
FunctionResult,
AgentFinish,
PrimingMessage,
)
from langchain.prompts import SystemMessagePromptTemplate
from langchain.schema import (
@@ -24,8 +24,6 @@ from langchain.schema import (
SystemMessage,
)
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.runnable import Runnable
from langchain.automaton.typedefs import PrimingMessage
from langchain.tools import BaseTool, Tool
TEMPLATE_ = """\
@@ -155,11 +153,15 @@ class ThinkActPromptGenerator(PromptValue):
def to_messages(self) -> List[BaseMessage]:
"""Return prompt as a list of Messages."""
return [
message
for message in self.message_log.messages
if isinstance(message, BaseMessage)
]
messages = []
for message in self.message_log.messages:
if isinstance(message, BaseMessage):
messages.append(message)
elif isinstance(message, FunctionResult):
messages.append(
SystemMessage(content=f"Observation: `{message.result}`")
)
return messages
@classmethod
def from_message_log(cls, message_log: MessageLog):
@@ -172,47 +174,34 @@ class MRKLAgent:
self,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
*,
max_iterations: int = 10,
) -> None:
"""Initialize the chat automaton."""
self.think_act: Runnable[
MessageLog, Sequence[MessageLike]
] = create_llm_program(
self.think_act = create_llm_program(
llm,
prompt_generator=ThinkActPromptGenerator.from_message_log,
stop=["Observation:", "observation:"],
parser=ActionParser().decode,
tools=tools,
invoke_tools=True,
)
self.tool_invoker: Runnable[FunctionCall, FunctionResult] = create_tool_invoker(
tools
)
self.max_iterations = max_iterations
def run(self, message_log: MessageLog) -> None:
"""Run the agent."""
if not message_log.messages:
raise AssertionError()
if not message_log:
raise AssertionError(f"Expected at least one message in message_log")
last_message = message_log.messages[-1]
for _ in range(self.max_iterations):
last_message = message_log[-1]
max_iterations = 10
iteration_num = 0
while True:
if iteration_num > max_iterations:
break
if isinstance(last_message, AgentFinish):
break
elif isinstance(last_message, FunctionCall):
messages = [
self.tool_invoker.invoke(last_message),
# After we have a function result, we want to prime the LLM with the word "
# Thought"
PrimingMessage(content="Thought:"),
]
else:
messages = self.think_act.invoke(message_log)
if not messages:
raise AssertionError(f"No messages returned from last step")
messages = self.think_act.invoke(message_log)
# Prime the LLM to start with "Thought: " after an observation
if isinstance(messages[-1], FunctionResult):
messages.append(PrimingMessage(content="Thought:"))
message_log.add_messages(messages)
last_message = messages[-1]
iteration_num += 1

View File

@@ -1,108 +0,0 @@
from __future__ import annotations
import json
from typing import Sequence
from langchain.base_language import BaseLanguageModel
from langchain.runnables.openai_functions import OpenAIFunctionsRouter
from langchain.schema import BaseMessage, AIMessage
from langchain.schema.runnable import Runnable
from langchain.tools.base import BaseTool
from langchain.tools.convert_to_openai import format_tool_to_openai_function
# def create_action_taking_llm(
# llm: BaseLanguageModel,
# *,
# tools: Sequence[BaseTool] = (),
# stop: Sequence[str] | None = None,
# invoke_function: bool = True,
# ) -> Runnable:
# """A chain that can create an action.
#
# Args:
# llm: The language model to use.
# tools: The tools to use.
# stop: The stop tokens to use.
# invoke_function: Whether to invoke the function.
#
# Returns:
# a segment of a runnable that take an action.
# """
#
# openai_funcs = [format_tool_to_openai_function(tool_) for tool_ in tools]
#
# def _interpret_message(message: BaseMessage) -> ActingResult:
# """Interpret a message."""
# if (
# isinstance(message, AIMessage)
# and "function_call" in message.additional_kwargs
# ):
# if invoke_function:
# result = invoke_from_function.invoke( # TODO: fixme using invoke
# message
# )
# else:
# result = None
# return {
# "message": message,
# "function_call": {
# "name": message.additional_kwargs["function_call"]["name"],
# "arguments": json.loads(
# message.additional_kwargs["function_call"]["arguments"]
# ),
# "result": result,
# # Check this works.
# # "result": message.additional_kwargs["function_call"]
# # | invoke_from_function,
# },
# }
# else:
# return {
# "message": message,
# "function_call": None,
# }
#
# invoke_from_function = OpenAIFunctionsRouter(
# functions=openai_funcs,
# runnables={
# openai_func["name"]: tool_
# for openai_func, tool_ in zip(openai_funcs, tools)
# },
# )
#
# if stop:
# _llm = llm.bind(stop=stop)
# else:
# _llm = llm
#
# chain = _llm.bind(functions=openai_funcs) | _interpret_message
# return chain
#
#
# def _interpret_message(message: BaseMessage) -> ActingResult:
# """Interpret a message."""
# if isinstance(message, AIMessage) and "function_call" in message.additional_kwargs:
# raise NotImplementedError()
# # if invoke_function:
# # result = invoke_from_function.invoke(message) # TODO: fixme using invoke
# # else:
# # result = None
# # return {
# # "message": message,
# # "function_call": {
# # "name": message.additional_kwargs["function_call"]["name"],
# # "arguments": json.loads(
# # message.additional_kwargs["function_call"]["arguments"]
# # ),
# # "result": result,
# # # Check this works.
# # # "result": message.additional_kwargs["function_call"]
# # # | invoke_from_function,
# # },
# # }
# else:
# return {
# "message": message,
# "function_call": None,
# }

View File

@@ -0,0 +1,72 @@
from __future__ import annotations
from typing import Sequence, List
from langchain.automaton.prompt_generators import MessageLogPromptValue
from langchain.automaton.runnables import create_llm_program
from langchain.automaton.typedefs import (
MessageLog,
AgentFinish,
FunctionCall,
)
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
from langchain.schema import Generation
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import AIMessage
from langchain.schema.output_parser import BaseGenerationOutputParser
from langchain.tools import BaseTool
class OpenAIFunctionsParser(BaseGenerationOutputParser):
def parse_result(self, result: List[Generation]):
if len(result) != 1:
raise AssertionError(f"Expected exactly one result")
first_result = result[0]
message = first_result.message
if not isinstance(message, AIMessage) or not message.additional_kwargs:
return AgentFinish(result=message)
parser = JsonOutputFunctionsParser(strict=False, args_only=False)
try:
function_request = parser.parse_result(result)
except Exception as e:
raise RuntimeError(f"Error parsing result: {result} {repr(e)}") from e
return FunctionCall(
name=function_request["name"],
arguments=function_request["arguments"],
)
class OpenAIAgent:
def __init__(
self,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
*,
max_iterations: int = 10,
) -> None:
"""Initialize the chat automaton."""
self.llm_program = create_llm_program(
llm,
prompt_generator=MessageLogPromptValue.from_message_log,
tools=tools,
parser=OpenAIFunctionsParser(),
)
self.max_iterations = max_iterations
def run(self, message_log: MessageLog) -> None:
"""Run the agent."""
if not message_log:
raise AssertionError(f"Expected at least one message in message_log")
for _ in range(self.max_iterations):
last_message = message_log[-1]
if isinstance(last_message, AgentFinish):
break
messages = self.llm_program.invoke(message_log)
message_log.add_messages(messages)

View File

@@ -0,0 +1,55 @@
from __future__ import annotations
import json
from typing import List
from langchain.automaton.typedefs import MessageLog, FunctionResult
from langchain.schema import PromptValue, BaseMessage, FunctionMessage
class MessageLogPromptValue(PromptValue):
"""Base abstract class for inputs to any language model.
PromptValues can be converted to both LLM (pure text-generation) inputs and
ChatModel inputs.
"""
message_log: MessageLog
# If True will use the OpenAI function method
use_function_message: bool = (
False # TODO(Eugene): replace with adapter, should be generic
)
class Config:
arbitrary_types_allowed = True
def to_string(self) -> str:
"""Return prompt value as string."""
finalized = []
for message in self.to_messages():
prefix = message.type
finalized.append(f"{prefix}: {message.content}")
return "\n".join(finalized) + "\n" + "ai:"
def to_messages(self) -> List[BaseMessage]:
"""Return prompt as a list of Messages."""
messages = []
for message in self.message_log.messages:
if isinstance(message, BaseMessage):
messages.append(message)
elif isinstance(message, FunctionResult):
if self.use_function_message:
messages.append(
FunctionMessage(
name=message.name, content=json.dumps(message.result)
)
)
else:
# Ignore internal messages
pass
return messages
@classmethod
def from_message_log(cls, message_log: MessageLog) -> MessageLogPromptValue:
"""Create a PromptValue from a MessageLog."""
return cls(message_log=message_log)

View File

@@ -1,7 +1,7 @@
"""Module contains well known runnables for agents."""
"""Module contains useful runnables for agents."""
from __future__ import annotations
from typing import Sequence, Callable, Union, List, Optional
from typing import Sequence, Callable, List, Optional, Union
from langchain.automaton.typedefs import (
MessageLike,
@@ -13,6 +13,10 @@ from langchain.schema import BaseMessage, AIMessage, PromptValue
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.runnable.base import RunnableLambda, Runnable
from langchain.tools import BaseTool
from langchain.tools.convert_to_openai import format_tool_to_openai_function
# PUBLIC API
def create_tool_invoker(
@@ -21,16 +25,19 @@ def create_tool_invoker(
"""Re-write with router."""
tools_by_name = {tool.name: tool for tool in tools}
def func(input: FunctionCall) -> FunctionResult:
def func(function_call: FunctionCall) -> FunctionResult:
"""A function that can invoke a tool using .run"""
tool = tools_by_name[input.name]
try:
result = tool.run(input.arguments or {})
tool = tools_by_name[function_call.name]
except KeyError:
raise AssertionError(f"No such tool: {function_call.name}")
try:
result = tool.run(function_call.arguments or {})
error = None
except Exception as e:
result = None
error = repr(e) + repr(input.arguments)
return FunctionResult(result=result, error=error)
error = repr(e) + repr(function_call.arguments)
return FunctionResult(name=function_call.name, result=result, error=error)
return RunnableLambda(func=func)
@@ -39,17 +46,28 @@ def create_llm_program(
llm: BaseLanguageModel,
prompt_generator: Callable[[MessageLog], PromptValue],
*,
tools: Optional[Sequence[BaseTool]] = None,
stop: Optional[Sequence[str]] = None,
parser: Union[Runnable, Callable] = None,
parser: Union[
Runnable[Union[BaseMessage, str], MessageLike],
Callable[[Union[BaseMessage, str]], MessageLike],
None,
] = None,
invoke_tools: bool = True,
) -> Runnable[MessageLog, List[MessageLike]]:
"""Create a runnable that can update memory."""
tool_invoker = create_tool_invoker(tools) if invoke_tools else None
openai_funcs = [format_tool_to_openai_function(tool_) for tool_ in tools]
def _bound(message_log: MessageLog):
messages = []
prompt_value = prompt_generator(message_log)
llm_chain = llm
if stop:
llm_chain = llm_chain.bind(stop=stop)
if tools:
llm_chain = llm_chain.bind(tools=openai_funcs)
result = llm_chain.invoke(prompt_value)
@@ -58,7 +76,7 @@ def create_llm_program(
elif isinstance(result, str):
messages.append(AIMessage(content=result))
else:
raise NotImplementedError(f"Unsupported type {result}")
raise NotImplementedError(f"Unsupported type {type(result)}")
if parser:
if not isinstance(parser, Runnable):
@@ -73,6 +91,12 @@ def create_llm_program(
)
messages.append(parsed_result)
last_message = messages[-1]
if tool_invoker and isinstance(last_message, FunctionCall):
function_result = tool_invoker.invoke(last_message)
messages.append(function_result)
return messages
return RunnableLambda(

View File

@@ -1,201 +0,0 @@
from __future__ import annotations
from typing import Sequence, List
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
# from langchain.automaton.chat_automaton import ChatAutomaton
from langchain.automaton.mrkl_agent import ActionParser
from langchain.automaton.tests.utils import (
FakeChatOpenAI,
construct_func_invocation_message,
)
from langchain.schema import PromptValue
from langchain.schema.messages import (
AIMessage,
BaseMessage,
SystemMessage,
FunctionMessage,
)
from langchain.schema.runnable import RunnableLambda
from langchain.tools import tool, Tool
from langchain.tools.base import tool as tool_maker
def get_tools() -> List[Tool]:
@tool
def name() -> str:
"""Use to look up the user's name"""
return "Eugene"
@tool
def get_weather(city: str) -> str:
"""Get weather in a specific city."""
return "42F and sunny"
@tool
def add(x: int, y: int) -> int:
"""Use to add two numbers."""
return x + y
return list(locals().values())
def test_structured_output_chat() -> None:
parser = StructuredChatOutputParser()
output = parser.parse(
"""
```json
{
"action": "hello",
"action_input": {
"a": 2
}
}
```
"""
)
assert output == {}
class MessageBasedPromptValue(PromptValue):
"""Prompt Value populated from messages."""
messages: List[BaseMessage]
@classmethod
def from_messages(cls, messages: Sequence[BaseMessage]) -> MessageBasedPromptValue:
return cls(messages=messages)
def to_messages(self) -> List[BaseMessage]:
return self.messages
def to_string(self) -> str:
return "\n".join([message.content for message in self.messages])
# def prompt_generator(memory: Memory) -> PromptValue:
# """Generate a prompt."""
# if not memory.messages:
# raise AssertionError("Memory is empty")
# return MessageBasedPromptValue.from_messages(messages=memory.messages)
#
def test_automaton() -> None:
"""Run the automaton."""
@tool_maker
def get_time() -> str:
"""Get time."""
return "9 PM"
@tool_maker
def get_location() -> str:
"""Get location."""
return "the park"
tools = [get_time, get_location]
llm = FakeChatOpenAI(
message_iter=iter(
[
construct_func_invocation_message(get_time, {}),
AIMessage(
content="The time is 9 PM.",
),
]
)
)
# TODO(FIX MUTABILITY)
memory = Memory(
messages=[
SystemMessage(
content=(
"Hello! I'm a chatbot that can help you write a letter. "
"What would you like to do?"
),
)
]
)
chat_automaton = ChatAutomaton(
llm=llm, tools=tools, prompt_generator=prompt_generator
)
executor = Executor(chat_automaton, memory, max_iterations=1)
state, executed_states = executor.run()
assert executed_states == [
{
"data": {
"message": FunctionMessage(
content="9 PM", additional_kwargs={}, name="get_time"
)
},
"id": "llm_program",
}
]
def test_generate_template() -> None:
"""Generate template."""
template = generate_template()
assert template.format_messages(tools="hello", tool_names="hello") == []
def test_parser() -> None:
"""Tes the parser."""
sample_text = """
Some text before
<action>
{
"key": "value",
"number": 42
}
</action>
Some text after
"""
action_parser = ActionParser(strict=False)
action = action_parser.decode(sample_text)
assert action == {
"key": "value",
"number": 42,
}
def test_function_invocation() -> None:
"""test function invocation"""
tools = get_tools()
from langchain.automaton.well_known_states import create_tool_invoker
runnable = create_tool_invoker(tools)
result = runnable.invoke({"name": "add", "inputs": {"x": 1, "y": 2}})
assert result == 3
def test_create_llm_program() -> None:
"""Generate llm program."""
from langchain.automaton.mrkl_automaton import (
_generate_prompt,
_generate_mrkl_memory,
)
tools = get_tools()
llm = FakeChatOpenAI(
message_iter=iter(
[
AIMessage(
content="""Thought: Hello. <action>{"name": "key"}</action>""",
),
]
)
)
program = create_llm_program(
"think-act",
llm,
prompt_generator=_generate_prompt,
stop=["Observation"],
parser=RunnableLambda(ActionParser(strict=False).decode),
)
mrkl_memory = _generate_mrkl_memory(tools)
result = program.invoke(mrkl_memory)
assert result == {"id": "think-act", "data": {}}

View File

@@ -0,0 +1,41 @@
from __future__ import annotations
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
from langchain.automaton.mrkl_agent import ActionParser
def test_structured_output_chat() -> None:
parser = StructuredChatOutputParser()
output = parser.parse(
"""
```json
{
"action": "hello",
"action_input": {
"a": 2
}
}
```
"""
)
assert output == {}
def test_parser() -> None:
"""Tes the parser."""
sample_text = """
Some text before
<action>
{
"key": "value",
"number": 42
}
</action>
Some text after
"""
action_parser = ActionParser()
action = action_parser.decode(sample_text)
assert action == {
"key": "value",
"number": 42,
}

View File

@@ -0,0 +1,97 @@
from __future__ import annotations
from typing import List, cast
import pytest
from langchain.automaton.openai_agent import OpenAIAgent
from langchain.automaton.tests.utils import (
FakeChatOpenAI,
construct_func_invocation_message,
)
from langchain.automaton.typedefs import (
FunctionCall,
FunctionResult,
MessageLog,
AgentFinish,
)
from langchain.schema.messages import (
AIMessage,
SystemMessage,
)
from langchain.tools import tool, Tool
from langchain.tools.base import BaseTool
@pytest.fixture()
def tools() -> List[BaseTool]:
@tool
def get_time() -> str:
"""Get time."""
return "9 PM"
@tool
def get_location() -> str:
"""Get location."""
return "the park"
return cast(List[Tool], [get_time, get_location])
def test_openai_agent(tools: List[Tool]) -> None:
get_time, get_location = tools
llm = FakeChatOpenAI(
message_iter=iter(
[
construct_func_invocation_message(get_time, {}),
AIMessage(
content="The time is 9 PM.",
),
]
)
)
agent = OpenAIAgent(llm=llm, tools=tools, max_iterations=10)
message_log = MessageLog(
[
SystemMessage(
content="What time is it?",
)
]
)
expected_messages = [
SystemMessage(
content="What time is it?",
),
AIMessage(
content="",
additional_kwargs={
"function_call": {
"name": "get_time",
"arguments": "{}",
}
},
),
FunctionCall(
name="get_time",
arguments={},
),
FunctionResult(
name="get_time",
result="9 PM",
error=None,
),
AIMessage(
content="The time is 9 PM.",
),
AgentFinish(
AIMessage(
content="The time is 9 PM.",
),
),
]
agent.run(message_log)
assert message_log.messages == expected_messages

View File

@@ -1,11 +1,10 @@
from __future__ import annotations
import dataclasses
from typing import Any, Optional, Sequence, List, Mapping, overload, Union
from typing import Any, Optional, Sequence, Mapping, overload, Union
from langchain.schema import (
BaseMessage,
PromptValue,
)
@@ -17,6 +16,7 @@ class FunctionCall:
@dataclasses.dataclass(frozen=True)
class FunctionResult:
name: str
result: Any
error: Optional[str]
@@ -41,7 +41,7 @@ MessageLike = Union[
class MessageLog:
"""A generalized message log for message like items."""
def __init__(self, messages: Sequence[MessageLike]) -> None:
def __init__(self, messages: Sequence[MessageLike] = ()) -> None:
"""Initialize the message log."""
self.messages = list(messages)
@@ -66,35 +66,15 @@ class MessageLog:
else:
return self.messages[index]
def __bool__(self):
return bool(self.messages)
def __len__(self) -> int:
"""Get the length of the chat template."""
return len(self.messages)
class MessageLogPromptValue(PromptValue):
"""Base abstract class for inputs to any language model.
PromptValues can be converted to both LLM (pure text-generation) inputs and
ChatModel inputs.
"""
message_log: MessageLog
class Config:
arbitrary_types_allowed = True
def to_string(self) -> str:
"""Return prompt value as string."""
finalized = []
for message in self.to_messages():
prefix = message.type
finalized.append(f"{prefix}: {message.content}")
return "\n".join(finalized) + "\n" + "ai:"
def to_messages(self) -> List[BaseMessage]:
"""Return prompt as a list of Messages."""
return [
message
for message in self.message_log.messages
if isinstance(message, BaseMessage)
]
class Agent:
def run(self, message_log: MessageLog) -> None:
"""Run the agent on a message."""
raise NotImplementedError

View File

@@ -79,6 +79,15 @@ class BaseMessage(Serializable):
"""Whether this class is LangChain serializable."""
return True
def __repr__(self):
if self.additional_kwargs:
return f"{self.type}: {self.additional_kwargs}"
else:
return f"{self.type}: {self.content}"
def __str__(self):
return self.__repr__()
def __add__(self, other: Any) -> ChatPromptTemplate:
from langchain.prompts.chat import ChatPromptTemplate