mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-13 05:25:07 +00:00
Centralize logic for loading from LangChainHub, add ability to pin dependencies (#805)
It's generally considered to be a good practice to pin dependencies to prevent surprise breakages when a new version of a dependency is released. This commit adds the ability to pin dependencies when loading from LangChainHub. Centralizing this logic and using urllib fixes an issue identified by some windows users highlighted in this video - https://youtu.be/aJ6IQUh8MLQ?t=537
This commit is contained in:
@@ -62,13 +62,26 @@
|
||||
"self_ask_with_search.run(\"What is the hometown of the reigning men's U.S. Open champion?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"id": "3aede965",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Pinning Dependencies\n",
|
||||
"\n",
|
||||
"Specific versions of LangChainHub agents can be pinned with the `lc@<ref>://` syntax."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e679f7b6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
"source": [
|
||||
"self_ask_with_search = initialize_agent(tools, llm, agent_path=\"lc@2826ef9e8acdf88465e1e5fc8a7bf59e0f9d0a85://agents/self-ask-with-search/agent.json\", verbose=True)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
@@ -1,11 +1,8 @@
|
||||
"""Functionality for loading agents."""
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import requests
|
||||
import yaml
|
||||
|
||||
from langchain.agents.agent import Agent
|
||||
@@ -16,6 +13,7 @@ from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent
|
||||
from langchain.agents.tools import Tool
|
||||
from langchain.chains.loading import load_chain, load_chain_from_config
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.utilities.loading import try_load_from_hub
|
||||
|
||||
AGENT_TO_CLASS = {
|
||||
"zero-shot-react-description": ZeroShotAgent,
|
||||
@@ -81,29 +79,14 @@ def load_agent_from_config(
|
||||
|
||||
def load_agent(path: Union[str, Path], **kwargs: Any) -> Agent:
|
||||
"""Unified method for loading a agent from LangChainHub or local fs."""
|
||||
if isinstance(path, str) and path.startswith("lc://agents"):
|
||||
path = os.path.relpath(path, "lc://agents/")
|
||||
return _load_from_hub(path, **kwargs)
|
||||
if hub_result := try_load_from_hub(
|
||||
path, _load_agent_from_file, "agents", {"json", "yaml"}
|
||||
):
|
||||
return hub_result
|
||||
else:
|
||||
return _load_agent_from_file(path, **kwargs)
|
||||
|
||||
|
||||
def _load_from_hub(path: str, **kwargs: Any) -> Agent:
|
||||
"""Load agent from hub."""
|
||||
suffix = path.split(".")[-1]
|
||||
if suffix not in {"json", "yaml"}:
|
||||
raise ValueError("Unsupported file type.")
|
||||
full_url = URL_BASE + path
|
||||
r = requests.get(full_url)
|
||||
if r.status_code != 200:
|
||||
raise ValueError(f"Could not find file at {full_url}")
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
file = tmpdirname + "/agent." + suffix
|
||||
with open(file, "wb") as f:
|
||||
f.write(r.content)
|
||||
return _load_agent_from_file(file, **kwargs)
|
||||
|
||||
|
||||
def _load_agent_from_file(file: Union[str, Path], **kwargs: Any) -> Agent:
|
||||
"""Load agent from file."""
|
||||
# Convert file to Path object.
|
||||
|
@@ -1,11 +1,8 @@
|
||||
"""Functionality for loading chains."""
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
import yaml
|
||||
|
||||
from langchain.chains.api.base import APIChain
|
||||
@@ -27,6 +24,7 @@ from langchain.chains.sql_database.base import SQLDatabaseChain
|
||||
from langchain.chains.vector_db_qa.base import VectorDBQA
|
||||
from langchain.llms.loading import load_llm, load_llm_from_config
|
||||
from langchain.prompts.loading import load_prompt, load_prompt_from_config
|
||||
from langchain.utilities.loading import try_load_from_hub
|
||||
|
||||
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/chains/"
|
||||
|
||||
@@ -441,9 +439,10 @@ def load_chain_from_config(config: dict, **kwargs: Any) -> Chain:
|
||||
|
||||
def load_chain(path: Union[str, Path], **kwargs: Any) -> Chain:
|
||||
"""Unified method for loading a chain from LangChainHub or local fs."""
|
||||
if isinstance(path, str) and path.startswith("lc://chains"):
|
||||
path = os.path.relpath(path, "lc://chains/")
|
||||
return _load_from_hub(path, **kwargs)
|
||||
if hub_result := try_load_from_hub(
|
||||
path, _load_chain_from_file, "chains", {"json", "yaml"}
|
||||
):
|
||||
return hub_result
|
||||
else:
|
||||
return _load_chain_from_file(path, **kwargs)
|
||||
|
||||
@@ -466,19 +465,3 @@ def _load_chain_from_file(file: Union[str, Path], **kwargs: Any) -> Chain:
|
||||
raise ValueError("File type must be json or yaml")
|
||||
# Load the chain from the config now.
|
||||
return load_chain_from_config(config, **kwargs)
|
||||
|
||||
|
||||
def _load_from_hub(path: str, **kwargs: Any) -> Chain:
|
||||
"""Load chain from hub."""
|
||||
suffix = path.split(".")[-1]
|
||||
if suffix not in {"json", "yaml"}:
|
||||
raise ValueError("Unsupported file type.")
|
||||
full_url = URL_BASE + path
|
||||
r = requests.get(full_url)
|
||||
if r.status_code != 200:
|
||||
raise ValueError(f"Could not find file at {full_url}")
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
file = tmpdirname + "/chain." + suffix
|
||||
with open(file, "wb") as f:
|
||||
f.write(r.content)
|
||||
return _load_chain_from_file(file, **kwargs)
|
||||
|
@@ -1,17 +1,15 @@
|
||||
"""Load prompts from disk."""
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import requests
|
||||
import yaml
|
||||
|
||||
from langchain.prompts.base import BasePromptTemplate, RegexParser
|
||||
from langchain.prompts.few_shot import FewShotPromptTemplate
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.utilities.loading import try_load_from_hub
|
||||
|
||||
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/"
|
||||
|
||||
@@ -114,9 +112,10 @@ def _load_prompt(config: dict) -> PromptTemplate:
|
||||
|
||||
def load_prompt(path: Union[str, Path]) -> BasePromptTemplate:
|
||||
"""Unified method for loading a prompt from LangChainHub or local fs."""
|
||||
if isinstance(path, str) and path.startswith("lc://prompts"):
|
||||
path = os.path.relpath(path, "lc://prompts/")
|
||||
return _load_from_hub(path)
|
||||
if hub_result := try_load_from_hub(
|
||||
path, _load_prompt_from_file, "prompts", {"py", "json", "yaml"}
|
||||
):
|
||||
return hub_result
|
||||
else:
|
||||
return _load_prompt_from_file(path)
|
||||
|
||||
@@ -151,19 +150,3 @@ def _load_prompt_from_file(file: Union[str, Path]) -> BasePromptTemplate:
|
||||
raise ValueError(f"Got unsupported file type {file_path.suffix}")
|
||||
# Load the prompt from the config now.
|
||||
return load_prompt_from_config(config)
|
||||
|
||||
|
||||
def _load_from_hub(path: str) -> BasePromptTemplate:
|
||||
"""Load prompt from hub."""
|
||||
suffix = path.split(".")[-1]
|
||||
if suffix not in {"py", "json", "yaml"}:
|
||||
raise ValueError("Unsupported file type.")
|
||||
full_url = URL_BASE + path
|
||||
r = requests.get(full_url)
|
||||
if r.status_code != 200:
|
||||
raise ValueError(f"Could not find file at {full_url}")
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
file = tmpdirname + "/prompt." + suffix
|
||||
with open(file, "wb") as f:
|
||||
f.write(r.content)
|
||||
return _load_prompt_from_file(file)
|
||||
|
49
langchain/utilities/loading.py
Normal file
49
langchain/utilities/loading.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""Utilities for loading configurations from langchian-hub."""
|
||||
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Optional, Set, TypeVar, Union
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
|
||||
DEFAULT_REF = os.environ.get("LANGCHAIN_HUB_DEFAULT_REF", "master")
|
||||
URL_BASE = os.environ.get(
|
||||
"LANGCHAIN_HUB_URL_BASE",
|
||||
"https://raw.githubusercontent.com/hwchase17/langchain-hub/{ref}/",
|
||||
)
|
||||
HUB_PATH_RE = re.compile(r"lc(?P<ref>@[^:]+)?://(?P<path>.*)")
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def try_load_from_hub(
|
||||
path: Union[str, Path],
|
||||
loader: Callable[[str], T],
|
||||
valid_prefix: str,
|
||||
valid_suffixes: Set[str],
|
||||
**kwargs: Any,
|
||||
) -> Optional[T]:
|
||||
"""Load configuration from hub. Returns None if path is not a hub path."""
|
||||
if not isinstance(path, str) or not (match := HUB_PATH_RE.match(path)):
|
||||
return None
|
||||
ref, remote_path_str = match.groups()
|
||||
ref = ref[1:] if ref else DEFAULT_REF
|
||||
remote_path = Path(remote_path_str)
|
||||
if remote_path.parts[0] != valid_prefix:
|
||||
return None
|
||||
if remote_path.suffix[1:] not in valid_suffixes:
|
||||
raise ValueError("Unsupported file type.")
|
||||
|
||||
full_url = urljoin(URL_BASE.format(ref=ref), str(remote_path))
|
||||
r = requests.get(full_url, timeout=5)
|
||||
if r.status_code != 200:
|
||||
raise ValueError(f"Could not find file at {full_url}")
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
file = Path(tmpdirname) / remote_path.name
|
||||
with open(file, "wb") as f:
|
||||
f.write(r.content)
|
||||
return loader(str(file), **kwargs)
|
29
poetry.lock
generated
29
poetry.lock
generated
@@ -2270,6 +2270,7 @@ files = [
|
||||
{file = "lxml-4.9.2-cp35-cp35m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ca989b91cf3a3ba28930a9fc1e9aeafc2a395448641df1f387a2d394638943b0"},
|
||||
{file = "lxml-4.9.2-cp35-cp35m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:822068f85e12a6e292803e112ab876bc03ed1f03dddb80154c395f891ca6b31e"},
|
||||
{file = "lxml-4.9.2-cp35-cp35m-win32.whl", hash = "sha256:be7292c55101e22f2a3d4d8913944cbea71eea90792bf914add27454a13905df"},
|
||||
{file = "lxml-4.9.2-cp35-cp35m-win_amd64.whl", hash = "sha256:998c7c41910666d2976928c38ea96a70d1aa43be6fe502f21a651e17483a43c5"},
|
||||
{file = "lxml-4.9.2-cp36-cp36m-macosx_10_15_x86_64.whl", hash = "sha256:b26a29f0b7fc6f0897f043ca366142d2b609dc60756ee6e4e90b5f762c6adc53"},
|
||||
{file = "lxml-4.9.2-cp36-cp36m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_24_i686.whl", hash = "sha256:ab323679b8b3030000f2be63e22cdeea5b47ee0abd2d6a1dc0c8103ddaa56cd7"},
|
||||
{file = "lxml-4.9.2-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:689bb688a1db722485e4610a503e3e9210dcc20c520b45ac8f7533c837be76fe"},
|
||||
@@ -2279,6 +2280,7 @@ files = [
|
||||
{file = "lxml-4.9.2-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:58bfa3aa19ca4c0f28c5dde0ff56c520fbac6f0daf4fac66ed4c8d2fb7f22e74"},
|
||||
{file = "lxml-4.9.2-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:bc718cd47b765e790eecb74d044cc8d37d58562f6c314ee9484df26276d36a38"},
|
||||
{file = "lxml-4.9.2-cp36-cp36m-win32.whl", hash = "sha256:d5bf6545cd27aaa8a13033ce56354ed9e25ab0e4ac3b5392b763d8d04b08e0c5"},
|
||||
{file = "lxml-4.9.2-cp36-cp36m-win_amd64.whl", hash = "sha256:3ab9fa9d6dc2a7f29d7affdf3edebf6ece6fb28a6d80b14c3b2fb9d39b9322c3"},
|
||||
{file = "lxml-4.9.2-cp37-cp37m-macosx_10_15_x86_64.whl", hash = "sha256:05ca3f6abf5cf78fe053da9b1166e062ade3fa5d4f92b4ed688127ea7d7b1d03"},
|
||||
{file = "lxml-4.9.2-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_24_i686.whl", hash = "sha256:a5da296eb617d18e497bcf0a5c528f5d3b18dadb3619fbdadf4ed2356ef8d941"},
|
||||
{file = "lxml-4.9.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:04876580c050a8c5341d706dd464ff04fd597095cc8c023252566a8826505726"},
|
||||
@@ -4176,6 +4178,27 @@ urllib3 = ">=1.21.1,<1.27"
|
||||
socks = ["PySocks (>=1.5.6,!=1.5.7)"]
|
||||
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
|
||||
|
||||
[[package]]
|
||||
name = "responses"
|
||||
version = "0.22.0"
|
||||
description = "A utility library for mocking out the `requests` Python library."
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "responses-0.22.0-py3-none-any.whl", hash = "sha256:dcf294d204d14c436fddcc74caefdbc5764795a40ff4e6a7740ed8ddbf3294be"},
|
||||
{file = "responses-0.22.0.tar.gz", hash = "sha256:396acb2a13d25297789a5866b4881cf4e46ffd49cc26c43ab1117f40b973102e"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
requests = ">=2.22.0,<3.0"
|
||||
toml = "*"
|
||||
types-toml = "*"
|
||||
urllib3 = ">=1.25.10"
|
||||
|
||||
[package.extras]
|
||||
tests = ["coverage (>=6.0.0)", "flake8", "mypy", "pytest (>=7.0.0)", "pytest-asyncio", "pytest-cov", "pytest-httpserver", "types-requests"]
|
||||
|
||||
[[package]]
|
||||
name = "rfc3339-validator"
|
||||
version = "0.1.4"
|
||||
@@ -5010,11 +5033,14 @@ files = [
|
||||
{file = "tokenizers-0.13.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47ef745dbf9f49281e900e9e72915356d69de3a4e4d8a475bda26bfdb5047736"},
|
||||
{file = "tokenizers-0.13.2-cp310-cp310-win32.whl", hash = "sha256:96cedf83864bcc15a3ffd088a6f81a8a8f55b8b188eabd7a7f2a4469477036df"},
|
||||
{file = "tokenizers-0.13.2-cp310-cp310-win_amd64.whl", hash = "sha256:eda77de40a0262690c666134baf19ec5c4f5b8bde213055911d9f5a718c506e1"},
|
||||
{file = "tokenizers-0.13.2-cp311-cp311-macosx_10_11_universal2.whl", hash = "sha256:9eee037bb5aa14daeb56b4c39956164b2bebbe6ab4ca7779d88aa16b79bd4e17"},
|
||||
{file = "tokenizers-0.13.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:d1b079c4c9332048fec4cb9c2055c2373c74fbb336716a5524c9a720206d787e"},
|
||||
{file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a689654fc745135cce4eea3b15e29c372c3e0b01717c6978b563de5c38af9811"},
|
||||
{file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3606528c07cda0566cff6cbfbda2b167f923661be595feac95701ffcdcbdbb21"},
|
||||
{file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:41291d0160946084cbd53c8ec3d029df3dc2af2673d46b25ff1a7f31a9d55d51"},
|
||||
{file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7892325f9ca1cc5fca0333d5bfd96a19044ce9b092ce2df625652109a3de16b8"},
|
||||
{file = "tokenizers-0.13.2-cp311-cp311-win32.whl", hash = "sha256:93714958d4ebe5362d3de7a6bd73dc86c36b5af5941ebef6c325ac900fa58865"},
|
||||
{file = "tokenizers-0.13.2-cp311-cp311-win_amd64.whl", hash = "sha256:fa7ef7ee380b1f49211bbcfac8a006b1a3fa2fa4c7f4ee134ae384eb4ea5e453"},
|
||||
{file = "tokenizers-0.13.2-cp37-cp37m-macosx_10_11_x86_64.whl", hash = "sha256:da521bfa94df6a08a6254bb8214ea04854bb9044d61063ae2529361688b5440a"},
|
||||
{file = "tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a739d4d973d422e1073989769723f3b6ad8b11e59e635a63de99aea4b2208188"},
|
||||
{file = "tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cac01fc0b868e4d0a3aa7c5c53396da0a0a63136e81475d32fcf5c348fcb2866"},
|
||||
@@ -5023,6 +5049,7 @@ files = [
|
||||
{file = "tokenizers-0.13.2-cp37-cp37m-win32.whl", hash = "sha256:a537061ee18ba104b7f3daa735060c39db3a22c8a9595845c55b6c01d36c5e87"},
|
||||
{file = "tokenizers-0.13.2-cp37-cp37m-win_amd64.whl", hash = "sha256:c82fb87b1cbfa984d8f05b2b3c3c73e428b216c1d4f0e286d0a3b27f521b32eb"},
|
||||
{file = "tokenizers-0.13.2-cp38-cp38-macosx_10_11_x86_64.whl", hash = "sha256:ce298605a833ac7f81b8062d3102a42dcd9fa890493e8f756112c346339fe5c5"},
|
||||
{file = "tokenizers-0.13.2-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:f44d59bafe3d61e8a56b9e0a963075187c0f0091023120b13fbe37a87936f171"},
|
||||
{file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a51b93932daba12ed07060935978a6779593a59709deab04a0d10e6fd5c29e60"},
|
||||
{file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6969e5ea7ccb909ce7d6d4dfd009115dc72799b0362a2ea353267168667408c4"},
|
||||
{file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:92f040c4d938ea64683526b45dfc81c580e3b35aaebe847e7eec374961231734"},
|
||||
@@ -5666,4 +5693,4 @@ llms = ["manifest-ml", "torch", "transformers"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "537ede877b299a8800eb26a428607be97b59f89435259abda2c7cc86092306c6"
|
||||
content-hash = "bed5e0cb4cfa8b6173dd9574982bd9154ebc61721704002b85474af3c2b675ca"
|
||||
|
@@ -57,6 +57,7 @@ pytest-dotenv = "^0.5.2"
|
||||
duckdb-engine = "^0.6.6"
|
||||
pytest-watcher = "^0.2.6"
|
||||
freezegun = "^1.2.2"
|
||||
responses = "^0.22.0"
|
||||
|
||||
[tool.poetry.group.lint.dependencies]
|
||||
flake8-docstrings = "^1.6.0"
|
||||
|
1
tests/unit_tests/utilities/__init__.py
Normal file
1
tests/unit_tests/utilities/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests utilities module."""
|
93
tests/unit_tests/utilities/test_loading.py
Normal file
93
tests/unit_tests/utilities/test_loading.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""Test the functionality of loading from langchain-hub."""
|
||||
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
from unittest.mock import Mock
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import pytest
|
||||
import responses
|
||||
|
||||
from langchain.utilities.loading import DEFAULT_REF, URL_BASE, try_load_from_hub
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mocked_responses() -> Iterable[responses.RequestsMock]:
|
||||
"""Fixture mocking requests.get."""
|
||||
with responses.RequestsMock() as rsps:
|
||||
yield rsps
|
||||
|
||||
|
||||
def test_non_hub_path() -> None:
|
||||
"""Test that a non-hub path returns None."""
|
||||
path = "chains/some_path"
|
||||
loader = Mock()
|
||||
valid_suffixes = {"suffix"}
|
||||
result = try_load_from_hub(path, loader, "chains", valid_suffixes)
|
||||
|
||||
assert result is None
|
||||
loader.assert_not_called()
|
||||
|
||||
|
||||
def test_invalid_prefix() -> None:
|
||||
"""Test that a hub path with an invalid prefix returns None."""
|
||||
path = "lc://agents/some_path"
|
||||
loader = Mock()
|
||||
valid_suffixes = {"suffix"}
|
||||
result = try_load_from_hub(path, loader, "chains", valid_suffixes)
|
||||
|
||||
assert result is None
|
||||
loader.assert_not_called()
|
||||
|
||||
|
||||
def test_invalid_suffix() -> None:
|
||||
"""Test that a hub path with an invalid suffix raises an error."""
|
||||
path = "lc://chains/path.invalid"
|
||||
loader = Mock()
|
||||
valid_suffixes = {"json"}
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported file type."):
|
||||
try_load_from_hub(path, loader, "chains", valid_suffixes)
|
||||
|
||||
loader.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("ref", [None, "v0.3"])
|
||||
def test_success(mocked_responses: responses.RequestsMock, ref: str) -> None:
|
||||
"""Test that a valid hub path is loaded correctly with and without a ref."""
|
||||
path = "chains/path/chain.json"
|
||||
lc_path_prefix = f"lc{('@' + ref) if ref else ''}://"
|
||||
valid_suffixes = {"json"}
|
||||
body = json.dumps({"foo": "bar"})
|
||||
ref = ref or DEFAULT_REF
|
||||
|
||||
file_contents = None
|
||||
|
||||
def loader(file_path: str) -> None:
|
||||
nonlocal file_contents
|
||||
assert file_contents is None
|
||||
file_contents = Path(file_path).read_text()
|
||||
|
||||
mocked_responses.get(
|
||||
urljoin(URL_BASE.format(ref=ref), path),
|
||||
body=body,
|
||||
status=200,
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
try_load_from_hub(f"{lc_path_prefix}{path}", loader, "chains", valid_suffixes)
|
||||
assert file_contents == body
|
||||
|
||||
|
||||
def test_failed_request(mocked_responses: responses.RequestsMock) -> None:
|
||||
"""Test that a failed request raises an error."""
|
||||
path = "chains/path/chain.json"
|
||||
loader = Mock()
|
||||
|
||||
mocked_responses.get(urljoin(URL_BASE.format(ref=DEFAULT_REF), path), status=500)
|
||||
|
||||
with pytest.raises(ValueError, match=re.compile("Could not find file at .*")):
|
||||
try_load_from_hub(f"lc://{path}", loader, "chains", {"json"})
|
||||
loader.assert_not_called()
|
Reference in New Issue
Block a user