diff --git a/docs/docs/use_cases/apis.ipynb b/docs/docs/use_cases/apis.ipynb index 11e2e8e825b..42944481299 100644 --- a/docs/docs/use_cases/apis.ipynb +++ b/docs/docs/use_cases/apis.ipynb @@ -205,7 +205,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 4, "id": "4ef0c3d0", "metadata": {}, "outputs": [ @@ -217,7 +217,7 @@ "\n", "\u001b[1m> Entering new APIChain chain...\u001b[0m\n", "\u001b[32;1m\u001b[1;3mhttps://api.open-meteo.com/v1/forecast?latitude=48.1351&longitude=11.5820&hourly=temperature_2m&temperature_unit=fahrenheit¤t_weather=true\u001b[0m\n", - "\u001b[33;1m\u001b[1;3m{\"latitude\":48.14,\"longitude\":11.58,\"generationtime_ms\":1.0769367218017578,\"utc_offset_seconds\":0,\"timezone\":\"GMT\",\"timezone_abbreviation\":\"GMT\",\"elevation\":521.0,\"current_weather\":{\"temperature\":52.9,\"windspeed\":12.6,\"winddirection\":239.0,\"weathercode\":3,\"is_day\":0,\"time\":\"2023-08-07T22:00\"},\"hourly_units\":{\"time\":\"iso8601\",\"temperature_2m\":\"°F\"},\"hourly\":{\"time\":[\"2023-08-07T00:00\",\"2023-08-07T01:00\",\"2023-08-07T02:00\",\"2023-08-07T03:00\",\"2023-08-07T04:00\",\"2023-08-07T05:00\",\"2023-08-07T06:00\",\"2023-08-07T07:00\",\"2023-08-07T08:00\",\"2023-08-07T09:00\",\"2023-08-07T10:00\",\"2023-08-07T11:00\",\"2023-08-07T12:00\",\"2023-08-07T13:00\",\"2023-08-07T14:00\",\"2023-08-07T15:00\",\"2023-08-07T16:00\",\"2023-08-07T17:00\",\"2023-08-07T18:00\",\"2023-08-07T19:00\",\"2023-08-07T20:00\",\"2023-08-07T21:00\",\"2023-08-07T22:00\",\"2023-08-07T23:00\",\"2023-08-08T00:00\",\"2023-08-08T01:00\",\"2023-08-08T02:00\",\"2023-08-08T03:00\",\"2023-08-08T04:00\",\"2023-08-08T05:00\",\"2023-08-08T06:00\",\"2023-08-08T07:00\",\"2023-08-08T08:00\",\"2023-08-08T09:00\",\"2023-08-08T10:00\",\"2023-08-08T11:00\",\"2023-08-08T12:00\",\"2023-08-08T13:00\",\"2023-08-08T14:00\",\"2023-08-08T15:00\",\"2023-08-08T16:00\",\"2023-08-08T17:00\",\"2023-08-08T18:00\",\"2023-08-08T19:00\",\"2023-08-08T20:00\",\"2023-08-08T21:00\",\"2023-08-08T22:00\",\"2023-08-08T23:00\",\"2023-08-09T00:00\",\"2023-08-09T01:00\",\"2023-08-09T02:00\",\"2023-08-09T03:00\",\"2023-08-09T04:00\",\"2023-08-09T05:00\",\"2023-08-09T06:00\",\"2023-08-09T07:00\",\"2023-08-09T08:00\",\"2023-08-09T09:00\",\"2023-08-09T10:00\",\"2023-08-09T11:00\",\"2023-08-09T12:00\",\"2023-08-09T13:00\",\"2023-08-09T14:00\",\"2023-08-09T15:00\",\"2023-08-09T16:00\",\"2023-08-09T17:00\",\"2023-08-09T18:00\",\"2023-08-09T19:00\",\"2023-08-09T20:00\",\"2023-08-09T21:00\",\"2023-08-09T22:00\",\"2023-08-09T23:00\",\"2023-08-10T00:00\",\"2023-08-10T01:00\",\"2023-08-10T02:00\",\"2023-08-10T03:00\",\"2023-08-10T04:00\",\"2023-08-10T05:00\",\"2023-08-10T06:00\",\"2023-08-10T07:00\",\"2023-08-10T08:00\",\"2023-08-10T09:00\",\"2023-08-10T10:00\",\"2023-08-10T11:00\",\"2023-08-10T12:00\",\"2023-08-10T13:00\",\"2023-08-10T14:00\",\"2023-08-10T15:00\",\"2023-08-10T16:00\",\"2023-08-10T17:00\",\"2023-08-10T18:00\",\"2023-08-10T19:00\",\"2023-08-10T20:00\",\"2023-08-10T21:00\",\"2023-08-10T22:00\",\"2023-08-10T23:00\",\"2023-08-11T00:00\",\"2023-08-11T01:00\",\"2023-08-11T02:00\",\"2023-08-11T03:00\",\"2023-08-11T04:00\",\"2023-08-11T05:00\",\"2023-08-11T06:00\",\"2023-08-11T07:00\",\"2023-08-11T08:00\",\"2023-08-11T09:00\",\"2023-08-11T10:00\",\"2023-08-11T11:00\",\"2023-08-11T12:00\",\"2023-08-11T13:00\",\"2023-08-11T14:00\",\"2023-08-11T15:00\",\"2023-08-11T16:00\",\"2023-08-11T17:00\",\"2023-08-11T18:00\",\"2023-08-11T19:00\",\"2023-08-11T20:00\",\"2023-08-11T21:00\",\"2023-08-11T22:00\",\"2023-08-11T23:00\",\"2023-08-12T00:00\",\"2023-08-12T01:00\",\"2023-08-12T02:00\",\"2023-08-12T03:00\",\"2023-08-12T04:00\",\"2023-08-12T05:00\",\"2023-08-12T06:00\",\"2023-08-12T07:00\",\"2023-08-12T08:00\",\"2023-08-12T09:00\",\"2023-08-12T10:00\",\"2023-08-12T11:00\",\"2023-08-12T12:00\",\"2023-08-12T13:00\",\"2023-08-12T14:00\",\"2023-08-12T15:00\",\"2023-08-12T16:00\",\"2023-08-12T17:00\",\"2023-08-12T18:00\",\"2023-08-12T19:00\",\"2023-08-12T20:00\",\"2023-08-12T21:00\",\"2023-08-12T22:00\",\"2023-08-12T23:00\",\"2023-08-13T00:00\",\"2023-08-13T01:00\",\"2023-08-13T02:00\",\"2023-08-13T03:00\",\"2023-08-13T04:00\",\"2023-08-13T05:00\",\"2023-08-13T06:00\",\"2023-08-13T07:00\",\"2023-08-13T08:00\",\"2023-08-13T09:00\",\"2023-08-13T10:00\",\"2023-08-13T11:00\",\"2023-08-13T12:00\",\"2023-08-13T13:00\",\"2023-08-13T14:00\",\"2023-08-13T15:00\",\"2023-08-13T16:00\",\"2023-08-13T17:00\",\"2023-08-13T18:00\",\"2023-08-13T19:00\",\"2023-08-13T20:00\",\"2023-08-13T21:00\",\"2023-08-13T22:00\",\"2023-08-13T23:00\"],\"temperature_2m\":[53.0,51.2,50.9,50.4,50.7,51.3,51.7,52.9,54.3,56.1,57.4,59.3,59.1,60.7,59.7,58.8,58.8,57.8,56.6,55.3,53.9,52.7,52.9,53.2,52.0,51.8,51.3,50.7,50.8,51.5,53.9,57.7,61.2,63.2,64.7,66.6,67.5,67.0,68.7,68.7,67.9,66.2,64.4,61.4,59.8,58.9,57.9,56.3,55.7,55.3,55.5,55.4,55.7,56.5,57.6,58.8,59.7,59.1,58.9,60.6,59.9,59.8,59.9,61.7,63.2,63.6,62.3,58.9,57.3,57.1,57.0,56.5,56.2,56.0,55.3,54.7,54.4,55.2,57.8,60.7,63.0,65.3,66.9,68.2,70.1,72.1,72.6,71.4,69.7,68.6,66.2,63.6,61.8,60.6,59.6,58.9,58.0,57.1,56.3,56.2,56.7,57.9,59.9,63.7,68.4,72.4,75.0,76.8,78.0,78.7,78.9,78.4,76.9,74.8,72.5,70.1,67.6,65.6,64.4,63.9,63.4,62.7,62.2,62.1,62.5,63.4,65.1,68.0,71.7,74.8,76.8,78.2,79.1,79.6,79.7,79.2,77.6,75.3,73.7,68.6,66.8,65.3,64.2,63.4,62.6,61.7,60.9,60.6,60.9,61.6,63.2,65.9,69.3,72.2,74.4,76.2,77.6,78.8,79.6,79.6,78.4,76.4,74.3,72.3,70.4,68.7,67.6,66.8]}}\u001b[0m\n", + "\u001b[33;1m\u001b[1;3m{\"latitude\":48.14,\"longitude\":11.58,\"generationtime_ms\":0.1710653305053711,\"utc_offset_seconds\":0,\"timezone\":\"GMT\",\"timezone_abbreviation\":\"GMT\",\"elevation\":521.0,\"current_weather_units\":{\"time\":\"iso8601\",\"interval\":\"seconds\",\"temperature\":\"°F\",\"windspeed\":\"km/h\",\"winddirection\":\"°\",\"is_day\":\"\",\"weathercode\":\"wmo code\"},\"current_weather\":{\"time\":\"2023-11-01T21:30\",\"interval\":900,\"temperature\":46.5,\"windspeed\":7.7,\"winddirection\":259,\"is_day\":0,\"weathercode\":3},\"hourly_units\":{\"time\":\"iso8601\",\"temperature_2m\":\"°F\"},\"hourly\":{\"time\":[\"2023-11-01T00:00\",\"2023-11-01T01:00\",\"2023-11-01T02:00\",\"2023-11-01T03:00\",\"2023-11-01T04:00\",\"2023-11-01T05:00\",\"2023-11-01T06:00\",\"2023-11-01T07:00\",\"2023-11-01T08:00\",\"2023-11-01T09:00\",\"2023-11-01T10:00\",\"2023-11-01T11:00\",\"2023-11-01T12:00\",\"2023-11-01T13:00\",\"2023-11-01T14:00\",\"2023-11-01T15:00\",\"2023-11-01T16:00\",\"2023-11-01T17:00\",\"2023-11-01T18:00\",\"2023-11-01T19:00\",\"2023-11-01T20:00\",\"2023-11-01T21:00\",\"2023-11-01T22:00\",\"2023-11-01T23:00\",\"2023-11-02T00:00\",\"2023-11-02T01:00\",\"2023-11-02T02:00\",\"2023-11-02T03:00\",\"2023-11-02T04:00\",\"2023-11-02T05:00\",\"2023-11-02T06:00\",\"2023-11-02T07:00\",\"2023-11-02T08:00\",\"2023-11-02T09:00\",\"2023-11-02T10:00\",\"2023-11-02T11:00\",\"2023-11-02T12:00\",\"2023-11-02T13:00\",\"2023-11-02T14:00\",\"2023-11-02T15:00\",\"2023-11-02T16:00\",\"2023-11-02T17:00\",\"2023-11-02T18:00\",\"2023-11-02T19:00\",\"2023-11-02T20:00\",\"2023-11-02T21:00\",\"2023-11-02T22:00\",\"2023-11-02T23:00\",\"2023-11-03T00:00\",\"2023-11-03T01:00\",\"2023-11-03T02:00\",\"2023-11-03T03:00\",\"2023-11-03T04:00\",\"2023-11-03T05:00\",\"2023-11-03T06:00\",\"2023-11-03T07:00\",\"2023-11-03T08:00\",\"2023-11-03T09:00\",\"2023-11-03T10:00\",\"2023-11-03T11:00\",\"2023-11-03T12:00\",\"2023-11-03T13:00\",\"2023-11-03T14:00\",\"2023-11-03T15:00\",\"2023-11-03T16:00\",\"2023-11-03T17:00\",\"2023-11-03T18:00\",\"2023-11-03T19:00\",\"2023-11-03T20:00\",\"2023-11-03T21:00\",\"2023-11-03T22:00\",\"2023-11-03T23:00\",\"2023-11-04T00:00\",\"2023-11-04T01:00\",\"2023-11-04T02:00\",\"2023-11-04T03:00\",\"2023-11-04T04:00\",\"2023-11-04T05:00\",\"2023-11-04T06:00\",\"2023-11-04T07:00\",\"2023-11-04T08:00\",\"2023-11-04T09:00\",\"2023-11-04T10:00\",\"2023-11-04T11:00\",\"2023-11-04T12:00\",\"2023-11-04T13:00\",\"2023-11-04T14:00\",\"2023-11-04T15:00\",\"2023-11-04T16:00\",\"2023-11-04T17:00\",\"2023-11-04T18:00\",\"2023-11-04T19:00\",\"2023-11-04T20:00\",\"2023-11-04T21:00\",\"2023-11-04T22:00\",\"2023-11-04T23:00\",\"2023-11-05T00:00\",\"2023-11-05T01:00\",\"2023-11-05T02:00\",\"2023-11-05T03:00\",\"2023-11-05T04:00\",\"2023-11-05T05:00\",\"2023-11-05T06:00\",\"2023-11-05T07:00\",\"2023-11-05T08:00\",\"2023-11-05T09:00\",\"2023-11-05T10:00\",\"2023-11-05T11:00\",\"2023-11-05T12:00\",\"2023-11-05T13:00\",\"2023-11-05T14:00\",\"2023-11-05T15:00\",\"2023-11-05T16:00\",\"2023-11-05T17:00\",\"2023-11-05T18:00\",\"2023-11-05T19:00\",\"2023-11-05T20:00\",\"2023-11-05T21:00\",\"2023-11-05T22:00\",\"2023-11-05T23:00\",\"2023-11-06T00:00\",\"2023-11-06T01:00\",\"2023-11-06T02:00\",\"2023-11-06T03:00\",\"2023-11-06T04:00\",\"2023-11-06T05:00\",\"2023-11-06T06:00\",\"2023-11-06T07:00\",\"2023-11-06T08:00\",\"2023-11-06T09:00\",\"2023-11-06T10:00\",\"2023-11-06T11:00\",\"2023-11-06T12:00\",\"2023-11-06T13:00\",\"2023-11-06T14:00\",\"2023-11-06T15:00\",\"2023-11-06T16:00\",\"2023-11-06T17:00\",\"2023-11-06T18:00\",\"2023-11-06T19:00\",\"2023-11-06T20:00\",\"2023-11-06T21:00\",\"2023-11-06T22:00\",\"2023-11-06T23:00\",\"2023-11-07T00:00\",\"2023-11-07T01:00\",\"2023-11-07T02:00\",\"2023-11-07T03:00\",\"2023-11-07T04:00\",\"2023-11-07T05:00\",\"2023-11-07T06:00\",\"2023-11-07T07:00\",\"2023-11-07T08:00\",\"2023-11-07T09:00\",\"2023-11-07T10:00\",\"2023-11-07T11:00\",\"2023-11-07T12:00\",\"2023-11-07T13:00\",\"2023-11-07T14:00\",\"2023-11-07T15:00\",\"2023-11-07T16:00\",\"2023-11-07T17:00\",\"2023-11-07T18:00\",\"2023-11-07T19:00\",\"2023-11-07T20:00\",\"2023-11-07T21:00\",\"2023-11-07T22:00\",\"2023-11-07T23:00\"],\"temperature_2m\":[47.9,46.9,47.1,46.6,45.8,45.2,43.4,43.5,46.8,51.5,55.0,56.3,58.1,57.9,57.0,56.6,54.4,52.1,49.1,48.3,47.7,46.9,46.2,45.8,44.4,42.4,41.7,41.7,42.0,42.7,43.6,44.3,45.9,48.0,49.1,50.7,52.2,52.6,51.9,50.3,48.1,47.4,47.1,46.9,46.2,45.7,45.6,45.6,45.7,45.3,45.1,44.2,43.6,43.2,42.8,41.6,41.0,42.1,42.4,42.3,42.7,43.9,44.2,43.6,41.9,40.4,39.0,40.8,40.2,40.1,39.6,38.8,38.2,36.9,35.8,36.4,37.3,38.5,38.9,39.0,41.8,45.4,48.7,50.8,51.7,52.1,51.3,49.8,48.6,47.8,47.0,46.3,45.9,45.6,45.7,46.1,46.3,46.4,46.3,46.3,45.8,45.4,45.5,47.1,49.3,51.2,52.4,53.1,53.5,53.4,53.0,52.4,51.6,50.5,49.6,49.0,48.6,48.1,47.6,47.0,46.4,46.0,45.5,45.1,44.4,43.7,43.9,45.6,48.1,50.3,51.7,52.8,53.5,52.7,51.5,50.2,48.8,47.4,46.2,45.5,45.0,44.6,44.3,44.2,43.9,43.4,43.0,42.6,42.3,42.0,42.2,43.0,44.3,45.5,46.8,48.1,48.9,49.0,48.7,48.1,47.4,46.5,45.7,45.1,44.5,44.3,44.5,45.1]}}\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] @@ -225,10 +225,10 @@ { "data": { "text/plain": [ - "' The current temperature in Munich, Germany is 52.9°F.'" + "' The current temperature in Munich, Germany is 46.5°F.'" ] }, - "execution_count": 8, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -240,7 +240,10 @@ "\n", "llm = OpenAI(temperature=0)\n", "chain = APIChain.from_llm_and_api_docs(\n", - " llm, open_meteo_docs.OPEN_METEO_DOCS, verbose=True\n", + " llm,\n", + " open_meteo_docs.OPEN_METEO_DOCS,\n", + " verbose=True,\n", + " limit_to_domains=[\"https://api.open-meteo.com/\"],\n", ")\n", "chain.run(\n", " \"What is the weather like right now in Munich, Germany in degrees Fahrenheit?\"\n", @@ -322,7 +325,11 @@ "\n", "headers = {\"Authorization\": f\"Bearer {os.environ['TMDB_BEARER_TOKEN']}\"}\n", "chain = APIChain.from_llm_and_api_docs(\n", - " llm, tmdb_docs.TMDB_DOCS, headers=headers, verbose=True\n", + " llm,\n", + " tmdb_docs.TMDB_DOCS,\n", + " headers=headers,\n", + " verbose=True,\n", + " limit_to_domains=[\"https://api.themoviedb.org/\"],\n", ")\n", "chain.run(\"Search for 'Avatar'\")" ] @@ -343,7 +350,11 @@ "llm = OpenAI(temperature=0)\n", "headers = {\"X-ListenAPI-Key\": listen_api_key}\n", "chain = APIChain.from_llm_and_api_docs(\n", - " llm, podcast_docs.PODCAST_DOCS, headers=headers, verbose=True\n", + " llm,\n", + " podcast_docs.PODCAST_DOCS,\n", + " headers=headers,\n", + " verbose=True,\n", + " limit_to_domains=[\"https://listen-api.listennotes.com/\"],\n", ")\n", "chain.run(\n", " \"Search for 'silicon valley bank' podcast episodes, audio length is more than 30 minutes, return only 1 results\"\n", @@ -438,7 +449,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.1" + "version": "3.11.4" } }, "nbformat": 4, diff --git a/libs/langchain/langchain/chains/api/base.py b/libs/langchain/langchain/chains/api/base.py index 01a0e24d6e2..b8d0d178711 100644 --- a/libs/langchain/langchain/chains/api/base.py +++ b/libs/langchain/langchain/chains/api/base.py @@ -1,7 +1,8 @@ """Chain that makes API calls and summarizes the responses to answer a question.""" from __future__ import annotations -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Sequence, Tuple +from urllib.parse import urlparse from langchain.callbacks.manager import ( AsyncCallbackManagerForChainRun, @@ -16,6 +17,38 @@ from langchain.schema.language_model import BaseLanguageModel from langchain.utilities.requests import TextRequestsWrapper +def _extract_scheme_and_domain(url: str) -> Tuple[str, str]: + """Extract the scheme + domain from a given URL. + + Args: + url (str): The input URL. + + Returns: + return a 2-tuple of scheme and domain + """ + parsed_uri = urlparse(url) + return parsed_uri.scheme, parsed_uri.netloc + + +def _check_in_allowed_domain(url: str, limit_to_domains: Sequence[str]) -> bool: + """Check if a URL is in the allowed domains. + + Args: + url (str): The input URL. + limit_to_domains (Sequence[str]): The allowed domains. + + Returns: + bool: True if the URL is in the allowed domains, False otherwise. + """ + scheme, domain = _extract_scheme_and_domain(url) + + for allowed_domain in limit_to_domains: + allowed_scheme, allowed_domain = _extract_scheme_and_domain(allowed_domain) + if scheme == allowed_scheme and domain == allowed_domain: + return True + return False + + class APIChain(Chain): """Chain that makes API calls and summarizes the responses to answer a question. @@ -40,6 +73,19 @@ class APIChain(Chain): api_docs: str question_key: str = "question" #: :meta private: output_key: str = "output" #: :meta private: + limit_to_domains: Optional[Sequence[str]] + """Use to limit the domains that can be accessed by the API chain. + + * For example, to limit to just the domain `https://www.example.com`, set + `limit_to_domains=["https://www.example.com"]`. + + * The default value is an empty tuple, which means that no domains are + allowed by default. By design this will raise an error on instantiation. + * Use a None if you want to allow all domains by default -- this is not + recommended for security reasons, as it would allow malicious users to + make requests to arbitrary URLS including internal APIs accessible from + the server. + """ @property def input_keys(self) -> List[str]: @@ -68,6 +114,21 @@ class APIChain(Chain): ) return values + @root_validator(pre=True) + def validate_limit_to_domains(cls, values: Dict) -> Dict: + """Check that allowed domains are valid.""" + if "limit_to_domains" not in values: + raise ValueError( + "You must specify a list of domains to limit access using " + "`limit_to_domains`" + ) + if not values["limit_to_domains"] and values["limit_to_domains"] is not None: + raise ValueError( + "Please provide a list of domains to limit access using " + "`limit_to_domains`." + ) + return values + @root_validator(pre=True) def validate_api_answer_prompt(cls, values: Dict) -> Dict: """Check that api answer prompt expects the right variables.""" @@ -93,6 +154,12 @@ class APIChain(Chain): ) _run_manager.on_text(api_url, color="green", end="\n", verbose=self.verbose) api_url = api_url.strip() + if self.limit_to_domains and not _check_in_allowed_domain( + api_url, self.limit_to_domains + ): + raise ValueError( + f"{api_url} is not in the allowed domains: {self.limit_to_domains}" + ) api_response = self.requests_wrapper.get(api_url) _run_manager.on_text( api_response, color="yellow", end="\n", verbose=self.verbose @@ -122,6 +189,12 @@ class APIChain(Chain): api_url, color="green", end="\n", verbose=self.verbose ) api_url = api_url.strip() + if self.limit_to_domains and not _check_in_allowed_domain( + api_url, self.limit_to_domains + ): + raise ValueError( + f"{api_url} is not in the allowed domains: {self.limit_to_domains}" + ) api_response = await self.requests_wrapper.aget(api_url) await _run_manager.on_text( api_response, color="yellow", end="\n", verbose=self.verbose @@ -143,6 +216,7 @@ class APIChain(Chain): headers: Optional[dict] = None, api_url_prompt: BasePromptTemplate = API_URL_PROMPT, api_response_prompt: BasePromptTemplate = API_RESPONSE_PROMPT, + limit_to_domains: Optional[Sequence[str]] = tuple(), **kwargs: Any, ) -> APIChain: """Load chain from just an LLM and the api docs.""" @@ -154,6 +228,7 @@ class APIChain(Chain): api_answer_chain=get_answer_chain, requests_wrapper=requests_wrapper, api_docs=api_docs, + limit_to_domains=limit_to_domains, **kwargs, ) diff --git a/libs/langchain/tests/unit_tests/chains/test_api.py b/libs/langchain/tests/unit_tests/chains/test_api.py index 93d38ff6add..9134453cdfb 100644 --- a/libs/langchain/tests/unit_tests/chains/test_api.py +++ b/libs/langchain/tests/unit_tests/chains/test_api.py @@ -22,8 +22,7 @@ class FakeRequestsChain(TextRequestsWrapper): return self.output -@pytest.fixture -def test_api_data() -> dict: +def get_test_api_data() -> dict: """Fake api data to use for testing.""" api_docs = """ This API endpoint will search the notes for a user. @@ -48,39 +47,59 @@ def test_api_data() -> dict: } -@pytest.fixture -def fake_llm_api_chain(test_api_data: dict) -> APIChain: +def get_api_chain(**kwargs: Any) -> APIChain: """Fake LLM API chain for testing.""" - TEST_API_DOCS = test_api_data["api_docs"] - TEST_QUESTION = test_api_data["question"] - TEST_URL = test_api_data["api_url"] - TEST_API_RESPONSE = test_api_data["api_response"] - TEST_API_SUMMARY = test_api_data["api_summary"] + data = get_test_api_data() + test_api_docs = data["api_docs"] + test_question = data["question"] + test_url = data["api_url"] + test_api_response = data["api_response"] + test_api_summary = data["api_summary"] api_url_query_prompt = API_URL_PROMPT.format( - api_docs=TEST_API_DOCS, question=TEST_QUESTION + api_docs=test_api_docs, question=test_question ) api_response_prompt = API_RESPONSE_PROMPT.format( - api_docs=TEST_API_DOCS, - question=TEST_QUESTION, - api_url=TEST_URL, - api_response=TEST_API_RESPONSE, + api_docs=test_api_docs, + question=test_question, + api_url=test_url, + api_response=test_api_response, ) - queries = {api_url_query_prompt: TEST_URL, api_response_prompt: TEST_API_SUMMARY} + queries = {api_url_query_prompt: test_url, api_response_prompt: test_api_summary} fake_llm = FakeLLM(queries=queries) api_request_chain = LLMChain(llm=fake_llm, prompt=API_URL_PROMPT) api_answer_chain = LLMChain(llm=fake_llm, prompt=API_RESPONSE_PROMPT) - requests_wrapper = FakeRequestsChain(output=TEST_API_RESPONSE) + requests_wrapper = FakeRequestsChain(output=test_api_response) return APIChain( api_request_chain=api_request_chain, api_answer_chain=api_answer_chain, requests_wrapper=requests_wrapper, - api_docs=TEST_API_DOCS, + api_docs=test_api_docs, + **kwargs, ) -def test_api_question(fake_llm_api_chain: APIChain, test_api_data: dict) -> None: +def test_api_question() -> None: """Test simple question that needs API access.""" - question = test_api_data["question"] - output = fake_llm_api_chain.run(question) - assert output == test_api_data["api_summary"] + with pytest.raises(ValueError): + get_api_chain() + with pytest.raises(ValueError): + get_api_chain(limit_to_domains=tuple()) + + # All domains allowed (not advised) + api_chain = get_api_chain(limit_to_domains=None) + data = get_test_api_data() + assert api_chain.run(data["question"]) == data["api_summary"] + + # Use a domain that's allowed + api_chain = get_api_chain( + limit_to_domains=["https://thisapidoesntexist.com/api/notes?q=langchain"] + ) + # Attempts to make a request against a domain that's not allowed + assert api_chain.run(data["question"]) == data["api_summary"] + + # Use domains that are not valid + api_chain = get_api_chain(limit_to_domains=["h", "*"]) + with pytest.raises(ValueError): + # Attempts to make a request against a domain that's not allowed + assert api_chain.run(data["question"]) == data["api_summary"]