diff --git a/docs/docs/integrations/tools/brave_search.ipynb b/docs/docs/integrations/tools/brave_search.ipynb index 2afe244dd4f..ecf7910282e 100644 --- a/docs/docs/integrations/tools/brave_search.ipynb +++ b/docs/docs/integrations/tools/brave_search.ipynb @@ -43,12 +43,21 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "5b14008a", "metadata": {}, "outputs": [], "source": [ - "tool = BraveSearch.from_api_key(api_key=api_key, search_kwargs={\"count\": 3})" + "tool = BraveSearch.from_api_key(api_key=api_key, search_kwargs={\"count\": 3})\n", + "\n", + "# or if you want to get the api key from environment variable BRAVE_SEARCH_API_KEY, and leave search_kwargs empty\n", + "# tool = BraveSearch()\n", + "\n", + "# or if you want to provide just the api key, and leave search_kwargs empty\n", + "# tool = BraveSearch.from_api_key(api_key=api_key)\n", + "\n", + "# or if you want to provide just the search_kwargs and read the api key from the BRAVE_SEARCH_API_KEY environment variable\n", + "# tool = BraveSearch.from_search_kwargs(search_kwargs={\"count\": 3})" ] }, { diff --git a/libs/community/langchain_community/document_loaders/brave_search.py b/libs/community/langchain_community/document_loaders/brave_search.py index 7fa5ff13624..6759e0d6359 100644 --- a/libs/community/langchain_community/document_loaders/brave_search.py +++ b/libs/community/langchain_community/document_loaders/brave_search.py @@ -1,6 +1,7 @@ from typing import Iterator, List, Optional from langchain_core.documents import Document +from pydantic import SecretStr from langchain_community.document_loaders.base import BaseLoader from langchain_community.utilities.brave_search import BraveSearchWrapper @@ -23,7 +24,7 @@ class BraveSearchLoader(BaseLoader): def load(self) -> List[Document]: brave_client = BraveSearchWrapper( - api_key=self.api_key, + api_key=SecretStr(self.api_key), search_kwargs=self.search_kwargs, ) return brave_client.download_documents(self.query) diff --git a/libs/community/langchain_community/tools/brave_search/tool.py b/libs/community/langchain_community/tools/brave_search/tool.py index be8d8ae41d1..9d170494d6d 100644 --- a/libs/community/langchain_community/tools/brave_search/tool.py +++ b/libs/community/langchain_community/tools/brave_search/tool.py @@ -4,12 +4,38 @@ from typing import Any, Optional from langchain_core.callbacks import CallbackManagerForToolRun from langchain_core.tools import BaseTool +from pydantic import Field, SecretStr from langchain_community.utilities.brave_search import BraveSearchWrapper class BraveSearch(BaseTool): # type: ignore[override] - """Tool that queries the BraveSearch.""" + """Tool that queries the BraveSearch. + + Api key can be provided as an environment variable BRAVE_SEARCH_API_KEY + or as a parameter. + + + Example usages: + .. code-block:: python + # uses BRAVE_SEARCH_API_KEY from environment + tool = BraveSearch() + + .. code-block:: python + # uses the provided api key + tool = BraveSearch.from_api_key("your-api-key") + + .. code-block:: python + # uses the provided api key and search kwargs + tool = BraveSearch.from_api_key( + api_key = "your-api-key", + search_kwargs={"max_results": 5} + ) + + .. code-block:: python + # uses BRAVE_SEARCH_API_KEY from environment + tool = BraveSearch.from_search_kwargs({"max_results": 5}) + """ name: str = "brave_search" description: str = ( @@ -17,7 +43,7 @@ class BraveSearch(BaseTool): # type: ignore[override] "useful for when you need to answer questions about current events." " input should be a search query." ) - search_wrapper: BraveSearchWrapper + search_wrapper: BraveSearchWrapper = Field(default_factory=BraveSearchWrapper) @classmethod def from_api_key( @@ -33,7 +59,28 @@ class BraveSearch(BaseTool): # type: ignore[override] Returns: A tool. """ - wrapper = BraveSearchWrapper(api_key=api_key, search_kwargs=search_kwargs or {}) + wrapper = BraveSearchWrapper( + api_key=SecretStr(api_key), search_kwargs=search_kwargs or {} + ) + return cls(search_wrapper=wrapper, **kwargs) + + @classmethod + def from_search_kwargs(cls, search_kwargs: dict, **kwargs: Any) -> BraveSearch: + """Create a tool from search kwargs. + + Uses the environment variable BRAVE_SEARCH_API_KEY for api key. + + Args: + search_kwargs: Any additional kwargs to pass to the search wrapper. + **kwargs: Any additional kwargs to pass to the tool. + + Returns: + A tool. + """ + # we can not provide api key because it's calculated in the wrapper, + # so the ignore is needed for linter + # not ideal but needed to keep the tool code changes non-breaking + wrapper = BraveSearchWrapper(search_kwargs=search_kwargs) return cls(search_wrapper=wrapper, **kwargs) def _run( diff --git a/libs/community/langchain_community/utilities/brave_search.py b/libs/community/langchain_community/utilities/brave_search.py index 41dd9015f2e..15f00c81cb6 100644 --- a/libs/community/langchain_community/utilities/brave_search.py +++ b/libs/community/langchain_community/utilities/brave_search.py @@ -3,13 +3,16 @@ from typing import List import requests from langchain_core.documents import Document -from pydantic import BaseModel, Field +from langchain_core.utils import secret_from_env +from pydantic import BaseModel, Field, SecretStr class BraveSearchWrapper(BaseModel): """Wrapper around the Brave search engine.""" - api_key: str + api_key: SecretStr = Field( + default_factory=secret_from_env(["BRAVE_SEARCH_API_KEY"]) + ) """The API key to use for the Brave search engine.""" search_kwargs: dict = Field(default_factory=dict) """Additional keyword arguments to pass to the search request.""" @@ -64,7 +67,7 @@ class BraveSearchWrapper(BaseModel): def _search_request(self, query: str) -> List[dict]: headers = { - "X-Subscription-Token": self.api_key, + "X-Subscription-Token": self.api_key.get_secret_value(), "Accept": "application/json", } req = requests.PreparedRequest() diff --git a/libs/community/tests/unit_tests/utilities/test_brave_search.py b/libs/community/tests/unit_tests/utilities/test_brave_search.py new file mode 100644 index 00000000000..e1c9bdc15a7 --- /dev/null +++ b/libs/community/tests/unit_tests/utilities/test_brave_search.py @@ -0,0 +1,32 @@ +from typing import Any + +import pytest +from pydantic import SecretStr + +from langchain_community.utilities.brave_search import BraveSearchWrapper + + +def test_api_key_explicit() -> None: + """Test that the API key is correctly set when provided explicitly.""" + explicit_key = "explicit-api-key" + wrapper = BraveSearchWrapper(api_key=SecretStr(explicit_key), search_kwargs={}) + assert wrapper.api_key.get_secret_value() == explicit_key + + +def test_api_key_from_env(monkeypatch: Any) -> None: + """Test that the API key is correctly obtained from the environment variable.""" + env_key = "env-api-key" + monkeypatch.setenv("BRAVE_SEARCH_API_KEY", env_key) + # Do not pass the api_key explicitly + wrapper = BraveSearchWrapper() # type: ignore[call-arg] + assert wrapper.api_key.get_secret_value() == env_key + + +def test_api_key_missing(monkeypatch: Any) -> None: + """Test that instantiation fails when no API key is provided + either explicitly or via environment.""" + # Ensure that the environment variable is not set + monkeypatch.delenv("BRAVE_SEARCH_API_KEY", raising=False) + with pytest.raises(ValueError): + # This should raise an error because no api_key is available. + BraveSearchWrapper() # type: ignore[call-arg]