From 96ad09fa2d9deaae42bb93b0e880d4beb4f19448 Mon Sep 17 00:00:00 2001 From: Mohammad Mohtashim <45242107+keenborder786@users.noreply.github.com> Date: Thu, 13 Feb 2025 09:12:07 +0500 Subject: [PATCH] (Community): Added API Key for Jina Search API Wrapper (#29622) - **Description:** Simple change for adding the API Key for Jina Search API Wrapper - **Issue:** #29596 --- .../docs/integrations/tools/jina_search.ipynb | 5 +++- .../tools/jina_search/tool.py | 2 +- .../utilities/jina_search.py | 21 +++++++++++++-- .../unit_tests/tools/jina_search/__init__.py | 0 .../tools/jina_search/test_jina_search.py | 27 +++++++++++++++++++ 5 files changed, 51 insertions(+), 4 deletions(-) create mode 100644 libs/community/tests/unit_tests/tools/jina_search/__init__.py create mode 100644 libs/community/tests/unit_tests/tools/jina_search/test_jina_search.py diff --git a/docs/docs/integrations/tools/jina_search.ipynb b/docs/docs/integrations/tools/jina_search.ipynb index 03af86055f0..c07e9904ff2 100644 --- a/docs/docs/integrations/tools/jina_search.ipynb +++ b/docs/docs/integrations/tools/jina_search.ipynb @@ -64,7 +64,10 @@ "outputs": [], "source": [ "import getpass\n", - "import os" + "import os\n", + "\n", + "if not os.environ.get(\"JINA_API_KEY\"):\n", + " os.environ[\"JINA_API_KEY\"] = getpass.getpass(\"Jina API key:\\n\")" ] }, { diff --git a/libs/community/langchain_community/tools/jina_search/tool.py b/libs/community/langchain_community/tools/jina_search/tool.py index 42a2a5e7e7a..dbb707f7f6a 100644 --- a/libs/community/langchain_community/tools/jina_search/tool.py +++ b/libs/community/langchain_community/tools/jina_search/tool.py @@ -30,7 +30,7 @@ class JinaSearch(BaseTool): # type: ignore[override] "each in clean, LLM-friendly text. This way, you can always keep your LLM " "up-to-date, improve its factuality, and reduce hallucinations." ) - search_wrapper: JinaSearchAPIWrapper = Field(default_factory=JinaSearchAPIWrapper) + search_wrapper: JinaSearchAPIWrapper = Field(default_factory=JinaSearchAPIWrapper) # type: ignore[arg-type] def _run( self, diff --git a/libs/community/langchain_community/utilities/jina_search.py b/libs/community/langchain_community/utilities/jina_search.py index b35b31f6830..d879b0ebd35 100644 --- a/libs/community/langchain_community/utilities/jina_search.py +++ b/libs/community/langchain_community/utilities/jina_search.py @@ -1,18 +1,34 @@ import json -from typing import List +from typing import Any, Dict, List import requests from langchain_core.documents import Document -from pydantic import BaseModel +from langchain_core.utils import get_from_dict_or_env +from pydantic import BaseModel, ConfigDict, SecretStr, model_validator from yarl import URL class JinaSearchAPIWrapper(BaseModel): """Wrapper around the Jina search engine.""" + api_key: SecretStr + base_url: str = "https://s.jina.ai/" """The base URL for the Jina search engine.""" + model_config = ConfigDict( + extra="forbid", + ) + + @model_validator(mode="before") + @classmethod + def validate_environment(cls, values: Dict) -> Any: + """Validate that api key and endpoint exists in environment.""" + api_key = get_from_dict_or_env(values, "api_key", "JINA_API_KEY") + values["api_key"] = api_key + + return values + def run(self, query: str) -> str: """Query the Jina search engine and return the results as a JSON string. @@ -59,6 +75,7 @@ class JinaSearchAPIWrapper(BaseModel): def _search_request(self, query: str) -> List[dict]: headers = { "Accept": "application/json", + "Authorization": f"Bearer {self.api_key.get_secret_value()}", } url = str(URL(self.base_url + query)) response = requests.get(url, headers=headers) diff --git a/libs/community/tests/unit_tests/tools/jina_search/__init__.py b/libs/community/tests/unit_tests/tools/jina_search/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/libs/community/tests/unit_tests/tools/jina_search/test_jina_search.py b/libs/community/tests/unit_tests/tools/jina_search/test_jina_search.py new file mode 100644 index 00000000000..89174714523 --- /dev/null +++ b/libs/community/tests/unit_tests/tools/jina_search/test_jina_search.py @@ -0,0 +1,27 @@ +import os +import unittest +from typing import Any +from unittest.mock import patch + +from langchain_community.tools.jina_search.tool import JinaSearch +from langchain_community.utilities.jina_search import JinaSearchAPIWrapper + +os.environ["JINA_API_KEY"] = "test_key" + + +class TestJinaSearchTool(unittest.TestCase): + @patch( + "langchain_community.tools.jina_search.tool.JinaSearch.invoke", + return_value="mocked_result", + ) + def test_invoke(self, mock_run: Any) -> None: + query = "Test query text" + wrapper = JinaSearchAPIWrapper(api_key="test_key") # type: ignore[arg-type] + jina_search_tool = JinaSearch(api_wrapper=wrapper) # type: ignore[call-arg] + results = jina_search_tool.invoke(query) + expected_result = "mocked_result" + assert results == expected_result + + +if __name__ == "__main__": + unittest.main()