Add Tavily Search API as a Tool (#12103)

Adding Tavily Search API as a tool. I will be the maintainer and
assaf_elovic is the twitter handler.

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Rotem Weiss
2023-10-21 18:23:21 +03:00
committed by GitHub
parent 85302a9ec1
commit 78d186fb44
4 changed files with 363 additions and 0 deletions

View File

@@ -0,0 +1,5 @@
"""Tavily Search API toolkit."""
from langchain.tools.tavily_search.tool import TavilySearchResults
__all__ = ["TavilySearchResults"]

View File

@@ -0,0 +1,51 @@
"""Tool for the Tavily search API."""
from typing import Dict, List, Optional, Union
from langchain.callbacks.manager import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain.tools.base import BaseTool
from langchain.utilities.tavily_search import TavilySearchAPIWrapper
class TavilySearchResults(BaseTool):
"""Tool that queries the Tavily Search API and gets back json."""
name: str = "tavily_search_results_json"
description: str = """"
"A search engine optimized for comprehensive, accurate, and trusted results. "
"Useful for when you need to answer questions about current events. "
"Input should be a search query."
"""
api_wrapper: TavilySearchAPIWrapper
max_results: int = 5
def _run(
self,
query: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> Union[List[Dict], str]:
"""Use the tool."""
try:
return self.api_wrapper.results(
query,
self.max_results,
)
except Exception as e:
return repr(e)
async def _arun(
self,
query: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> Union[List[Dict], str]:
"""Use the tool asynchronously."""
try:
return await self.api_wrapper.results_async(
query,
self.max_results,
)
except Exception as e:
return repr(e)

View File

@@ -0,0 +1,167 @@
"""Util that calls Tavily Search API.
In order to set this up, follow instructions at:
"""
import json
from typing import Dict, List, Optional
import aiohttp
import requests
from langchain.pydantic_v1 import BaseModel, Extra, root_validator
from langchain.utils import get_from_dict_or_env
TAVILY_API_URL = "https://api.tavily.com"
class TavilySearchAPIWrapper(BaseModel):
"""Wrapper for Tavily Search API."""
tavily_api_key: str
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
def _tavily_search_results(
self,
query: str,
max_results: Optional[int] = 5,
search_depth: Optional[str] = "advanced",
include_domains: Optional[List[str]] = [],
exclude_domains: Optional[List[str]] = [],
include_answer: Optional[bool] = False,
include_raw_content: Optional[bool] = False,
include_images: Optional[bool] = False,
) -> List[dict]:
params = {
"api_key": self.tavily_api_key,
"query": query,
"max_results": max_results,
"search_depth": search_depth,
"include_domains": include_domains,
"exclude_domains": exclude_domains,
"include_answer": include_answer,
"include_raw_content": include_raw_content,
"include_images": include_images,
}
response = requests.post(
# type: ignore
f"{TAVILY_API_URL}/search",
json=params,
)
response.raise_for_status()
search_results = response.json()
return self.clean_results(search_results["results"])
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and endpoint exists in environment."""
tavily_api_key = get_from_dict_or_env(
values, "tavily_api_key", "TAVILY_API_KEY"
)
values["tavily_api_key"] = tavily_api_key
return values
def results(
self,
query: str,
max_results: Optional[int] = 5,
search_depth: Optional[str] = "advanced",
include_domains: Optional[List[str]] = [],
exclude_domains: Optional[List[str]] = [],
include_answer: Optional[bool] = False,
include_raw_content: Optional[bool] = False,
include_images: Optional[bool] = False,
) -> List[Dict]:
"""Run query through Tavily Search and return metadata.
Args:
query: The query to search for.
max_results: The maximum number of results to return.
search_depth: The depth of the search. Can be "basic" or "advanced".
include_domains: A list of domains to include in the search.
exclude_domains: A list of domains to exclude from the search.
include_answer: Whether to include the answer in the results.
include_raw_content: Whether to include the raw content in the results.
include_images: Whether to include images in the results.
Returns:
query: The query that was searched for.
follow_up_questions: A list of follow up questions.
response_time: The response time of the query.
answer: The answer to the query.
images: A list of images.
results: A list of dictionaries containing the results:
title: The title of the result.
url: The url of the result.
content: The content of the result.
score: The score of the result.
raw_content: The raw content of the result.
""" # noqa: E501
raw_search_results = self._tavily_search_results(
query,
max_results,
search_depth,
include_domains,
exclude_domains,
include_answer,
include_raw_content,
include_images,
)
return raw_search_results
async def results_async(
self,
query: str,
max_results: Optional[int] = 5,
search_depth: Optional[str] = "advanced",
include_domains: Optional[List[str]] = [],
exclude_domains: Optional[List[str]] = [],
include_answer: Optional[bool] = False,
include_raw_content: Optional[bool] = False,
include_images: Optional[bool] = False,
) -> List[Dict]:
"""Get results from the Tavily Search API asynchronously."""
# Function to perform the API call
async def fetch() -> str:
params = {
"api_key": self.tavily_api_key,
"query": query,
"max_results": max_results,
"search_depth": search_depth,
"include_domains": include_domains,
"exclude_domains": exclude_domains,
"include_answer": include_answer,
"include_raw_content": include_raw_content,
"include_images": include_images,
}
async with aiohttp.ClientSession() as session:
async with session.post(f"{TAVILY_API_URL}/search", json=params) as res:
if res.status == 200:
data = await res.text()
return data
else:
raise Exception(f"Error {res.status}: {res.reason}")
results_json_str = await fetch()
results_json = json.loads(results_json_str)
return self.clean_results(results_json["results"])
def clean_results(self, results: List[Dict]) -> List[Dict]:
"""Clean results from Tavily Search API."""
clean_results = []
for result in results:
clean_results.append(
{
"url": result["url"],
"content": result["content"],
}
)
return clean_results