mirror of
				https://github.com/hwchase17/langchain.git
				synced 2025-10-31 16:08:59 +00:00 
			
		
		
		
	Co-authored-by: Bagatur <baskaryan@gmail.com> Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
		
			
				
	
	
		
			65 lines
		
	
	
		
			2.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			65 lines
		
	
	
		
			2.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """Tool for the Reddit search API."""
 | |
| 
 | |
| from typing import Optional, Type
 | |
| 
 | |
| from langchain_core.callbacks import CallbackManagerForToolRun
 | |
| from langchain_core.tools import BaseTool
 | |
| from pydantic import BaseModel, Field
 | |
| 
 | |
| from langchain_community.utilities.reddit_search import RedditSearchAPIWrapper
 | |
| 
 | |
| 
 | |
| class RedditSearchSchema(BaseModel):
 | |
|     """Input for Reddit search."""
 | |
| 
 | |
|     query: str = Field(
 | |
|         description="should be query string that post title should \
 | |
|         contain, or '*' if anything is allowed."
 | |
|     )
 | |
|     sort: str = Field(
 | |
|         description='should be sort method, which is one of: "relevance" \
 | |
|         , "hot", "top", "new", or "comments".'
 | |
|     )
 | |
|     time_filter: str = Field(
 | |
|         description='should be time period to filter by, which is \
 | |
|         one of "all", "day", "hour", "month", "week", or "year"'
 | |
|     )
 | |
|     subreddit: str = Field(
 | |
|         description='should be name of subreddit, like "all" for \
 | |
|         r/all'
 | |
|     )
 | |
|     limit: str = Field(
 | |
|         description="a positive integer indicating the maximum number \
 | |
|         of results to return"
 | |
|     )
 | |
| 
 | |
| 
 | |
| class RedditSearchRun(BaseTool):  # type: ignore[override, override]
 | |
|     """Tool that queries for posts on a subreddit."""
 | |
| 
 | |
|     name: str = "reddit_search"
 | |
|     description: str = (
 | |
|         "A tool that searches for posts on Reddit."
 | |
|         "Useful when you need to know post information on a subreddit."
 | |
|     )
 | |
|     api_wrapper: RedditSearchAPIWrapper = Field(default_factory=RedditSearchAPIWrapper)  # type: ignore[arg-type]
 | |
|     args_schema: Type[BaseModel] = RedditSearchSchema
 | |
| 
 | |
|     def _run(
 | |
|         self,
 | |
|         query: str,
 | |
|         sort: str,
 | |
|         time_filter: str,
 | |
|         subreddit: str,
 | |
|         limit: str,
 | |
|         run_manager: Optional[CallbackManagerForToolRun] = None,
 | |
|     ) -> str:
 | |
|         """Use the tool."""
 | |
|         return self.api_wrapper.run(
 | |
|             query=query,
 | |
|             sort=sort,
 | |
|             time_filter=time_filter,
 | |
|             subreddit=subreddit,
 | |
|             limit=int(limit),
 | |
|         )
 |