APIChain add restrictions to domains (CVE-2023-32786) (#12747)

* Restrict the chain to specific domains by default
* This is a breaking change, but it will fail loudly upon object
instantiation -- so there should be no silent errors for users
* Resolves CVE-2023-32786
This commit is contained in:
Eugene Yurtsev 2023-11-01 18:50:34 -04:00 committed by GitHub
parent 4421ba46d7
commit b1caae62fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 135 additions and 30 deletions

File diff suppressed because one or more lines are too long

View File

@ -1,7 +1,8 @@
"""Chain that makes API calls and summarizes the responses to answer a question.""" """Chain that makes API calls and summarizes the responses to answer a question."""
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, Sequence, Tuple
from urllib.parse import urlparse
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun, AsyncCallbackManagerForChainRun,
@ -16,6 +17,38 @@ from langchain.schema.language_model import BaseLanguageModel
from langchain.utilities.requests import TextRequestsWrapper from langchain.utilities.requests import TextRequestsWrapper
def _extract_scheme_and_domain(url: str) -> Tuple[str, str]:
"""Extract the scheme + domain from a given URL.
Args:
url (str): The input URL.
Returns:
return a 2-tuple of scheme and domain
"""
parsed_uri = urlparse(url)
return parsed_uri.scheme, parsed_uri.netloc
def _check_in_allowed_domain(url: str, limit_to_domains: Sequence[str]) -> bool:
"""Check if a URL is in the allowed domains.
Args:
url (str): The input URL.
limit_to_domains (Sequence[str]): The allowed domains.
Returns:
bool: True if the URL is in the allowed domains, False otherwise.
"""
scheme, domain = _extract_scheme_and_domain(url)
for allowed_domain in limit_to_domains:
allowed_scheme, allowed_domain = _extract_scheme_and_domain(allowed_domain)
if scheme == allowed_scheme and domain == allowed_domain:
return True
return False
class APIChain(Chain): class APIChain(Chain):
"""Chain that makes API calls and summarizes the responses to answer a question. """Chain that makes API calls and summarizes the responses to answer a question.
@ -40,6 +73,19 @@ class APIChain(Chain):
api_docs: str api_docs: str
question_key: str = "question" #: :meta private: question_key: str = "question" #: :meta private:
output_key: str = "output" #: :meta private: output_key: str = "output" #: :meta private:
limit_to_domains: Optional[Sequence[str]]
"""Use to limit the domains that can be accessed by the API chain.
* For example, to limit to just the domain `https://www.example.com`, set
`limit_to_domains=["https://www.example.com"]`.
* The default value is an empty tuple, which means that no domains are
allowed by default. By design this will raise an error on instantiation.
* Use a None if you want to allow all domains by default -- this is not
recommended for security reasons, as it would allow malicious users to
make requests to arbitrary URLS including internal APIs accessible from
the server.
"""
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> List[str]:
@ -68,6 +114,21 @@ class APIChain(Chain):
) )
return values return values
@root_validator(pre=True)
def validate_limit_to_domains(cls, values: Dict) -> Dict:
"""Check that allowed domains are valid."""
if "limit_to_domains" not in values:
raise ValueError(
"You must specify a list of domains to limit access using "
"`limit_to_domains`"
)
if not values["limit_to_domains"] and values["limit_to_domains"] is not None:
raise ValueError(
"Please provide a list of domains to limit access using "
"`limit_to_domains`."
)
return values
@root_validator(pre=True) @root_validator(pre=True)
def validate_api_answer_prompt(cls, values: Dict) -> Dict: def validate_api_answer_prompt(cls, values: Dict) -> Dict:
"""Check that api answer prompt expects the right variables.""" """Check that api answer prompt expects the right variables."""
@ -93,6 +154,12 @@ class APIChain(Chain):
) )
_run_manager.on_text(api_url, color="green", end="\n", verbose=self.verbose) _run_manager.on_text(api_url, color="green", end="\n", verbose=self.verbose)
api_url = api_url.strip() api_url = api_url.strip()
if self.limit_to_domains and not _check_in_allowed_domain(
api_url, self.limit_to_domains
):
raise ValueError(
f"{api_url} is not in the allowed domains: {self.limit_to_domains}"
)
api_response = self.requests_wrapper.get(api_url) api_response = self.requests_wrapper.get(api_url)
_run_manager.on_text( _run_manager.on_text(
api_response, color="yellow", end="\n", verbose=self.verbose api_response, color="yellow", end="\n", verbose=self.verbose
@ -122,6 +189,12 @@ class APIChain(Chain):
api_url, color="green", end="\n", verbose=self.verbose api_url, color="green", end="\n", verbose=self.verbose
) )
api_url = api_url.strip() api_url = api_url.strip()
if self.limit_to_domains and not _check_in_allowed_domain(
api_url, self.limit_to_domains
):
raise ValueError(
f"{api_url} is not in the allowed domains: {self.limit_to_domains}"
)
api_response = await self.requests_wrapper.aget(api_url) api_response = await self.requests_wrapper.aget(api_url)
await _run_manager.on_text( await _run_manager.on_text(
api_response, color="yellow", end="\n", verbose=self.verbose api_response, color="yellow", end="\n", verbose=self.verbose
@ -143,6 +216,7 @@ class APIChain(Chain):
headers: Optional[dict] = None, headers: Optional[dict] = None,
api_url_prompt: BasePromptTemplate = API_URL_PROMPT, api_url_prompt: BasePromptTemplate = API_URL_PROMPT,
api_response_prompt: BasePromptTemplate = API_RESPONSE_PROMPT, api_response_prompt: BasePromptTemplate = API_RESPONSE_PROMPT,
limit_to_domains: Optional[Sequence[str]] = tuple(),
**kwargs: Any, **kwargs: Any,
) -> APIChain: ) -> APIChain:
"""Load chain from just an LLM and the api docs.""" """Load chain from just an LLM and the api docs."""
@ -154,6 +228,7 @@ class APIChain(Chain):
api_answer_chain=get_answer_chain, api_answer_chain=get_answer_chain,
requests_wrapper=requests_wrapper, requests_wrapper=requests_wrapper,
api_docs=api_docs, api_docs=api_docs,
limit_to_domains=limit_to_domains,
**kwargs, **kwargs,
) )

View File

@ -22,8 +22,7 @@ class FakeRequestsChain(TextRequestsWrapper):
return self.output return self.output
@pytest.fixture def get_test_api_data() -> dict:
def test_api_data() -> dict:
"""Fake api data to use for testing.""" """Fake api data to use for testing."""
api_docs = """ api_docs = """
This API endpoint will search the notes for a user. This API endpoint will search the notes for a user.
@ -48,39 +47,59 @@ def test_api_data() -> dict:
} }
@pytest.fixture def get_api_chain(**kwargs: Any) -> APIChain:
def fake_llm_api_chain(test_api_data: dict) -> APIChain:
"""Fake LLM API chain for testing.""" """Fake LLM API chain for testing."""
TEST_API_DOCS = test_api_data["api_docs"] data = get_test_api_data()
TEST_QUESTION = test_api_data["question"] test_api_docs = data["api_docs"]
TEST_URL = test_api_data["api_url"] test_question = data["question"]
TEST_API_RESPONSE = test_api_data["api_response"] test_url = data["api_url"]
TEST_API_SUMMARY = test_api_data["api_summary"] test_api_response = data["api_response"]
test_api_summary = data["api_summary"]
api_url_query_prompt = API_URL_PROMPT.format( api_url_query_prompt = API_URL_PROMPT.format(
api_docs=TEST_API_DOCS, question=TEST_QUESTION api_docs=test_api_docs, question=test_question
) )
api_response_prompt = API_RESPONSE_PROMPT.format( api_response_prompt = API_RESPONSE_PROMPT.format(
api_docs=TEST_API_DOCS, api_docs=test_api_docs,
question=TEST_QUESTION, question=test_question,
api_url=TEST_URL, api_url=test_url,
api_response=TEST_API_RESPONSE, api_response=test_api_response,
) )
queries = {api_url_query_prompt: TEST_URL, api_response_prompt: TEST_API_SUMMARY} queries = {api_url_query_prompt: test_url, api_response_prompt: test_api_summary}
fake_llm = FakeLLM(queries=queries) fake_llm = FakeLLM(queries=queries)
api_request_chain = LLMChain(llm=fake_llm, prompt=API_URL_PROMPT) api_request_chain = LLMChain(llm=fake_llm, prompt=API_URL_PROMPT)
api_answer_chain = LLMChain(llm=fake_llm, prompt=API_RESPONSE_PROMPT) api_answer_chain = LLMChain(llm=fake_llm, prompt=API_RESPONSE_PROMPT)
requests_wrapper = FakeRequestsChain(output=TEST_API_RESPONSE) requests_wrapper = FakeRequestsChain(output=test_api_response)
return APIChain( return APIChain(
api_request_chain=api_request_chain, api_request_chain=api_request_chain,
api_answer_chain=api_answer_chain, api_answer_chain=api_answer_chain,
requests_wrapper=requests_wrapper, requests_wrapper=requests_wrapper,
api_docs=TEST_API_DOCS, api_docs=test_api_docs,
**kwargs,
) )
def test_api_question(fake_llm_api_chain: APIChain, test_api_data: dict) -> None: def test_api_question() -> None:
"""Test simple question that needs API access.""" """Test simple question that needs API access."""
question = test_api_data["question"] with pytest.raises(ValueError):
output = fake_llm_api_chain.run(question) get_api_chain()
assert output == test_api_data["api_summary"] with pytest.raises(ValueError):
get_api_chain(limit_to_domains=tuple())
# All domains allowed (not advised)
api_chain = get_api_chain(limit_to_domains=None)
data = get_test_api_data()
assert api_chain.run(data["question"]) == data["api_summary"]
# Use a domain that's allowed
api_chain = get_api_chain(
limit_to_domains=["https://thisapidoesntexist.com/api/notes?q=langchain"]
)
# Attempts to make a request against a domain that's not allowed
assert api_chain.run(data["question"]) == data["api_summary"]
# Use domains that are not valid
api_chain = get_api_chain(limit_to_domains=["h", "*"])
with pytest.raises(ValueError):
# Attempts to make a request against a domain that's not allowed
assert api_chain.run(data["question"]) == data["api_summary"]