diff --git a/libs/langchain/langchain/agents/loading.py b/libs/langchain/langchain/agents/loading.py index 7b3137268dd..8a0c4a9b3f0 100644 --- a/libs/langchain/langchain/agents/loading.py +++ b/libs/langchain/langchain/agents/loading.py @@ -97,8 +97,9 @@ def load_agent( Returns: An agent executor. """ + valid_suffixes = {"json", "yaml"} if hub_result := try_load_from_hub( - path, _load_agent_from_file, "agents", {"json", "yaml"} + path, _load_agent_from_file, "agents", valid_suffixes ): return hub_result else: @@ -109,19 +110,20 @@ def _load_agent_from_file( file: Union[str, Path], **kwargs: Any ) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]: """Load agent from file.""" + valid_suffixes = {"json", "yaml"} # Convert file to Path object. if isinstance(file, str): file_path = Path(file) else: file_path = file # Load from either json or yaml. - if file_path.suffix == ".json": + if file_path.suffix[1:] == "json": with open(file_path) as f: config = json.load(f) - elif file_path.suffix == ".yaml": + elif file_path.suffix[1:] == "yaml": with open(file_path, "r") as f: config = yaml.safe_load(f) else: - raise ValueError("File type must be json or yaml") + raise ValueError(f"Unsupported file type, must be one of {valid_suffixes}.") # Load the agent from the config now. return load_agent_from_config(config, **kwargs) diff --git a/libs/langchain/langchain/utilities/loading.py b/libs/langchain/langchain/utilities/loading.py index f694c1a1bd2..60f3e3cf7d4 100644 --- a/libs/langchain/langchain/utilities/loading.py +++ b/libs/langchain/langchain/utilities/loading.py @@ -35,7 +35,7 @@ def try_load_from_hub( if remote_path.parts[0] != valid_prefix: return None if remote_path.suffix[1:] not in valid_suffixes: - raise ValueError("Unsupported file type.") + raise ValueError(f"Unsupported file type, must be one of {valid_suffixes}.") # Using Path with URLs is not recommended, because on Windows # the backslash is used as the path separator, which can cause issues diff --git a/libs/langchain/tests/unit_tests/utilities/test_loading.py b/libs/langchain/tests/unit_tests/utilities/test_loading.py index 468cc6590eb..f9a275fd819 100644 --- a/libs/langchain/tests/unit_tests/utilities/test_loading.py +++ b/libs/langchain/tests/unit_tests/utilities/test_loading.py @@ -48,7 +48,9 @@ def test_invalid_suffix() -> None: loader = Mock() valid_suffixes = {"json"} - with pytest.raises(ValueError, match="Unsupported file type."): + with pytest.raises( + ValueError, match=f"Unsupported file type, must be one of {valid_suffixes}." + ): try_load_from_hub(path, loader, "chains", valid_suffixes) loader.assert_not_called()