mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-13 08:27:03 +00:00
make agent action serializable (#10797)
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
parent
21b236e5e4
commit
d60145229b
@ -1,6 +1,5 @@
|
|||||||
"""Module implements an agent that uses OpenAI's APIs function enabled API."""
|
"""Module implements an agent that uses OpenAI's APIs function enabled API."""
|
||||||
import json
|
import json
|
||||||
from dataclasses import dataclass
|
|
||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
from typing import Any, List, Optional, Sequence, Tuple, Union
|
from typing import Any, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
@ -21,6 +20,7 @@ from langchain.schema import (
|
|||||||
BasePromptTemplate,
|
BasePromptTemplate,
|
||||||
OutputParserException,
|
OutputParserException,
|
||||||
)
|
)
|
||||||
|
from langchain.schema.agent import AgentActionMessageLog
|
||||||
from langchain.schema.language_model import BaseLanguageModel
|
from langchain.schema.language_model import BaseLanguageModel
|
||||||
from langchain.schema.messages import (
|
from langchain.schema.messages import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
@ -31,10 +31,8 @@ from langchain.schema.messages import (
|
|||||||
from langchain.tools import BaseTool
|
from langchain.tools import BaseTool
|
||||||
from langchain.tools.convert_to_openai import format_tool_to_openai_function
|
from langchain.tools.convert_to_openai import format_tool_to_openai_function
|
||||||
|
|
||||||
|
# For backwards compatibility
|
||||||
@dataclass
|
_FunctionsAgentAction = AgentActionMessageLog
|
||||||
class _FunctionsAgentAction(AgentAction):
|
|
||||||
message_log: List[BaseMessage]
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_agent_action_to_messages(
|
def _convert_agent_action_to_messages(
|
||||||
@ -51,7 +49,7 @@ def _convert_agent_action_to_messages(
|
|||||||
AIMessage that corresponds to the original tool invocation.
|
AIMessage that corresponds to the original tool invocation.
|
||||||
"""
|
"""
|
||||||
if isinstance(agent_action, _FunctionsAgentAction):
|
if isinstance(agent_action, _FunctionsAgentAction):
|
||||||
return agent_action.message_log + [
|
return list(agent_action.message_log) + [
|
||||||
_create_function_message(agent_action, observation)
|
_create_function_message(agent_action, observation)
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
"""Module implements an agent that uses OpenAI's APIs function enabled API."""
|
"""Module implements an agent that uses OpenAI's APIs function enabled API."""
|
||||||
import json
|
import json
|
||||||
from dataclasses import dataclass
|
|
||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
from typing import Any, List, Optional, Sequence, Tuple, Union
|
from typing import Any, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
@ -21,6 +20,7 @@ from langchain.schema import (
|
|||||||
BasePromptTemplate,
|
BasePromptTemplate,
|
||||||
OutputParserException,
|
OutputParserException,
|
||||||
)
|
)
|
||||||
|
from langchain.schema.agent import AgentActionMessageLog
|
||||||
from langchain.schema.language_model import BaseLanguageModel
|
from langchain.schema.language_model import BaseLanguageModel
|
||||||
from langchain.schema.messages import (
|
from langchain.schema.messages import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
@ -30,10 +30,8 @@ from langchain.schema.messages import (
|
|||||||
)
|
)
|
||||||
from langchain.tools import BaseTool
|
from langchain.tools import BaseTool
|
||||||
|
|
||||||
|
# For backwards compatibility
|
||||||
@dataclass
|
_FunctionsAgentAction = AgentActionMessageLog
|
||||||
class _FunctionsAgentAction(AgentAction):
|
|
||||||
message_log: List[BaseMessage]
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_agent_action_to_messages(
|
def _convert_agent_action_to_messages(
|
||||||
@ -50,7 +48,7 @@ def _convert_agent_action_to_messages(
|
|||||||
AIMessage that corresponds to the original tool invocation.
|
AIMessage that corresponds to the original tool invocation.
|
||||||
"""
|
"""
|
||||||
if isinstance(agent_action, _FunctionsAgentAction):
|
if isinstance(agent_action, _FunctionsAgentAction):
|
||||||
return agent_action.message_log + [
|
return list(agent_action.message_log) + [
|
||||||
_create_function_message(agent_action, observation)
|
_create_function_message(agent_action, observation)
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from typing import Any, Sequence, Union
|
||||||
from typing import NamedTuple, Union
|
|
||||||
|
from langchain.load.serializable import Serializable
|
||||||
|
from langchain.schema.messages import BaseMessage
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
class AgentAction(Serializable):
|
||||||
class AgentAction:
|
|
||||||
"""A full description of an action for an ActionAgent to execute."""
|
"""A full description of an action for an ActionAgent to execute."""
|
||||||
|
|
||||||
tool: str
|
tool: str
|
||||||
@ -13,13 +14,57 @@ class AgentAction:
|
|||||||
tool_input: Union[str, dict]
|
tool_input: Union[str, dict]
|
||||||
"""The input to pass in to the Tool."""
|
"""The input to pass in to the Tool."""
|
||||||
log: str
|
log: str
|
||||||
"""Additional information to log about the action."""
|
"""Additional information to log about the action.
|
||||||
|
This log can be used in a few ways. First, it can be used to audit
|
||||||
|
what exactly the LLM predicted to lead to this (tool, tool_input).
|
||||||
|
Second, it can be used in future iterations to show the LLMs prior
|
||||||
|
thoughts. This is useful when (tool, tool_input) does not contain
|
||||||
|
full information about the LLM prediction (for example, any `thought`
|
||||||
|
before the tool/tool_input)."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, tool: str, tool_input: Union[str, dict], log: str, **kwargs: Any
|
||||||
|
):
|
||||||
|
super().__init__(tool=tool, tool_input=tool_input, log=log, **kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_serializable(self) -> bool:
|
||||||
|
"""
|
||||||
|
Return whether or not the class is serializable.
|
||||||
|
"""
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
class AgentFinish(NamedTuple):
|
class AgentActionMessageLog(AgentAction):
|
||||||
|
message_log: Sequence[BaseMessage]
|
||||||
|
"""Similar to log, this can be used to pass along extra
|
||||||
|
information about what exact messages were predicted by the LLM
|
||||||
|
before parsing out the (tool, tool_input). This is again useful
|
||||||
|
if (tool, tool_input) cannot be used to fully recreate the LLM
|
||||||
|
prediction, and you need that LLM prediction (for future agent iteration).
|
||||||
|
Compared to `log`, this is useful when the underlying LLM is a
|
||||||
|
ChatModel (and therefore returns messages rather than a string)."""
|
||||||
|
|
||||||
|
|
||||||
|
class AgentFinish(Serializable):
|
||||||
"""The final return value of an ActionAgent."""
|
"""The final return value of an ActionAgent."""
|
||||||
|
|
||||||
return_values: dict
|
return_values: dict
|
||||||
"""Dictionary of return values."""
|
"""Dictionary of return values."""
|
||||||
log: str
|
log: str
|
||||||
"""Additional information to log about the return value"""
|
"""Additional information to log about the return value.
|
||||||
|
This is used to pass along the full LLM prediction, not just the parsed out
|
||||||
|
return value. For example, if the full LLM prediction was
|
||||||
|
`Final Answer: 2` you may want to just return `2` as a return value, but pass
|
||||||
|
along the full string as a `log` (for debugging or observability purposes).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, return_values: dict, log: str, **kwargs: Any):
|
||||||
|
super().__init__(return_values=return_values, log=log, **kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_serializable(self) -> bool:
|
||||||
|
"""
|
||||||
|
Return whether or not the class is serializable.
|
||||||
|
"""
|
||||||
|
return True
|
||||||
|
Loading…
Reference in New Issue
Block a user