diff --git a/docs/modules/agents/examples/load_from_hub.ipynb b/docs/modules/agents/examples/load_from_hub.ipynb index 7b5e0b1b0ce..bc1bc760e9f 100644 --- a/docs/modules/agents/examples/load_from_hub.ipynb +++ b/docs/modules/agents/examples/load_from_hub.ipynb @@ -62,13 +62,26 @@ "self_ask_with_search.run(\"What is the hometown of the reigning men's U.S. Open champion?\")" ] }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "3aede965", + "metadata": {}, + "source": [ + "# Pinning Dependencies\n", + "\n", + "Specific versions of LangChainHub agents can be pinned with the `lc@://` syntax." + ] + }, { "cell_type": "code", "execution_count": null, "id": "e679f7b6", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "self_ask_with_search = initialize_agent(tools, llm, agent_path=\"lc@2826ef9e8acdf88465e1e5fc8a7bf59e0f9d0a85://agents/self-ask-with-search/agent.json\", verbose=True)" + ] } ], "metadata": { diff --git a/langchain/agents/loading.py b/langchain/agents/loading.py index bdf04c02415..3298895b539 100644 --- a/langchain/agents/loading.py +++ b/langchain/agents/loading.py @@ -1,11 +1,8 @@ """Functionality for loading agents.""" import json -import os -import tempfile from pathlib import Path from typing import Any, List, Optional, Union -import requests import yaml from langchain.agents.agent import Agent @@ -16,6 +13,7 @@ from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent from langchain.agents.tools import Tool from langchain.chains.loading import load_chain, load_chain_from_config from langchain.llms.base import BaseLLM +from langchain.utilities.loading import try_load_from_hub AGENT_TO_CLASS = { "zero-shot-react-description": ZeroShotAgent, @@ -81,29 +79,14 @@ def load_agent_from_config( def load_agent(path: Union[str, Path], **kwargs: Any) -> Agent: """Unified method for loading a agent from LangChainHub or local fs.""" - if isinstance(path, str) and path.startswith("lc://agents"): - path = os.path.relpath(path, "lc://agents/") - return _load_from_hub(path, **kwargs) + if hub_result := try_load_from_hub( + path, _load_agent_from_file, "agents", {"json", "yaml"} + ): + return hub_result else: return _load_agent_from_file(path, **kwargs) -def _load_from_hub(path: str, **kwargs: Any) -> Agent: - """Load agent from hub.""" - suffix = path.split(".")[-1] - if suffix not in {"json", "yaml"}: - raise ValueError("Unsupported file type.") - full_url = URL_BASE + path - r = requests.get(full_url) - if r.status_code != 200: - raise ValueError(f"Could not find file at {full_url}") - with tempfile.TemporaryDirectory() as tmpdirname: - file = tmpdirname + "/agent." + suffix - with open(file, "wb") as f: - f.write(r.content) - return _load_agent_from_file(file, **kwargs) - - def _load_agent_from_file(file: Union[str, Path], **kwargs: Any) -> Agent: """Load agent from file.""" # Convert file to Path object. diff --git a/langchain/chains/loading.py b/langchain/chains/loading.py index 10095d02c90..7e42969c3cc 100644 --- a/langchain/chains/loading.py +++ b/langchain/chains/loading.py @@ -1,11 +1,8 @@ """Functionality for loading chains.""" import json -import os -import tempfile from pathlib import Path from typing import Any, Union -import requests import yaml from langchain.chains.api.base import APIChain @@ -27,6 +24,7 @@ from langchain.chains.sql_database.base import SQLDatabaseChain from langchain.chains.vector_db_qa.base import VectorDBQA from langchain.llms.loading import load_llm, load_llm_from_config from langchain.prompts.loading import load_prompt, load_prompt_from_config +from langchain.utilities.loading import try_load_from_hub URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/chains/" @@ -441,9 +439,10 @@ def load_chain_from_config(config: dict, **kwargs: Any) -> Chain: def load_chain(path: Union[str, Path], **kwargs: Any) -> Chain: """Unified method for loading a chain from LangChainHub or local fs.""" - if isinstance(path, str) and path.startswith("lc://chains"): - path = os.path.relpath(path, "lc://chains/") - return _load_from_hub(path, **kwargs) + if hub_result := try_load_from_hub( + path, _load_chain_from_file, "chains", {"json", "yaml"} + ): + return hub_result else: return _load_chain_from_file(path, **kwargs) @@ -466,19 +465,3 @@ def _load_chain_from_file(file: Union[str, Path], **kwargs: Any) -> Chain: raise ValueError("File type must be json or yaml") # Load the chain from the config now. return load_chain_from_config(config, **kwargs) - - -def _load_from_hub(path: str, **kwargs: Any) -> Chain: - """Load chain from hub.""" - suffix = path.split(".")[-1] - if suffix not in {"json", "yaml"}: - raise ValueError("Unsupported file type.") - full_url = URL_BASE + path - r = requests.get(full_url) - if r.status_code != 200: - raise ValueError(f"Could not find file at {full_url}") - with tempfile.TemporaryDirectory() as tmpdirname: - file = tmpdirname + "/chain." + suffix - with open(file, "wb") as f: - f.write(r.content) - return _load_chain_from_file(file, **kwargs) diff --git a/langchain/prompts/loading.py b/langchain/prompts/loading.py index eaae6ad6c5f..30be599a7a1 100644 --- a/langchain/prompts/loading.py +++ b/langchain/prompts/loading.py @@ -1,17 +1,15 @@ """Load prompts from disk.""" import importlib import json -import os -import tempfile from pathlib import Path from typing import Union -import requests import yaml from langchain.prompts.base import BasePromptTemplate, RegexParser from langchain.prompts.few_shot import FewShotPromptTemplate from langchain.prompts.prompt import PromptTemplate +from langchain.utilities.loading import try_load_from_hub URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/" @@ -114,9 +112,10 @@ def _load_prompt(config: dict) -> PromptTemplate: def load_prompt(path: Union[str, Path]) -> BasePromptTemplate: """Unified method for loading a prompt from LangChainHub or local fs.""" - if isinstance(path, str) and path.startswith("lc://prompts"): - path = os.path.relpath(path, "lc://prompts/") - return _load_from_hub(path) + if hub_result := try_load_from_hub( + path, _load_prompt_from_file, "prompts", {"py", "json", "yaml"} + ): + return hub_result else: return _load_prompt_from_file(path) @@ -151,19 +150,3 @@ def _load_prompt_from_file(file: Union[str, Path]) -> BasePromptTemplate: raise ValueError(f"Got unsupported file type {file_path.suffix}") # Load the prompt from the config now. return load_prompt_from_config(config) - - -def _load_from_hub(path: str) -> BasePromptTemplate: - """Load prompt from hub.""" - suffix = path.split(".")[-1] - if suffix not in {"py", "json", "yaml"}: - raise ValueError("Unsupported file type.") - full_url = URL_BASE + path - r = requests.get(full_url) - if r.status_code != 200: - raise ValueError(f"Could not find file at {full_url}") - with tempfile.TemporaryDirectory() as tmpdirname: - file = tmpdirname + "/prompt." + suffix - with open(file, "wb") as f: - f.write(r.content) - return _load_prompt_from_file(file) diff --git a/langchain/utilities/loading.py b/langchain/utilities/loading.py new file mode 100644 index 00000000000..6b70318d450 --- /dev/null +++ b/langchain/utilities/loading.py @@ -0,0 +1,49 @@ +"""Utilities for loading configurations from langchian-hub.""" + +import os +import re +import tempfile +from pathlib import Path +from typing import Any, Callable, Optional, Set, TypeVar, Union +from urllib.parse import urljoin + +import requests + +DEFAULT_REF = os.environ.get("LANGCHAIN_HUB_DEFAULT_REF", "master") +URL_BASE = os.environ.get( + "LANGCHAIN_HUB_URL_BASE", + "https://raw.githubusercontent.com/hwchase17/langchain-hub/{ref}/", +) +HUB_PATH_RE = re.compile(r"lc(?P@[^:]+)?://(?P.*)") + + +T = TypeVar("T") + + +def try_load_from_hub( + path: Union[str, Path], + loader: Callable[[str], T], + valid_prefix: str, + valid_suffixes: Set[str], + **kwargs: Any, +) -> Optional[T]: + """Load configuration from hub. Returns None if path is not a hub path.""" + if not isinstance(path, str) or not (match := HUB_PATH_RE.match(path)): + return None + ref, remote_path_str = match.groups() + ref = ref[1:] if ref else DEFAULT_REF + remote_path = Path(remote_path_str) + if remote_path.parts[0] != valid_prefix: + return None + if remote_path.suffix[1:] not in valid_suffixes: + raise ValueError("Unsupported file type.") + + full_url = urljoin(URL_BASE.format(ref=ref), str(remote_path)) + r = requests.get(full_url, timeout=5) + if r.status_code != 200: + raise ValueError(f"Could not find file at {full_url}") + with tempfile.TemporaryDirectory() as tmpdirname: + file = Path(tmpdirname) / remote_path.name + with open(file, "wb") as f: + f.write(r.content) + return loader(str(file), **kwargs) diff --git a/poetry.lock b/poetry.lock index 1e3ffe4e491..dc2adb7c7a7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2270,6 +2270,7 @@ files = [ {file = "lxml-4.9.2-cp35-cp35m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ca989b91cf3a3ba28930a9fc1e9aeafc2a395448641df1f387a2d394638943b0"}, {file = "lxml-4.9.2-cp35-cp35m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:822068f85e12a6e292803e112ab876bc03ed1f03dddb80154c395f891ca6b31e"}, {file = "lxml-4.9.2-cp35-cp35m-win32.whl", hash = "sha256:be7292c55101e22f2a3d4d8913944cbea71eea90792bf914add27454a13905df"}, + {file = "lxml-4.9.2-cp35-cp35m-win_amd64.whl", hash = "sha256:998c7c41910666d2976928c38ea96a70d1aa43be6fe502f21a651e17483a43c5"}, {file = "lxml-4.9.2-cp36-cp36m-macosx_10_15_x86_64.whl", hash = "sha256:b26a29f0b7fc6f0897f043ca366142d2b609dc60756ee6e4e90b5f762c6adc53"}, {file = "lxml-4.9.2-cp36-cp36m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_24_i686.whl", hash = "sha256:ab323679b8b3030000f2be63e22cdeea5b47ee0abd2d6a1dc0c8103ddaa56cd7"}, {file = "lxml-4.9.2-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:689bb688a1db722485e4610a503e3e9210dcc20c520b45ac8f7533c837be76fe"}, @@ -2279,6 +2280,7 @@ files = [ {file = "lxml-4.9.2-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:58bfa3aa19ca4c0f28c5dde0ff56c520fbac6f0daf4fac66ed4c8d2fb7f22e74"}, {file = "lxml-4.9.2-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:bc718cd47b765e790eecb74d044cc8d37d58562f6c314ee9484df26276d36a38"}, {file = "lxml-4.9.2-cp36-cp36m-win32.whl", hash = "sha256:d5bf6545cd27aaa8a13033ce56354ed9e25ab0e4ac3b5392b763d8d04b08e0c5"}, + {file = "lxml-4.9.2-cp36-cp36m-win_amd64.whl", hash = "sha256:3ab9fa9d6dc2a7f29d7affdf3edebf6ece6fb28a6d80b14c3b2fb9d39b9322c3"}, {file = "lxml-4.9.2-cp37-cp37m-macosx_10_15_x86_64.whl", hash = "sha256:05ca3f6abf5cf78fe053da9b1166e062ade3fa5d4f92b4ed688127ea7d7b1d03"}, {file = "lxml-4.9.2-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_24_i686.whl", hash = "sha256:a5da296eb617d18e497bcf0a5c528f5d3b18dadb3619fbdadf4ed2356ef8d941"}, {file = "lxml-4.9.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:04876580c050a8c5341d706dd464ff04fd597095cc8c023252566a8826505726"}, @@ -4176,6 +4178,27 @@ urllib3 = ">=1.21.1,<1.27" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "responses" +version = "0.22.0" +description = "A utility library for mocking out the `requests` Python library." +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "responses-0.22.0-py3-none-any.whl", hash = "sha256:dcf294d204d14c436fddcc74caefdbc5764795a40ff4e6a7740ed8ddbf3294be"}, + {file = "responses-0.22.0.tar.gz", hash = "sha256:396acb2a13d25297789a5866b4881cf4e46ffd49cc26c43ab1117f40b973102e"}, +] + +[package.dependencies] +requests = ">=2.22.0,<3.0" +toml = "*" +types-toml = "*" +urllib3 = ">=1.25.10" + +[package.extras] +tests = ["coverage (>=6.0.0)", "flake8", "mypy", "pytest (>=7.0.0)", "pytest-asyncio", "pytest-cov", "pytest-httpserver", "types-requests"] + [[package]] name = "rfc3339-validator" version = "0.1.4" @@ -5010,11 +5033,14 @@ files = [ {file = "tokenizers-0.13.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47ef745dbf9f49281e900e9e72915356d69de3a4e4d8a475bda26bfdb5047736"}, {file = "tokenizers-0.13.2-cp310-cp310-win32.whl", hash = "sha256:96cedf83864bcc15a3ffd088a6f81a8a8f55b8b188eabd7a7f2a4469477036df"}, {file = "tokenizers-0.13.2-cp310-cp310-win_amd64.whl", hash = "sha256:eda77de40a0262690c666134baf19ec5c4f5b8bde213055911d9f5a718c506e1"}, + {file = "tokenizers-0.13.2-cp311-cp311-macosx_10_11_universal2.whl", hash = "sha256:9eee037bb5aa14daeb56b4c39956164b2bebbe6ab4ca7779d88aa16b79bd4e17"}, + {file = "tokenizers-0.13.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:d1b079c4c9332048fec4cb9c2055c2373c74fbb336716a5524c9a720206d787e"}, {file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a689654fc745135cce4eea3b15e29c372c3e0b01717c6978b563de5c38af9811"}, {file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3606528c07cda0566cff6cbfbda2b167f923661be595feac95701ffcdcbdbb21"}, {file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:41291d0160946084cbd53c8ec3d029df3dc2af2673d46b25ff1a7f31a9d55d51"}, {file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7892325f9ca1cc5fca0333d5bfd96a19044ce9b092ce2df625652109a3de16b8"}, {file = "tokenizers-0.13.2-cp311-cp311-win32.whl", hash = "sha256:93714958d4ebe5362d3de7a6bd73dc86c36b5af5941ebef6c325ac900fa58865"}, + {file = "tokenizers-0.13.2-cp311-cp311-win_amd64.whl", hash = "sha256:fa7ef7ee380b1f49211bbcfac8a006b1a3fa2fa4c7f4ee134ae384eb4ea5e453"}, {file = "tokenizers-0.13.2-cp37-cp37m-macosx_10_11_x86_64.whl", hash = "sha256:da521bfa94df6a08a6254bb8214ea04854bb9044d61063ae2529361688b5440a"}, {file = "tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a739d4d973d422e1073989769723f3b6ad8b11e59e635a63de99aea4b2208188"}, {file = "tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cac01fc0b868e4d0a3aa7c5c53396da0a0a63136e81475d32fcf5c348fcb2866"}, @@ -5023,6 +5049,7 @@ files = [ {file = "tokenizers-0.13.2-cp37-cp37m-win32.whl", hash = "sha256:a537061ee18ba104b7f3daa735060c39db3a22c8a9595845c55b6c01d36c5e87"}, {file = "tokenizers-0.13.2-cp37-cp37m-win_amd64.whl", hash = "sha256:c82fb87b1cbfa984d8f05b2b3c3c73e428b216c1d4f0e286d0a3b27f521b32eb"}, {file = "tokenizers-0.13.2-cp38-cp38-macosx_10_11_x86_64.whl", hash = "sha256:ce298605a833ac7f81b8062d3102a42dcd9fa890493e8f756112c346339fe5c5"}, + {file = "tokenizers-0.13.2-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:f44d59bafe3d61e8a56b9e0a963075187c0f0091023120b13fbe37a87936f171"}, {file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a51b93932daba12ed07060935978a6779593a59709deab04a0d10e6fd5c29e60"}, {file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6969e5ea7ccb909ce7d6d4dfd009115dc72799b0362a2ea353267168667408c4"}, {file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:92f040c4d938ea64683526b45dfc81c580e3b35aaebe847e7eec374961231734"}, @@ -5666,4 +5693,4 @@ llms = ["manifest-ml", "torch", "transformers"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "537ede877b299a8800eb26a428607be97b59f89435259abda2c7cc86092306c6" +content-hash = "bed5e0cb4cfa8b6173dd9574982bd9154ebc61721704002b85474af3c2b675ca" diff --git a/pyproject.toml b/pyproject.toml index 39cb771a452..d093d5b3baf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,7 @@ pytest-dotenv = "^0.5.2" duckdb-engine = "^0.6.6" pytest-watcher = "^0.2.6" freezegun = "^1.2.2" +responses = "^0.22.0" [tool.poetry.group.lint.dependencies] flake8-docstrings = "^1.6.0" diff --git a/tests/unit_tests/utilities/__init__.py b/tests/unit_tests/utilities/__init__.py new file mode 100644 index 00000000000..c3bc6c0518f --- /dev/null +++ b/tests/unit_tests/utilities/__init__.py @@ -0,0 +1 @@ +"""Tests utilities module.""" diff --git a/tests/unit_tests/utilities/test_loading.py b/tests/unit_tests/utilities/test_loading.py new file mode 100644 index 00000000000..380297f740c --- /dev/null +++ b/tests/unit_tests/utilities/test_loading.py @@ -0,0 +1,93 @@ +"""Test the functionality of loading from langchain-hub.""" + +import json +import re +from pathlib import Path +from typing import Iterable +from unittest.mock import Mock +from urllib.parse import urljoin + +import pytest +import responses + +from langchain.utilities.loading import DEFAULT_REF, URL_BASE, try_load_from_hub + + +@pytest.fixture(autouse=True) +def mocked_responses() -> Iterable[responses.RequestsMock]: + """Fixture mocking requests.get.""" + with responses.RequestsMock() as rsps: + yield rsps + + +def test_non_hub_path() -> None: + """Test that a non-hub path returns None.""" + path = "chains/some_path" + loader = Mock() + valid_suffixes = {"suffix"} + result = try_load_from_hub(path, loader, "chains", valid_suffixes) + + assert result is None + loader.assert_not_called() + + +def test_invalid_prefix() -> None: + """Test that a hub path with an invalid prefix returns None.""" + path = "lc://agents/some_path" + loader = Mock() + valid_suffixes = {"suffix"} + result = try_load_from_hub(path, loader, "chains", valid_suffixes) + + assert result is None + loader.assert_not_called() + + +def test_invalid_suffix() -> None: + """Test that a hub path with an invalid suffix raises an error.""" + path = "lc://chains/path.invalid" + loader = Mock() + valid_suffixes = {"json"} + + with pytest.raises(ValueError, match="Unsupported file type."): + try_load_from_hub(path, loader, "chains", valid_suffixes) + + loader.assert_not_called() + + +@pytest.mark.parametrize("ref", [None, "v0.3"]) +def test_success(mocked_responses: responses.RequestsMock, ref: str) -> None: + """Test that a valid hub path is loaded correctly with and without a ref.""" + path = "chains/path/chain.json" + lc_path_prefix = f"lc{('@' + ref) if ref else ''}://" + valid_suffixes = {"json"} + body = json.dumps({"foo": "bar"}) + ref = ref or DEFAULT_REF + + file_contents = None + + def loader(file_path: str) -> None: + nonlocal file_contents + assert file_contents is None + file_contents = Path(file_path).read_text() + + mocked_responses.get( + urljoin(URL_BASE.format(ref=ref), path), + body=body, + status=200, + content_type="application/json", + ) + + try_load_from_hub(f"{lc_path_prefix}{path}", loader, "chains", valid_suffixes) + assert file_contents == body + + +def test_failed_request(mocked_responses: responses.RequestsMock) -> None: + """Test that a failed request raises an error.""" + path = "chains/path/chain.json" + loader = Mock() + + mocked_responses.get(urljoin(URL_BASE.format(ref=DEFAULT_REF), path), status=500) + + with pytest.raises(ValueError, match=re.compile("Could not find file at .*")): + try_load_from_hub(f"lc://{path}", loader, "chains", {"json"}) + loader.assert_not_called()