make agent action serializable (#10797)

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
Harrison Chase 2023-09-19 16:16:14 -07:00 committed by GitHub
parent 21b236e5e4
commit d60145229b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 60 additions and 19 deletions

View File

@ -1,6 +1,5 @@
"""Module implements an agent that uses OpenAI's APIs function enabled API.""" """Module implements an agent that uses OpenAI's APIs function enabled API."""
import json import json
from dataclasses import dataclass
from json import JSONDecodeError from json import JSONDecodeError
from typing import Any, List, Optional, Sequence, Tuple, Union from typing import Any, List, Optional, Sequence, Tuple, Union
@ -21,6 +20,7 @@ from langchain.schema import (
BasePromptTemplate, BasePromptTemplate,
OutputParserException, OutputParserException,
) )
from langchain.schema.agent import AgentActionMessageLog
from langchain.schema.language_model import BaseLanguageModel from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import ( from langchain.schema.messages import (
AIMessage, AIMessage,
@ -31,10 +31,8 @@ from langchain.schema.messages import (
from langchain.tools import BaseTool from langchain.tools import BaseTool
from langchain.tools.convert_to_openai import format_tool_to_openai_function from langchain.tools.convert_to_openai import format_tool_to_openai_function
# For backwards compatibility
@dataclass _FunctionsAgentAction = AgentActionMessageLog
class _FunctionsAgentAction(AgentAction):
message_log: List[BaseMessage]
def _convert_agent_action_to_messages( def _convert_agent_action_to_messages(
@ -51,7 +49,7 @@ def _convert_agent_action_to_messages(
AIMessage that corresponds to the original tool invocation. AIMessage that corresponds to the original tool invocation.
""" """
if isinstance(agent_action, _FunctionsAgentAction): if isinstance(agent_action, _FunctionsAgentAction):
return agent_action.message_log + [ return list(agent_action.message_log) + [
_create_function_message(agent_action, observation) _create_function_message(agent_action, observation)
] ]
else: else:

View File

@ -1,6 +1,5 @@
"""Module implements an agent that uses OpenAI's APIs function enabled API.""" """Module implements an agent that uses OpenAI's APIs function enabled API."""
import json import json
from dataclasses import dataclass
from json import JSONDecodeError from json import JSONDecodeError
from typing import Any, List, Optional, Sequence, Tuple, Union from typing import Any, List, Optional, Sequence, Tuple, Union
@ -21,6 +20,7 @@ from langchain.schema import (
BasePromptTemplate, BasePromptTemplate,
OutputParserException, OutputParserException,
) )
from langchain.schema.agent import AgentActionMessageLog
from langchain.schema.language_model import BaseLanguageModel from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import ( from langchain.schema.messages import (
AIMessage, AIMessage,
@ -30,10 +30,8 @@ from langchain.schema.messages import (
) )
from langchain.tools import BaseTool from langchain.tools import BaseTool
# For backwards compatibility
@dataclass _FunctionsAgentAction = AgentActionMessageLog
class _FunctionsAgentAction(AgentAction):
message_log: List[BaseMessage]
def _convert_agent_action_to_messages( def _convert_agent_action_to_messages(
@ -50,7 +48,7 @@ def _convert_agent_action_to_messages(
AIMessage that corresponds to the original tool invocation. AIMessage that corresponds to the original tool invocation.
""" """
if isinstance(agent_action, _FunctionsAgentAction): if isinstance(agent_action, _FunctionsAgentAction):
return agent_action.message_log + [ return list(agent_action.message_log) + [
_create_function_message(agent_action, observation) _create_function_message(agent_action, observation)
] ]
else: else:

View File

@ -1,11 +1,12 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from typing import Any, Sequence, Union
from typing import NamedTuple, Union
from langchain.load.serializable import Serializable
from langchain.schema.messages import BaseMessage
@dataclass class AgentAction(Serializable):
class AgentAction:
"""A full description of an action for an ActionAgent to execute.""" """A full description of an action for an ActionAgent to execute."""
tool: str tool: str
@ -13,13 +14,57 @@ class AgentAction:
tool_input: Union[str, dict] tool_input: Union[str, dict]
"""The input to pass in to the Tool.""" """The input to pass in to the Tool."""
log: str log: str
"""Additional information to log about the action.""" """Additional information to log about the action.
This log can be used in a few ways. First, it can be used to audit
what exactly the LLM predicted to lead to this (tool, tool_input).
Second, it can be used in future iterations to show the LLMs prior
thoughts. This is useful when (tool, tool_input) does not contain
full information about the LLM prediction (for example, any `thought`
before the tool/tool_input)."""
def __init__(
self, tool: str, tool_input: Union[str, dict], log: str, **kwargs: Any
):
super().__init__(tool=tool, tool_input=tool_input, log=log, **kwargs)
@property
def lc_serializable(self) -> bool:
"""
Return whether or not the class is serializable.
"""
return True
class AgentFinish(NamedTuple): class AgentActionMessageLog(AgentAction):
message_log: Sequence[BaseMessage]
"""Similar to log, this can be used to pass along extra
information about what exact messages were predicted by the LLM
before parsing out the (tool, tool_input). This is again useful
if (tool, tool_input) cannot be used to fully recreate the LLM
prediction, and you need that LLM prediction (for future agent iteration).
Compared to `log`, this is useful when the underlying LLM is a
ChatModel (and therefore returns messages rather than a string)."""
class AgentFinish(Serializable):
"""The final return value of an ActionAgent.""" """The final return value of an ActionAgent."""
return_values: dict return_values: dict
"""Dictionary of return values.""" """Dictionary of return values."""
log: str log: str
"""Additional information to log about the return value""" """Additional information to log about the return value.
This is used to pass along the full LLM prediction, not just the parsed out
return value. For example, if the full LLM prediction was
`Final Answer: 2` you may want to just return `2` as a return value, but pass
along the full string as a `log` (for debugging or observability purposes).
"""
def __init__(self, return_values: dict, log: str, **kwargs: Any):
super().__init__(return_values=return_values, log=log, **kwargs)
@property
def lc_serializable(self) -> bool:
"""
Return whether or not the class is serializable.
"""
return True