mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-09 06:24:47 +00:00
parent
c14a8df2ee
commit
17b5090c18
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Sequence, Union
|
||||
from typing import Any, Literal, Sequence, Union
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema.messages import BaseMessage
|
||||
@ -21,10 +21,12 @@ class AgentAction(Serializable):
|
||||
thoughts. This is useful when (tool, tool_input) does not contain
|
||||
full information about the LLM prediction (for example, any `thought`
|
||||
before the tool/tool_input)."""
|
||||
type: Literal["AgentAction"] = "AgentAction"
|
||||
|
||||
def __init__(
|
||||
self, tool: str, tool_input: Union[str, dict], log: str, **kwargs: Any
|
||||
):
|
||||
"""Override init to support instantiation by position for backward compat."""
|
||||
super().__init__(tool=tool, tool_input=tool_input, log=log, **kwargs)
|
||||
|
||||
@classmethod
|
||||
@ -42,6 +44,10 @@ class AgentActionMessageLog(AgentAction):
|
||||
prediction, and you need that LLM prediction (for future agent iteration).
|
||||
Compared to `log`, this is useful when the underlying LLM is a
|
||||
ChatModel (and therefore returns messages rather than a string)."""
|
||||
# Ignoring type because we're overriding the type from AgentAction.
|
||||
# And this is the correct thing to do in this case.
|
||||
# The type literal is used for serialization purposes.
|
||||
type: Literal["AgentActionMessageLog"] = "AgentActionMessageLog" # type: ignore
|
||||
|
||||
|
||||
class AgentFinish(Serializable):
|
||||
@ -56,8 +62,10 @@ class AgentFinish(Serializable):
|
||||
`Final Answer: 2` you may want to just return `2` as a return value, but pass
|
||||
along the full string as a `log` (for debugging or observability purposes).
|
||||
"""
|
||||
type: Literal["AgentFinish"] = "AgentFinish"
|
||||
|
||||
def __init__(self, return_values: dict, log: str, **kwargs: Any):
|
||||
"""Override init to support instantiation by position for backward compat."""
|
||||
super().__init__(return_values=return_values, log=log, **kwargs)
|
||||
|
||||
@classmethod
|
||||
|
@ -6,7 +6,8 @@ from typing import Union
|
||||
from langchain.prompts.base import StringPromptValue
|
||||
from langchain.prompts.chat import ChatPromptValueConcrete
|
||||
from langchain.pydantic_v1 import BaseModel
|
||||
from langchain.schema import Document
|
||||
from langchain.schema import AgentAction, AgentFinish, Document
|
||||
from langchain.schema.agent import AgentActionMessageLog
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
@ -104,6 +105,9 @@ def test_serialization_of_wellknown_objects() -> None:
|
||||
AIMessageChunk,
|
||||
StringPromptValue,
|
||||
ChatPromptValueConcrete,
|
||||
AgentFinish,
|
||||
AgentAction,
|
||||
AgentActionMessageLog,
|
||||
]
|
||||
|
||||
lc_objects = [
|
||||
@ -132,6 +136,14 @@ def test_serialization_of_wellknown_objects() -> None:
|
||||
StringPromptValue(text="hello"),
|
||||
ChatPromptValueConcrete(messages=[HumanMessage(content="human")]),
|
||||
Document(page_content="hello"),
|
||||
AgentFinish(return_values={}, log=""),
|
||||
AgentAction(tool="tool", tool_input="input", log=""),
|
||||
AgentActionMessageLog(
|
||||
tool="tool",
|
||||
tool_input="input",
|
||||
log="",
|
||||
message_log=[HumanMessage(content="human")],
|
||||
),
|
||||
]
|
||||
|
||||
for lc_object in lc_objects:
|
||||
|
Loading…
Reference in New Issue
Block a user