diff --git a/langchain/agents/loading.py b/langchain/agents/loading.py index cb00d795801..42731e7fcc6 100644 --- a/langchain/agents/loading.py +++ b/langchain/agents/loading.py @@ -1,8 +1,11 @@ """Functionality for loading agents.""" import json +import os +import tempfile from pathlib import Path from typing import Any, Union +import requests import yaml from langchain.agents.agent import Agent @@ -19,6 +22,8 @@ AGENT_TO_CLASS = { "conversational-react-description": ConversationalAgent, } +URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/agents/" + def load_agent_from_config(config: dict, **kwargs: Any) -> Agent: """Load agent from Config Dict.""" @@ -40,7 +45,32 @@ def load_agent_from_config(config: dict, **kwargs: Any) -> Agent: return agent_cls(**combined_config) # type: ignore -def load_agent(file: Union[str, Path], **kwargs: Any) -> Agent: +def load_agent(path: Union[str, Path], **kwargs: Any) -> Agent: + """Unified method for loading a agent from LangChainHub or local fs.""" + if isinstance(path, str) and path.startswith("lc://agents"): + path = os.path.relpath(path, "lc://agents/") + return _load_from_hub(path, **kwargs) + else: + return _load_agent_from_file(path, **kwargs) + + +def _load_from_hub(path: str, **kwargs: Any) -> Agent: + """Load agent from hub.""" + suffix = path.split(".")[-1] + if suffix not in {"json", "yaml"}: + raise ValueError("Unsupported file type.") + full_url = URL_BASE + path + r = requests.get(full_url) + if r.status_code != 200: + raise ValueError(f"Could not find file at {full_url}") + with tempfile.TemporaryDirectory() as tmpdirname: + file = tmpdirname + "/agent." + suffix + with open(file, "wb") as f: + f.write(r.content) + return _load_agent_from_file(file) + + +def _load_agent_from_file(file: Union[str, Path], **kwargs: Any) -> Agent: """Load agent from file.""" # Convert file to Path object. if isinstance(file, str):