mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-06 05:08:20 +00:00
Harrison/serper api bug (#4902)
Co-authored-by: Jerry Luan <xmaswillyou@gmail.com>
This commit is contained in:
parent
c998569c8f
commit
9e2227ba11
@ -5,6 +5,7 @@ import aiohttp
|
|||||||
import requests
|
import requests
|
||||||
from pydantic.class_validators import root_validator
|
from pydantic.class_validators import root_validator
|
||||||
from pydantic.main import BaseModel
|
from pydantic.main import BaseModel
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
from langchain.utils import get_from_dict_or_env
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
@ -28,7 +29,16 @@ class GoogleSerperAPIWrapper(BaseModel):
|
|||||||
k: int = 10
|
k: int = 10
|
||||||
gl: str = "us"
|
gl: str = "us"
|
||||||
hl: str = "en"
|
hl: str = "en"
|
||||||
type: str = "search" # search, images, places, news
|
# "places" and "images" is available from Serper but not implemented in the
|
||||||
|
# parser of run(). They can be used in results()
|
||||||
|
type: Literal["news", "search", "places", "images"] = "search"
|
||||||
|
result_key_for_type = {
|
||||||
|
"news": "news",
|
||||||
|
"places": "places",
|
||||||
|
"images": "images",
|
||||||
|
"search": "organic",
|
||||||
|
}
|
||||||
|
|
||||||
tbs: Optional[str] = None
|
tbs: Optional[str] = None
|
||||||
serper_api_key: Optional[str] = None
|
serper_api_key: Optional[str] = None
|
||||||
aiosession: Optional[aiohttp.ClientSession] = None
|
aiosession: Optional[aiohttp.ClientSession] = None
|
||||||
@ -50,7 +60,7 @@ class GoogleSerperAPIWrapper(BaseModel):
|
|||||||
|
|
||||||
def results(self, query: str, **kwargs: Any) -> Dict:
|
def results(self, query: str, **kwargs: Any) -> Dict:
|
||||||
"""Run query through GoogleSearch."""
|
"""Run query through GoogleSearch."""
|
||||||
return self._google_serper_search_results(
|
return self._google_serper_api_results(
|
||||||
query,
|
query,
|
||||||
gl=self.gl,
|
gl=self.gl,
|
||||||
hl=self.hl,
|
hl=self.hl,
|
||||||
@ -62,7 +72,7 @@ class GoogleSerperAPIWrapper(BaseModel):
|
|||||||
|
|
||||||
def run(self, query: str, **kwargs: Any) -> str:
|
def run(self, query: str, **kwargs: Any) -> str:
|
||||||
"""Run query through GoogleSearch and parse result."""
|
"""Run query through GoogleSearch and parse result."""
|
||||||
results = self._google_serper_search_results(
|
results = self._google_serper_api_results(
|
||||||
query,
|
query,
|
||||||
gl=self.gl,
|
gl=self.gl,
|
||||||
hl=self.hl,
|
hl=self.hl,
|
||||||
@ -125,7 +135,7 @@ class GoogleSerperAPIWrapper(BaseModel):
|
|||||||
for attribute, value in kg.get("attributes", {}).items():
|
for attribute, value in kg.get("attributes", {}).items():
|
||||||
snippets.append(f"{title} {attribute}: {value}.")
|
snippets.append(f"{title} {attribute}: {value}.")
|
||||||
|
|
||||||
for result in results["organic"][: self.k]:
|
for result in results[self.result_key_for_type[self.type]][: self.k]:
|
||||||
if "snippet" in result:
|
if "snippet" in result:
|
||||||
snippets.append(result["snippet"])
|
snippets.append(result["snippet"])
|
||||||
for attribute, value in result.get("attributes", {}).items():
|
for attribute, value in result.get("attributes", {}).items():
|
||||||
@ -138,7 +148,7 @@ class GoogleSerperAPIWrapper(BaseModel):
|
|||||||
def _parse_results(self, results: dict) -> str:
|
def _parse_results(self, results: dict) -> str:
|
||||||
return " ".join(self._parse_snippets(results))
|
return " ".join(self._parse_snippets(results))
|
||||||
|
|
||||||
def _google_serper_search_results(
|
def _google_serper_api_results(
|
||||||
self, search_term: str, search_type: str = "search", **kwargs: Any
|
self, search_term: str, search_type: str = "search", **kwargs: Any
|
||||||
) -> dict:
|
) -> dict:
|
||||||
headers = {
|
headers = {
|
||||||
|
@ -4,13 +4,20 @@ import pytest
|
|||||||
from langchain.utilities.google_serper import GoogleSerperAPIWrapper
|
from langchain.utilities.google_serper import GoogleSerperAPIWrapper
|
||||||
|
|
||||||
|
|
||||||
def test_call() -> None:
|
def test_search_call() -> None:
|
||||||
"""Test that call gives the correct answer."""
|
"""Test that call gives the correct answer from search."""
|
||||||
search = GoogleSerperAPIWrapper()
|
search = GoogleSerperAPIWrapper()
|
||||||
output = search.run("What was Obama's first name?")
|
output = search.run("What was Obama's first name?")
|
||||||
assert "Barack Hussein Obama II" in output
|
assert "Barack Hussein Obama II" in output
|
||||||
|
|
||||||
|
|
||||||
|
def test_news_call() -> None:
|
||||||
|
"""Test that call gives the correct answer from news search."""
|
||||||
|
search = GoogleSerperAPIWrapper(type="news")
|
||||||
|
output = search.run("What's new with stock market?").lower()
|
||||||
|
assert "stock" in output or "market" in output
|
||||||
|
|
||||||
|
|
||||||
async def test_results() -> None:
|
async def test_results() -> None:
|
||||||
"""Test that call gives the correct answer."""
|
"""Test that call gives the correct answer."""
|
||||||
search = GoogleSerperAPIWrapper()
|
search = GoogleSerperAPIWrapper()
|
||||||
|
Loading…
Reference in New Issue
Block a user