mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-01 19:03:25 +00:00
APIChain add restrictions to domains (CVE-2023-32786) (#12747)
* Restrict the chain to specific domains by default * This is a breaking change, but it will fail loudly upon object instantiation -- so there should be no silent errors for users * Resolves CVE-2023-32786
This commit is contained in:
parent
4421ba46d7
commit
b1caae62fd
File diff suppressed because one or more lines are too long
@ -1,7 +1,8 @@
|
||||
"""Chain that makes API calls and summarizes the responses to answer a question."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
@ -16,6 +17,38 @@ from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.utilities.requests import TextRequestsWrapper
|
||||
|
||||
|
||||
def _extract_scheme_and_domain(url: str) -> Tuple[str, str]:
|
||||
"""Extract the scheme + domain from a given URL.
|
||||
|
||||
Args:
|
||||
url (str): The input URL.
|
||||
|
||||
Returns:
|
||||
return a 2-tuple of scheme and domain
|
||||
"""
|
||||
parsed_uri = urlparse(url)
|
||||
return parsed_uri.scheme, parsed_uri.netloc
|
||||
|
||||
|
||||
def _check_in_allowed_domain(url: str, limit_to_domains: Sequence[str]) -> bool:
|
||||
"""Check if a URL is in the allowed domains.
|
||||
|
||||
Args:
|
||||
url (str): The input URL.
|
||||
limit_to_domains (Sequence[str]): The allowed domains.
|
||||
|
||||
Returns:
|
||||
bool: True if the URL is in the allowed domains, False otherwise.
|
||||
"""
|
||||
scheme, domain = _extract_scheme_and_domain(url)
|
||||
|
||||
for allowed_domain in limit_to_domains:
|
||||
allowed_scheme, allowed_domain = _extract_scheme_and_domain(allowed_domain)
|
||||
if scheme == allowed_scheme and domain == allowed_domain:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class APIChain(Chain):
|
||||
"""Chain that makes API calls and summarizes the responses to answer a question.
|
||||
|
||||
@ -40,6 +73,19 @@ class APIChain(Chain):
|
||||
api_docs: str
|
||||
question_key: str = "question" #: :meta private:
|
||||
output_key: str = "output" #: :meta private:
|
||||
limit_to_domains: Optional[Sequence[str]]
|
||||
"""Use to limit the domains that can be accessed by the API chain.
|
||||
|
||||
* For example, to limit to just the domain `https://www.example.com`, set
|
||||
`limit_to_domains=["https://www.example.com"]`.
|
||||
|
||||
* The default value is an empty tuple, which means that no domains are
|
||||
allowed by default. By design this will raise an error on instantiation.
|
||||
* Use a None if you want to allow all domains by default -- this is not
|
||||
recommended for security reasons, as it would allow malicious users to
|
||||
make requests to arbitrary URLS including internal APIs accessible from
|
||||
the server.
|
||||
"""
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
@ -68,6 +114,21 @@ class APIChain(Chain):
|
||||
)
|
||||
return values
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_limit_to_domains(cls, values: Dict) -> Dict:
|
||||
"""Check that allowed domains are valid."""
|
||||
if "limit_to_domains" not in values:
|
||||
raise ValueError(
|
||||
"You must specify a list of domains to limit access using "
|
||||
"`limit_to_domains`"
|
||||
)
|
||||
if not values["limit_to_domains"] and values["limit_to_domains"] is not None:
|
||||
raise ValueError(
|
||||
"Please provide a list of domains to limit access using "
|
||||
"`limit_to_domains`."
|
||||
)
|
||||
return values
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_api_answer_prompt(cls, values: Dict) -> Dict:
|
||||
"""Check that api answer prompt expects the right variables."""
|
||||
@ -93,6 +154,12 @@ class APIChain(Chain):
|
||||
)
|
||||
_run_manager.on_text(api_url, color="green", end="\n", verbose=self.verbose)
|
||||
api_url = api_url.strip()
|
||||
if self.limit_to_domains and not _check_in_allowed_domain(
|
||||
api_url, self.limit_to_domains
|
||||
):
|
||||
raise ValueError(
|
||||
f"{api_url} is not in the allowed domains: {self.limit_to_domains}"
|
||||
)
|
||||
api_response = self.requests_wrapper.get(api_url)
|
||||
_run_manager.on_text(
|
||||
api_response, color="yellow", end="\n", verbose=self.verbose
|
||||
@ -122,6 +189,12 @@ class APIChain(Chain):
|
||||
api_url, color="green", end="\n", verbose=self.verbose
|
||||
)
|
||||
api_url = api_url.strip()
|
||||
if self.limit_to_domains and not _check_in_allowed_domain(
|
||||
api_url, self.limit_to_domains
|
||||
):
|
||||
raise ValueError(
|
||||
f"{api_url} is not in the allowed domains: {self.limit_to_domains}"
|
||||
)
|
||||
api_response = await self.requests_wrapper.aget(api_url)
|
||||
await _run_manager.on_text(
|
||||
api_response, color="yellow", end="\n", verbose=self.verbose
|
||||
@ -143,6 +216,7 @@ class APIChain(Chain):
|
||||
headers: Optional[dict] = None,
|
||||
api_url_prompt: BasePromptTemplate = API_URL_PROMPT,
|
||||
api_response_prompt: BasePromptTemplate = API_RESPONSE_PROMPT,
|
||||
limit_to_domains: Optional[Sequence[str]] = tuple(),
|
||||
**kwargs: Any,
|
||||
) -> APIChain:
|
||||
"""Load chain from just an LLM and the api docs."""
|
||||
@ -154,6 +228,7 @@ class APIChain(Chain):
|
||||
api_answer_chain=get_answer_chain,
|
||||
requests_wrapper=requests_wrapper,
|
||||
api_docs=api_docs,
|
||||
limit_to_domains=limit_to_domains,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
@ -22,8 +22,7 @@ class FakeRequestsChain(TextRequestsWrapper):
|
||||
return self.output
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_api_data() -> dict:
|
||||
def get_test_api_data() -> dict:
|
||||
"""Fake api data to use for testing."""
|
||||
api_docs = """
|
||||
This API endpoint will search the notes for a user.
|
||||
@ -48,39 +47,59 @@ def test_api_data() -> dict:
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_llm_api_chain(test_api_data: dict) -> APIChain:
|
||||
def get_api_chain(**kwargs: Any) -> APIChain:
|
||||
"""Fake LLM API chain for testing."""
|
||||
TEST_API_DOCS = test_api_data["api_docs"]
|
||||
TEST_QUESTION = test_api_data["question"]
|
||||
TEST_URL = test_api_data["api_url"]
|
||||
TEST_API_RESPONSE = test_api_data["api_response"]
|
||||
TEST_API_SUMMARY = test_api_data["api_summary"]
|
||||
data = get_test_api_data()
|
||||
test_api_docs = data["api_docs"]
|
||||
test_question = data["question"]
|
||||
test_url = data["api_url"]
|
||||
test_api_response = data["api_response"]
|
||||
test_api_summary = data["api_summary"]
|
||||
|
||||
api_url_query_prompt = API_URL_PROMPT.format(
|
||||
api_docs=TEST_API_DOCS, question=TEST_QUESTION
|
||||
api_docs=test_api_docs, question=test_question
|
||||
)
|
||||
api_response_prompt = API_RESPONSE_PROMPT.format(
|
||||
api_docs=TEST_API_DOCS,
|
||||
question=TEST_QUESTION,
|
||||
api_url=TEST_URL,
|
||||
api_response=TEST_API_RESPONSE,
|
||||
api_docs=test_api_docs,
|
||||
question=test_question,
|
||||
api_url=test_url,
|
||||
api_response=test_api_response,
|
||||
)
|
||||
queries = {api_url_query_prompt: TEST_URL, api_response_prompt: TEST_API_SUMMARY}
|
||||
queries = {api_url_query_prompt: test_url, api_response_prompt: test_api_summary}
|
||||
fake_llm = FakeLLM(queries=queries)
|
||||
api_request_chain = LLMChain(llm=fake_llm, prompt=API_URL_PROMPT)
|
||||
api_answer_chain = LLMChain(llm=fake_llm, prompt=API_RESPONSE_PROMPT)
|
||||
requests_wrapper = FakeRequestsChain(output=TEST_API_RESPONSE)
|
||||
requests_wrapper = FakeRequestsChain(output=test_api_response)
|
||||
return APIChain(
|
||||
api_request_chain=api_request_chain,
|
||||
api_answer_chain=api_answer_chain,
|
||||
requests_wrapper=requests_wrapper,
|
||||
api_docs=TEST_API_DOCS,
|
||||
api_docs=test_api_docs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def test_api_question(fake_llm_api_chain: APIChain, test_api_data: dict) -> None:
|
||||
def test_api_question() -> None:
|
||||
"""Test simple question that needs API access."""
|
||||
question = test_api_data["question"]
|
||||
output = fake_llm_api_chain.run(question)
|
||||
assert output == test_api_data["api_summary"]
|
||||
with pytest.raises(ValueError):
|
||||
get_api_chain()
|
||||
with pytest.raises(ValueError):
|
||||
get_api_chain(limit_to_domains=tuple())
|
||||
|
||||
# All domains allowed (not advised)
|
||||
api_chain = get_api_chain(limit_to_domains=None)
|
||||
data = get_test_api_data()
|
||||
assert api_chain.run(data["question"]) == data["api_summary"]
|
||||
|
||||
# Use a domain that's allowed
|
||||
api_chain = get_api_chain(
|
||||
limit_to_domains=["https://thisapidoesntexist.com/api/notes?q=langchain"]
|
||||
)
|
||||
# Attempts to make a request against a domain that's not allowed
|
||||
assert api_chain.run(data["question"]) == data["api_summary"]
|
||||
|
||||
# Use domains that are not valid
|
||||
api_chain = get_api_chain(limit_to_domains=["h", "*"])
|
||||
with pytest.raises(ValueError):
|
||||
# Attempts to make a request against a domain that's not allowed
|
||||
assert api_chain.run(data["question"]) == data["api_summary"]
|
||||
|
Loading…
Reference in New Issue
Block a user