mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-30 19:49:09 +00:00
update tools (#13243)
This commit is contained in:
parent
8d6faf5665
commit
7f1d26160d
@ -18,6 +18,8 @@ class YouRetriever(BaseRetriever):
|
||||
|
||||
ydc_api_key: str
|
||||
k: Optional[int] = None
|
||||
n_hits: Optional[int] = None
|
||||
n_snippets_per_hit: Optional[int] = None
|
||||
endpoint_type: str = "web"
|
||||
|
||||
@root_validator(pre=True)
|
||||
@ -43,8 +45,10 @@ class YouRetriever(BaseRetriever):
|
||||
).json()
|
||||
|
||||
docs = []
|
||||
for hit in results["hits"]:
|
||||
for snippet in hit["snippets"]:
|
||||
n_hits = self.n_hits or len(results["hits"])
|
||||
for hit in results["hits"][:n_hits]:
|
||||
n_snippets_per_hit = self.n_snippets_per_hit or len(hit["snippets"])
|
||||
for snippet in hit["snippets"][:n_snippets_per_hit]:
|
||||
docs.append(Document(page_content=snippet))
|
||||
if self.k is not None and len(docs) >= self.k:
|
||||
return docs
|
||||
|
@ -1,13 +1,17 @@
|
||||
"""Tool for the Arxiv API."""
|
||||
|
||||
from typing import Optional
|
||||
from typing import Optional, Type
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForToolRun
|
||||
from langchain.pydantic_v1 import Field
|
||||
from langchain.pydantic_v1 import BaseModel, Field
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.utilities.arxiv import ArxivAPIWrapper
|
||||
|
||||
|
||||
class ArxivInput(BaseModel):
|
||||
query: str = Field(description="search query to look up")
|
||||
|
||||
|
||||
class ArxivQueryRun(BaseTool):
|
||||
"""Tool that searches the Arxiv API."""
|
||||
|
||||
@ -21,6 +25,7 @@ class ArxivQueryRun(BaseTool):
|
||||
"Input should be a search query."
|
||||
)
|
||||
api_wrapper: ArxivAPIWrapper = Field(default_factory=ArxivAPIWrapper)
|
||||
args_schema: Type[BaseModel] = ArxivInput
|
||||
|
||||
def _run(
|
||||
self,
|
||||
|
@ -1,7 +1,12 @@
|
||||
from langchain.pydantic_v1 import BaseModel, Field
|
||||
from langchain.schema import BaseRetriever
|
||||
from langchain.tools import Tool
|
||||
|
||||
|
||||
class RetrieverInput(BaseModel):
|
||||
query: str = Field(description="query to look up in retriever")
|
||||
|
||||
|
||||
def create_retriever_tool(
|
||||
retriever: BaseRetriever, name: str, description: str
|
||||
) -> Tool:
|
||||
@ -22,4 +27,5 @@ def create_retriever_tool(
|
||||
description=description,
|
||||
func=retriever.get_relevant_documents,
|
||||
coroutine=retriever.aget_relevant_documents,
|
||||
args_schema=RetrieverInput,
|
||||
)
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""Tavily Search API toolkit."""
|
||||
|
||||
from langchain.tools.tavily_search.tool import TavilySearchResults
|
||||
from langchain.tools.tavily_search.tool import TavilyAnswer, TavilySearchResults
|
||||
|
||||
__all__ = ["TavilySearchResults"]
|
||||
__all__ = ["TavilySearchResults", "TavilyAnswer"]
|
||||
|
@ -1,26 +1,32 @@
|
||||
"""Tool for the Tavily search API."""
|
||||
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Dict, List, Optional, Type, Union
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain.pydantic_v1 import BaseModel, Field
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.utilities.tavily_search import TavilySearchAPIWrapper
|
||||
|
||||
|
||||
class TavilyInput(BaseModel):
|
||||
query: str = Field(description="search query to look up")
|
||||
|
||||
|
||||
class TavilySearchResults(BaseTool):
|
||||
"""Tool that queries the Tavily Search API and gets back json."""
|
||||
|
||||
name: str = "tavily_search_results_json"
|
||||
description: str = """"
|
||||
description: str = (
|
||||
"A search engine optimized for comprehensive, accurate, and trusted results. "
|
||||
"Useful for when you need to answer questions about current events. "
|
||||
"Input should be a search query."
|
||||
"""
|
||||
)
|
||||
api_wrapper: TavilySearchAPIWrapper
|
||||
max_results: int = 5
|
||||
args_schema: Type[BaseModel] = TavilyInput
|
||||
|
||||
def _run(
|
||||
self,
|
||||
@ -49,3 +55,50 @@ class TavilySearchResults(BaseTool):
|
||||
)
|
||||
except Exception as e:
|
||||
return repr(e)
|
||||
|
||||
|
||||
class TavilyAnswer(BaseTool):
|
||||
"""Tool that queries the Tavily Search API and gets back an answer."""
|
||||
|
||||
name: str = "tavily_answer"
|
||||
description: str = (
|
||||
"A search engine optimized for comprehensive, accurate, and trusted results. "
|
||||
"Useful for when you need to answer questions about current events. "
|
||||
"Input should be a search query. "
|
||||
"This returns only the answer - not the original source data."
|
||||
)
|
||||
api_wrapper: TavilySearchAPIWrapper
|
||||
args_schema: Type[BaseModel] = TavilyInput
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> Union[List[Dict], str]:
|
||||
"""Use the tool."""
|
||||
try:
|
||||
return self.api_wrapper.raw_results(
|
||||
query,
|
||||
max_results=5,
|
||||
include_answer=True,
|
||||
search_depth="basic",
|
||||
)["answer"]
|
||||
except Exception as e:
|
||||
return repr(e)
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
) -> Union[List[Dict], str]:
|
||||
"""Use the tool asynchronously."""
|
||||
try:
|
||||
result = await self.api_wrapper.raw_results_async(
|
||||
query,
|
||||
max_results=5,
|
||||
include_answer=True,
|
||||
search_depth="basic",
|
||||
)
|
||||
return result["answer"]
|
||||
except Exception as e:
|
||||
return repr(e)
|
||||
|
@ -24,7 +24,17 @@ class TavilySearchAPIWrapper(BaseModel):
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
def _tavily_search_results(
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and endpoint exists in environment."""
|
||||
tavily_api_key = get_from_dict_or_env(
|
||||
values, "tavily_api_key", "TAVILY_API_KEY"
|
||||
)
|
||||
values["tavily_api_key"] = tavily_api_key
|
||||
|
||||
return values
|
||||
|
||||
def raw_results(
|
||||
self,
|
||||
query: str,
|
||||
max_results: Optional[int] = 5,
|
||||
@ -34,7 +44,7 @@ class TavilySearchAPIWrapper(BaseModel):
|
||||
include_answer: Optional[bool] = False,
|
||||
include_raw_content: Optional[bool] = False,
|
||||
include_images: Optional[bool] = False,
|
||||
) -> List[dict]:
|
||||
) -> Dict:
|
||||
params = {
|
||||
"api_key": self.tavily_api_key,
|
||||
"query": query,
|
||||
@ -51,20 +61,8 @@ class TavilySearchAPIWrapper(BaseModel):
|
||||
f"{TAVILY_API_URL}/search",
|
||||
json=params,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
search_results = response.json()
|
||||
return self.clean_results(search_results["results"])
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and endpoint exists in environment."""
|
||||
tavily_api_key = get_from_dict_or_env(
|
||||
values, "tavily_api_key", "TAVILY_API_KEY"
|
||||
)
|
||||
values["tavily_api_key"] = tavily_api_key
|
||||
|
||||
return values
|
||||
return response.json()
|
||||
|
||||
def results(
|
||||
self,
|
||||
@ -88,7 +86,6 @@ class TavilySearchAPIWrapper(BaseModel):
|
||||
include_answer: Whether to include the answer in the results.
|
||||
include_raw_content: Whether to include the raw content in the results.
|
||||
include_images: Whether to include images in the results.
|
||||
|
||||
Returns:
|
||||
query: The query that was searched for.
|
||||
follow_up_questions: A list of follow up questions.
|
||||
@ -101,22 +98,20 @@ class TavilySearchAPIWrapper(BaseModel):
|
||||
content: The content of the result.
|
||||
score: The score of the result.
|
||||
raw_content: The raw content of the result.
|
||||
|
||||
|
||||
""" # noqa: E501
|
||||
raw_search_results = self._tavily_search_results(
|
||||
raw_search_results = self.raw_results(
|
||||
query,
|
||||
max_results,
|
||||
search_depth,
|
||||
include_domains,
|
||||
exclude_domains,
|
||||
include_answer,
|
||||
include_raw_content,
|
||||
include_images,
|
||||
max_results=max_results,
|
||||
search_depth=search_depth,
|
||||
include_domains=include_domains,
|
||||
exclude_domains=exclude_domains,
|
||||
include_answer=include_answer,
|
||||
include_raw_content=include_raw_content,
|
||||
include_images=include_images,
|
||||
)
|
||||
return raw_search_results
|
||||
return self.clean_results(raw_search_results["results"])
|
||||
|
||||
async def results_async(
|
||||
async def raw_results_async(
|
||||
self,
|
||||
query: str,
|
||||
max_results: Optional[int] = 5,
|
||||
@ -126,7 +121,7 @@ class TavilySearchAPIWrapper(BaseModel):
|
||||
include_answer: Optional[bool] = False,
|
||||
include_raw_content: Optional[bool] = False,
|
||||
include_images: Optional[bool] = False,
|
||||
) -> List[Dict]:
|
||||
) -> Dict:
|
||||
"""Get results from the Tavily Search API asynchronously."""
|
||||
|
||||
# Function to perform the API call
|
||||
@ -151,7 +146,29 @@ class TavilySearchAPIWrapper(BaseModel):
|
||||
raise Exception(f"Error {res.status}: {res.reason}")
|
||||
|
||||
results_json_str = await fetch()
|
||||
results_json = json.loads(results_json_str)
|
||||
return json.loads(results_json_str)
|
||||
|
||||
async def results_async(
|
||||
self,
|
||||
query: str,
|
||||
max_results: Optional[int] = 5,
|
||||
search_depth: Optional[str] = "advanced",
|
||||
include_domains: Optional[List[str]] = [],
|
||||
exclude_domains: Optional[List[str]] = [],
|
||||
include_answer: Optional[bool] = False,
|
||||
include_raw_content: Optional[bool] = False,
|
||||
include_images: Optional[bool] = False,
|
||||
) -> List[Dict]:
|
||||
results_json = await self.raw_results_async(
|
||||
query=query,
|
||||
max_results=max_results,
|
||||
search_depth=search_depth,
|
||||
include_domains=include_domains,
|
||||
exclude_domains=exclude_domains,
|
||||
include_answer=include_answer,
|
||||
include_raw_content=include_raw_content,
|
||||
include_images=include_images,
|
||||
)
|
||||
return self.clean_results(results_json["results"])
|
||||
|
||||
def clean_results(self, results: List[Dict]) -> List[Dict]:
|
||||
|
Loading…
Reference in New Issue
Block a user