mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 07:35:18 +00:00
Add multi-input Reddit search tool (#13893)
- **Description:** Added a tool called RedditSearchRun and an accompanying API wrapper, which searches Reddit for posts with support for time filtering, post sorting, query string and subreddit filtering. - **Issue:** #13891 - **Dependencies:** `praw` module is used to search Reddit - **Tag maintainer:** @baskaryan , and any of the other maintainers if needed - **Twitter handle:** None. Hello, This is our first PR and we hope that our changes will be helpful to the community. We have run `make format`, `make lint` and `make test` locally before submitting the PR. To our knowledge, our changes do not introduce any new errors. Our PR integrates the `praw` package which is already used by RedditPostsLoader in LangChain. Nonetheless, we have added integration tests and edited unit tests to test our changes. An example notebook is also provided. These changes were put together by me, @Anika2000, @CharlesXu123, and @Jeremy-Cheng-stack Thank you in advance to the maintainers for their time. --------- Co-authored-by: What-Is-A-Username <49571870+What-Is-A-Username@users.noreply.github.com> Co-authored-by: Anika2000 <anika.sultana@mail.utoronto.ca> Co-authored-by: Jeremy Cheng <81793294+Jeremy-Cheng-stack@users.noreply.github.com> Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
00a6e8962c
commit
a00db4b28f
262
docs/docs/integrations/tools/reddit_search.ipynb
Normal file
262
docs/docs/integrations/tools/reddit_search.ipynb
Normal file
@ -0,0 +1,262 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Reddit Search \n",
|
||||
"\n",
|
||||
"In this notebook, we learn how the Reddit search tool works. \n",
|
||||
"First make sure that you have installed praw with the command below: "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"vscode": {
|
||||
"languageId": "shellscript"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install praw"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Then you need to set you need to set up the proper API keys and environment variables. You would need to create a Reddit user account and get credentials. So, create a Reddit user account by going to https://www.reddit.com and signing up. \n",
|
||||
"Then get your credentials by going to https://www.reddit.com/prefs/apps and creating an app. \n",
|
||||
"You should have your client_id and secret from creating the app. Now, you can paste those strings in client_id and client_secret variable. \n",
|
||||
"Note: You can put any string for user_agent "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"client_id = \"\"\n",
|
||||
"client_secret = \"\"\n",
|
||||
"user_agent = \"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.tools.reddit_search.tool import RedditSearchRun\n",
|
||||
"from langchain.utilities.reddit_search import RedditSearchAPIWrapper\n",
|
||||
"\n",
|
||||
"search = RedditSearchRun(\n",
|
||||
" api_wrapper=RedditSearchAPIWrapper(\n",
|
||||
" reddit_client_id=client_id,\n",
|
||||
" reddit_client_secret=client_secret,\n",
|
||||
" reddit_user_agent=user_agent,\n",
|
||||
" )\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"You can then set your queries for example, what subreddit you want to query, how many posts you want to be returned, how you would like the result to be sorted etc."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.tools.reddit_search.tool import RedditSearchSchema\n",
|
||||
"\n",
|
||||
"search_params = RedditSearchSchema(\n",
|
||||
" query=\"beginner\", sort=\"new\", time_filter=\"week\", subreddit=\"python\", limit=\"2\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Finally run the search and get your results"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"result = search.run(tool_input=search_params.dict())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(result)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Here is an example of printing the result. \n",
|
||||
"Note: You may get different output depending on the newest post in the subreddit but the formatting should be similar."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"\n",
|
||||
"> Searching r/python found 2 posts:\n",
|
||||
"> Post Title: 'Setup Github Copilot in Visual Studio Code'\n",
|
||||
"> User: Feisty-Recording-715\n",
|
||||
"> Subreddit: r/Python:\n",
|
||||
"> Text body: 🛠️ This tutorial is perfect for beginners looking to strengthen their understanding of version control or for experienced developers seeking a quick reference for GitHub setup in Visual Studio Code.\n",
|
||||
">\n",
|
||||
">🎓 By the end of this video, you'll be equipped with the skills to confidently manage your codebase, collaborate with others, and contribute to open-source projects on GitHub.\n",
|
||||
">\n",
|
||||
">\n",
|
||||
">Video link: https://youtu.be/IdT1BhrSfdo?si=mV7xVpiyuhlD8Zrw\n",
|
||||
">\n",
|
||||
">Your feedback is welcome\n",
|
||||
"> Post URL: https://www.reddit.com/r/Python/comments/1823wr7/setup_github_copilot_in_visual_studio_code/\n",
|
||||
"> Post Category: N/A.\n",
|
||||
"> Score: 0\n",
|
||||
">\n",
|
||||
">Post Title: 'A Chinese Checkers game made with pygame and PySide6, with custom bots support'\n",
|
||||
">User: HenryChess\n",
|
||||
">Subreddit: r/Python:\n",
|
||||
"> Text body: GitHub link: https://github.com/henrychess/pygame-chinese-checkers\n",
|
||||
">\n",
|
||||
">I'm not sure if this counts as beginner or intermediate. I think I'm still in the beginner zone, so I flair it as beginner.\n",
|
||||
">\n",
|
||||
">This is a Chinese Checkers (aka Sternhalma) game for 2 to 3 players. The bots I wrote are easy to beat, as they're mainly for debugging the game logic part of the code. However, you can write up your own custom bots. There is a guide at the github page.\n",
|
||||
"> Post URL: https://www.reddit.com/r/Python/comments/181xq0u/a_chinese_checkers_game_made_with_pygame_and/\n",
|
||||
"> Post Category: N/A.\n",
|
||||
" > Score: 1\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Using tool with an agent chain\n",
|
||||
"\n",
|
||||
"Reddit search functionality is also provided as a multi-input tool. In this example, we adapt [existing code from the docs](https://python.langchain.com/docs/modules/agents/how_to/sharedmemory_for_tools), and use ChatOpenAI to create an agent chain with memory. This agent chain is able to pull information from Reddit and use these posts to respond to subsequent input. \n",
|
||||
"\n",
|
||||
"To run the example, add your reddit API access information and also get an OpenAI key from the [OpenAI API](https://help.openai.com/en/articles/4936850-where-do-i-find-my-api-key)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Adapted code from https://python.langchain.com/docs/modules/agents/how_to/sharedmemory_for_tools\n",
|
||||
"\n",
|
||||
"from langchain.agents import AgentExecutor, StructuredChatAgent, Tool\n",
|
||||
"from langchain.chains import LLMChain\n",
|
||||
"from langchain.chat_models import ChatOpenAI\n",
|
||||
"from langchain.memory import ConversationBufferMemory, ReadOnlySharedMemory\n",
|
||||
"from langchain.prompts import PromptTemplate\n",
|
||||
"from langchain.tools.reddit_search.tool import RedditSearchRun\n",
|
||||
"from langchain.utilities.reddit_search import RedditSearchAPIWrapper\n",
|
||||
"\n",
|
||||
"# Provide keys for Reddit\n",
|
||||
"client_id = \"\"\n",
|
||||
"client_secret = \"\"\n",
|
||||
"user_agent = \"\"\n",
|
||||
"# Provide key for OpenAI\n",
|
||||
"openai_api_key = \"\"\n",
|
||||
"\n",
|
||||
"template = \"\"\"This is a conversation between a human and a bot:\n",
|
||||
"\n",
|
||||
"{chat_history}\n",
|
||||
"\n",
|
||||
"Write a summary of the conversation for {input}:\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"prompt = PromptTemplate(input_variables=[\"input\", \"chat_history\"], template=template)\n",
|
||||
"memory = ConversationBufferMemory(memory_key=\"chat_history\")\n",
|
||||
"\n",
|
||||
"prefix = \"\"\"Have a conversation with a human, answering the following questions as best you can. You have access to the following tools:\"\"\"\n",
|
||||
"suffix = \"\"\"Begin!\"\n",
|
||||
"\n",
|
||||
"{chat_history}\n",
|
||||
"Question: {input}\n",
|
||||
"{agent_scratchpad}\"\"\"\n",
|
||||
"\n",
|
||||
"tools = [\n",
|
||||
" RedditSearchRun(\n",
|
||||
" api_wrapper=RedditSearchAPIWrapper(\n",
|
||||
" reddit_client_id=client_id,\n",
|
||||
" reddit_client_secret=client_secret,\n",
|
||||
" reddit_user_agent=user_agent,\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"prompt = StructuredChatAgent.create_prompt(\n",
|
||||
" prefix=prefix,\n",
|
||||
" tools=tools,\n",
|
||||
" suffix=suffix,\n",
|
||||
" input_variables=[\"input\", \"chat_history\", \"agent_scratchpad\"],\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"llm = ChatOpenAI(temperature=0, openai_api_key=openai_api_key)\n",
|
||||
"\n",
|
||||
"llm_chain = LLMChain(llm=llm, prompt=prompt)\n",
|
||||
"agent = StructuredChatAgent(llm_chain=llm_chain, verbose=True, tools=tools)\n",
|
||||
"agent_chain = AgentExecutor.from_agent_and_tools(\n",
|
||||
" agent=agent, verbose=True, memory=memory, tools=tools\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Answering the first prompt requires usage of the Reddit search tool.\n",
|
||||
"agent_chain.run(input=\"What is the newest post on r/langchain for the week?\")\n",
|
||||
"# Answering the subsequent prompt uses memory.\n",
|
||||
"agent_chain.run(input=\"Who is the author of the post?\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3.11.5 64-bit ('langchaindev')",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.5"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "3929050b09828356c9f5ebaf862d05c053d8228eddbc70f990c168e54dd824ba"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -64,6 +64,7 @@ from langchain.tools.openweathermap.tool import OpenWeatherMapQueryRun
|
||||
from langchain.tools.dataforseo_api_search import DataForSeoAPISearchRun
|
||||
from langchain.tools.dataforseo_api_search import DataForSeoAPISearchResults
|
||||
from langchain.tools.memorize.tool import Memorize
|
||||
from langchain.tools.reddit_search.tool import RedditSearchRun
|
||||
from langchain.utilities.arxiv import ArxivAPIWrapper
|
||||
from langchain.utilities.golden_query import GoldenQueryAPIWrapper
|
||||
from langchain.utilities.pubmed import PubMedAPIWrapper
|
||||
@ -88,6 +89,7 @@ from langchain.utilities.wikipedia import WikipediaAPIWrapper
|
||||
from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper
|
||||
from langchain.utilities.openweathermap import OpenWeatherMapAPIWrapper
|
||||
from langchain.utilities.dataforseo_api_search import DataForSeoAPIWrapper
|
||||
from langchain.utilities.reddit_search import RedditSearchAPIWrapper
|
||||
|
||||
|
||||
def _get_python_repl() -> BaseTool:
|
||||
@ -374,6 +376,10 @@ def _get_google_cloud_texttospeech(**kwargs: Any) -> BaseTool:
|
||||
return GoogleCloudTextToSpeechTool(**kwargs)
|
||||
|
||||
|
||||
def _get_reddit_search(**kwargs: Any) -> BaseTool:
|
||||
return RedditSearchRun(api_wrapper=RedditSearchAPIWrapper(**kwargs))
|
||||
|
||||
|
||||
_EXTRA_LLM_TOOLS: Dict[
|
||||
str,
|
||||
Tuple[Callable[[Arg(BaseLanguageModel, "llm"), KwArg(Any)], BaseTool], List[str]],
|
||||
@ -454,6 +460,10 @@ _EXTRA_OPTIONAL_TOOLS: Dict[str, Tuple[Callable[[KwArg(Any)], BaseTool], List[st
|
||||
),
|
||||
"eleven_labs_text2speech": (_get_eleven_labs_text2speech, ["eleven_api_key"]),
|
||||
"google_cloud_texttospeech": (_get_google_cloud_texttospeech, []),
|
||||
"reddit_search": (
|
||||
_get_reddit_search,
|
||||
["reddit_client_id", "reddit_client_secret", "reddit_user_agent"],
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
|
@ -480,6 +480,12 @@ def _import_python_tool_PythonREPLTool() -> Any:
|
||||
)
|
||||
|
||||
|
||||
def _import_reddit_search_RedditSearchRun() -> Any:
|
||||
from langchain.tools.reddit_search.tool import RedditSearchRun
|
||||
|
||||
return RedditSearchRun
|
||||
|
||||
|
||||
def _import_render() -> Any:
|
||||
from langchain.tools.render import format_tool_to_openai_function
|
||||
|
||||
@ -833,6 +839,8 @@ def __getattr__(name: str) -> Any:
|
||||
return _import_python_tool_PythonAstREPLTool()
|
||||
elif name == "PythonREPLTool":
|
||||
return _import_python_tool_PythonREPLTool()
|
||||
elif name == "RedditSearchRun":
|
||||
return _import_reddit_search_RedditSearchRun()
|
||||
elif name == "format_tool_to_openai_function":
|
||||
return _import_render()
|
||||
elif name == "BaseRequestsTool":
|
||||
@ -983,6 +991,7 @@ __all__ = [
|
||||
"OpenAPISpec",
|
||||
"OpenWeatherMapQueryRun",
|
||||
"PubmedQueryRun",
|
||||
"RedditSearchRun",
|
||||
"QueryCheckerTool",
|
||||
"QueryPowerBITool",
|
||||
"QuerySQLCheckerTool",
|
||||
|
63
libs/langchain/langchain/tools/reddit_search/tool.py
Normal file
63
libs/langchain/langchain/tools/reddit_search/tool.py
Normal file
@ -0,0 +1,63 @@
|
||||
"""Tool for the Reddit search API."""
|
||||
|
||||
from typing import Optional, Type
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForToolRun
|
||||
from langchain.pydantic_v1 import BaseModel, Field
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.utilities.reddit_search import RedditSearchAPIWrapper
|
||||
|
||||
|
||||
class RedditSearchSchema(BaseModel):
|
||||
"""Input for Reddit search."""
|
||||
|
||||
query: str = Field(
|
||||
description="should be query string that post title should \
|
||||
contain, or '*' if anything is allowed."
|
||||
)
|
||||
sort: str = Field(
|
||||
description='should be sort method, which is one of: "relevance" \
|
||||
, "hot", "top", "new", or "comments".'
|
||||
)
|
||||
time_filter: str = Field(
|
||||
description='should be time period to filter by, which is \
|
||||
one of "all", "day", "hour", "month", "week", or "year"'
|
||||
)
|
||||
subreddit: str = Field(
|
||||
description='should be name of subreddit, like "all" for \
|
||||
r/all'
|
||||
)
|
||||
limit: str = Field(
|
||||
description="a positive integer indicating the maximum number \
|
||||
of results to return"
|
||||
)
|
||||
|
||||
|
||||
class RedditSearchRun(BaseTool):
|
||||
"""Tool that queries for posts on a subreddit."""
|
||||
|
||||
name: str = "reddit_search"
|
||||
description: str = (
|
||||
"A tool that searches for posts on Reddit."
|
||||
"Useful when you need to know post information on a subreddit."
|
||||
)
|
||||
api_wrapper: RedditSearchAPIWrapper = Field(default_factory=RedditSearchAPIWrapper)
|
||||
args_schema: Type[BaseModel] = RedditSearchSchema
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
sort: str,
|
||||
time_filter: str,
|
||||
subreddit: str,
|
||||
limit: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Use the tool."""
|
||||
return self.api_wrapper.run(
|
||||
query=query,
|
||||
sort=sort,
|
||||
time_filter=time_filter,
|
||||
subreddit=subreddit,
|
||||
limit=int(limit),
|
||||
)
|
121
libs/langchain/langchain/utilities/reddit_search.py
Normal file
121
libs/langchain/langchain/utilities/reddit_search.py
Normal file
@ -0,0 +1,121 @@
|
||||
"""Wrapper for the Reddit API"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.pydantic_v1 import BaseModel, root_validator
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
class RedditSearchAPIWrapper(BaseModel):
|
||||
"""Wrapper for Reddit API
|
||||
|
||||
To use, set the environment variables ``REDDIT_CLIENT_ID``,
|
||||
``REDDIT_CLIENT_SECRET``, ``REDDIT_USER_AGENT`` to set the client ID,
|
||||
client secret, and user agent, respectively, as given by Reddit's API.
|
||||
Alternatively, all three can be supplied as named parameters in the
|
||||
constructor: ``reddit_client_id``, ``reddit_client_secret``, and
|
||||
``reddit_user_agent``, respectively.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.utilities import RedditSearchAPIWrapper
|
||||
reddit_search = RedditSearchAPIWrapper()
|
||||
"""
|
||||
|
||||
reddit_client: Any
|
||||
|
||||
# Values required to access Reddit API via praw
|
||||
reddit_client_id: Optional[str]
|
||||
reddit_client_secret: Optional[str]
|
||||
reddit_user_agent: Optional[str]
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that the API ID, secret and user agent exists in environment
|
||||
and check that praw module is present.
|
||||
"""
|
||||
reddit_client_id = get_from_dict_or_env(
|
||||
values, "reddit_client_id", "REDDIT_CLIENT_ID"
|
||||
)
|
||||
values["reddit_client_id"] = reddit_client_id
|
||||
|
||||
reddit_client_secret = get_from_dict_or_env(
|
||||
values, "reddit_client_secret", "REDDIT_CLIENT_SECRET"
|
||||
)
|
||||
values["reddit_client_secret"] = reddit_client_secret
|
||||
|
||||
reddit_user_agent = get_from_dict_or_env(
|
||||
values, "reddit_user_agent", "REDDIT_USER_AGENT"
|
||||
)
|
||||
values["reddit_user_agent"] = reddit_user_agent
|
||||
|
||||
try:
|
||||
import praw
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"praw package not found, please install it with pip install praw"
|
||||
)
|
||||
|
||||
reddit_client = praw.Reddit(
|
||||
client_id=reddit_client_id,
|
||||
client_secret=reddit_client_secret,
|
||||
user_agent=reddit_user_agent,
|
||||
)
|
||||
values["reddit_client"] = reddit_client
|
||||
|
||||
return values
|
||||
|
||||
def run(
|
||||
self, query: str, sort: str, time_filter: str, subreddit: str, limit: int
|
||||
) -> str:
|
||||
"""Search Reddit and return posts as a single string."""
|
||||
results: List[Dict] = self.results(
|
||||
query=query,
|
||||
sort=sort,
|
||||
time_filter=time_filter,
|
||||
subreddit=subreddit,
|
||||
limit=limit,
|
||||
)
|
||||
if len(results) > 0:
|
||||
output: List[str] = [f"Searching r/{subreddit} found {len(results)} posts:"]
|
||||
for r in results:
|
||||
category = "N/A" if r["post_category"] is None else r["post_category"]
|
||||
p = f"Post Title: '{r['post_title']}'\n\
|
||||
User: {r['post_author']}\n\
|
||||
Subreddit: {r['post_subreddit']}:\n\
|
||||
Text body: {r['post_text']}\n\
|
||||
Post URL: {r['post_url']}\n\
|
||||
Post Category: {category}.\n\
|
||||
Score: {r['post_score']}\n"
|
||||
output.append(p)
|
||||
return "\n".join(output)
|
||||
else:
|
||||
return f"Searching r/{subreddit} did not find any posts:"
|
||||
|
||||
def results(
|
||||
self, query: str, sort: str, time_filter: str, subreddit: str, limit: int
|
||||
) -> List[Dict]:
|
||||
"""Use praw to search Reddit and return a list of dictionaries,
|
||||
one for each post.
|
||||
"""
|
||||
subredditObject = self.reddit_client.subreddit(subreddit)
|
||||
search_results = subredditObject.search(
|
||||
query=query, sort=sort, time_filter=time_filter, limit=limit
|
||||
)
|
||||
search_results = [r for r in search_results]
|
||||
results_object = []
|
||||
for submission in search_results:
|
||||
results_object.append(
|
||||
{
|
||||
"post_subreddit": submission.subreddit_name_prefixed,
|
||||
"post_category": submission.category,
|
||||
"post_title": submission.title,
|
||||
"post_text": submission.selftext,
|
||||
"post_score": submission.score,
|
||||
"post_id": submission.id,
|
||||
"post_url": submission.url,
|
||||
"post_author": submission.author,
|
||||
}
|
||||
)
|
||||
return results_object
|
66
libs/langchain/poetry.lock
generated
66
libs/langchain/poetry.lock
generated
@ -6479,6 +6479,49 @@ files = [
|
||||
[package.extras]
|
||||
dill = ["dill (>=0.3.7)"]
|
||||
|
||||
[[package]]
|
||||
name = "praw"
|
||||
version = "7.7.1"
|
||||
description = "PRAW, an acronym for \"Python Reddit API Wrapper\", is a Python package that allows for simple access to Reddit's API."
|
||||
optional = true
|
||||
python-versions = "~=3.7"
|
||||
files = [
|
||||
{file = "praw-7.7.1-py3-none-any.whl", hash = "sha256:9ec5dc943db00c175bc6a53f4e089ce625f3fdfb27305564b616747b767d38ef"},
|
||||
{file = "praw-7.7.1.tar.gz", hash = "sha256:f1d7eef414cafe28080dda12ed09253a095a69933d5c8132eca11d4dc8a070bf"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
prawcore = ">=2.1,<3"
|
||||
update-checker = ">=0.18"
|
||||
websocket-client = ">=0.54.0"
|
||||
|
||||
[package.extras]
|
||||
ci = ["coveralls"]
|
||||
dev = ["betamax (>=0.8,<0.9)", "betamax-matchers (>=0.3.0,<0.5)", "furo", "packaging", "pre-commit", "pytest (>=2.7.3)", "requests (>=2.20.1,<3)", "sphinx", "urllib3 (==1.26.*)"]
|
||||
lint = ["furo", "pre-commit", "sphinx"]
|
||||
readthedocs = ["furo", "sphinx"]
|
||||
test = ["betamax (>=0.8,<0.9)", "betamax-matchers (>=0.3.0,<0.5)", "pytest (>=2.7.3)", "requests (>=2.20.1,<3)", "urllib3 (==1.26.*)"]
|
||||
|
||||
[[package]]
|
||||
name = "prawcore"
|
||||
version = "2.4.0"
|
||||
description = "\"Low-level communication layer for PRAW 4+."
|
||||
optional = true
|
||||
python-versions = "~=3.8"
|
||||
files = [
|
||||
{file = "prawcore-2.4.0-py3-none-any.whl", hash = "sha256:29af5da58d85704b439ad3c820873ad541f4535e00bb98c66f0fbcc8c603065a"},
|
||||
{file = "prawcore-2.4.0.tar.gz", hash = "sha256:b7b2b5a1d04406e086ab4e79988dc794df16059862f329f4c6a43ed09986c335"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
requests = ">=2.6.0,<3.0"
|
||||
|
||||
[package.extras]
|
||||
ci = ["coveralls"]
|
||||
dev = ["packaging", "prawcore[lint]", "prawcore[test]"]
|
||||
lint = ["pre-commit", "ruff (>=0.0.291)"]
|
||||
test = ["betamax (>=0.8,<0.9)", "pytest (>=2.7.3)", "urllib3 (==1.26.*)"]
|
||||
|
||||
[[package]]
|
||||
name = "prometheus-client"
|
||||
version = "0.17.1"
|
||||
@ -10407,6 +10450,25 @@ tzdata = {version = "*", markers = "platform_system == \"Windows\""}
|
||||
[package.extras]
|
||||
devenv = ["black", "check-manifest", "flake8", "pyroma", "pytest (>=4.3)", "pytest-cov", "pytest-mock (>=3.3)", "zest.releaser"]
|
||||
|
||||
[[package]]
|
||||
name = "update-checker"
|
||||
version = "0.18.0"
|
||||
description = "A python module that will check for package updates."
|
||||
optional = true
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "update_checker-0.18.0-py3-none-any.whl", hash = "sha256:cbba64760a36fe2640d80d85306e8fe82b6816659190993b7bdabadee4d4bbfd"},
|
||||
{file = "update_checker-0.18.0.tar.gz", hash = "sha256:6a2d45bb4ac585884a6b03f9eade9161cedd9e8111545141e9aa9058932acb13"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
requests = ">=2.3.0"
|
||||
|
||||
[package.extras]
|
||||
dev = ["black", "flake8", "pytest (>=2.7.3)"]
|
||||
lint = ["black", "flake8"]
|
||||
test = ["pytest (>=2.7.3)"]
|
||||
|
||||
[[package]]
|
||||
name = "upstash-redis"
|
||||
version = "0.15.0"
|
||||
@ -11201,7 +11263,7 @@ cli = ["typer"]
|
||||
cohere = ["cohere"]
|
||||
docarray = ["docarray"]
|
||||
embeddings = ["sentence-transformers"]
|
||||
extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cohere", "dashvector", "databricks-vectorsearch", "dgml-utils", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "geopandas", "gitpython", "google-cloud-documentai", "gql", "html2text", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "openai", "openai", "openapi-pydantic", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "upstash-redis", "xata", "xmltodict"]
|
||||
extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cohere", "dashvector", "databricks-vectorsearch", "dgml-utils", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "geopandas", "gitpython", "google-cloud-documentai", "gql", "html2text", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "openai", "openai", "openapi-pydantic", "pandas", "pdfminer-six", "pgvector", "praw", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "upstash-redis", "xata", "xmltodict"]
|
||||
javascript = ["esprima"]
|
||||
llms = ["clarifai", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "openlm", "torch", "transformers"]
|
||||
openai = ["openai", "tiktoken"]
|
||||
@ -11211,4 +11273,4 @@ text-helpers = ["chardet"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "c540c453e0b1221ed8c1bbb638430264dc9ec52df12d915357c961b31c69e132"
|
||||
content-hash = "473f3af6bf8c6e224f9522df2ff78f7ab965ea044c60efe60e7fe0461d2d2014"
|
||||
|
@ -143,6 +143,7 @@ azure-ai-textanalytics = {version = "^5.3.0", optional = true}
|
||||
google-cloud-documentai = {version = "^2.20.1", optional = true}
|
||||
fireworks-ai = {version = "^0.6.0", optional = true, python = ">=3.9,<4.0"}
|
||||
javelin-sdk = {version = "^0.1.8", optional = true}
|
||||
praw = {version = "^7.7.1", optional = true}
|
||||
msal = {version = "^1.25.0", optional = true}
|
||||
databricks-vectorsearch = {version = "^0.21", optional = true}
|
||||
dgml-utils = {version = "^0.3.0", optional = true}
|
||||
@ -383,6 +384,7 @@ extended_testing = [
|
||||
"rspace_client",
|
||||
"fireworks-ai",
|
||||
"javelin-sdk",
|
||||
"praw",
|
||||
"databricks-vectorsearch",
|
||||
"dgml-utils",
|
||||
"cohere",
|
||||
|
@ -0,0 +1,67 @@
|
||||
import pytest
|
||||
|
||||
from langchain.utilities.reddit_search import RedditSearchAPIWrapper
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_client() -> RedditSearchAPIWrapper:
|
||||
return RedditSearchAPIWrapper()
|
||||
|
||||
|
||||
def assert_results_exists(results: list) -> None:
|
||||
if len(results) > 0:
|
||||
for result in results:
|
||||
assert "post_title" in result
|
||||
assert "post_author" in result
|
||||
assert "post_subreddit" in result
|
||||
assert "post_text" in result
|
||||
assert "post_url" in result
|
||||
assert "post_score" in result
|
||||
assert "post_category" in result
|
||||
assert "post_id" in result
|
||||
else:
|
||||
assert results == []
|
||||
|
||||
|
||||
@pytest.mark.requires("praw")
|
||||
def test_run_empty_query(api_client: RedditSearchAPIWrapper) -> None:
|
||||
"""Test that run gives the correct answer with empty query."""
|
||||
search = api_client.run(
|
||||
query="", sort="relevance", time_filter="all", subreddit="all", limit=5
|
||||
)
|
||||
assert search == "Searching r/all did not find any posts:"
|
||||
|
||||
|
||||
@pytest.mark.requires("praw")
|
||||
def test_run_query(api_client: RedditSearchAPIWrapper) -> None:
|
||||
"""Test that run gives the correct answer."""
|
||||
search = api_client.run(
|
||||
query="university",
|
||||
sort="relevance",
|
||||
time_filter="all",
|
||||
subreddit="funny",
|
||||
limit=5,
|
||||
)
|
||||
assert "University" in search
|
||||
|
||||
|
||||
@pytest.mark.requires("praw")
|
||||
def test_results_exists(api_client: RedditSearchAPIWrapper) -> None:
|
||||
"""Test that results gives the correct output format."""
|
||||
search = api_client.results(
|
||||
query="What is the best programming language?",
|
||||
sort="relevance",
|
||||
time_filter="all",
|
||||
subreddit="all",
|
||||
limit=10,
|
||||
)
|
||||
assert_results_exists(search)
|
||||
|
||||
|
||||
@pytest.mark.requires("praw")
|
||||
def test_results_empty_query(api_client: RedditSearchAPIWrapper) -> None:
|
||||
"""Test that results gives the correct output with empty query."""
|
||||
search = api_client.results(
|
||||
query="", sort="relevance", time_filter="all", subreddit="all", limit=10
|
||||
)
|
||||
assert search == []
|
@ -78,6 +78,7 @@ EXPECTED_ALL = [
|
||||
"OpenAPISpec",
|
||||
"OpenWeatherMapQueryRun",
|
||||
"PubmedQueryRun",
|
||||
"RedditSearchRun",
|
||||
"QueryCheckerTool",
|
||||
"QueryPowerBITool",
|
||||
"QuerySQLCheckerTool",
|
||||
|
@ -79,6 +79,7 @@ _EXPECTED = [
|
||||
"OpenAPISpec",
|
||||
"OpenWeatherMapQueryRun",
|
||||
"PubmedQueryRun",
|
||||
"RedditSearchRun",
|
||||
"QueryCheckerTool",
|
||||
"QueryPowerBITool",
|
||||
"QuerySQLCheckerTool",
|
||||
|
Loading…
Reference in New Issue
Block a user