mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-31 20:19:43 +00:00
format intermediate steps (#10794)
Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
This commit is contained in:
parent
386ef1e654
commit
7dec2d399b
16
libs/langchain/langchain/agents/format_scratchpad/log.py
Normal file
16
libs/langchain/langchain/agents/format_scratchpad/log.py
Normal file
@ -0,0 +1,16 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
from langchain.schema.agent import AgentAction
|
||||
|
||||
|
||||
def format_log_to_str(
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
observation_prefix: str = "Observation: ",
|
||||
llm_prefix: str = "Thought: ",
|
||||
) -> str:
|
||||
"""Construct the scratchpad that lets the agent continue its thought process."""
|
||||
thoughts = ""
|
||||
for action, observation in intermediate_steps:
|
||||
thoughts += action.log
|
||||
thoughts += f"\n{observation_prefix}{observation}\n{llm_prefix}"
|
||||
return thoughts
|
@ -0,0 +1,19 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
from langchain.schema.agent import AgentAction
|
||||
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage
|
||||
|
||||
|
||||
def format_log_to_messages(
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
template_tool_response: str = "{observation}",
|
||||
) -> List[BaseMessage]:
|
||||
"""Construct the scratchpad that lets the agent continue its thought process."""
|
||||
thoughts: List[BaseMessage] = []
|
||||
for action, observation in intermediate_steps:
|
||||
thoughts.append(AIMessage(content=action.log))
|
||||
human_message = HumanMessage(
|
||||
content=template_tool_response.format(observation=observation)
|
||||
)
|
||||
thoughts.append(human_message)
|
||||
return thoughts
|
@ -0,0 +1,66 @@
|
||||
import json
|
||||
from typing import List, Sequence, Tuple
|
||||
|
||||
from langchain.schema.agent import AgentAction, AgentActionMessageLog
|
||||
from langchain.schema.messages import AIMessage, BaseMessage, FunctionMessage
|
||||
|
||||
|
||||
def _convert_agent_action_to_messages(
|
||||
agent_action: AgentAction, observation: str
|
||||
) -> List[BaseMessage]:
|
||||
"""Convert an agent action to a message.
|
||||
|
||||
This code is used to reconstruct the original AI message from the agent action.
|
||||
|
||||
Args:
|
||||
agent_action: Agent action to convert.
|
||||
|
||||
Returns:
|
||||
AIMessage that corresponds to the original tool invocation.
|
||||
"""
|
||||
if isinstance(agent_action, AgentActionMessageLog):
|
||||
return list(agent_action.message_log) + [
|
||||
_create_function_message(agent_action, observation)
|
||||
]
|
||||
else:
|
||||
return [AIMessage(content=agent_action.log)]
|
||||
|
||||
|
||||
def _create_function_message(
|
||||
agent_action: AgentAction, observation: str
|
||||
) -> FunctionMessage:
|
||||
"""Convert agent action and observation into a function message.
|
||||
Args:
|
||||
agent_action: the tool invocation request from the agent
|
||||
observation: the result of the tool invocation
|
||||
Returns:
|
||||
FunctionMessage that corresponds to the original tool invocation
|
||||
"""
|
||||
if not isinstance(observation, str):
|
||||
try:
|
||||
content = json.dumps(observation, ensure_ascii=False)
|
||||
except Exception:
|
||||
content = str(observation)
|
||||
else:
|
||||
content = observation
|
||||
return FunctionMessage(
|
||||
name=agent_action.tool,
|
||||
content=content,
|
||||
)
|
||||
|
||||
|
||||
def format_to_openai_functions(
|
||||
intermediate_steps: Sequence[Tuple[AgentAction, str]],
|
||||
) -> List[BaseMessage]:
|
||||
"""Format intermediate steps.
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||
Returns:
|
||||
list of messages to send to the LLM for the next prediction
|
||||
"""
|
||||
messages = []
|
||||
|
||||
for agent_action, observation in intermediate_steps:
|
||||
messages.extend(_convert_agent_action_to_messages(agent_action, observation))
|
||||
|
||||
return messages
|
15
libs/langchain/langchain/agents/format_scratchpad/xml.py
Normal file
15
libs/langchain/langchain/agents/format_scratchpad/xml.py
Normal file
@ -0,0 +1,15 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
from langchain.schema.agent import AgentAction
|
||||
|
||||
|
||||
def format_xml(
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
) -> str:
|
||||
log = ""
|
||||
for action, observation in intermediate_steps:
|
||||
log += (
|
||||
f"<tool>{action.tool}</tool><tool_input>{action.tool_input}"
|
||||
f"</tool_input><observation>{observation}</observation>"
|
||||
)
|
||||
return log
|
@ -1,7 +1,9 @@
|
||||
"""Memory used to save agent output AND intermediate steps."""
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps
|
||||
from langchain.agents.format_scratchpad.openai_functions import (
|
||||
format_to_openai_functions,
|
||||
)
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.messages import BaseMessage, get_buffer_string
|
||||
@ -50,7 +52,7 @@ class AgentTokenBufferMemory(BaseChatMemory):
|
||||
"""Save context from this conversation to buffer. Pruned."""
|
||||
input_str, output_str = self._get_input_output(inputs, outputs)
|
||||
self.chat_memory.add_user_message(input_str)
|
||||
steps = _format_intermediate_steps(outputs[self.intermediate_steps_key])
|
||||
steps = format_to_openai_functions(outputs[self.intermediate_steps_key])
|
||||
for msg in steps:
|
||||
self.chat_memory.add_message(msg)
|
||||
self.chat_memory.add_ai_message(output_str)
|
||||
|
@ -1,8 +1,10 @@
|
||||
"""Module implements an agent that uses OpenAI's APIs function enabled API."""
|
||||
import json
|
||||
from typing import Any, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from langchain.agents import BaseSingleActionAgent
|
||||
from langchain.agents.format_scratchpad.openai_functions import (
|
||||
format_to_openai_functions,
|
||||
)
|
||||
from langchain.agents.output_parsers.openai_functions import (
|
||||
OpenAIFunctionsAgentOutputParser,
|
||||
)
|
||||
@ -21,82 +23,14 @@ from langchain.schema import (
|
||||
AgentFinish,
|
||||
BasePromptTemplate,
|
||||
)
|
||||
from langchain.schema.agent import AgentActionMessageLog
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
FunctionMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain.tools import BaseTool
|
||||
from langchain.tools.convert_to_openai import format_tool_to_openai_function
|
||||
|
||||
# For backwards compatibility
|
||||
_FunctionsAgentAction = AgentActionMessageLog
|
||||
|
||||
|
||||
def _convert_agent_action_to_messages(
|
||||
agent_action: AgentAction, observation: str
|
||||
) -> List[BaseMessage]:
|
||||
"""Convert an agent action to a message.
|
||||
|
||||
This code is used to reconstruct the original AI message from the agent action.
|
||||
|
||||
Args:
|
||||
agent_action: Agent action to convert.
|
||||
|
||||
Returns:
|
||||
AIMessage that corresponds to the original tool invocation.
|
||||
"""
|
||||
if isinstance(agent_action, _FunctionsAgentAction):
|
||||
return list(agent_action.message_log) + [
|
||||
_create_function_message(agent_action, observation)
|
||||
]
|
||||
else:
|
||||
return [AIMessage(content=agent_action.log)]
|
||||
|
||||
|
||||
def _create_function_message(
|
||||
agent_action: AgentAction, observation: str
|
||||
) -> FunctionMessage:
|
||||
"""Convert agent action and observation into a function message.
|
||||
Args:
|
||||
agent_action: the tool invocation request from the agent
|
||||
observation: the result of the tool invocation
|
||||
Returns:
|
||||
FunctionMessage that corresponds to the original tool invocation
|
||||
"""
|
||||
if not isinstance(observation, str):
|
||||
try:
|
||||
content = json.dumps(observation, ensure_ascii=False)
|
||||
except Exception:
|
||||
content = str(observation)
|
||||
else:
|
||||
content = observation
|
||||
return FunctionMessage(
|
||||
name=agent_action.tool,
|
||||
content=content,
|
||||
)
|
||||
|
||||
|
||||
def _format_intermediate_steps(
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
) -> List[BaseMessage]:
|
||||
"""Format intermediate steps.
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||
Returns:
|
||||
list of messages to send to the LLM for the next prediction
|
||||
"""
|
||||
messages = []
|
||||
|
||||
for intermediate_step in intermediate_steps:
|
||||
agent_action, observation = intermediate_step
|
||||
messages.extend(_convert_agent_action_to_messages(agent_action, observation))
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
class OpenAIFunctionsAgent(BaseSingleActionAgent):
|
||||
"""An Agent driven by OpenAIs function powered API.
|
||||
@ -159,7 +93,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
|
||||
agent_scratchpad = format_to_openai_functions(intermediate_steps)
|
||||
selected_inputs = {
|
||||
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
|
||||
}
|
||||
@ -198,7 +132,7 @@ class OpenAIFunctionsAgent(BaseSingleActionAgent):
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
|
||||
agent_scratchpad = format_to_openai_functions(intermediate_steps)
|
||||
selected_inputs = {
|
||||
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
|
||||
}
|
||||
|
@ -4,6 +4,9 @@ from json import JSONDecodeError
|
||||
from typing import Any, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from langchain.agents import BaseMultiActionAgent
|
||||
from langchain.agents.format_scratchpad.openai_functions import (
|
||||
format_to_openai_functions,
|
||||
)
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chat_models.openai import ChatOpenAI
|
||||
@ -25,7 +28,6 @@ from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
FunctionMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain.tools import BaseTool
|
||||
@ -34,68 +36,6 @@ from langchain.tools import BaseTool
|
||||
_FunctionsAgentAction = AgentActionMessageLog
|
||||
|
||||
|
||||
def _convert_agent_action_to_messages(
|
||||
agent_action: AgentAction, observation: str
|
||||
) -> List[BaseMessage]:
|
||||
"""Convert an agent action to a message.
|
||||
|
||||
This code is used to reconstruct the original AI message from the agent action.
|
||||
|
||||
Args:
|
||||
agent_action: Agent action to convert.
|
||||
|
||||
Returns:
|
||||
AIMessage that corresponds to the original tool invocation.
|
||||
"""
|
||||
if isinstance(agent_action, _FunctionsAgentAction):
|
||||
return list(agent_action.message_log) + [
|
||||
_create_function_message(agent_action, observation)
|
||||
]
|
||||
else:
|
||||
return [AIMessage(content=agent_action.log)]
|
||||
|
||||
|
||||
def _create_function_message(
|
||||
agent_action: AgentAction, observation: str
|
||||
) -> FunctionMessage:
|
||||
"""Convert agent action and observation into a function message.
|
||||
Args:
|
||||
agent_action: the tool invocation request from the agent
|
||||
observation: the result of the tool invocation
|
||||
Returns:
|
||||
FunctionMessage that corresponds to the original tool invocation
|
||||
"""
|
||||
if not isinstance(observation, str):
|
||||
try:
|
||||
content = json.dumps(observation, ensure_ascii=False)
|
||||
except Exception:
|
||||
content = str(observation)
|
||||
else:
|
||||
content = observation
|
||||
return FunctionMessage(
|
||||
name=agent_action.tool,
|
||||
content=content,
|
||||
)
|
||||
|
||||
|
||||
def _format_intermediate_steps(
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
) -> List[BaseMessage]:
|
||||
"""Format intermediate steps.
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||
Returns:
|
||||
list of messages to send to the LLM for the next prediction
|
||||
"""
|
||||
messages = []
|
||||
|
||||
for intermediate_step in intermediate_steps:
|
||||
agent_action, observation = intermediate_step
|
||||
messages.extend(_convert_agent_action_to_messages(agent_action, observation))
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def _parse_ai_message(message: BaseMessage) -> Union[List[AgentAction], AgentFinish]:
|
||||
"""Parse an AI message."""
|
||||
if not isinstance(message, AIMessage):
|
||||
@ -257,7 +197,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
|
||||
agent_scratchpad = format_to_openai_functions(intermediate_steps)
|
||||
selected_inputs = {
|
||||
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
|
||||
}
|
||||
@ -286,7 +226,7 @@ class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
|
||||
agent_scratchpad = format_to_openai_functions(intermediate_steps)
|
||||
selected_inputs = {
|
||||
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
|
||||
}
|
||||
|
@ -0,0 +1,40 @@
|
||||
from langchain.agents.format_scratchpad.log import format_log_to_str
|
||||
from langchain.schema.agent import AgentAction
|
||||
|
||||
|
||||
def test_single_agent_action_observation() -> None:
|
||||
intermediate_steps = [
|
||||
(AgentAction(tool="Tool1", tool_input="input1", log="Log1"), "Observation1")
|
||||
]
|
||||
expected_result = "Log1\nObservation: Observation1\nThought: "
|
||||
assert format_log_to_str(intermediate_steps) == expected_result
|
||||
|
||||
|
||||
def test_multiple_agent_actions_observations() -> None:
|
||||
intermediate_steps = [
|
||||
(AgentAction(tool="Tool1", tool_input="input1", log="Log1"), "Observation1"),
|
||||
(AgentAction(tool="Tool2", tool_input="input2", log="Log2"), "Observation2"),
|
||||
(AgentAction(tool="Tool3", tool_input="input3", log="Log3"), "Observation3"),
|
||||
]
|
||||
expected_result = """Log1\nObservation: Observation1\nThought: \
|
||||
Log2\nObservation: Observation2\nThought: Log3\nObservation: \
|
||||
Observation3\nThought: """
|
||||
assert format_log_to_str(intermediate_steps) == expected_result
|
||||
|
||||
|
||||
def test_custom_prefixes() -> None:
|
||||
intermediate_steps = [
|
||||
(AgentAction(tool="Tool1", tool_input="input1", log="Log1"), "Observation1")
|
||||
]
|
||||
observation_prefix = "Custom Observation: "
|
||||
llm_prefix = "Custom Thought: "
|
||||
expected_result = "Log1\nCustom Observation: Observation1\nCustom Thought: "
|
||||
assert (
|
||||
format_log_to_str(intermediate_steps, observation_prefix, llm_prefix)
|
||||
== expected_result
|
||||
)
|
||||
|
||||
|
||||
def test_empty_intermediate_steps() -> None:
|
||||
output = format_log_to_str([])
|
||||
assert output == ""
|
@ -0,0 +1,49 @@
|
||||
from langchain.agents.format_scratchpad.log_to_messages import format_log_to_messages
|
||||
from langchain.schema.agent import AgentAction
|
||||
from langchain.schema.messages import AIMessage, HumanMessage
|
||||
|
||||
|
||||
def test_single_intermediate_step_default_response() -> None:
|
||||
intermediate_steps = [
|
||||
(AgentAction(tool="Tool1", tool_input="input1", log="Log1"), "Observation1")
|
||||
]
|
||||
expected_result = [AIMessage(content="Log1"), HumanMessage(content="Observation1")]
|
||||
assert format_log_to_messages(intermediate_steps) == expected_result
|
||||
|
||||
|
||||
def test_multiple_intermediate_steps_default_response() -> None:
|
||||
intermediate_steps = [
|
||||
(AgentAction(tool="Tool1", tool_input="input1", log="Log1"), "Observation1"),
|
||||
(AgentAction(tool="Tool2", tool_input="input2", log="Log2"), "Observation2"),
|
||||
(AgentAction(tool="Tool3", tool_input="input3", log="Log3"), "Observation3"),
|
||||
]
|
||||
expected_result = [
|
||||
AIMessage(content="Log1"),
|
||||
HumanMessage(content="Observation1"),
|
||||
AIMessage(content="Log2"),
|
||||
HumanMessage(content="Observation2"),
|
||||
AIMessage(content="Log3"),
|
||||
HumanMessage(content="Observation3"),
|
||||
]
|
||||
assert format_log_to_messages(intermediate_steps) == expected_result
|
||||
|
||||
|
||||
def test_custom_template_tool_response() -> None:
|
||||
intermediate_steps = [
|
||||
(AgentAction(tool="Tool1", tool_input="input1", log="Log1"), "Observation1")
|
||||
]
|
||||
template_tool_response = "Response: {observation}"
|
||||
expected_result = [
|
||||
AIMessage(content="Log1"),
|
||||
HumanMessage(content="Response: Observation1"),
|
||||
]
|
||||
assert (
|
||||
format_log_to_messages(
|
||||
intermediate_steps, template_tool_response=template_tool_response
|
||||
)
|
||||
== expected_result
|
||||
)
|
||||
|
||||
|
||||
def test_empty_steps() -> None:
|
||||
assert format_log_to_messages([]) == []
|
@ -0,0 +1,60 @@
|
||||
from langchain.agents.format_scratchpad.openai_functions import (
|
||||
format_to_openai_functions,
|
||||
)
|
||||
from langchain.schema.agent import AgentActionMessageLog
|
||||
from langchain.schema.messages import AIMessage, FunctionMessage
|
||||
|
||||
|
||||
def test_calls_convert_agent_action_to_messages() -> None:
|
||||
additional_kwargs1 = {
|
||||
"function_call": {
|
||||
"name": "tool1",
|
||||
"arguments": "input1",
|
||||
}
|
||||
}
|
||||
message1 = AIMessage(content="", additional_kwargs=additional_kwargs1)
|
||||
action1 = AgentActionMessageLog(
|
||||
tool="tool1", tool_input="input1", log="log1", message_log=[message1]
|
||||
)
|
||||
additional_kwargs2 = {
|
||||
"function_call": {
|
||||
"name": "tool2",
|
||||
"arguments": "input2",
|
||||
}
|
||||
}
|
||||
message2 = AIMessage(content="", additional_kwargs=additional_kwargs2)
|
||||
action2 = AgentActionMessageLog(
|
||||
tool="tool2", tool_input="input2", log="log2", message_log=[message2]
|
||||
)
|
||||
|
||||
additional_kwargs3 = {
|
||||
"function_call": {
|
||||
"name": "tool3",
|
||||
"arguments": "input3",
|
||||
}
|
||||
}
|
||||
message3 = AIMessage(content="", additional_kwargs=additional_kwargs3)
|
||||
action3 = AgentActionMessageLog(
|
||||
tool="tool3", tool_input="input3", log="log3", message_log=[message3]
|
||||
)
|
||||
|
||||
intermediate_steps = [
|
||||
(action1, "observation1"),
|
||||
(action2, "observation2"),
|
||||
(action3, "observation3"),
|
||||
]
|
||||
expected_messages = [
|
||||
message1,
|
||||
FunctionMessage(name="tool1", content="observation1"),
|
||||
message2,
|
||||
FunctionMessage(name="tool2", content="observation2"),
|
||||
message3,
|
||||
FunctionMessage(name="tool3", content="observation3"),
|
||||
]
|
||||
output = format_to_openai_functions(intermediate_steps)
|
||||
assert output == expected_messages
|
||||
|
||||
|
||||
def test_handles_empty_input_list() -> None:
|
||||
output = format_to_openai_functions([])
|
||||
assert output == []
|
@ -0,0 +1,40 @@
|
||||
from langchain.agents.format_scratchpad.xml import format_xml
|
||||
from langchain.schema.agent import AgentAction
|
||||
|
||||
|
||||
def test_single_agent_action_observation() -> None:
|
||||
# Arrange
|
||||
agent_action = AgentAction(tool="Tool1", tool_input="Input1", log="Log1")
|
||||
observation = "Observation1"
|
||||
intermediate_steps = [(agent_action, observation)]
|
||||
|
||||
# Act
|
||||
result = format_xml(intermediate_steps)
|
||||
expected_result = """<tool>Tool1</tool><tool_input>Input1\
|
||||
</tool_input><observation>Observation1</observation>"""
|
||||
# Assert
|
||||
assert result == expected_result
|
||||
|
||||
|
||||
def test_multiple_agent_actions_observations() -> None:
|
||||
# Arrange
|
||||
agent_action1 = AgentAction(tool="Tool1", tool_input="Input1", log="Log1")
|
||||
agent_action2 = AgentAction(tool="Tool2", tool_input="Input2", log="Log2")
|
||||
observation1 = "Observation1"
|
||||
observation2 = "Observation2"
|
||||
intermediate_steps = [(agent_action1, observation1), (agent_action2, observation2)]
|
||||
|
||||
# Act
|
||||
result = format_xml(intermediate_steps)
|
||||
|
||||
# Assert
|
||||
expected_result = """<tool>Tool1</tool><tool_input>Input1\
|
||||
</tool_input><observation>Observation1</observation><tool>\
|
||||
Tool2</tool><tool_input>Input2</tool_input><observation>\
|
||||
Observation2</observation>"""
|
||||
assert result == expected_result
|
||||
|
||||
|
||||
def test_empty_list_agent_actions() -> None:
|
||||
result = format_xml([])
|
||||
assert result == ""
|
Loading…
Reference in New Issue
Block a user