mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-06 13:18:12 +00:00
agent serialization (#4642)
This commit is contained in:
parent
ef49c659f6
commit
fbfa49f2c1
@ -12,6 +12,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
import yaml
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
from langchain.agents.agent_types import AgentType
|
||||
from langchain.agents.tools import InvalidTool
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
@ -132,7 +133,11 @@ class BaseSingleActionAgent(BaseModel):
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
"""Return dictionary representation of agent."""
|
||||
_dict = super().dict()
|
||||
_dict["_type"] = str(self._agent_type)
|
||||
_type = self._agent_type
|
||||
if isinstance(_type, AgentType):
|
||||
_dict["_type"] = str(_type.value)
|
||||
else:
|
||||
_dict["_type"] = _type
|
||||
return _dict
|
||||
|
||||
def save(self, file_path: Union[Path, str]) -> None:
|
||||
@ -307,6 +312,12 @@ class LLMSingleActionAgent(BaseSingleActionAgent):
|
||||
def input_keys(self) -> List[str]:
|
||||
return list(set(self.llm_chain.input_keys) - {"intermediate_steps"})
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
"""Return dictionary representation of agent."""
|
||||
_dict = super().dict()
|
||||
del _dict["output_parser"]
|
||||
return _dict
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
@ -376,6 +387,12 @@ class Agent(BaseSingleActionAgent):
|
||||
output_parser: AgentOutputParser
|
||||
allowed_tools: Optional[List[str]] = None
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
"""Return dictionary representation of agent."""
|
||||
_dict = super().dict()
|
||||
del _dict["output_parser"]
|
||||
return _dict
|
||||
|
||||
def get_allowed_tools(self) -> Optional[List[str]]:
|
||||
return self.allowed_tools
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
"""Functionality for loading agents."""
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
@ -12,6 +13,8 @@ from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.loading import load_chain, load_chain_from_config
|
||||
from langchain.utilities.loading import try_load_from_hub
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/agents/"
|
||||
|
||||
|
||||
@ -61,6 +64,13 @@ def load_agent_from_config(
|
||||
config["llm_chain"] = load_chain(config.pop("llm_chain_path"))
|
||||
else:
|
||||
raise ValueError("One of `llm_chain` and `llm_chain_path` should be specified.")
|
||||
if "output_parser" in config:
|
||||
logger.warning(
|
||||
"Currently loading output parsers on agent is not supported, "
|
||||
"will just use the default one."
|
||||
)
|
||||
del config["output_parser"]
|
||||
|
||||
combined_config = {**config, **kwargs}
|
||||
return agent_cls(**combined_config) # type: ignore
|
||||
|
||||
|
@ -10,6 +10,7 @@ from langchain.llms.base import BaseLLM
|
||||
from langchain.llms.cerebriumai import CerebriumAI
|
||||
from langchain.llms.cohere import Cohere
|
||||
from langchain.llms.deepinfra import DeepInfra
|
||||
from langchain.llms.fake import FakeListLLM
|
||||
from langchain.llms.forefrontai import ForefrontAI
|
||||
from langchain.llms.google_palm import GooglePalm
|
||||
from langchain.llms.gooseai import GooseAI
|
||||
@ -71,6 +72,7 @@ __all__ = [
|
||||
"PredictionGuard",
|
||||
"HumanInputLLM",
|
||||
"HuggingFaceTextGenInference",
|
||||
"FakeListLLM",
|
||||
]
|
||||
|
||||
type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
|
||||
@ -105,4 +107,5 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
|
||||
"writer": Writer,
|
||||
"rwkv": RWKV,
|
||||
"huggingface_textgen_inference": HuggingFaceTextGenInference,
|
||||
"fake-list": FakeListLLM,
|
||||
}
|
||||
|
@ -29,4 +29,4 @@ class FakeListLLM(LLM):
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
return {}
|
||||
return {"responses": self.responses}
|
||||
|
19
tests/unit_tests/agents/test_serialization.py
Normal file
19
tests/unit_tests/agents/test_serialization.py
Normal file
@ -0,0 +1,19 @@
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
from langchain.agents.agent_types import AgentType
|
||||
from langchain.agents.initialize import initialize_agent, load_agent
|
||||
from langchain.llms.fake import FakeListLLM
|
||||
|
||||
|
||||
def test_mrkl_serialization() -> None:
|
||||
agent = initialize_agent(
|
||||
[],
|
||||
FakeListLLM(responses=[]),
|
||||
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
||||
verbose=True,
|
||||
)
|
||||
with TemporaryDirectory() as tempdir:
|
||||
file = Path(tempdir) / "agent.json"
|
||||
agent.save_agent(file)
|
||||
load_agent(file)
|
Loading…
Reference in New Issue
Block a user