mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 06:39:52 +00:00
update tools (#13243)
This commit is contained in:
parent
8d6faf5665
commit
7f1d26160d
@ -18,6 +18,8 @@ class YouRetriever(BaseRetriever):
|
|||||||
|
|
||||||
ydc_api_key: str
|
ydc_api_key: str
|
||||||
k: Optional[int] = None
|
k: Optional[int] = None
|
||||||
|
n_hits: Optional[int] = None
|
||||||
|
n_snippets_per_hit: Optional[int] = None
|
||||||
endpoint_type: str = "web"
|
endpoint_type: str = "web"
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
@ -43,8 +45,10 @@ class YouRetriever(BaseRetriever):
|
|||||||
).json()
|
).json()
|
||||||
|
|
||||||
docs = []
|
docs = []
|
||||||
for hit in results["hits"]:
|
n_hits = self.n_hits or len(results["hits"])
|
||||||
for snippet in hit["snippets"]:
|
for hit in results["hits"][:n_hits]:
|
||||||
|
n_snippets_per_hit = self.n_snippets_per_hit or len(hit["snippets"])
|
||||||
|
for snippet in hit["snippets"][:n_snippets_per_hit]:
|
||||||
docs.append(Document(page_content=snippet))
|
docs.append(Document(page_content=snippet))
|
||||||
if self.k is not None and len(docs) >= self.k:
|
if self.k is not None and len(docs) >= self.k:
|
||||||
return docs
|
return docs
|
||||||
|
@ -1,13 +1,17 @@
|
|||||||
"""Tool for the Arxiv API."""
|
"""Tool for the Arxiv API."""
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional, Type
|
||||||
|
|
||||||
from langchain.callbacks.manager import CallbackManagerForToolRun
|
from langchain.callbacks.manager import CallbackManagerForToolRun
|
||||||
from langchain.pydantic_v1 import Field
|
from langchain.pydantic_v1 import BaseModel, Field
|
||||||
from langchain.tools.base import BaseTool
|
from langchain.tools.base import BaseTool
|
||||||
from langchain.utilities.arxiv import ArxivAPIWrapper
|
from langchain.utilities.arxiv import ArxivAPIWrapper
|
||||||
|
|
||||||
|
|
||||||
|
class ArxivInput(BaseModel):
|
||||||
|
query: str = Field(description="search query to look up")
|
||||||
|
|
||||||
|
|
||||||
class ArxivQueryRun(BaseTool):
|
class ArxivQueryRun(BaseTool):
|
||||||
"""Tool that searches the Arxiv API."""
|
"""Tool that searches the Arxiv API."""
|
||||||
|
|
||||||
@ -21,6 +25,7 @@ class ArxivQueryRun(BaseTool):
|
|||||||
"Input should be a search query."
|
"Input should be a search query."
|
||||||
)
|
)
|
||||||
api_wrapper: ArxivAPIWrapper = Field(default_factory=ArxivAPIWrapper)
|
api_wrapper: ArxivAPIWrapper = Field(default_factory=ArxivAPIWrapper)
|
||||||
|
args_schema: Type[BaseModel] = ArxivInput
|
||||||
|
|
||||||
def _run(
|
def _run(
|
||||||
self,
|
self,
|
||||||
|
@ -1,7 +1,12 @@
|
|||||||
|
from langchain.pydantic_v1 import BaseModel, Field
|
||||||
from langchain.schema import BaseRetriever
|
from langchain.schema import BaseRetriever
|
||||||
from langchain.tools import Tool
|
from langchain.tools import Tool
|
||||||
|
|
||||||
|
|
||||||
|
class RetrieverInput(BaseModel):
|
||||||
|
query: str = Field(description="query to look up in retriever")
|
||||||
|
|
||||||
|
|
||||||
def create_retriever_tool(
|
def create_retriever_tool(
|
||||||
retriever: BaseRetriever, name: str, description: str
|
retriever: BaseRetriever, name: str, description: str
|
||||||
) -> Tool:
|
) -> Tool:
|
||||||
@ -22,4 +27,5 @@ def create_retriever_tool(
|
|||||||
description=description,
|
description=description,
|
||||||
func=retriever.get_relevant_documents,
|
func=retriever.get_relevant_documents,
|
||||||
coroutine=retriever.aget_relevant_documents,
|
coroutine=retriever.aget_relevant_documents,
|
||||||
|
args_schema=RetrieverInput,
|
||||||
)
|
)
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
"""Tavily Search API toolkit."""
|
"""Tavily Search API toolkit."""
|
||||||
|
|
||||||
from langchain.tools.tavily_search.tool import TavilySearchResults
|
from langchain.tools.tavily_search.tool import TavilyAnswer, TavilySearchResults
|
||||||
|
|
||||||
__all__ = ["TavilySearchResults"]
|
__all__ = ["TavilySearchResults", "TavilyAnswer"]
|
||||||
|
@ -1,26 +1,32 @@
|
|||||||
"""Tool for the Tavily search API."""
|
"""Tool for the Tavily search API."""
|
||||||
|
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Type, Union
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManagerForToolRun,
|
AsyncCallbackManagerForToolRun,
|
||||||
CallbackManagerForToolRun,
|
CallbackManagerForToolRun,
|
||||||
)
|
)
|
||||||
|
from langchain.pydantic_v1 import BaseModel, Field
|
||||||
from langchain.tools.base import BaseTool
|
from langchain.tools.base import BaseTool
|
||||||
from langchain.utilities.tavily_search import TavilySearchAPIWrapper
|
from langchain.utilities.tavily_search import TavilySearchAPIWrapper
|
||||||
|
|
||||||
|
|
||||||
|
class TavilyInput(BaseModel):
|
||||||
|
query: str = Field(description="search query to look up")
|
||||||
|
|
||||||
|
|
||||||
class TavilySearchResults(BaseTool):
|
class TavilySearchResults(BaseTool):
|
||||||
"""Tool that queries the Tavily Search API and gets back json."""
|
"""Tool that queries the Tavily Search API and gets back json."""
|
||||||
|
|
||||||
name: str = "tavily_search_results_json"
|
name: str = "tavily_search_results_json"
|
||||||
description: str = """"
|
description: str = (
|
||||||
"A search engine optimized for comprehensive, accurate, and trusted results. "
|
"A search engine optimized for comprehensive, accurate, and trusted results. "
|
||||||
"Useful for when you need to answer questions about current events. "
|
"Useful for when you need to answer questions about current events. "
|
||||||
"Input should be a search query."
|
"Input should be a search query."
|
||||||
"""
|
)
|
||||||
api_wrapper: TavilySearchAPIWrapper
|
api_wrapper: TavilySearchAPIWrapper
|
||||||
max_results: int = 5
|
max_results: int = 5
|
||||||
|
args_schema: Type[BaseModel] = TavilyInput
|
||||||
|
|
||||||
def _run(
|
def _run(
|
||||||
self,
|
self,
|
||||||
@ -49,3 +55,50 @@ class TavilySearchResults(BaseTool):
|
|||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return repr(e)
|
return repr(e)
|
||||||
|
|
||||||
|
|
||||||
|
class TavilyAnswer(BaseTool):
|
||||||
|
"""Tool that queries the Tavily Search API and gets back an answer."""
|
||||||
|
|
||||||
|
name: str = "tavily_answer"
|
||||||
|
description: str = (
|
||||||
|
"A search engine optimized for comprehensive, accurate, and trusted results. "
|
||||||
|
"Useful for when you need to answer questions about current events. "
|
||||||
|
"Input should be a search query. "
|
||||||
|
"This returns only the answer - not the original source data."
|
||||||
|
)
|
||||||
|
api_wrapper: TavilySearchAPIWrapper
|
||||||
|
args_schema: Type[BaseModel] = TavilyInput
|
||||||
|
|
||||||
|
def _run(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||||
|
) -> Union[List[Dict], str]:
|
||||||
|
"""Use the tool."""
|
||||||
|
try:
|
||||||
|
return self.api_wrapper.raw_results(
|
||||||
|
query,
|
||||||
|
max_results=5,
|
||||||
|
include_answer=True,
|
||||||
|
search_depth="basic",
|
||||||
|
)["answer"]
|
||||||
|
except Exception as e:
|
||||||
|
return repr(e)
|
||||||
|
|
||||||
|
async def _arun(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||||
|
) -> Union[List[Dict], str]:
|
||||||
|
"""Use the tool asynchronously."""
|
||||||
|
try:
|
||||||
|
result = await self.api_wrapper.raw_results_async(
|
||||||
|
query,
|
||||||
|
max_results=5,
|
||||||
|
include_answer=True,
|
||||||
|
search_depth="basic",
|
||||||
|
)
|
||||||
|
return result["answer"]
|
||||||
|
except Exception as e:
|
||||||
|
return repr(e)
|
||||||
|
@ -24,7 +24,17 @@ class TavilySearchAPIWrapper(BaseModel):
|
|||||||
|
|
||||||
extra = Extra.forbid
|
extra = Extra.forbid
|
||||||
|
|
||||||
def _tavily_search_results(
|
@root_validator(pre=True)
|
||||||
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
|
"""Validate that api key and endpoint exists in environment."""
|
||||||
|
tavily_api_key = get_from_dict_or_env(
|
||||||
|
values, "tavily_api_key", "TAVILY_API_KEY"
|
||||||
|
)
|
||||||
|
values["tavily_api_key"] = tavily_api_key
|
||||||
|
|
||||||
|
return values
|
||||||
|
|
||||||
|
def raw_results(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
max_results: Optional[int] = 5,
|
max_results: Optional[int] = 5,
|
||||||
@ -34,7 +44,7 @@ class TavilySearchAPIWrapper(BaseModel):
|
|||||||
include_answer: Optional[bool] = False,
|
include_answer: Optional[bool] = False,
|
||||||
include_raw_content: Optional[bool] = False,
|
include_raw_content: Optional[bool] = False,
|
||||||
include_images: Optional[bool] = False,
|
include_images: Optional[bool] = False,
|
||||||
) -> List[dict]:
|
) -> Dict:
|
||||||
params = {
|
params = {
|
||||||
"api_key": self.tavily_api_key,
|
"api_key": self.tavily_api_key,
|
||||||
"query": query,
|
"query": query,
|
||||||
@ -51,20 +61,8 @@ class TavilySearchAPIWrapper(BaseModel):
|
|||||||
f"{TAVILY_API_URL}/search",
|
f"{TAVILY_API_URL}/search",
|
||||||
json=params,
|
json=params,
|
||||||
)
|
)
|
||||||
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
search_results = response.json()
|
return response.json()
|
||||||
return self.clean_results(search_results["results"])
|
|
||||||
|
|
||||||
@root_validator(pre=True)
|
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
|
||||||
"""Validate that api key and endpoint exists in environment."""
|
|
||||||
tavily_api_key = get_from_dict_or_env(
|
|
||||||
values, "tavily_api_key", "TAVILY_API_KEY"
|
|
||||||
)
|
|
||||||
values["tavily_api_key"] = tavily_api_key
|
|
||||||
|
|
||||||
return values
|
|
||||||
|
|
||||||
def results(
|
def results(
|
||||||
self,
|
self,
|
||||||
@ -88,7 +86,6 @@ class TavilySearchAPIWrapper(BaseModel):
|
|||||||
include_answer: Whether to include the answer in the results.
|
include_answer: Whether to include the answer in the results.
|
||||||
include_raw_content: Whether to include the raw content in the results.
|
include_raw_content: Whether to include the raw content in the results.
|
||||||
include_images: Whether to include images in the results.
|
include_images: Whether to include images in the results.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
query: The query that was searched for.
|
query: The query that was searched for.
|
||||||
follow_up_questions: A list of follow up questions.
|
follow_up_questions: A list of follow up questions.
|
||||||
@ -101,22 +98,20 @@ class TavilySearchAPIWrapper(BaseModel):
|
|||||||
content: The content of the result.
|
content: The content of the result.
|
||||||
score: The score of the result.
|
score: The score of the result.
|
||||||
raw_content: The raw content of the result.
|
raw_content: The raw content of the result.
|
||||||
|
|
||||||
|
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
raw_search_results = self._tavily_search_results(
|
raw_search_results = self.raw_results(
|
||||||
query,
|
query,
|
||||||
max_results,
|
max_results=max_results,
|
||||||
search_depth,
|
search_depth=search_depth,
|
||||||
include_domains,
|
include_domains=include_domains,
|
||||||
exclude_domains,
|
exclude_domains=exclude_domains,
|
||||||
include_answer,
|
include_answer=include_answer,
|
||||||
include_raw_content,
|
include_raw_content=include_raw_content,
|
||||||
include_images,
|
include_images=include_images,
|
||||||
)
|
)
|
||||||
return raw_search_results
|
return self.clean_results(raw_search_results["results"])
|
||||||
|
|
||||||
async def results_async(
|
async def raw_results_async(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
max_results: Optional[int] = 5,
|
max_results: Optional[int] = 5,
|
||||||
@ -126,7 +121,7 @@ class TavilySearchAPIWrapper(BaseModel):
|
|||||||
include_answer: Optional[bool] = False,
|
include_answer: Optional[bool] = False,
|
||||||
include_raw_content: Optional[bool] = False,
|
include_raw_content: Optional[bool] = False,
|
||||||
include_images: Optional[bool] = False,
|
include_images: Optional[bool] = False,
|
||||||
) -> List[Dict]:
|
) -> Dict:
|
||||||
"""Get results from the Tavily Search API asynchronously."""
|
"""Get results from the Tavily Search API asynchronously."""
|
||||||
|
|
||||||
# Function to perform the API call
|
# Function to perform the API call
|
||||||
@ -151,7 +146,29 @@ class TavilySearchAPIWrapper(BaseModel):
|
|||||||
raise Exception(f"Error {res.status}: {res.reason}")
|
raise Exception(f"Error {res.status}: {res.reason}")
|
||||||
|
|
||||||
results_json_str = await fetch()
|
results_json_str = await fetch()
|
||||||
results_json = json.loads(results_json_str)
|
return json.loads(results_json_str)
|
||||||
|
|
||||||
|
async def results_async(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
max_results: Optional[int] = 5,
|
||||||
|
search_depth: Optional[str] = "advanced",
|
||||||
|
include_domains: Optional[List[str]] = [],
|
||||||
|
exclude_domains: Optional[List[str]] = [],
|
||||||
|
include_answer: Optional[bool] = False,
|
||||||
|
include_raw_content: Optional[bool] = False,
|
||||||
|
include_images: Optional[bool] = False,
|
||||||
|
) -> List[Dict]:
|
||||||
|
results_json = await self.raw_results_async(
|
||||||
|
query=query,
|
||||||
|
max_results=max_results,
|
||||||
|
search_depth=search_depth,
|
||||||
|
include_domains=include_domains,
|
||||||
|
exclude_domains=exclude_domains,
|
||||||
|
include_answer=include_answer,
|
||||||
|
include_raw_content=include_raw_content,
|
||||||
|
include_images=include_images,
|
||||||
|
)
|
||||||
return self.clean_results(results_json["results"])
|
return self.clean_results(results_json["results"])
|
||||||
|
|
||||||
def clean_results(self, results: List[Dict]) -> List[Dict]:
|
def clean_results(self, results: List[Dict]) -> List[Dict]:
|
||||||
|
Loading…
Reference in New Issue
Block a user