From 608629225228a1d9fb380a21f54e392524b37637 Mon Sep 17 00:00:00 2001 From: Roy Williams Date: Mon, 30 Jan 2023 17:52:17 -0500 Subject: [PATCH] Centralize logic for loading from LangChainHub, add ability to pin dependencies (#805) It's generally considered to be a good practice to pin dependencies to prevent surprise breakages when a new version of a dependency is released. This commit adds the ability to pin dependencies when loading from LangChainHub. Centralizing this logic and using urllib fixes an issue identified by some windows users highlighted in this video - https://youtu.be/aJ6IQUh8MLQ?t=537 --- .../agents/examples/load_from_hub.ipynb | 15 ++- langchain/agents/loading.py | 27 +----- langchain/chains/loading.py | 27 +----- langchain/prompts/loading.py | 27 +----- langchain/utilities/loading.py | 49 ++++++++++ poetry.lock | 29 +++++- pyproject.toml | 1 + tests/unit_tests/utilities/__init__.py | 1 + tests/unit_tests/utilities/test_loading.py | 93 +++++++++++++++++++ 9 files changed, 201 insertions(+), 68 deletions(-) create mode 100644 langchain/utilities/loading.py create mode 100644 tests/unit_tests/utilities/__init__.py create mode 100644 tests/unit_tests/utilities/test_loading.py diff --git a/docs/modules/agents/examples/load_from_hub.ipynb b/docs/modules/agents/examples/load_from_hub.ipynb index 7b5e0b1b0ce..bc1bc760e9f 100644 --- a/docs/modules/agents/examples/load_from_hub.ipynb +++ b/docs/modules/agents/examples/load_from_hub.ipynb @@ -62,13 +62,26 @@ "self_ask_with_search.run(\"What is the hometown of the reigning men's U.S. Open champion?\")" ] }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "3aede965", + "metadata": {}, + "source": [ + "# Pinning Dependencies\n", + "\n", + "Specific versions of LangChainHub agents can be pinned with the `lc@://` syntax." + ] + }, { "cell_type": "code", "execution_count": null, "id": "e679f7b6", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "self_ask_with_search = initialize_agent(tools, llm, agent_path=\"lc@2826ef9e8acdf88465e1e5fc8a7bf59e0f9d0a85://agents/self-ask-with-search/agent.json\", verbose=True)" + ] } ], "metadata": { diff --git a/langchain/agents/loading.py b/langchain/agents/loading.py index bdf04c02415..3298895b539 100644 --- a/langchain/agents/loading.py +++ b/langchain/agents/loading.py @@ -1,11 +1,8 @@ """Functionality for loading agents.""" import json -import os -import tempfile from pathlib import Path from typing import Any, List, Optional, Union -import requests import yaml from langchain.agents.agent import Agent @@ -16,6 +13,7 @@ from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent from langchain.agents.tools import Tool from langchain.chains.loading import load_chain, load_chain_from_config from langchain.llms.base import BaseLLM +from langchain.utilities.loading import try_load_from_hub AGENT_TO_CLASS = { "zero-shot-react-description": ZeroShotAgent, @@ -81,29 +79,14 @@ def load_agent_from_config( def load_agent(path: Union[str, Path], **kwargs: Any) -> Agent: """Unified method for loading a agent from LangChainHub or local fs.""" - if isinstance(path, str) and path.startswith("lc://agents"): - path = os.path.relpath(path, "lc://agents/") - return _load_from_hub(path, **kwargs) + if hub_result := try_load_from_hub( + path, _load_agent_from_file, "agents", {"json", "yaml"} + ): + return hub_result else: return _load_agent_from_file(path, **kwargs) -def _load_from_hub(path: str, **kwargs: Any) -> Agent: - """Load agent from hub.""" - suffix = path.split(".")[-1] - if suffix not in {"json", "yaml"}: - raise ValueError("Unsupported file type.") - full_url = URL_BASE + path - r = requests.get(full_url) - if r.status_code != 200: - raise ValueError(f"Could not find file at {full_url}") - with tempfile.TemporaryDirectory() as tmpdirname: - file = tmpdirname + "/agent." + suffix - with open(file, "wb") as f: - f.write(r.content) - return _load_agent_from_file(file, **kwargs) - - def _load_agent_from_file(file: Union[str, Path], **kwargs: Any) -> Agent: """Load agent from file.""" # Convert file to Path object. diff --git a/langchain/chains/loading.py b/langchain/chains/loading.py index 10095d02c90..7e42969c3cc 100644 --- a/langchain/chains/loading.py +++ b/langchain/chains/loading.py @@ -1,11 +1,8 @@ """Functionality for loading chains.""" import json -import os -import tempfile from pathlib import Path from typing import Any, Union -import requests import yaml from langchain.chains.api.base import APIChain @@ -27,6 +24,7 @@ from langchain.chains.sql_database.base import SQLDatabaseChain from langchain.chains.vector_db_qa.base import VectorDBQA from langchain.llms.loading import load_llm, load_llm_from_config from langchain.prompts.loading import load_prompt, load_prompt_from_config +from langchain.utilities.loading import try_load_from_hub URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/chains/" @@ -441,9 +439,10 @@ def load_chain_from_config(config: dict, **kwargs: Any) -> Chain: def load_chain(path: Union[str, Path], **kwargs: Any) -> Chain: """Unified method for loading a chain from LangChainHub or local fs.""" - if isinstance(path, str) and path.startswith("lc://chains"): - path = os.path.relpath(path, "lc://chains/") - return _load_from_hub(path, **kwargs) + if hub_result := try_load_from_hub( + path, _load_chain_from_file, "chains", {"json", "yaml"} + ): + return hub_result else: return _load_chain_from_file(path, **kwargs) @@ -466,19 +465,3 @@ def _load_chain_from_file(file: Union[str, Path], **kwargs: Any) -> Chain: raise ValueError("File type must be json or yaml") # Load the chain from the config now. return load_chain_from_config(config, **kwargs) - - -def _load_from_hub(path: str, **kwargs: Any) -> Chain: - """Load chain from hub.""" - suffix = path.split(".")[-1] - if suffix not in {"json", "yaml"}: - raise ValueError("Unsupported file type.") - full_url = URL_BASE + path - r = requests.get(full_url) - if r.status_code != 200: - raise ValueError(f"Could not find file at {full_url}") - with tempfile.TemporaryDirectory() as tmpdirname: - file = tmpdirname + "/chain." + suffix - with open(file, "wb") as f: - f.write(r.content) - return _load_chain_from_file(file, **kwargs) diff --git a/langchain/prompts/loading.py b/langchain/prompts/loading.py index eaae6ad6c5f..30be599a7a1 100644 --- a/langchain/prompts/loading.py +++ b/langchain/prompts/loading.py @@ -1,17 +1,15 @@ """Load prompts from disk.""" import importlib import json -import os -import tempfile from pathlib import Path from typing import Union -import requests import yaml from langchain.prompts.base import BasePromptTemplate, RegexParser from langchain.prompts.few_shot import FewShotPromptTemplate from langchain.prompts.prompt import PromptTemplate +from langchain.utilities.loading import try_load_from_hub URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/" @@ -114,9 +112,10 @@ def _load_prompt(config: dict) -> PromptTemplate: def load_prompt(path: Union[str, Path]) -> BasePromptTemplate: """Unified method for loading a prompt from LangChainHub or local fs.""" - if isinstance(path, str) and path.startswith("lc://prompts"): - path = os.path.relpath(path, "lc://prompts/") - return _load_from_hub(path) + if hub_result := try_load_from_hub( + path, _load_prompt_from_file, "prompts", {"py", "json", "yaml"} + ): + return hub_result else: return _load_prompt_from_file(path) @@ -151,19 +150,3 @@ def _load_prompt_from_file(file: Union[str, Path]) -> BasePromptTemplate: raise ValueError(f"Got unsupported file type {file_path.suffix}") # Load the prompt from the config now. return load_prompt_from_config(config) - - -def _load_from_hub(path: str) -> BasePromptTemplate: - """Load prompt from hub.""" - suffix = path.split(".")[-1] - if suffix not in {"py", "json", "yaml"}: - raise ValueError("Unsupported file type.") - full_url = URL_BASE + path - r = requests.get(full_url) - if r.status_code != 200: - raise ValueError(f"Could not find file at {full_url}") - with tempfile.TemporaryDirectory() as tmpdirname: - file = tmpdirname + "/prompt." + suffix - with open(file, "wb") as f: - f.write(r.content) - return _load_prompt_from_file(file) diff --git a/langchain/utilities/loading.py b/langchain/utilities/loading.py new file mode 100644 index 00000000000..6b70318d450 --- /dev/null +++ b/langchain/utilities/loading.py @@ -0,0 +1,49 @@ +"""Utilities for loading configurations from langchian-hub.""" + +import os +import re +import tempfile +from pathlib import Path +from typing import Any, Callable, Optional, Set, TypeVar, Union +from urllib.parse import urljoin + +import requests + +DEFAULT_REF = os.environ.get("LANGCHAIN_HUB_DEFAULT_REF", "master") +URL_BASE = os.environ.get( + "LANGCHAIN_HUB_URL_BASE", + "https://raw.githubusercontent.com/hwchase17/langchain-hub/{ref}/", +) +HUB_PATH_RE = re.compile(r"lc(?P@[^:]+)?://(?P.*)") + + +T = TypeVar("T") + + +def try_load_from_hub( + path: Union[str, Path], + loader: Callable[[str], T], + valid_prefix: str, + valid_suffixes: Set[str], + **kwargs: Any, +) -> Optional[T]: + """Load configuration from hub. Returns None if path is not a hub path.""" + if not isinstance(path, str) or not (match := HUB_PATH_RE.match(path)): + return None + ref, remote_path_str = match.groups() + ref = ref[1:] if ref else DEFAULT_REF + remote_path = Path(remote_path_str) + if remote_path.parts[0] != valid_prefix: + return None + if remote_path.suffix[1:] not in valid_suffixes: + raise ValueError("Unsupported file type.") + + full_url = urljoin(URL_BASE.format(ref=ref), str(remote_path)) + r = requests.get(full_url, timeout=5) + if r.status_code != 200: + raise ValueError(f"Could not find file at {full_url}") + with tempfile.TemporaryDirectory() as tmpdirname: + file = Path(tmpdirname) / remote_path.name + with open(file, "wb") as f: + f.write(r.content) + return loader(str(file), **kwargs) diff --git a/poetry.lock b/poetry.lock index 1e3ffe4e491..dc2adb7c7a7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2270,6 +2270,7 @@ files = [ {file = "lxml-4.9.2-cp35-cp35m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ca989b91cf3a3ba28930a9fc1e9aeafc2a395448641df1f387a2d394638943b0"}, {file = "lxml-4.9.2-cp35-cp35m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:822068f85e12a6e292803e112ab876bc03ed1f03dddb80154c395f891ca6b31e"}, {file = "lxml-4.9.2-cp35-cp35m-win32.whl", hash = "sha256:be7292c55101e22f2a3d4d8913944cbea71eea90792bf914add27454a13905df"}, + {file = "lxml-4.9.2-cp35-cp35m-win_amd64.whl", hash = "sha256:998c7c41910666d2976928c38ea96a70d1aa43be6fe502f21a651e17483a43c5"}, {file = "lxml-4.9.2-cp36-cp36m-macosx_10_15_x86_64.whl", hash = "sha256:b26a29f0b7fc6f0897f043ca366142d2b609dc60756ee6e4e90b5f762c6adc53"}, {file = "lxml-4.9.2-cp36-cp36m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_24_i686.whl", hash = "sha256:ab323679b8b3030000f2be63e22cdeea5b47ee0abd2d6a1dc0c8103ddaa56cd7"}, {file = "lxml-4.9.2-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:689bb688a1db722485e4610a503e3e9210dcc20c520b45ac8f7533c837be76fe"}, @@ -2279,6 +2280,7 @@ files = [ {file = "lxml-4.9.2-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:58bfa3aa19ca4c0f28c5dde0ff56c520fbac6f0daf4fac66ed4c8d2fb7f22e74"}, {file = "lxml-4.9.2-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:bc718cd47b765e790eecb74d044cc8d37d58562f6c314ee9484df26276d36a38"}, {file = "lxml-4.9.2-cp36-cp36m-win32.whl", hash = "sha256:d5bf6545cd27aaa8a13033ce56354ed9e25ab0e4ac3b5392b763d8d04b08e0c5"}, + {file = "lxml-4.9.2-cp36-cp36m-win_amd64.whl", hash = "sha256:3ab9fa9d6dc2a7f29d7affdf3edebf6ece6fb28a6d80b14c3b2fb9d39b9322c3"}, {file = "lxml-4.9.2-cp37-cp37m-macosx_10_15_x86_64.whl", hash = "sha256:05ca3f6abf5cf78fe053da9b1166e062ade3fa5d4f92b4ed688127ea7d7b1d03"}, {file = "lxml-4.9.2-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_24_i686.whl", hash = "sha256:a5da296eb617d18e497bcf0a5c528f5d3b18dadb3619fbdadf4ed2356ef8d941"}, {file = "lxml-4.9.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:04876580c050a8c5341d706dd464ff04fd597095cc8c023252566a8826505726"}, @@ -4176,6 +4178,27 @@ urllib3 = ">=1.21.1,<1.27" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "responses" +version = "0.22.0" +description = "A utility library for mocking out the `requests` Python library." +category = "dev" +optional = false +python-versions = ">=3.7" +files = [ + {file = "responses-0.22.0-py3-none-any.whl", hash = "sha256:dcf294d204d14c436fddcc74caefdbc5764795a40ff4e6a7740ed8ddbf3294be"}, + {file = "responses-0.22.0.tar.gz", hash = "sha256:396acb2a13d25297789a5866b4881cf4e46ffd49cc26c43ab1117f40b973102e"}, +] + +[package.dependencies] +requests = ">=2.22.0,<3.0" +toml = "*" +types-toml = "*" +urllib3 = ">=1.25.10" + +[package.extras] +tests = ["coverage (>=6.0.0)", "flake8", "mypy", "pytest (>=7.0.0)", "pytest-asyncio", "pytest-cov", "pytest-httpserver", "types-requests"] + [[package]] name = "rfc3339-validator" version = "0.1.4" @@ -5010,11 +5033,14 @@ files = [ {file = "tokenizers-0.13.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47ef745dbf9f49281e900e9e72915356d69de3a4e4d8a475bda26bfdb5047736"}, {file = "tokenizers-0.13.2-cp310-cp310-win32.whl", hash = "sha256:96cedf83864bcc15a3ffd088a6f81a8a8f55b8b188eabd7a7f2a4469477036df"}, {file = "tokenizers-0.13.2-cp310-cp310-win_amd64.whl", hash = "sha256:eda77de40a0262690c666134baf19ec5c4f5b8bde213055911d9f5a718c506e1"}, + {file = "tokenizers-0.13.2-cp311-cp311-macosx_10_11_universal2.whl", hash = "sha256:9eee037bb5aa14daeb56b4c39956164b2bebbe6ab4ca7779d88aa16b79bd4e17"}, + {file = "tokenizers-0.13.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:d1b079c4c9332048fec4cb9c2055c2373c74fbb336716a5524c9a720206d787e"}, {file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a689654fc745135cce4eea3b15e29c372c3e0b01717c6978b563de5c38af9811"}, {file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3606528c07cda0566cff6cbfbda2b167f923661be595feac95701ffcdcbdbb21"}, {file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:41291d0160946084cbd53c8ec3d029df3dc2af2673d46b25ff1a7f31a9d55d51"}, {file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7892325f9ca1cc5fca0333d5bfd96a19044ce9b092ce2df625652109a3de16b8"}, {file = "tokenizers-0.13.2-cp311-cp311-win32.whl", hash = "sha256:93714958d4ebe5362d3de7a6bd73dc86c36b5af5941ebef6c325ac900fa58865"}, + {file = "tokenizers-0.13.2-cp311-cp311-win_amd64.whl", hash = "sha256:fa7ef7ee380b1f49211bbcfac8a006b1a3fa2fa4c7f4ee134ae384eb4ea5e453"}, {file = "tokenizers-0.13.2-cp37-cp37m-macosx_10_11_x86_64.whl", hash = "sha256:da521bfa94df6a08a6254bb8214ea04854bb9044d61063ae2529361688b5440a"}, {file = "tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a739d4d973d422e1073989769723f3b6ad8b11e59e635a63de99aea4b2208188"}, {file = "tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cac01fc0b868e4d0a3aa7c5c53396da0a0a63136e81475d32fcf5c348fcb2866"}, @@ -5023,6 +5049,7 @@ files = [ {file = "tokenizers-0.13.2-cp37-cp37m-win32.whl", hash = "sha256:a537061ee18ba104b7f3daa735060c39db3a22c8a9595845c55b6c01d36c5e87"}, {file = "tokenizers-0.13.2-cp37-cp37m-win_amd64.whl", hash = "sha256:c82fb87b1cbfa984d8f05b2b3c3c73e428b216c1d4f0e286d0a3b27f521b32eb"}, {file = "tokenizers-0.13.2-cp38-cp38-macosx_10_11_x86_64.whl", hash = "sha256:ce298605a833ac7f81b8062d3102a42dcd9fa890493e8f756112c346339fe5c5"}, + {file = "tokenizers-0.13.2-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:f44d59bafe3d61e8a56b9e0a963075187c0f0091023120b13fbe37a87936f171"}, {file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a51b93932daba12ed07060935978a6779593a59709deab04a0d10e6fd5c29e60"}, {file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6969e5ea7ccb909ce7d6d4dfd009115dc72799b0362a2ea353267168667408c4"}, {file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:92f040c4d938ea64683526b45dfc81c580e3b35aaebe847e7eec374961231734"}, @@ -5666,4 +5693,4 @@ llms = ["manifest-ml", "torch", "transformers"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "537ede877b299a8800eb26a428607be97b59f89435259abda2c7cc86092306c6" +content-hash = "bed5e0cb4cfa8b6173dd9574982bd9154ebc61721704002b85474af3c2b675ca" diff --git a/pyproject.toml b/pyproject.toml index 39cb771a452..d093d5b3baf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,7 @@ pytest-dotenv = "^0.5.2" duckdb-engine = "^0.6.6" pytest-watcher = "^0.2.6" freezegun = "^1.2.2" +responses = "^0.22.0" [tool.poetry.group.lint.dependencies] flake8-docstrings = "^1.6.0" diff --git a/tests/unit_tests/utilities/__init__.py b/tests/unit_tests/utilities/__init__.py new file mode 100644 index 00000000000..c3bc6c0518f --- /dev/null +++ b/tests/unit_tests/utilities/__init__.py @@ -0,0 +1 @@ +"""Tests utilities module.""" diff --git a/tests/unit_tests/utilities/test_loading.py b/tests/unit_tests/utilities/test_loading.py new file mode 100644 index 00000000000..380297f740c --- /dev/null +++ b/tests/unit_tests/utilities/test_loading.py @@ -0,0 +1,93 @@ +"""Test the functionality of loading from langchain-hub.""" + +import json +import re +from pathlib import Path +from typing import Iterable +from unittest.mock import Mock +from urllib.parse import urljoin + +import pytest +import responses + +from langchain.utilities.loading import DEFAULT_REF, URL_BASE, try_load_from_hub + + +@pytest.fixture(autouse=True) +def mocked_responses() -> Iterable[responses.RequestsMock]: + """Fixture mocking requests.get.""" + with responses.RequestsMock() as rsps: + yield rsps + + +def test_non_hub_path() -> None: + """Test that a non-hub path returns None.""" + path = "chains/some_path" + loader = Mock() + valid_suffixes = {"suffix"} + result = try_load_from_hub(path, loader, "chains", valid_suffixes) + + assert result is None + loader.assert_not_called() + + +def test_invalid_prefix() -> None: + """Test that a hub path with an invalid prefix returns None.""" + path = "lc://agents/some_path" + loader = Mock() + valid_suffixes = {"suffix"} + result = try_load_from_hub(path, loader, "chains", valid_suffixes) + + assert result is None + loader.assert_not_called() + + +def test_invalid_suffix() -> None: + """Test that a hub path with an invalid suffix raises an error.""" + path = "lc://chains/path.invalid" + loader = Mock() + valid_suffixes = {"json"} + + with pytest.raises(ValueError, match="Unsupported file type."): + try_load_from_hub(path, loader, "chains", valid_suffixes) + + loader.assert_not_called() + + +@pytest.mark.parametrize("ref", [None, "v0.3"]) +def test_success(mocked_responses: responses.RequestsMock, ref: str) -> None: + """Test that a valid hub path is loaded correctly with and without a ref.""" + path = "chains/path/chain.json" + lc_path_prefix = f"lc{('@' + ref) if ref else ''}://" + valid_suffixes = {"json"} + body = json.dumps({"foo": "bar"}) + ref = ref or DEFAULT_REF + + file_contents = None + + def loader(file_path: str) -> None: + nonlocal file_contents + assert file_contents is None + file_contents = Path(file_path).read_text() + + mocked_responses.get( + urljoin(URL_BASE.format(ref=ref), path), + body=body, + status=200, + content_type="application/json", + ) + + try_load_from_hub(f"{lc_path_prefix}{path}", loader, "chains", valid_suffixes) + assert file_contents == body + + +def test_failed_request(mocked_responses: responses.RequestsMock) -> None: + """Test that a failed request raises an error.""" + path = "chains/path/chain.json" + loader = Mock() + + mocked_responses.get(urljoin(URL_BASE.format(ref=DEFAULT_REF), path), status=500) + + with pytest.raises(ValueError, match=re.compile("Could not find file at .*")): + try_load_from_hub(f"lc://{path}", loader, "chains", {"json"}) + loader.assert_not_called()