Compare commits

...

21 Commits

Author SHA1 Message Date
Eugene Yurtsev
15d5c49076 Merge branch 'master' into eugene/automaton_variant_1 2023-08-07 10:40:09 -04:00
Eugene Yurtsev
65660535bc x 2023-08-07 10:39:56 -04:00
Eugene Yurtsev
6c41dd82f0 x 2023-08-07 10:38:14 -04:00
Eugene Yurtsev
78d788c28c x 2023-08-07 10:37:41 -04:00
Eugene Yurtsev
e9deeab37f x 2023-08-07 10:32:53 -04:00
Eugene Yurtsev
4d595eec5b x 2023-08-06 23:13:32 -04:00
Eugene Yurtsev
047b001336 x 2023-08-06 22:50:46 -04:00
Eugene Yurtsev
840e936c7c x 2023-08-06 22:30:33 -04:00
Eugene Yurtsev
6cc6b490be x 2023-08-06 22:23:51 -04:00
Eugene Yurtsev
f45d1ed4f5 x 2023-08-06 22:20:29 -04:00
Eugene Yurtsev
4ffc417858 x 2023-08-05 23:10:15 -04:00
Eugene Yurtsev
9e74a70859 x 2023-08-05 22:39:04 -04:00
Eugene Yurtsev
0997f2c0f1 x 2023-08-05 22:37:16 -04:00
Eugene Yurtsev
437b545426 x 2023-08-05 16:01:12 -04:00
Eugene Yurtsev
9b9d07572b x 2023-08-05 14:47:17 -04:00
Eugene Yurtsev
6a90c6c2c8 Merge branch 'eugene/fix_mutation_in_place' into eugene/automaton 2023-08-04 10:56:24 -04:00
Eugene Yurtsev
8371187689 x 2023-08-04 10:46:56 -04:00
Eugene Yurtsev
4309c17ffa x 2023-08-04 10:44:57 -04:00
Eugene Yurtsev
183a9d4e66 x 2023-08-04 10:19:10 -04:00
Eugene Yurtsev
c1b444e1e7 x 2023-08-03 15:39:45 -04:00
Eugene Yurtsev
5f117384c0 x 2023-08-03 12:28:01 -04:00
15 changed files with 757 additions and 26 deletions

View File

@@ -0,0 +1,39 @@
"""Module containing an automaton definition."""
from __future__ import annotations
import abc
from typing import (
TypedDict,
Mapping,
Any,
Protocol,
)
from langchain.automaton.typedefs import Memory
class ExecutedState(TypedDict):
"""The response of an action taking LLM."""
id: str # the ID of the state that was just executed
data: Mapping[str, Any]
class State(Protocol):
"""Automaton state protocol."""
def execute(self, memory: Memory) -> ExecutedState:
"""Execute the state, returning the result."""
...
class Automaton:
@abc.abstractmethod
def get_start_state(self, *args: Any, **kwargs: Any) -> State:
"""Get the start state."""
raise NotImplementedError()
@abc.abstractmethod
def get_next_state(self, executed_state: ExecutedState) -> State:
"""Get the next state."""
raise NotImplementedError()

View File

@@ -0,0 +1,71 @@
from __future__ import annotations
from typing import Any, Sequence
from langchain.automaton.automaton import ExecutedState, State, Automaton
from langchain.automaton.typedefs import (
MessageType,
infer_message_type,
PromptGenerator,
)
from langchain.automaton.well_known_states import LLMProgram, UserInputState
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import BaseTool
class ChatAutomaton(Automaton):
def __init__(
self,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
prompt_generator: PromptGenerator,
) -> None:
"""Initialize the chat automaton."""
super().__init__()
self.llm = llm
self.tools = tools
# TODO: Fix mutability of chat template, potentially add factory method
self.prompt_generator = prompt_generator
self.llm_program_state = LLMProgram(
llm=self.llm,
tools=self.tools,
prompt_generator=self.prompt_generator,
)
def get_start_state(self, *args: Any, **kwargs: Any) -> State:
"""Get the start state."""
return self.llm_program_state
def get_next_state(
self, executed_state: ExecutedState # Add memory for transition functions?
) -> State:
"""Get the next state."""
previous_state_id = executed_state["id"]
data = executed_state["data"]
if previous_state_id == "user_input":
return self.llm_program_state
elif previous_state_id == "llm_program":
message_type = infer_message_type(data["message"])
if message_type == MessageType.AI:
return UserInputState()
elif message_type == MessageType.FUNCTION:
return self.llm_program_state
else:
raise AssertionError(f"Unknown message type: {message_type}")
else:
raise ValueError(f"Unknown state ID: {previous_state_id}")
# This is transition matrix syntax
# transition_matrix = {
# ("user_input", "*"): LLMProgram,
# ("llm_program", MessageType.AI): UserInputState,
# ("llm_program", MessageType.AI_SELF): LLMProgram,
# (
# "llm_program",
# MessageType.AI_INVOKE,
# ): FuncInvocationState, # But must add function message
# ("func_invocation", MessageType.FUNCTION): LLMProgram,
# }
#

View File

@@ -0,0 +1,38 @@
from __future__ import annotations
from typing import (
List,
Tuple,
)
from langchain.automaton.automaton import ExecutedState, State, Automaton
from langchain.automaton.typedefs import Memory
# Need to make into runnable
# This is a for looping runnable... :)
class Executor:
def __init__(
self, automaton: Automaton, memory: Memory, max_iterations: int
) -> None:
"""Initialize the executor."""
self.automaton = automaton
self.max_iterations = max_iterations
self.memory = memory
def run(self) -> Tuple[State, List[ExecutedState]]:
"""Run the automaton.
Returns:
The final state and result of executed states.
"""
state = self.automaton.get_start_state()
executed_states = []
for _ in range(self.max_iterations):
executed_state = state.execute(memory=self.memory)
executed_states.append(executed_state)
# Should the transition function get memory?
state = self.automaton.get_next_state(executed_state)
return state, executed_states

View File

@@ -0,0 +1,136 @@
from __future__ import annotations
from operator import itemgetter
from typing import Any, Callable, List, Mapping, TypedDict, Union, Sequence, Optional
from langchain.base_language import BaseLanguageModel
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
from langchain.schema import BaseMessage, AIMessage
from langchain.schema.output import ChatGeneration
from langchain.schema.runnable import RouterRunnable, Runnable, RunnableBinding
from langchain.tools.base import BaseTool
from langchain.tools.convert_to_openai import format_tool_to_openai_function
import json
class OpenAIFunction(TypedDict):
"""A function to call on the OpenAI API."""
name: str
"""The name of the function."""
description: str
"""The description of the function."""
parameters: dict
"""The parameters to the function."""
class OpenAIFunctionsRouter(RunnableBinding[ChatGeneration, Any]):
"""A runnable that routes to the selected function."""
functions: List[OpenAIFunction]
def __init__(
self,
functions: List[OpenAIFunction],
runnables: Mapping[
str,
Union[
Runnable[dict, Any],
Callable[[dict], Any],
],
],
):
assert len(functions) == len(runnables)
assert all(func["name"] in runnables for func in functions)
router = (
JsonOutputFunctionsParser(args_only=False)
| {"key": itemgetter("name"), "input": itemgetter("arguments")}
| RouterRunnable(runnables)
)
super().__init__(bound=router, kwargs={}, functions=functions)
class FunctionCall(TypedDict):
name: str
"""The name of the function."""
arguments: dict
"""The arguments to the function."""
result: Any # Need to denote not invoked yet as well
"""The result of the function call"""
class ActingResult(TypedDict):
"""The result of an action."""
message: BaseMessage
"""The message that was passed to the action."""
function_call: Optional[FunctionCall]
def create_action_taking_llm(
llm: BaseLanguageModel,
*,
tools: Sequence[BaseTool] = (),
stop: Sequence[str] | None = None,
invoke_function: bool = True,
) -> Runnable:
"""A chain that can create an action.
Args:
llm: The language model to use.
tools: The tools to use.
stop: The stop tokens to use.
invoke_function: Whether to invoke the function.
Returns:
a segment of a runnable that take an action.
"""
openai_funcs = [format_tool_to_openai_function(tool_) for tool_ in tools]
def _interpret_message(message: BaseMessage) -> ActingResult:
"""Interpret a message."""
if (
isinstance(message, AIMessage)
and "function_call" in message.additional_kwargs
):
if invoke_function:
result = invoke_from_function.invoke( # TODO: fixme using invoke
message
)
else:
result = None
return {
"message": message,
"function_call": {
"name": message.additional_kwargs["function_call"]["name"],
"arguments": json.loads(
message.additional_kwargs["function_call"]["arguments"]
),
"result": result,
# Check this works.
# "result": message.additional_kwargs["function_call"]
# | invoke_from_function,
},
}
else:
return {
"message": message,
"function_call": None,
}
invoke_from_function = OpenAIFunctionsRouter(
functions=openai_funcs,
runnables={
openai_func["name"]: tool_
for openai_func, tool_ in zip(openai_funcs, tools)
},
)
if stop:
_llm = llm.bind(stop=stop)
else:
_llm = llm
chain = _llm.bind(functions=openai_funcs) | _interpret_message
return chain

View File

@@ -0,0 +1,99 @@
from __future__ import annotations
from typing import (
List,
Sequence,
)
from langchain.automaton.chat_automaton import ChatAutomaton
from langchain.automaton.executor import Executor
from langchain.automaton.tests.utils import (
FakeChatOpenAI,
construct_func_invocation_message,
)
from langchain.automaton.typedefs import Memory
from langchain.schema import PromptValue
from langchain.schema.messages import (
AIMessage,
BaseMessage,
SystemMessage,
FunctionMessage,
)
from langchain.tools.base import tool as tool_maker
class MessageBasedPromptValue(PromptValue):
"""Prompt Value populated from messages."""
messages: List[BaseMessage]
@classmethod
def from_messages(cls, messages: Sequence[BaseMessage]) -> MessageBasedPromptValue:
return cls(messages=messages)
def to_messages(self) -> List[BaseMessage]:
return self.messages
def to_string(self) -> str:
return " ".join([message.content for message in self.messages])
def prompt_generator(memory: Memory) -> PromptValue:
"""Generate a prompt."""
if not memory.messages:
raise AssertionError("Memory is empty")
return MessageBasedPromptValue.from_messages(messages=memory.messages)
def test_automaton() -> None:
"""Run the automaton."""
@tool_maker
def get_time() -> str:
"""Get time."""
return "9 PM"
@tool_maker
def get_location() -> str:
"""Get location."""
return "the park"
tools = [get_time, get_location]
llm = FakeChatOpenAI(
message_iter=iter(
[
construct_func_invocation_message(get_time, {}),
AIMessage(
content="The time is 9 PM.",
),
]
)
)
# TODO(FIX MUTABILITY)
memory = Memory(
messages=[
SystemMessage(
content=(
"Hello! I'm a chatbot that can help you write a letter. "
"What would you like to do?"
),
)
]
)
chat_automaton = ChatAutomaton(
llm=llm, tools=tools, prompt_generator=prompt_generator
)
executor = Executor(chat_automaton, memory, max_iterations=1)
state, executed_states = executor.run()
assert executed_states == [
{
"data": {
"message": FunctionMessage(
content="9 PM", additional_kwargs={}, name="get_time"
)
},
"id": "llm_program",
}
]

View File

@@ -0,0 +1,73 @@
from __future__ import annotations
from langchain.automaton.open_ai_functions import OpenAIFunctionsRouter
from langchain.automaton.tests.utils import FakeChatOpenAI
from langchain.schema import AIMessage
from langchain.schema.runnable import RunnableLambda
def test_openai_functions_router() -> None:
"""Test the OpenAIFunctionsRouter."""
def revise(notes: str) -> str:
"""Revises the draft."""
return f"Revised draft: {notes}!"
def accept(draft: str) -> str:
"""Accepts the draft."""
return f"Accepted draft: {draft}!"
router = OpenAIFunctionsRouter(
functions=[
{
"name": "revise",
"description": "Sends the draft for revision.",
"parameters": {
"type": "object",
"properties": {
"notes": {
"type": "string",
"description": "The editor's notes to guide the revision.",
},
},
},
},
{
"name": "accept",
"description": "Accepts the draft.",
"parameters": {
"type": "object",
"properties": {
"draft": {
"type": "string",
"description": "The draft to accept.",
},
},
},
},
],
runnables={
"revise": RunnableLambda(lambda x: revise(x["revise"])),
"accept": RunnableLambda(lambda x: accept(x["draft"])),
},
)
model = FakeChatOpenAI(
message_iter=iter(
[
AIMessage(
content="",
additional_kwargs={
"function_call": {
"name": "accept",
"arguments": '{\n "draft": "turtles"\n}',
}
},
)
]
)
)
chain = model.bind(functions=router.functions) | router
assert chain.invoke("Something about turtles?") == "Accepted draft: turtles!"

View File

@@ -0,0 +1,46 @@
from __future__ import annotations
import json
from typing import Iterator, List, Any, Mapping
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.chat_models.base import BaseChatModel
from langchain.schema import BaseMessage, ChatResult, ChatGeneration, AIMessage
from langchain.tools import BaseTool
class FakeChatOpenAI(BaseChatModel):
"""A fake chat model that returns a pre-defined response."""
message_iter: Iterator[BaseMessage]
@property
def _llm_type(self) -> str:
"""The type of the model."""
return "fake-openai-chat-model"
def _generate(
self,
messages: List[BaseMessage],
stop: List[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> ChatResult:
"""Generate a response to the given messages."""
message = next(self.message_iter)
return ChatResult(generations=[ChatGeneration(message=message)])
def construct_func_invocation_message(
tool: BaseTool, args: Mapping[str, Any]
) -> AIMessage:
"""Construct a function invocation message."""
return AIMessage(
content="",
additional_kwargs={
"function_call": {
"name": tool.name,
"arguments": json.dumps(args),
}
},
)

View File

@@ -0,0 +1,61 @@
from __future__ import annotations
import dataclasses
import enum
from typing import Callable
from langchain.schema import (
BaseMessage,
FunctionMessage,
AIMessage,
SystemMessage,
HumanMessage,
BaseChatMessageHistory,
PromptValue,
)
class MessageType(enum.Enum):
"""The type of message."""
SYSTEM = enum.auto()
USER = enum.auto()
FUNCTION = enum.auto()
AI = enum.auto()
AI_INVOKE = enum.auto()
def infer_message_type(message: BaseMessage) -> MessageType:
"""Infer the message type."""
if isinstance(message, FunctionMessage):
return MessageType.FUNCTION
elif isinstance(message, AIMessage):
if message.additional_kwargs:
return MessageType.AI_INVOKE
else:
return MessageType.AI
elif isinstance(message, SystemMessage):
return MessageType.SYSTEM
elif isinstance(message, HumanMessage):
return MessageType.USER
else:
raise ValueError(f"Unknown message type: {type(message)}")
class Memory(BaseChatMessageHistory):
"""A memory for the automaton."""
def __init__(self, messages):
self.messages = messages
def add_message(self, message: BaseMessage) -> None:
"""Add a message to the memory."""
self.messages.append(message)
def clear(self) -> None:
"""Clear the memory."""
self.messages = []
# Interface that takes memory and returns a prompt value
PromptGenerator = Callable[[Memory], PromptValue]

View File

@@ -0,0 +1,79 @@
from __future__ import annotations
import dataclasses
from typing import Sequence
from langchain.automaton.automaton import State, ExecutedState
from langchain.automaton.open_ai_functions import create_action_taking_llm
from langchain.automaton.typedefs import (
Memory,
PromptGenerator,
infer_message_type,
MessageType,
)
from langchain.schema import HumanMessage, FunctionMessage, AIMessage
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import BaseTool
@dataclasses.dataclass
class LLMProgram(State):
"""A state that executes an LLM program."""
llm: BaseLanguageModel
tools: Sequence[BaseTool]
prompt_generator: PromptGenerator
def execute(self, memory: Memory) -> ExecutedState:
"""Execute LLM program."""
action_taking_llm = create_action_taking_llm(self.llm, tools=self.tools)
prompt_value = self.prompt_generator(memory)
result = action_taking_llm.invoke(prompt_value)
# Memory is mutable
message = result["message"]
if not isinstance(message, AIMessage):
raise AssertionError(
f"LLM program should return an AI message. Got a {type(message)}."
)
memory.add_message(message)
if infer_message_type(message) == MessageType.AI_INVOKE:
function_call = result["function_call"]
function_message = FunctionMessage(
name=function_call["name"],
content=function_call["result"],
)
memory.add_message(function_message)
routing_message = function_message
else:
routing_message = message
# What information should the state return in this case.
# Does it matter, folks can use it or not...
return {
"id": "llm_program",
"data": {
"message": routing_message, # Last message
},
}
@dataclasses.dataclass
class UserInputState(State):
"""A state that prompts the user for input from stdin.
This is primarily useful for interactive development.
"""
def execute(self, memory: Memory) -> ExecutedState:
"""Execute user input state."""
user_input = input("Enter your input: ")
message = HumanMessage(content=user_input)
memory.add_message(message)
return {
"id": "user_input",
"data": {
"message": message,
},
}

View File

@@ -25,8 +25,8 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
)
message = generation.message
try:
func_call = message.additional_kwargs["function_call"]
except ValueError as exc:
func_call = message.additional_kwargs["function_call"].copy()
except KeyError as exc:
raise OutputParserException(f"Could not parse function call: {exc}")
if self.args_only:
@@ -38,11 +38,16 @@ class JsonOutputFunctionsParser(OutputFunctionsParser):
"""Parse an output as the Json object."""
def parse_result(self, result: List[Generation]) -> Any:
func = super().parse_result(result)
function_call_info = super().parse_result(result)
if self.args_only:
return json.loads(func)
func["arguments"] = json.loads(func["arguments"])
return func
try:
return json.loads(function_call_info)
except (json.JSONDecodeError, TypeError) as exc:
raise OutputParserException(
f"Could not parse function call data: {exc}"
)
function_call_info["arguments"] = json.loads(function_call_info["arguments"])
return function_call_info
class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):

View File

@@ -316,6 +316,16 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC):
"""Format kwargs into a list of messages."""
MessageLike = Union[
BaseMessagePromptTemplate,
BaseChatPromptTemplate,
BaseMessage,
Tuple[str, str],
Tuple[Type, str],
str,
]
class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
"""A prompt template for chat models.
@@ -363,6 +373,9 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
other, (BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate)
):
return ChatPromptTemplate(messages=self.messages + [other])
elif isinstance(other, (list, tuple)):
_other = ChatPromptTemplate.from_messages(other)
return ChatPromptTemplate(messages=self.messages + _other.messages)
elif isinstance(other, str):
prompt = HumanMessagePromptTemplate.from_template(other)
return ChatPromptTemplate(messages=self.messages + [prompt])
@@ -453,17 +466,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
@classmethod
def from_messages(
cls,
messages: Sequence[
Union[
BaseMessagePromptTemplate,
BaseChatPromptTemplate,
BaseMessage,
Tuple[str, str],
Tuple[Type, str],
str,
]
],
cls, messages: Sequence[MessageLike],
) -> ChatPromptTemplate:
"""Create a chat prompt template from a variety of message formats.
@@ -589,6 +592,18 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs}
return type(self)(**prompt_dict)
def append(self, message: MessageLike) -> None:
"""Append message to the end of the chat template.
Args:
message: representation of a message to append.
"""
self.messages.append(_convert_to_message(message))
def extend(self, messages: Sequence[MessageLike]) -> None:
"""Extend the chat template with a sequence of messages."""
self.messages.extend([_convert_to_message(message) for message in messages])
@property
def _prompt_type(self) -> str:
"""Name of prompt type."""
@@ -632,14 +647,7 @@ def _create_template_from_message_type(
def _convert_to_message(
message: Union[
BaseMessagePromptTemplate,
BaseChatPromptTemplate,
BaseMessage,
Tuple[str, str],
Tuple[Type, str],
str,
]
message: MessageLike
) -> Union[BaseMessage, BaseMessagePromptTemplate, BaseChatPromptTemplate]:
"""Instantiate a message from a variety of message formats.

View File

@@ -685,7 +685,7 @@ def tool(
return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None,
infer_schema: bool = True,
) -> Callable:
) -> Callable[[Callable], BaseTool]:
"""Make tools out of functions, can be used with or without arguments.
Args:

View File

@@ -0,0 +1,76 @@
import json
import pytest
from langchain.output_parsers.openai_functions import (
JsonOutputFunctionsParser,
)
from langchain.schema import BaseMessage, ChatGeneration, OutputParserException
from langchain.schema.messages import AIMessage, HumanMessage
@pytest.fixture
def ai_message() -> AIMessage:
"""Return a simple AIMessage."""
content = "This is a test message"
args = json.dumps(
{
"arg1": "value1",
}
)
function_call = {"name": "function_name", "arguments": args}
additional_kwargs = {"function_call": function_call}
return AIMessage(content=content, additional_kwargs=additional_kwargs)
def test_json_output_function_parser(ai_message: AIMessage) -> None:
"""Test that the JsonOutputFunctionsParser with full output."""
chat_generation = ChatGeneration(message=ai_message)
# Full output
parser = JsonOutputFunctionsParser(args_only=False)
result = parser.parse_result([chat_generation])
assert result == {"arguments": {"arg1": "value1"}, "name": "function_name"}
# Args only
parser = JsonOutputFunctionsParser(args_only=True)
result = parser.parse_result([chat_generation])
assert result == {"arg1": "value1"}
# Verify that the original message is not modified
assert ai_message.additional_kwargs == {
"function_call": {"name": "function_name", "arguments": '{"arg1": "value1"}'}
}
@pytest.mark.parametrize(
"bad_message",
[
# Human message has no function call
HumanMessage(content="This is a test message"),
# AIMessage has no function call information.
AIMessage(content="This is a test message", additional_kwargs={}),
# Bad function call information (arguments should be a string)
AIMessage(
content="This is a test message",
additional_kwargs={
"function_call": {"name": "function_name", "arguments": {}}
},
),
# Bad function call information (arguments should be proper json)
AIMessage(
content="This is a test message",
additional_kwargs={
"function_call": {"name": "function_name", "arguments": "noqweqwe"}
},
),
],
)
def test_exceptions_raised_while_parsing(bad_message: BaseMessage) -> None:
"""Test exceptions raised correctly while using JSON parser."""
chat_generation = ChatGeneration(message=bad_message)
with pytest.raises(OutputParserException):
JsonOutputFunctionsParser().parse_result([chat_generation])