mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-05 14:43:08 +00:00
format intermediate steps (#10794)
Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
This commit is contained in:
parent
386ef1e654
commit
7dec2d399b
16
libs/langchain/langchain/agents/format_scratchpad/log.py
Normal file
16
libs/langchain/langchain/agents/format_scratchpad/log.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
from langchain.schema.agent import AgentAction
|
||||||
|
|
||||||
|
|
||||||
|
def format_log_to_str(
|
||||||
|
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||||
|
observation_prefix: str = "Observation: ",
|
||||||
|
llm_prefix: str = "Thought: ",
|
||||||
|
) -> str:
|
||||||
|
"""Construct the scratchpad that lets the agent continue its thought process."""
|
||||||
|
thoughts = ""
|
||||||
|
for action, observation in intermediate_steps:
|
||||||
|
thoughts += action.log
|
||||||
|
thoughts += f"\n{observation_prefix}{observation}\n{llm_prefix}"
|
||||||
|
return thoughts
|
@ -0,0 +1,19 @@
|
|||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
from langchain.schema.agent import AgentAction
|
||||||
|
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage
|
||||||
|
|
||||||
|
|
||||||
|
def format_log_to_messages(
|
||||||
|
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||||
|
template_tool_response: str = "{observation}",
|
||||||
|
) -> List[BaseMessage]:
|
||||||
|
"""Construct the scratchpad that lets the agent continue its thought process."""
|
||||||
|
thoughts: List[BaseMessage] = []
|
||||||
|
for action, observation in intermediate_steps:
|
||||||
|
thoughts.append(AIMessage(content=action.log))
|
||||||
|
human_message = HumanMessage(
|
||||||
|
content=template_tool_response.format(observation=observation)
|
||||||
|
)
|
||||||
|
thoughts.append(human_message)
|
||||||
|
return thoughts
|
@ -0,0 +1,66 @@
|
|||||||
|
import json
|
||||||
|
from typing import List, Sequence, Tuple
|
||||||
|
|
||||||
|
from langchain.schema.agent import AgentAction, AgentActionMessageLog
|
||||||
|
from langchain.schema.messages import AIMessage, BaseMessage, FunctionMessage
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_agent_action_to_messages(
|
||||||
|
agent_action: AgentAction, observation: str
|
||||||
|
) -> List[BaseMessage]:
|
||||||
|
"""Convert an agent action to a message.
|
||||||
|
|
||||||
|
This code is used to reconstruct the original AI message from the agent action.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_action: Agent action to convert.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AIMessage that corresponds to the original tool invocation.
|
||||||
|
"""
|
||||||
|
if isinstance(agent_action, AgentActionMessageLog):
|
||||||
|
return list(agent_action.message_log) + [
|
||||||
|
_create_function_message(agent_action, observation)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
return [AIMessage(content=agent_action.log)]
|
||||||
|
|
||||||
|
|
||||||
|
def _create_function_message(
|
||||||
|
agent_action: AgentAction, observation: str
|
||||||
|
) -> FunctionMessage:
|
||||||
|
"""Convert agent action and observation into a function message.
|
||||||
|
Args:
|
||||||
|
agent_action: the tool invocation request from the agent
|
||||||
|
observation: the result of the tool invocation
|
||||||
|
Returns:
|
||||||
|
FunctionMessage that corresponds to the original tool invocation
|
||||||
|
"""
|
||||||
|
if not isinstance(observation, str):
|
||||||
|
try:
|
||||||
|
content = json.dumps(observation, ensure_ascii=False)
|
||||||
|
except Exception:
|
||||||
|
content = str(observation)
|
||||||
|
else:
|
||||||
|
content = observation
|
||||||
|
return FunctionMessage(
|
||||||
|
name=agent_action.tool,
|
||||||
|
content=content,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def format_to_openai_functions(
|
||||||
|
intermediate_steps: Sequence[Tuple[AgentAction, str]],
|
||||||
|
) -> List[BaseMessage]:
|
||||||
|
"""Format intermediate steps.
|
||||||
|
Args:
|
||||||
|
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||||
|
Returns:
|
||||||
|
list of messages to send to the LLM for the next prediction
|
||||||
|
"""
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
for agent_action, observation in intermediate_steps:
|
||||||
|
messages.extend(_convert_agent_action_to_messages(agent_action, observation))
|
||||||
|
|
||||||
|
return messages
|
15
libs/langchain/langchain/agents/format_scratchpad/xml.py
Normal file
15
libs/langchain/langchain/agents/format_scratchpad/xml.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
from langchain.schema.agent import AgentAction
|
||||||
|
|
||||||
|
|
||||||
|
def format_xml(
|
||||||
|
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||||
|
) -> str:
|
||||||
|
log = ""
|
||||||
|
for action, observation in intermediate_steps:
|
||||||
|
log += (
|
||||||
|
f"<tool>{action.tool}</tool><tool_input>{action.tool_input}"
|
||||||
|
f"</tool_input><observation>{observation}</observation>"
|
||||||
|
)
|
||||||
|
return log
|
@ -1,7 +1,9 @@
|
|||||||
"""Memory used to save agent output AND intermediate steps."""
|
"""Memory used to save agent output AND intermediate steps."""
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps
|
from langchain.agents.format_scratchpad.openai_functions import (
|
||||||
|
format_to_openai_functions,
|
||||||
|
)
|
||||||
from langchain.memory.chat_memory import BaseChatMemory
|
from langchain.memory.chat_memory import BaseChatMemory
|
||||||
from langchain.schema.language_model import BaseLanguageModel
|
from langchain.schema.language_model import BaseLanguageModel
|
||||||
from langchain.schema.messages import BaseMessage, get_buffer_string
|
from langchain.schema.messages import BaseMessage, get_buffer_string
|
||||||
@ -50,7 +52,7 @@ class AgentTokenBufferMemory(BaseChatMemory):
|
|||||||
"""Save context from this conversation to buffer. Pruned."""
|
"""Save context from this conversation to buffer. Pruned."""
|
||||||
input_str, output_str = self._get_input_output(inputs, outputs)
|
input_str, output_str = self._get_input_output(inputs, outputs)
|
||||||
self.chat_memory.add_user_message(input_str)
|
self.chat_memory.add_user_message(input_str)
|
||||||
steps = _format_intermediate_steps(outputs[self.intermediate_steps_key])
|
steps = format_to_openai_functions(outputs[self.intermediate_steps_key])
|
||||||
for msg in steps:
|
for msg in steps:
|
||||||
self.chat_memory.add_message(msg)
|
self.chat_memory.add_message(msg)
|
||||||
self.chat_memory.add_ai_message(output_str)
|
self.chat_memory.add_ai_message(output_str)
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
"""Module implements an agent that uses OpenAI's APIs function enabled API."""
|
"""Module implements an agent that uses OpenAI's APIs function enabled API."""
|
||||||
import json
|
|
||||||
from typing import Any, List, Optional, Sequence, Tuple, Union
|
from typing import Any, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
from langchain.agents import BaseSingleActionAgent
|
from langchain.agents import BaseSingleActionAgent
|
||||||
|
from langchain.agents.format_scratchpad.openai_functions import (
|
||||||
|
format_to_openai_functions,
|
||||||
|
)
|
||||||
from langchain.agents.output_parsers.openai_functions import (
|
from langchain.agents.output_parsers.openai_functions import (
|
||||||
OpenAIFunctionsAgentOutputParser,
|
OpenAIFunctionsAgentOutputParser,
|
||||||
)
|
)
|
||||||
@ -21,82 +23,14 @@ from langchain.schema import (
|
|||||||
AgentFinish,
|
AgentFinish,
|
||||||
BasePromptTemplate,
|
BasePromptTemplate,
|
||||||
)
|
)
|
||||||
from langchain.schema.agent import AgentActionMessageLog
|
|
||||||
from langchain.schema.language_model import BaseLanguageModel
|
from langchain.schema.language_model import BaseLanguageModel
|
||||||
from langchain.schema.messages import (
|
from langchain.schema.messages import (
|
||||||
AIMessage,
|
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
FunctionMessage,
|
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
)
|
)
|
||||||
from langchain.tools import BaseTool
|
from langchain.tools import BaseTool
|
||||||
from langchain.tools.convert_to_openai import format_tool_to_openai_function
|
from langchain.tools.convert_to_openai import format_tool_to_openai_function
|
||||||
|
|
||||||
# For backwards compatibility
|
|
||||||
_FunctionsAgentAction = AgentActionMessageLog
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_agent_action_to_messages(
|
|
||||||
agent_action: AgentAction, observation: str
|
|
||||||
) -> List[BaseMessage]:
|
|
||||||
"""Convert an agent action to a message.
|
|
||||||
|
|
||||||
This code is used to reconstruct the original AI message from the agent action.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_action: Agent action to convert.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
AIMessage that corresponds to the original tool invocation.
|
|
||||||
"""
|
|
||||||
if isinstance(agent_action, _FunctionsAgentAction):
|
|
||||||
return list(agent_action.message_log) + [
|
|
||||||
_create_function_message(agent_action, observation)
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
return [AIMessage(content=agent_action.log)]
|
|
||||||
|
|
||||||
|
|
||||||
def _create_function_message(
|
|
||||||
agent_action: AgentAction, observation: str
|
|
||||||
) -> FunctionMessage:
|
|
||||||
"""Convert agent action and observation into a function message.
|
|
||||||
Args:
|
|
||||||
agent_action: the tool invocation request from the agent
|
|
||||||
observation: the result of the tool invocation
|
|
||||||
Returns:
|
|
||||||
FunctionMessage that corresponds to the original tool invocation
|
|
||||||
"""
|
|
||||||
if not isinstance(observation, str):
|
|
||||||
try:
|
|
||||||
content = json.dumps(observation, ensure_ascii=False)
|
|
||||||
except Exception:
|
|
||||||
content = str(observation)
|
|
||||||
else:
|
|
||||||
content = observation
|
|
||||||
return FunctionMessage(
|
|
||||||
name=agent_action.tool,
|
|
||||||
content=content,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _format_intermediate_steps(
|
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
|
||||||
) -> List[BaseMessage]:
|
|
||||||
"""Format intermediate steps.
|
|
||||||
Args:
|
|
||||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
|
||||||
Returns:
|
|
||||||
list of messages to send to the LLM for the next prediction
|
|
||||||
"""
|
|
||||||
messages = []
|
|
||||||
|
|
||||||
for intermediate_step in intermediate_steps:
|
|
||||||
agent_action, observation = intermediate_step
|
|
||||||
messages.extend(_convert_agent_action_to_messages(agent_action, observation))
|
|
||||||
|
|
||||||
return messages
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIFunctionsAgent(BaseSingleActionAgent):
|
class OpenAIFunctionsAgent(BaseSingleActionAgent):
|
||||||
"""An Agent driven by OpenAIs function powered API.
|
"""An Agent driven by OpenAIs function powered API.
|
||||||
@ -159,7 +93,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
|
|||||||
Returns:
|
Returns:
|
||||||
Action specifying what tool to use.
|
Action specifying what tool to use.
|
||||||
"""
|
"""
|
||||||
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
|
agent_scratchpad = format_to_openai_functions(intermediate_steps)
|
||||||
selected_inputs = {
|
selected_inputs = {
|
||||||
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
|
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
|
||||||
}
|
}
|
||||||
@ -198,7 +132,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
|
|||||||
Returns:
|
Returns:
|
||||||
Action specifying what tool to use.
|
Action specifying what tool to use.
|
||||||
"""
|
"""
|
||||||
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
|
agent_scratchpad = format_to_openai_functions(intermediate_steps)
|
||||||
selected_inputs = {
|
selected_inputs = {
|
||||||
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
|
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
|
||||||
}
|
}
|
||||||
|
@ -4,6 +4,9 @@ from json import JSONDecodeError
|
|||||||
from typing import Any, List, Optional, Sequence, Tuple, Union
|
from typing import Any, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
from langchain.agents import BaseMultiActionAgent
|
from langchain.agents import BaseMultiActionAgent
|
||||||
|
from langchain.agents.format_scratchpad.openai_functions import (
|
||||||
|
format_to_openai_functions,
|
||||||
|
)
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
from langchain.callbacks.manager import Callbacks
|
from langchain.callbacks.manager import Callbacks
|
||||||
from langchain.chat_models.openai import ChatOpenAI
|
from langchain.chat_models.openai import ChatOpenAI
|
||||||
@ -25,7 +28,6 @@ from langchain.schema.language_model import BaseLanguageModel
|
|||||||
from langchain.schema.messages import (
|
from langchain.schema.messages import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
FunctionMessage,
|
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
)
|
)
|
||||||
from langchain.tools import BaseTool
|
from langchain.tools import BaseTool
|
||||||
@ -34,68 +36,6 @@ from langchain.tools import BaseTool
|
|||||||
_FunctionsAgentAction = AgentActionMessageLog
|
_FunctionsAgentAction = AgentActionMessageLog
|
||||||
|
|
||||||
|
|
||||||
def _convert_agent_action_to_messages(
|
|
||||||
agent_action: AgentAction, observation: str
|
|
||||||
) -> List[BaseMessage]:
|
|
||||||
"""Convert an agent action to a message.
|
|
||||||
|
|
||||||
This code is used to reconstruct the original AI message from the agent action.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_action: Agent action to convert.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
AIMessage that corresponds to the original tool invocation.
|
|
||||||
"""
|
|
||||||
if isinstance(agent_action, _FunctionsAgentAction):
|
|
||||||
return list(agent_action.message_log) + [
|
|
||||||
_create_function_message(agent_action, observation)
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
return [AIMessage(content=agent_action.log)]
|
|
||||||
|
|
||||||
|
|
||||||
def _create_function_message(
|
|
||||||
agent_action: AgentAction, observation: str
|
|
||||||
) -> FunctionMessage:
|
|
||||||
"""Convert agent action and observation into a function message.
|
|
||||||
Args:
|
|
||||||
agent_action: the tool invocation request from the agent
|
|
||||||
observation: the result of the tool invocation
|
|
||||||
Returns:
|
|
||||||
FunctionMessage that corresponds to the original tool invocation
|
|
||||||
"""
|
|
||||||
if not isinstance(observation, str):
|
|
||||||
try:
|
|
||||||
content = json.dumps(observation, ensure_ascii=False)
|
|
||||||
except Exception:
|
|
||||||
content = str(observation)
|
|
||||||
else:
|
|
||||||
content = observation
|
|
||||||
return FunctionMessage(
|
|
||||||
name=agent_action.tool,
|
|
||||||
content=content,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _format_intermediate_steps(
|
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
|
||||||
) -> List[BaseMessage]:
|
|
||||||
"""Format intermediate steps.
|
|
||||||
Args:
|
|
||||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
|
||||||
Returns:
|
|
||||||
list of messages to send to the LLM for the next prediction
|
|
||||||
"""
|
|
||||||
messages = []
|
|
||||||
|
|
||||||
for intermediate_step in intermediate_steps:
|
|
||||||
agent_action, observation = intermediate_step
|
|
||||||
messages.extend(_convert_agent_action_to_messages(agent_action, observation))
|
|
||||||
|
|
||||||
return messages
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_ai_message(message: BaseMessage) -> Union[List[AgentAction], AgentFinish]:
|
def _parse_ai_message(message: BaseMessage) -> Union[List[AgentAction], AgentFinish]:
|
||||||
"""Parse an AI message."""
|
"""Parse an AI message."""
|
||||||
if not isinstance(message, AIMessage):
|
if not isinstance(message, AIMessage):
|
||||||
@ -257,7 +197,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
|
|||||||
Returns:
|
Returns:
|
||||||
Action specifying what tool to use.
|
Action specifying what tool to use.
|
||||||
"""
|
"""
|
||||||
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
|
agent_scratchpad = format_to_openai_functions(intermediate_steps)
|
||||||
selected_inputs = {
|
selected_inputs = {
|
||||||
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
|
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
|
||||||
}
|
}
|
||||||
@ -286,7 +226,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
|
|||||||
Returns:
|
Returns:
|
||||||
Action specifying what tool to use.
|
Action specifying what tool to use.
|
||||||
"""
|
"""
|
||||||
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
|
agent_scratchpad = format_to_openai_functions(intermediate_steps)
|
||||||
selected_inputs = {
|
selected_inputs = {
|
||||||
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
|
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,40 @@
|
|||||||
|
from langchain.agents.format_scratchpad.log import format_log_to_str
|
||||||
|
from langchain.schema.agent import AgentAction
|
||||||
|
|
||||||
|
|
||||||
|
def test_single_agent_action_observation() -> None:
|
||||||
|
intermediate_steps = [
|
||||||
|
(AgentAction(tool="Tool1", tool_input="input1", log="Log1"), "Observation1")
|
||||||
|
]
|
||||||
|
expected_result = "Log1\nObservation: Observation1\nThought: "
|
||||||
|
assert format_log_to_str(intermediate_steps) == expected_result
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_agent_actions_observations() -> None:
|
||||||
|
intermediate_steps = [
|
||||||
|
(AgentAction(tool="Tool1", tool_input="input1", log="Log1"), "Observation1"),
|
||||||
|
(AgentAction(tool="Tool2", tool_input="input2", log="Log2"), "Observation2"),
|
||||||
|
(AgentAction(tool="Tool3", tool_input="input3", log="Log3"), "Observation3"),
|
||||||
|
]
|
||||||
|
expected_result = """Log1\nObservation: Observation1\nThought: \
|
||||||
|
Log2\nObservation: Observation2\nThought: Log3\nObservation: \
|
||||||
|
Observation3\nThought: """
|
||||||
|
assert format_log_to_str(intermediate_steps) == expected_result
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_prefixes() -> None:
|
||||||
|
intermediate_steps = [
|
||||||
|
(AgentAction(tool="Tool1", tool_input="input1", log="Log1"), "Observation1")
|
||||||
|
]
|
||||||
|
observation_prefix = "Custom Observation: "
|
||||||
|
llm_prefix = "Custom Thought: "
|
||||||
|
expected_result = "Log1\nCustom Observation: Observation1\nCustom Thought: "
|
||||||
|
assert (
|
||||||
|
format_log_to_str(intermediate_steps, observation_prefix, llm_prefix)
|
||||||
|
== expected_result
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_intermediate_steps() -> None:
|
||||||
|
output = format_log_to_str([])
|
||||||
|
assert output == ""
|
@ -0,0 +1,49 @@
|
|||||||
|
from langchain.agents.format_scratchpad.log_to_messages import format_log_to_messages
|
||||||
|
from langchain.schema.agent import AgentAction
|
||||||
|
from langchain.schema.messages import AIMessage, HumanMessage
|
||||||
|
|
||||||
|
|
||||||
|
def test_single_intermediate_step_default_response() -> None:
|
||||||
|
intermediate_steps = [
|
||||||
|
(AgentAction(tool="Tool1", tool_input="input1", log="Log1"), "Observation1")
|
||||||
|
]
|
||||||
|
expected_result = [AIMessage(content="Log1"), HumanMessage(content="Observation1")]
|
||||||
|
assert format_log_to_messages(intermediate_steps) == expected_result
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_intermediate_steps_default_response() -> None:
|
||||||
|
intermediate_steps = [
|
||||||
|
(AgentAction(tool="Tool1", tool_input="input1", log="Log1"), "Observation1"),
|
||||||
|
(AgentAction(tool="Tool2", tool_input="input2", log="Log2"), "Observation2"),
|
||||||
|
(AgentAction(tool="Tool3", tool_input="input3", log="Log3"), "Observation3"),
|
||||||
|
]
|
||||||
|
expected_result = [
|
||||||
|
AIMessage(content="Log1"),
|
||||||
|
HumanMessage(content="Observation1"),
|
||||||
|
AIMessage(content="Log2"),
|
||||||
|
HumanMessage(content="Observation2"),
|
||||||
|
AIMessage(content="Log3"),
|
||||||
|
HumanMessage(content="Observation3"),
|
||||||
|
]
|
||||||
|
assert format_log_to_messages(intermediate_steps) == expected_result
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_template_tool_response() -> None:
|
||||||
|
intermediate_steps = [
|
||||||
|
(AgentAction(tool="Tool1", tool_input="input1", log="Log1"), "Observation1")
|
||||||
|
]
|
||||||
|
template_tool_response = "Response: {observation}"
|
||||||
|
expected_result = [
|
||||||
|
AIMessage(content="Log1"),
|
||||||
|
HumanMessage(content="Response: Observation1"),
|
||||||
|
]
|
||||||
|
assert (
|
||||||
|
format_log_to_messages(
|
||||||
|
intermediate_steps, template_tool_response=template_tool_response
|
||||||
|
)
|
||||||
|
== expected_result
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_steps() -> None:
|
||||||
|
assert format_log_to_messages([]) == []
|
@ -0,0 +1,60 @@
|
|||||||
|
from langchain.agents.format_scratchpad.openai_functions import (
|
||||||
|
format_to_openai_functions,
|
||||||
|
)
|
||||||
|
from langchain.schema.agent import AgentActionMessageLog
|
||||||
|
from langchain.schema.messages import AIMessage, FunctionMessage
|
||||||
|
|
||||||
|
|
||||||
|
def test_calls_convert_agent_action_to_messages() -> None:
|
||||||
|
additional_kwargs1 = {
|
||||||
|
"function_call": {
|
||||||
|
"name": "tool1",
|
||||||
|
"arguments": "input1",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
message1 = AIMessage(content="", additional_kwargs=additional_kwargs1)
|
||||||
|
action1 = AgentActionMessageLog(
|
||||||
|
tool="tool1", tool_input="input1", log="log1", message_log=[message1]
|
||||||
|
)
|
||||||
|
additional_kwargs2 = {
|
||||||
|
"function_call": {
|
||||||
|
"name": "tool2",
|
||||||
|
"arguments": "input2",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
message2 = AIMessage(content="", additional_kwargs=additional_kwargs2)
|
||||||
|
action2 = AgentActionMessageLog(
|
||||||
|
tool="tool2", tool_input="input2", log="log2", message_log=[message2]
|
||||||
|
)
|
||||||
|
|
||||||
|
additional_kwargs3 = {
|
||||||
|
"function_call": {
|
||||||
|
"name": "tool3",
|
||||||
|
"arguments": "input3",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
message3 = AIMessage(content="", additional_kwargs=additional_kwargs3)
|
||||||
|
action3 = AgentActionMessageLog(
|
||||||
|
tool="tool3", tool_input="input3", log="log3", message_log=[message3]
|
||||||
|
)
|
||||||
|
|
||||||
|
intermediate_steps = [
|
||||||
|
(action1, "observation1"),
|
||||||
|
(action2, "observation2"),
|
||||||
|
(action3, "observation3"),
|
||||||
|
]
|
||||||
|
expected_messages = [
|
||||||
|
message1,
|
||||||
|
FunctionMessage(name="tool1", content="observation1"),
|
||||||
|
message2,
|
||||||
|
FunctionMessage(name="tool2", content="observation2"),
|
||||||
|
message3,
|
||||||
|
FunctionMessage(name="tool3", content="observation3"),
|
||||||
|
]
|
||||||
|
output = format_to_openai_functions(intermediate_steps)
|
||||||
|
assert output == expected_messages
|
||||||
|
|
||||||
|
|
||||||
|
def test_handles_empty_input_list() -> None:
|
||||||
|
output = format_to_openai_functions([])
|
||||||
|
assert output == []
|
@ -0,0 +1,40 @@
|
|||||||
|
from langchain.agents.format_scratchpad.xml import format_xml
|
||||||
|
from langchain.schema.agent import AgentAction
|
||||||
|
|
||||||
|
|
||||||
|
def test_single_agent_action_observation() -> None:
|
||||||
|
# Arrange
|
||||||
|
agent_action = AgentAction(tool="Tool1", tool_input="Input1", log="Log1")
|
||||||
|
observation = "Observation1"
|
||||||
|
intermediate_steps = [(agent_action, observation)]
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = format_xml(intermediate_steps)
|
||||||
|
expected_result = """<tool>Tool1</tool><tool_input>Input1\
|
||||||
|
</tool_input><observation>Observation1</observation>"""
|
||||||
|
# Assert
|
||||||
|
assert result == expected_result
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_agent_actions_observations() -> None:
|
||||||
|
# Arrange
|
||||||
|
agent_action1 = AgentAction(tool="Tool1", tool_input="Input1", log="Log1")
|
||||||
|
agent_action2 = AgentAction(tool="Tool2", tool_input="Input2", log="Log2")
|
||||||
|
observation1 = "Observation1"
|
||||||
|
observation2 = "Observation2"
|
||||||
|
intermediate_steps = [(agent_action1, observation1), (agent_action2, observation2)]
|
||||||
|
|
||||||
|
# Act
|
||||||
|
result = format_xml(intermediate_steps)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
expected_result = """<tool>Tool1</tool><tool_input>Input1\
|
||||||
|
</tool_input><observation>Observation1</observation><tool>\
|
||||||
|
Tool2</tool><tool_input>Input2</tool_input><observation>\
|
||||||
|
Observation2</observation>"""
|
||||||
|
assert result == expected_result
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_list_agent_actions() -> None:
|
||||||
|
result = format_xml([])
|
||||||
|
assert result == ""
|
Loading…
Reference in New Issue
Block a user