langchain/libs/community/langchain_community/utilities/serpapi.py
Bagatur a0c2281540
infra: update mypy 1.10, ruff 0.5 (#23721)
```python
"""python scripts/update_mypy_ruff.py"""
import glob
import tomllib
from pathlib import Path

import toml
import subprocess
import re

ROOT_DIR = Path(__file__).parents[1]


def main():
    for path in glob.glob(str(ROOT_DIR / "libs/**/pyproject.toml"), recursive=True):
        print(path)
        with open(path, "rb") as f:
            pyproject = tomllib.load(f)
        try:
            pyproject["tool"]["poetry"]["group"]["typing"]["dependencies"]["mypy"] = (
                "^1.10"
            )
            pyproject["tool"]["poetry"]["group"]["lint"]["dependencies"]["ruff"] = (
                "^0.5"
            )
        except KeyError:
            continue
        with open(path, "w") as f:
            toml.dump(pyproject, f)
        cwd = "/".join(path.split("/")[:-1])
        completed = subprocess.run(
            "poetry lock --no-update; poetry install --with typing; poetry run mypy . --no-color",
            cwd=cwd,
            shell=True,
            capture_output=True,
            text=True,
        )
        logs = completed.stdout.split("\n")

        to_ignore = {}
        for l in logs:
            if re.match("^(.*)\:(\d+)\: error:.*\[(.*)\]", l):
                path, line_no, error_type = re.match(
                    "^(.*)\:(\d+)\: error:.*\[(.*)\]", l
                ).groups()
                if (path, line_no) in to_ignore:
                    to_ignore[(path, line_no)].append(error_type)
                else:
                    to_ignore[(path, line_no)] = [error_type]
        print(len(to_ignore))
        for (error_path, line_no), error_types in to_ignore.items():
            all_errors = ", ".join(error_types)
            full_path = f"{cwd}/{error_path}"
            try:
                with open(full_path, "r") as f:
                    file_lines = f.readlines()
            except FileNotFoundError:
                continue
            file_lines[int(line_no) - 1] = (
                file_lines[int(line_no) - 1][:-1] + f"  # type: ignore[{all_errors}]\n"
            )
            with open(full_path, "w") as f:
                f.write("".join(file_lines))

        subprocess.run(
            "poetry run ruff format .; poetry run ruff --select I --fix .",
            cwd=cwd,
            shell=True,
            capture_output=True,
            text=True,
        )


if __name__ == "__main__":
    main()

```
2024-07-03 10:33:27 -07:00

227 lines
8.5 KiB
Python

"""Chain that calls SerpAPI.
Heavily borrowed from https://github.com/ofirpress/self-ask
"""
import os
import sys
from typing import Any, Dict, Optional, Tuple
import aiohttp
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
from langchain_core.utils import get_from_dict_or_env
class HiddenPrints:
"""Context manager to hide prints."""
def __enter__(self) -> None:
"""Open file to pipe stdout to."""
self._original_stdout = sys.stdout
sys.stdout = open(os.devnull, "w")
def __exit__(self, *_: Any) -> None:
"""Close file that stdout was piped to."""
sys.stdout.close()
sys.stdout = self._original_stdout
class SerpAPIWrapper(BaseModel):
"""Wrapper around SerpAPI.
To use, you should have the ``google-search-results`` python package installed,
and the environment variable ``SERPAPI_API_KEY`` set with your API key, or pass
`serpapi_api_key` as a named parameter to the constructor.
Example:
.. code-block:: python
from langchain_community.utilities import SerpAPIWrapper
serpapi = SerpAPIWrapper()
"""
search_engine: Any #: :meta private:
params: dict = Field(
default={
"engine": "google",
"google_domain": "google.com",
"gl": "us",
"hl": "en",
}
)
serpapi_api_key: Optional[str] = None
aiosession: Optional[aiohttp.ClientSession] = None
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
serpapi_api_key = get_from_dict_or_env(
values, "serpapi_api_key", "SERPAPI_API_KEY"
)
values["serpapi_api_key"] = serpapi_api_key
try:
from serpapi import GoogleSearch
values["search_engine"] = GoogleSearch
except ImportError:
raise ImportError(
"Could not import serpapi python package. "
"Please install it with `pip install google-search-results`."
)
return values
async def arun(self, query: str, **kwargs: Any) -> str:
"""Run query through SerpAPI and parse result async."""
return self._process_response(await self.aresults(query))
def run(self, query: str, **kwargs: Any) -> str:
"""Run query through SerpAPI and parse result."""
return self._process_response(self.results(query))
def results(self, query: str) -> dict:
"""Run query through SerpAPI and return the raw result."""
params = self.get_params(query)
with HiddenPrints():
search = self.search_engine(params)
res = search.get_dict()
return res
async def aresults(self, query: str) -> dict:
"""Use aiohttp to run query through SerpAPI and return the results async."""
def construct_url_and_params() -> Tuple[str, Dict[str, str]]:
params = self.get_params(query)
params["source"] = "python"
if self.serpapi_api_key:
params["serp_api_key"] = self.serpapi_api_key
params["output"] = "json"
url = "https://serpapi.com/search"
return url, params
url, params = construct_url_and_params()
if not self.aiosession:
async with aiohttp.ClientSession() as session:
async with session.get(url, params=params) as response:
res = await response.json()
else:
async with self.aiosession.get(url, params=params) as response:
res = await response.json()
return res
def get_params(self, query: str) -> Dict[str, str]:
"""Get parameters for SerpAPI."""
_params = {
"api_key": self.serpapi_api_key,
"q": query,
}
params = {**self.params, **_params}
return params
@staticmethod
def _process_response(res: dict) -> str:
"""Process response from SerpAPI."""
if "error" in res.keys():
raise ValueError(f"Got error from SerpAPI: {res['error']}")
if "answer_box_list" in res.keys():
res["answer_box"] = res["answer_box_list"]
if "answer_box" in res.keys():
answer_box = res["answer_box"]
if isinstance(answer_box, list):
answer_box = answer_box[0]
if "result" in answer_box.keys():
return answer_box["result"]
elif "answer" in answer_box.keys():
return answer_box["answer"]
elif "snippet" in answer_box.keys():
return answer_box["snippet"]
elif "snippet_highlighted_words" in answer_box.keys():
return answer_box["snippet_highlighted_words"]
else:
answer = {}
for key, value in answer_box.items():
if not isinstance(value, (list, dict)) and not (
isinstance(value, str) and value.startswith("http")
):
answer[key] = value
return str(answer)
elif "events_results" in res.keys():
return res["events_results"][:10]
elif "sports_results" in res.keys():
return res["sports_results"]
elif "top_stories" in res.keys():
return res["top_stories"]
elif "news_results" in res.keys():
return res["news_results"]
elif "jobs_results" in res.keys() and "jobs" in res["jobs_results"].keys():
return res["jobs_results"]["jobs"]
elif (
"shopping_results" in res.keys()
and "title" in res["shopping_results"][0].keys()
):
return res["shopping_results"][:3]
elif "questions_and_answers" in res.keys():
return res["questions_and_answers"]
elif (
"popular_destinations" in res.keys()
and "destinations" in res["popular_destinations"].keys()
):
return res["popular_destinations"]["destinations"]
elif "top_sights" in res.keys() and "sights" in res["top_sights"].keys():
return res["top_sights"]["sights"]
elif (
"images_results" in res.keys()
and "thumbnail" in res["images_results"][0].keys()
):
return str([item["thumbnail"] for item in res["images_results"][:10]])
snippets = []
if "knowledge_graph" in res.keys():
knowledge_graph = res["knowledge_graph"]
title = knowledge_graph["title"] if "title" in knowledge_graph else ""
if "description" in knowledge_graph.keys():
snippets.append(knowledge_graph["description"])
for key, value in knowledge_graph.items():
if (
isinstance(key, str)
and isinstance(value, str)
and key not in ["title", "description"]
and not key.endswith("_stick")
and not key.endswith("_link")
and not value.startswith("http")
):
snippets.append(f"{title} {key}: {value}.")
for organic_result in res.get("organic_results", []):
if "snippet" in organic_result.keys():
snippets.append(organic_result["snippet"])
elif "snippet_highlighted_words" in organic_result.keys():
snippets.append(organic_result["snippet_highlighted_words"])
elif "rich_snippet" in organic_result.keys():
snippets.append(organic_result["rich_snippet"])
elif "rich_snippet_table" in organic_result.keys():
snippets.append(organic_result["rich_snippet_table"])
elif "link" in organic_result.keys():
snippets.append(organic_result["link"])
if "buying_guide" in res.keys():
snippets.append(res["buying_guide"])
if "local_results" in res and isinstance(res["local_results"], list):
snippets += res["local_results"]
if (
"local_results" in res.keys()
and isinstance(res["local_results"], dict)
and "places" in res["local_results"].keys()
):
snippets.append(res["local_results"]["places"])
if len(snippets) > 0:
return str(snippets)
else:
return "No good search result found"