mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-07 21:50:25 +00:00
agent serialization (#4642)
This commit is contained in:
parent
ef49c659f6
commit
fbfa49f2c1
@ -12,6 +12,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
|||||||
import yaml
|
import yaml
|
||||||
from pydantic import BaseModel, root_validator
|
from pydantic import BaseModel, root_validator
|
||||||
|
|
||||||
|
from langchain.agents.agent_types import AgentType
|
||||||
from langchain.agents.tools import InvalidTool
|
from langchain.agents.tools import InvalidTool
|
||||||
from langchain.base_language import BaseLanguageModel
|
from langchain.base_language import BaseLanguageModel
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
@ -132,7 +133,11 @@ class BaseSingleActionAgent(BaseModel):
|
|||||||
def dict(self, **kwargs: Any) -> Dict:
|
def dict(self, **kwargs: Any) -> Dict:
|
||||||
"""Return dictionary representation of agent."""
|
"""Return dictionary representation of agent."""
|
||||||
_dict = super().dict()
|
_dict = super().dict()
|
||||||
_dict["_type"] = str(self._agent_type)
|
_type = self._agent_type
|
||||||
|
if isinstance(_type, AgentType):
|
||||||
|
_dict["_type"] = str(_type.value)
|
||||||
|
else:
|
||||||
|
_dict["_type"] = _type
|
||||||
return _dict
|
return _dict
|
||||||
|
|
||||||
def save(self, file_path: Union[Path, str]) -> None:
|
def save(self, file_path: Union[Path, str]) -> None:
|
||||||
@ -307,6 +312,12 @@ class LLMSingleActionAgent(BaseSingleActionAgent):
|
|||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> List[str]:
|
||||||
return list(set(self.llm_chain.input_keys) - {"intermediate_steps"})
|
return list(set(self.llm_chain.input_keys) - {"intermediate_steps"})
|
||||||
|
|
||||||
|
def dict(self, **kwargs: Any) -> Dict:
|
||||||
|
"""Return dictionary representation of agent."""
|
||||||
|
_dict = super().dict()
|
||||||
|
del _dict["output_parser"]
|
||||||
|
return _dict
|
||||||
|
|
||||||
def plan(
|
def plan(
|
||||||
self,
|
self,
|
||||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||||
@ -376,6 +387,12 @@ class Agent(BaseSingleActionAgent):
|
|||||||
output_parser: AgentOutputParser
|
output_parser: AgentOutputParser
|
||||||
allowed_tools: Optional[List[str]] = None
|
allowed_tools: Optional[List[str]] = None
|
||||||
|
|
||||||
|
def dict(self, **kwargs: Any) -> Dict:
|
||||||
|
"""Return dictionary representation of agent."""
|
||||||
|
_dict = super().dict()
|
||||||
|
del _dict["output_parser"]
|
||||||
|
return _dict
|
||||||
|
|
||||||
def get_allowed_tools(self) -> Optional[List[str]]:
|
def get_allowed_tools(self) -> Optional[List[str]]:
|
||||||
return self.allowed_tools
|
return self.allowed_tools
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
"""Functionality for loading agents."""
|
"""Functionality for loading agents."""
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, List, Optional, Union
|
from typing import Any, List, Optional, Union
|
||||||
|
|
||||||
@ -12,6 +13,8 @@ from langchain.base_language import BaseLanguageModel
|
|||||||
from langchain.chains.loading import load_chain, load_chain_from_config
|
from langchain.chains.loading import load_chain, load_chain_from_config
|
||||||
from langchain.utilities.loading import try_load_from_hub
|
from langchain.utilities.loading import try_load_from_hub
|
||||||
|
|
||||||
|
logger = logging.getLogger(__file__)
|
||||||
|
|
||||||
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/agents/"
|
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/agents/"
|
||||||
|
|
||||||
|
|
||||||
@ -61,6 +64,13 @@ def load_agent_from_config(
|
|||||||
config["llm_chain"] = load_chain(config.pop("llm_chain_path"))
|
config["llm_chain"] = load_chain(config.pop("llm_chain_path"))
|
||||||
else:
|
else:
|
||||||
raise ValueError("One of `llm_chain` and `llm_chain_path` should be specified.")
|
raise ValueError("One of `llm_chain` and `llm_chain_path` should be specified.")
|
||||||
|
if "output_parser" in config:
|
||||||
|
logger.warning(
|
||||||
|
"Currently loading output parsers on agent is not supported, "
|
||||||
|
"will just use the default one."
|
||||||
|
)
|
||||||
|
del config["output_parser"]
|
||||||
|
|
||||||
combined_config = {**config, **kwargs}
|
combined_config = {**config, **kwargs}
|
||||||
return agent_cls(**combined_config) # type: ignore
|
return agent_cls(**combined_config) # type: ignore
|
||||||
|
|
||||||
|
@ -10,6 +10,7 @@ from langchain.llms.base import BaseLLM
|
|||||||
from langchain.llms.cerebriumai import CerebriumAI
|
from langchain.llms.cerebriumai import CerebriumAI
|
||||||
from langchain.llms.cohere import Cohere
|
from langchain.llms.cohere import Cohere
|
||||||
from langchain.llms.deepinfra import DeepInfra
|
from langchain.llms.deepinfra import DeepInfra
|
||||||
|
from langchain.llms.fake import FakeListLLM
|
||||||
from langchain.llms.forefrontai import ForefrontAI
|
from langchain.llms.forefrontai import ForefrontAI
|
||||||
from langchain.llms.google_palm import GooglePalm
|
from langchain.llms.google_palm import GooglePalm
|
||||||
from langchain.llms.gooseai import GooseAI
|
from langchain.llms.gooseai import GooseAI
|
||||||
@ -71,6 +72,7 @@ __all__ = [
|
|||||||
"PredictionGuard",
|
"PredictionGuard",
|
||||||
"HumanInputLLM",
|
"HumanInputLLM",
|
||||||
"HuggingFaceTextGenInference",
|
"HuggingFaceTextGenInference",
|
||||||
|
"FakeListLLM",
|
||||||
]
|
]
|
||||||
|
|
||||||
type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
|
type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
|
||||||
@ -105,4 +107,5 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
|
|||||||
"writer": Writer,
|
"writer": Writer,
|
||||||
"rwkv": RWKV,
|
"rwkv": RWKV,
|
||||||
"huggingface_textgen_inference": HuggingFaceTextGenInference,
|
"huggingface_textgen_inference": HuggingFaceTextGenInference,
|
||||||
|
"fake-list": FakeListLLM,
|
||||||
}
|
}
|
||||||
|
@ -29,4 +29,4 @@ class FakeListLLM(LLM):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def _identifying_params(self) -> Mapping[str, Any]:
|
def _identifying_params(self) -> Mapping[str, Any]:
|
||||||
return {}
|
return {"responses": self.responses}
|
||||||
|
19
tests/unit_tests/agents/test_serialization.py
Normal file
19
tests/unit_tests/agents/test_serialization.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from tempfile import TemporaryDirectory
|
||||||
|
|
||||||
|
from langchain.agents.agent_types import AgentType
|
||||||
|
from langchain.agents.initialize import initialize_agent, load_agent
|
||||||
|
from langchain.llms.fake import FakeListLLM
|
||||||
|
|
||||||
|
|
||||||
|
def test_mrkl_serialization() -> None:
|
||||||
|
agent = initialize_agent(
|
||||||
|
[],
|
||||||
|
FakeListLLM(responses=[]),
|
||||||
|
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
with TemporaryDirectory() as tempdir:
|
||||||
|
file = Path(tempdir) / "agent.json"
|
||||||
|
agent.save_agent(file)
|
||||||
|
load_agent(file)
|
Loading…
Reference in New Issue
Block a user