mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-03 15:55:44 +00:00
x
This commit is contained in:
@@ -1,71 +0,0 @@
|
||||
# from __future__ import annotations
|
||||
#
|
||||
# from typing import Any, Sequence
|
||||
#
|
||||
# from langchain.automaton.automaton import ExecutedState, State, Automaton
|
||||
# from langchain.automaton.typedefs import (
|
||||
# MessageType,
|
||||
# infer_message_type,
|
||||
# PromptGenerator,
|
||||
# )
|
||||
# from langchain.automaton.well_known_states import LLMProgram, UserInputState
|
||||
# from langchain.schema.language_model import BaseLanguageModel
|
||||
# from langchain.tools import BaseTool
|
||||
#
|
||||
#
|
||||
# class ChatAutomaton(Automaton):
|
||||
# def __init__(
|
||||
# self,
|
||||
# llm: BaseLanguageModel,
|
||||
# tools: Sequence[BaseTool],
|
||||
# prompt_generator: PromptGenerator,
|
||||
# ) -> None:
|
||||
# """Initialize the chat automaton."""
|
||||
# super().__init__()
|
||||
# self.llm = llm
|
||||
# self.tools = tools
|
||||
# # TODO: Fix mutability of chat template, potentially add factory method
|
||||
# self.prompt_generator = prompt_generator
|
||||
# self.llm_program_state = LLMProgram(
|
||||
# llm=self.llm,
|
||||
# tools=self.tools,
|
||||
# prompt_generator=self.prompt_generator,
|
||||
# )
|
||||
#
|
||||
# def get_start_state(self, *args: Any, **kwargs: Any) -> State:
|
||||
# """Get the start state."""
|
||||
# return self.llm_program_state
|
||||
#
|
||||
# def get_next_state(
|
||||
# self, executed_state: ExecutedState # Add memory for transition functions?
|
||||
# ) -> State:
|
||||
# """Get the next state."""
|
||||
# previous_state_id = executed_state["id"]
|
||||
# data = executed_state["data"]
|
||||
#
|
||||
# if previous_state_id == "user_input":
|
||||
# return self.llm_program_state
|
||||
# elif previous_state_id == "llm_program":
|
||||
# message_type = infer_message_type(data["message"])
|
||||
# if message_type == MessageType.AI:
|
||||
# return UserInputState()
|
||||
# elif message_type == MessageType.FUNCTION:
|
||||
# return self.llm_program_state
|
||||
# else:
|
||||
# raise AssertionError(f"Unknown message type: {message_type}")
|
||||
# else:
|
||||
# raise ValueError(f"Unknown state ID: {previous_state_id}")
|
||||
#
|
||||
#
|
||||
# # This is transition matrix syntax
|
||||
# # transition_matrix = {
|
||||
# # ("user_input", "*"): LLMProgram,
|
||||
# # ("llm_program", MessageType.AI): UserInputState,
|
||||
# # ("llm_program", MessageType.AI_SELF): LLMProgram,
|
||||
# # (
|
||||
# # "llm_program",
|
||||
# # MessageType.AI_INVOKE,
|
||||
# # ): FuncInvocationState, # But must add function message
|
||||
# # ("func_invocation", MessageType.FUNCTION): LLMProgram,
|
||||
# # }
|
||||
# #
|
||||
@@ -6,7 +6,6 @@ from typing import Sequence, Optional, Union, List
|
||||
|
||||
from langchain.automaton.runnables import (
|
||||
create_llm_program,
|
||||
create_tool_invoker,
|
||||
)
|
||||
from langchain.automaton.typedefs import (
|
||||
MessageLog,
|
||||
@@ -14,6 +13,7 @@ from langchain.automaton.typedefs import (
|
||||
FunctionCall,
|
||||
FunctionResult,
|
||||
AgentFinish,
|
||||
PrimingMessage,
|
||||
)
|
||||
from langchain.prompts import SystemMessagePromptTemplate
|
||||
from langchain.schema import (
|
||||
@@ -24,8 +24,6 @@ from langchain.schema import (
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.runnable import Runnable
|
||||
from langchain.automaton.typedefs import PrimingMessage
|
||||
from langchain.tools import BaseTool, Tool
|
||||
|
||||
TEMPLATE_ = """\
|
||||
@@ -155,11 +153,15 @@ class ThinkActPromptGenerator(PromptValue):
|
||||
|
||||
def to_messages(self) -> List[BaseMessage]:
|
||||
"""Return prompt as a list of Messages."""
|
||||
return [
|
||||
message
|
||||
for message in self.message_log.messages
|
||||
if isinstance(message, BaseMessage)
|
||||
]
|
||||
messages = []
|
||||
for message in self.message_log.messages:
|
||||
if isinstance(message, BaseMessage):
|
||||
messages.append(message)
|
||||
elif isinstance(message, FunctionResult):
|
||||
messages.append(
|
||||
SystemMessage(content=f"Observation: `{message.result}`")
|
||||
)
|
||||
return messages
|
||||
|
||||
@classmethod
|
||||
def from_message_log(cls, message_log: MessageLog):
|
||||
@@ -172,47 +174,34 @@ class MRKLAgent:
|
||||
self,
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool],
|
||||
*,
|
||||
max_iterations: int = 10,
|
||||
) -> None:
|
||||
"""Initialize the chat automaton."""
|
||||
self.think_act: Runnable[
|
||||
MessageLog, Sequence[MessageLike]
|
||||
] = create_llm_program(
|
||||
self.think_act = create_llm_program(
|
||||
llm,
|
||||
prompt_generator=ThinkActPromptGenerator.from_message_log,
|
||||
stop=["Observation:", "observation:"],
|
||||
parser=ActionParser().decode,
|
||||
tools=tools,
|
||||
invoke_tools=True,
|
||||
)
|
||||
self.tool_invoker: Runnable[FunctionCall, FunctionResult] = create_tool_invoker(
|
||||
tools
|
||||
)
|
||||
self.max_iterations = max_iterations
|
||||
|
||||
def run(self, message_log: MessageLog) -> None:
|
||||
"""Run the agent."""
|
||||
if not message_log.messages:
|
||||
raise AssertionError()
|
||||
if not message_log:
|
||||
raise AssertionError(f"Expected at least one message in message_log")
|
||||
|
||||
last_message = message_log.messages[-1]
|
||||
for _ in range(self.max_iterations):
|
||||
last_message = message_log[-1]
|
||||
|
||||
max_iterations = 10
|
||||
iteration_num = 0
|
||||
|
||||
while True:
|
||||
if iteration_num > max_iterations:
|
||||
break
|
||||
if isinstance(last_message, AgentFinish):
|
||||
break
|
||||
elif isinstance(last_message, FunctionCall):
|
||||
messages = [
|
||||
self.tool_invoker.invoke(last_message),
|
||||
# After we have a function result, we want to prime the LLM with the word "
|
||||
# Thought"
|
||||
PrimingMessage(content="Thought:"),
|
||||
]
|
||||
else:
|
||||
messages = self.think_act.invoke(message_log)
|
||||
|
||||
if not messages:
|
||||
raise AssertionError(f"No messages returned from last step")
|
||||
messages = self.think_act.invoke(message_log)
|
||||
# Prime the LLM to start with "Thought: " after an observation
|
||||
if isinstance(messages[-1], FunctionResult):
|
||||
messages.append(PrimingMessage(content="Thought:"))
|
||||
|
||||
message_log.add_messages(messages)
|
||||
last_message = messages[-1]
|
||||
iteration_num += 1
|
||||
|
||||
@@ -1,108 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Sequence
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.runnables.openai_functions import OpenAIFunctionsRouter
|
||||
from langchain.schema import BaseMessage, AIMessage
|
||||
from langchain.schema.runnable import Runnable
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.tools.convert_to_openai import format_tool_to_openai_function
|
||||
|
||||
|
||||
# def create_action_taking_llm(
|
||||
# llm: BaseLanguageModel,
|
||||
# *,
|
||||
# tools: Sequence[BaseTool] = (),
|
||||
# stop: Sequence[str] | None = None,
|
||||
# invoke_function: bool = True,
|
||||
# ) -> Runnable:
|
||||
# """A chain that can create an action.
|
||||
#
|
||||
# Args:
|
||||
# llm: The language model to use.
|
||||
# tools: The tools to use.
|
||||
# stop: The stop tokens to use.
|
||||
# invoke_function: Whether to invoke the function.
|
||||
#
|
||||
# Returns:
|
||||
# a segment of a runnable that take an action.
|
||||
# """
|
||||
#
|
||||
# openai_funcs = [format_tool_to_openai_function(tool_) for tool_ in tools]
|
||||
#
|
||||
# def _interpret_message(message: BaseMessage) -> ActingResult:
|
||||
# """Interpret a message."""
|
||||
# if (
|
||||
# isinstance(message, AIMessage)
|
||||
# and "function_call" in message.additional_kwargs
|
||||
# ):
|
||||
# if invoke_function:
|
||||
# result = invoke_from_function.invoke( # TODO: fixme using invoke
|
||||
# message
|
||||
# )
|
||||
# else:
|
||||
# result = None
|
||||
# return {
|
||||
# "message": message,
|
||||
# "function_call": {
|
||||
# "name": message.additional_kwargs["function_call"]["name"],
|
||||
# "arguments": json.loads(
|
||||
# message.additional_kwargs["function_call"]["arguments"]
|
||||
# ),
|
||||
# "result": result,
|
||||
# # Check this works.
|
||||
# # "result": message.additional_kwargs["function_call"]
|
||||
# # | invoke_from_function,
|
||||
# },
|
||||
# }
|
||||
# else:
|
||||
# return {
|
||||
# "message": message,
|
||||
# "function_call": None,
|
||||
# }
|
||||
#
|
||||
# invoke_from_function = OpenAIFunctionsRouter(
|
||||
# functions=openai_funcs,
|
||||
# runnables={
|
||||
# openai_func["name"]: tool_
|
||||
# for openai_func, tool_ in zip(openai_funcs, tools)
|
||||
# },
|
||||
# )
|
||||
#
|
||||
# if stop:
|
||||
# _llm = llm.bind(stop=stop)
|
||||
# else:
|
||||
# _llm = llm
|
||||
#
|
||||
# chain = _llm.bind(functions=openai_funcs) | _interpret_message
|
||||
# return chain
|
||||
#
|
||||
#
|
||||
# def _interpret_message(message: BaseMessage) -> ActingResult:
|
||||
# """Interpret a message."""
|
||||
# if isinstance(message, AIMessage) and "function_call" in message.additional_kwargs:
|
||||
# raise NotImplementedError()
|
||||
# # if invoke_function:
|
||||
# # result = invoke_from_function.invoke(message) # TODO: fixme using invoke
|
||||
# # else:
|
||||
# # result = None
|
||||
# # return {
|
||||
# # "message": message,
|
||||
# # "function_call": {
|
||||
# # "name": message.additional_kwargs["function_call"]["name"],
|
||||
# # "arguments": json.loads(
|
||||
# # message.additional_kwargs["function_call"]["arguments"]
|
||||
# # ),
|
||||
# # "result": result,
|
||||
# # # Check this works.
|
||||
# # # "result": message.additional_kwargs["function_call"]
|
||||
# # # | invoke_from_function,
|
||||
# # },
|
||||
# # }
|
||||
# else:
|
||||
# return {
|
||||
# "message": message,
|
||||
# "function_call": None,
|
||||
# }
|
||||
72
libs/langchain/langchain/automaton/openai_agent.py
Normal file
72
libs/langchain/langchain/automaton/openai_agent.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, List
|
||||
|
||||
from langchain.automaton.prompt_generators import MessageLogPromptValue
|
||||
from langchain.automaton.runnables import create_llm_program
|
||||
from langchain.automaton.typedefs import (
|
||||
MessageLog,
|
||||
AgentFinish,
|
||||
FunctionCall,
|
||||
)
|
||||
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
||||
from langchain.schema import Generation
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.messages import AIMessage
|
||||
from langchain.schema.output_parser import BaseGenerationOutputParser
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
|
||||
class OpenAIFunctionsParser(BaseGenerationOutputParser):
|
||||
def parse_result(self, result: List[Generation]):
|
||||
if len(result) != 1:
|
||||
raise AssertionError(f"Expected exactly one result")
|
||||
first_result = result[0]
|
||||
|
||||
message = first_result.message
|
||||
|
||||
if not isinstance(message, AIMessage) or not message.additional_kwargs:
|
||||
return AgentFinish(result=message)
|
||||
|
||||
parser = JsonOutputFunctionsParser(strict=False, args_only=False)
|
||||
try:
|
||||
function_request = parser.parse_result(result)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error parsing result: {result} {repr(e)}") from e
|
||||
|
||||
return FunctionCall(
|
||||
name=function_request["name"],
|
||||
arguments=function_request["arguments"],
|
||||
)
|
||||
|
||||
|
||||
class OpenAIAgent:
|
||||
def __init__(
|
||||
self,
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool],
|
||||
*,
|
||||
max_iterations: int = 10,
|
||||
) -> None:
|
||||
"""Initialize the chat automaton."""
|
||||
self.llm_program = create_llm_program(
|
||||
llm,
|
||||
prompt_generator=MessageLogPromptValue.from_message_log,
|
||||
tools=tools,
|
||||
parser=OpenAIFunctionsParser(),
|
||||
)
|
||||
self.max_iterations = max_iterations
|
||||
|
||||
def run(self, message_log: MessageLog) -> None:
|
||||
"""Run the agent."""
|
||||
if not message_log:
|
||||
raise AssertionError(f"Expected at least one message in message_log")
|
||||
|
||||
for _ in range(self.max_iterations):
|
||||
last_message = message_log[-1]
|
||||
|
||||
if isinstance(last_message, AgentFinish):
|
||||
break
|
||||
|
||||
messages = self.llm_program.invoke(message_log)
|
||||
message_log.add_messages(messages)
|
||||
55
libs/langchain/langchain/automaton/prompt_generators.py
Normal file
55
libs/langchain/langchain/automaton/prompt_generators.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import List
|
||||
|
||||
from langchain.automaton.typedefs import MessageLog, FunctionResult
|
||||
from langchain.schema import PromptValue, BaseMessage, FunctionMessage
|
||||
|
||||
|
||||
class MessageLogPromptValue(PromptValue):
|
||||
"""Base abstract class for inputs to any language model.
|
||||
|
||||
PromptValues can be converted to both LLM (pure text-generation) inputs and
|
||||
ChatModel inputs.
|
||||
"""
|
||||
|
||||
message_log: MessageLog
|
||||
# If True will use the OpenAI function method
|
||||
use_function_message: bool = (
|
||||
False # TODO(Eugene): replace with adapter, should be generic
|
||||
)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def to_string(self) -> str:
|
||||
"""Return prompt value as string."""
|
||||
finalized = []
|
||||
for message in self.to_messages():
|
||||
prefix = message.type
|
||||
finalized.append(f"{prefix}: {message.content}")
|
||||
return "\n".join(finalized) + "\n" + "ai:"
|
||||
|
||||
def to_messages(self) -> List[BaseMessage]:
|
||||
"""Return prompt as a list of Messages."""
|
||||
messages = []
|
||||
for message in self.message_log.messages:
|
||||
if isinstance(message, BaseMessage):
|
||||
messages.append(message)
|
||||
elif isinstance(message, FunctionResult):
|
||||
if self.use_function_message:
|
||||
messages.append(
|
||||
FunctionMessage(
|
||||
name=message.name, content=json.dumps(message.result)
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Ignore internal messages
|
||||
pass
|
||||
return messages
|
||||
|
||||
@classmethod
|
||||
def from_message_log(cls, message_log: MessageLog) -> MessageLogPromptValue:
|
||||
"""Create a PromptValue from a MessageLog."""
|
||||
return cls(message_log=message_log)
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Module contains well known runnables for agents."""
|
||||
"""Module contains useful runnables for agents."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Callable, Union, List, Optional
|
||||
from typing import Sequence, Callable, List, Optional, Union
|
||||
|
||||
from langchain.automaton.typedefs import (
|
||||
MessageLike,
|
||||
@@ -13,6 +13,10 @@ from langchain.schema import BaseMessage, AIMessage, PromptValue
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.runnable.base import RunnableLambda, Runnable
|
||||
from langchain.tools import BaseTool
|
||||
from langchain.tools.convert_to_openai import format_tool_to_openai_function
|
||||
|
||||
|
||||
# PUBLIC API
|
||||
|
||||
|
||||
def create_tool_invoker(
|
||||
@@ -21,16 +25,19 @@ def create_tool_invoker(
|
||||
"""Re-write with router."""
|
||||
tools_by_name = {tool.name: tool for tool in tools}
|
||||
|
||||
def func(input: FunctionCall) -> FunctionResult:
|
||||
def func(function_call: FunctionCall) -> FunctionResult:
|
||||
"""A function that can invoke a tool using .run"""
|
||||
tool = tools_by_name[input.name]
|
||||
try:
|
||||
result = tool.run(input.arguments or {})
|
||||
tool = tools_by_name[function_call.name]
|
||||
except KeyError:
|
||||
raise AssertionError(f"No such tool: {function_call.name}")
|
||||
try:
|
||||
result = tool.run(function_call.arguments or {})
|
||||
error = None
|
||||
except Exception as e:
|
||||
result = None
|
||||
error = repr(e) + repr(input.arguments)
|
||||
return FunctionResult(result=result, error=error)
|
||||
error = repr(e) + repr(function_call.arguments)
|
||||
return FunctionResult(name=function_call.name, result=result, error=error)
|
||||
|
||||
return RunnableLambda(func=func)
|
||||
|
||||
@@ -39,17 +46,28 @@ def create_llm_program(
|
||||
llm: BaseLanguageModel,
|
||||
prompt_generator: Callable[[MessageLog], PromptValue],
|
||||
*,
|
||||
tools: Optional[Sequence[BaseTool]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
parser: Union[Runnable, Callable] = None,
|
||||
parser: Union[
|
||||
Runnable[Union[BaseMessage, str], MessageLike],
|
||||
Callable[[Union[BaseMessage, str]], MessageLike],
|
||||
None,
|
||||
] = None,
|
||||
invoke_tools: bool = True,
|
||||
) -> Runnable[MessageLog, List[MessageLike]]:
|
||||
"""Create a runnable that can update memory."""
|
||||
|
||||
tool_invoker = create_tool_invoker(tools) if invoke_tools else None
|
||||
openai_funcs = [format_tool_to_openai_function(tool_) for tool_ in tools]
|
||||
|
||||
def _bound(message_log: MessageLog):
|
||||
messages = []
|
||||
prompt_value = prompt_generator(message_log)
|
||||
llm_chain = llm
|
||||
if stop:
|
||||
llm_chain = llm_chain.bind(stop=stop)
|
||||
if tools:
|
||||
llm_chain = llm_chain.bind(tools=openai_funcs)
|
||||
|
||||
result = llm_chain.invoke(prompt_value)
|
||||
|
||||
@@ -58,7 +76,7 @@ def create_llm_program(
|
||||
elif isinstance(result, str):
|
||||
messages.append(AIMessage(content=result))
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported type {result}")
|
||||
raise NotImplementedError(f"Unsupported type {type(result)}")
|
||||
|
||||
if parser:
|
||||
if not isinstance(parser, Runnable):
|
||||
@@ -73,6 +91,12 @@ def create_llm_program(
|
||||
)
|
||||
messages.append(parsed_result)
|
||||
|
||||
last_message = messages[-1]
|
||||
|
||||
if tool_invoker and isinstance(last_message, FunctionCall):
|
||||
function_result = tool_invoker.invoke(last_message)
|
||||
messages.append(function_result)
|
||||
|
||||
return messages
|
||||
|
||||
return RunnableLambda(
|
||||
|
||||
@@ -1,201 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, List
|
||||
|
||||
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
|
||||
# from langchain.automaton.chat_automaton import ChatAutomaton
|
||||
from langchain.automaton.mrkl_agent import ActionParser
|
||||
from langchain.automaton.tests.utils import (
|
||||
FakeChatOpenAI,
|
||||
construct_func_invocation_message,
|
||||
)
|
||||
from langchain.schema import PromptValue
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
SystemMessage,
|
||||
FunctionMessage,
|
||||
)
|
||||
from langchain.schema.runnable import RunnableLambda
|
||||
from langchain.tools import tool, Tool
|
||||
from langchain.tools.base import tool as tool_maker
|
||||
|
||||
|
||||
def get_tools() -> List[Tool]:
|
||||
@tool
|
||||
def name() -> str:
|
||||
"""Use to look up the user's name"""
|
||||
return "Eugene"
|
||||
|
||||
@tool
|
||||
def get_weather(city: str) -> str:
|
||||
"""Get weather in a specific city."""
|
||||
return "42F and sunny"
|
||||
|
||||
@tool
|
||||
def add(x: int, y: int) -> int:
|
||||
"""Use to add two numbers."""
|
||||
return x + y
|
||||
|
||||
return list(locals().values())
|
||||
|
||||
|
||||
def test_structured_output_chat() -> None:
|
||||
parser = StructuredChatOutputParser()
|
||||
output = parser.parse(
|
||||
"""
|
||||
```json
|
||||
{
|
||||
"action": "hello",
|
||||
"action_input": {
|
||||
"a": 2
|
||||
}
|
||||
}
|
||||
```
|
||||
"""
|
||||
)
|
||||
assert output == {}
|
||||
|
||||
|
||||
class MessageBasedPromptValue(PromptValue):
|
||||
"""Prompt Value populated from messages."""
|
||||
|
||||
messages: List[BaseMessage]
|
||||
|
||||
@classmethod
|
||||
def from_messages(cls, messages: Sequence[BaseMessage]) -> MessageBasedPromptValue:
|
||||
return cls(messages=messages)
|
||||
|
||||
def to_messages(self) -> List[BaseMessage]:
|
||||
return self.messages
|
||||
|
||||
def to_string(self) -> str:
|
||||
return "\n".join([message.content for message in self.messages])
|
||||
|
||||
|
||||
|
||||
# def prompt_generator(memory: Memory) -> PromptValue:
|
||||
# """Generate a prompt."""
|
||||
# if not memory.messages:
|
||||
# raise AssertionError("Memory is empty")
|
||||
# return MessageBasedPromptValue.from_messages(messages=memory.messages)
|
||||
#
|
||||
|
||||
def test_automaton() -> None:
|
||||
"""Run the automaton."""
|
||||
|
||||
@tool_maker
|
||||
def get_time() -> str:
|
||||
"""Get time."""
|
||||
return "9 PM"
|
||||
|
||||
@tool_maker
|
||||
def get_location() -> str:
|
||||
"""Get location."""
|
||||
return "the park"
|
||||
|
||||
tools = [get_time, get_location]
|
||||
llm = FakeChatOpenAI(
|
||||
message_iter=iter(
|
||||
[
|
||||
construct_func_invocation_message(get_time, {}),
|
||||
AIMessage(
|
||||
content="The time is 9 PM.",
|
||||
),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
# TODO(FIX MUTABILITY)
|
||||
|
||||
memory = Memory(
|
||||
messages=[
|
||||
SystemMessage(
|
||||
content=(
|
||||
"Hello! I'm a chatbot that can help you write a letter. "
|
||||
"What would you like to do?"
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
chat_automaton = ChatAutomaton(
|
||||
llm=llm, tools=tools, prompt_generator=prompt_generator
|
||||
)
|
||||
executor = Executor(chat_automaton, memory, max_iterations=1)
|
||||
state, executed_states = executor.run()
|
||||
assert executed_states == [
|
||||
{
|
||||
"data": {
|
||||
"message": FunctionMessage(
|
||||
content="9 PM", additional_kwargs={}, name="get_time"
|
||||
)
|
||||
},
|
||||
"id": "llm_program",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_generate_template() -> None:
|
||||
"""Generate template."""
|
||||
template = generate_template()
|
||||
assert template.format_messages(tools="hello", tool_names="hello") == []
|
||||
|
||||
|
||||
def test_parser() -> None:
|
||||
"""Tes the parser."""
|
||||
sample_text = """
|
||||
Some text before
|
||||
<action>
|
||||
{
|
||||
"key": "value",
|
||||
"number": 42
|
||||
}
|
||||
</action>
|
||||
Some text after
|
||||
"""
|
||||
action_parser = ActionParser(strict=False)
|
||||
action = action_parser.decode(sample_text)
|
||||
assert action == {
|
||||
"key": "value",
|
||||
"number": 42,
|
||||
}
|
||||
|
||||
def test_function_invocation() -> None:
|
||||
"""test function invocation"""
|
||||
tools = get_tools()
|
||||
from langchain.automaton.well_known_states import create_tool_invoker
|
||||
runnable = create_tool_invoker(tools)
|
||||
result = runnable.invoke({"name": "add", "inputs": {"x": 1, "y": 2}})
|
||||
assert result == 3
|
||||
|
||||
|
||||
|
||||
def test_create_llm_program() -> None:
|
||||
"""Generate llm program."""
|
||||
from langchain.automaton.mrkl_automaton import (
|
||||
_generate_prompt,
|
||||
_generate_mrkl_memory,
|
||||
)
|
||||
|
||||
tools = get_tools()
|
||||
llm = FakeChatOpenAI(
|
||||
message_iter=iter(
|
||||
[
|
||||
AIMessage(
|
||||
content="""Thought: Hello. <action>{"name": "key"}</action>""",
|
||||
),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
program = create_llm_program(
|
||||
"think-act",
|
||||
llm,
|
||||
prompt_generator=_generate_prompt,
|
||||
stop=["Observation"],
|
||||
parser=RunnableLambda(ActionParser(strict=False).decode),
|
||||
)
|
||||
|
||||
mrkl_memory = _generate_mrkl_memory(tools)
|
||||
result = program.invoke(mrkl_memory)
|
||||
assert result == {"id": "think-act", "data": {}}
|
||||
41
libs/langchain/langchain/automaton/tests/test_mrkl_agent.py
Normal file
41
libs/langchain/langchain/automaton/tests/test_mrkl_agent.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
|
||||
from langchain.automaton.mrkl_agent import ActionParser
|
||||
|
||||
|
||||
def test_structured_output_chat() -> None:
|
||||
parser = StructuredChatOutputParser()
|
||||
output = parser.parse(
|
||||
"""
|
||||
```json
|
||||
{
|
||||
"action": "hello",
|
||||
"action_input": {
|
||||
"a": 2
|
||||
}
|
||||
}
|
||||
```
|
||||
"""
|
||||
)
|
||||
assert output == {}
|
||||
|
||||
|
||||
def test_parser() -> None:
|
||||
"""Tes the parser."""
|
||||
sample_text = """
|
||||
Some text before
|
||||
<action>
|
||||
{
|
||||
"key": "value",
|
||||
"number": 42
|
||||
}
|
||||
</action>
|
||||
Some text after
|
||||
"""
|
||||
action_parser = ActionParser()
|
||||
action = action_parser.decode(sample_text)
|
||||
assert action == {
|
||||
"key": "value",
|
||||
"number": 42,
|
||||
}
|
||||
@@ -0,0 +1,97 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, cast
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.automaton.openai_agent import OpenAIAgent
|
||||
from langchain.automaton.tests.utils import (
|
||||
FakeChatOpenAI,
|
||||
construct_func_invocation_message,
|
||||
)
|
||||
from langchain.automaton.typedefs import (
|
||||
FunctionCall,
|
||||
FunctionResult,
|
||||
MessageLog,
|
||||
AgentFinish,
|
||||
)
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain.tools import tool, Tool
|
||||
from langchain.tools.base import BaseTool
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def tools() -> List[BaseTool]:
|
||||
@tool
|
||||
def get_time() -> str:
|
||||
"""Get time."""
|
||||
return "9 PM"
|
||||
|
||||
@tool
|
||||
def get_location() -> str:
|
||||
"""Get location."""
|
||||
return "the park"
|
||||
|
||||
return cast(List[Tool], [get_time, get_location])
|
||||
|
||||
|
||||
def test_openai_agent(tools: List[Tool]) -> None:
|
||||
get_time, get_location = tools
|
||||
llm = FakeChatOpenAI(
|
||||
message_iter=iter(
|
||||
[
|
||||
construct_func_invocation_message(get_time, {}),
|
||||
AIMessage(
|
||||
content="The time is 9 PM.",
|
||||
),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
agent = OpenAIAgent(llm=llm, tools=tools, max_iterations=10)
|
||||
|
||||
message_log = MessageLog(
|
||||
[
|
||||
SystemMessage(
|
||||
content="What time is it?",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
expected_messages = [
|
||||
SystemMessage(
|
||||
content="What time is it?",
|
||||
),
|
||||
AIMessage(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {
|
||||
"name": "get_time",
|
||||
"arguments": "{}",
|
||||
}
|
||||
},
|
||||
),
|
||||
FunctionCall(
|
||||
name="get_time",
|
||||
arguments={},
|
||||
),
|
||||
FunctionResult(
|
||||
name="get_time",
|
||||
result="9 PM",
|
||||
error=None,
|
||||
),
|
||||
AIMessage(
|
||||
content="The time is 9 PM.",
|
||||
),
|
||||
AgentFinish(
|
||||
AIMessage(
|
||||
content="The time is 9 PM.",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
agent.run(message_log)
|
||||
assert message_log.messages == expected_messages
|
||||
@@ -1,11 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from typing import Any, Optional, Sequence, List, Mapping, overload, Union
|
||||
from typing import Any, Optional, Sequence, Mapping, overload, Union
|
||||
|
||||
from langchain.schema import (
|
||||
BaseMessage,
|
||||
PromptValue,
|
||||
)
|
||||
|
||||
|
||||
@@ -17,6 +16,7 @@ class FunctionCall:
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class FunctionResult:
|
||||
name: str
|
||||
result: Any
|
||||
error: Optional[str]
|
||||
|
||||
@@ -41,7 +41,7 @@ MessageLike = Union[
|
||||
class MessageLog:
|
||||
"""A generalized message log for message like items."""
|
||||
|
||||
def __init__(self, messages: Sequence[MessageLike]) -> None:
|
||||
def __init__(self, messages: Sequence[MessageLike] = ()) -> None:
|
||||
"""Initialize the message log."""
|
||||
self.messages = list(messages)
|
||||
|
||||
@@ -66,35 +66,15 @@ class MessageLog:
|
||||
else:
|
||||
return self.messages[index]
|
||||
|
||||
def __bool__(self):
|
||||
return bool(self.messages)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Get the length of the chat template."""
|
||||
return len(self.messages)
|
||||
|
||||
|
||||
class MessageLogPromptValue(PromptValue):
|
||||
"""Base abstract class for inputs to any language model.
|
||||
|
||||
PromptValues can be converted to both LLM (pure text-generation) inputs and
|
||||
ChatModel inputs.
|
||||
"""
|
||||
|
||||
message_log: MessageLog
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def to_string(self) -> str:
|
||||
"""Return prompt value as string."""
|
||||
finalized = []
|
||||
for message in self.to_messages():
|
||||
prefix = message.type
|
||||
finalized.append(f"{prefix}: {message.content}")
|
||||
return "\n".join(finalized) + "\n" + "ai:"
|
||||
|
||||
def to_messages(self) -> List[BaseMessage]:
|
||||
"""Return prompt as a list of Messages."""
|
||||
return [
|
||||
message
|
||||
for message in self.message_log.messages
|
||||
if isinstance(message, BaseMessage)
|
||||
]
|
||||
class Agent:
|
||||
def run(self, message_log: MessageLog) -> None:
|
||||
"""Run the agent on a message."""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -79,6 +79,15 @@ class BaseMessage(Serializable):
|
||||
"""Whether this class is LangChain serializable."""
|
||||
return True
|
||||
|
||||
def __repr__(self):
|
||||
if self.additional_kwargs:
|
||||
return f"{self.type}: {self.additional_kwargs}"
|
||||
else:
|
||||
return f"{self.type}: {self.content}"
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
def __add__(self, other: Any) -> ChatPromptTemplate:
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
|
||||
|
||||
Reference in New Issue
Block a user