Add multi-input Reddit search tool (#13893)

- **Description:** Added a tool called RedditSearchRun and an
accompanying API wrapper, which searches Reddit for posts with support
for time filtering, post sorting, query string and subreddit filtering.
  - **Issue:** #13891 
  - **Dependencies:** `praw` module is used to search Reddit
- **Tag maintainer:** @baskaryan , and any of the other maintainers if
needed
  - **Twitter handle:** None.

  Hello,

This is our first PR and we hope that our changes will be helpful to the
community. We have run `make format`, `make lint` and `make test`
locally before submitting the PR. To our knowledge, our changes do not
introduce any new errors.

Our PR integrates the `praw` package which is already used by
RedditPostsLoader in LangChain. Nonetheless, we have added integration
tests and edited unit tests to test our changes. An example notebook is
also provided. These changes were put together by me, @Anika2000,
@CharlesXu123, and @Jeremy-Cheng-stack

Thank you in advance to the maintainers for their time.

---------

Co-authored-by: What-Is-A-Username <49571870+What-Is-A-Username@users.noreply.github.com>
Co-authored-by: Anika2000 <anika.sultana@mail.utoronto.ca>
Co-authored-by: Jeremy Cheng <81793294+Jeremy-Cheng-stack@users.noreply.github.com>
Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
Cheng (William) Huang 2023-11-29 20:16:40 -05:00 committed by GitHub
parent 00a6e8962c
commit a00db4b28f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 600 additions and 2 deletions

View File

@ -0,0 +1,262 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Reddit Search \n",
"\n",
"In this notebook, we learn how the Reddit search tool works. \n",
"First make sure that you have installed praw with the command below: "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "shellscript"
}
},
"outputs": [],
"source": [
"!pip install praw"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Then you need to set you need to set up the proper API keys and environment variables. You would need to create a Reddit user account and get credentials. So, create a Reddit user account by going to https://www.reddit.com and signing up. \n",
"Then get your credentials by going to https://www.reddit.com/prefs/apps and creating an app. \n",
"You should have your client_id and secret from creating the app. Now, you can paste those strings in client_id and client_secret variable. \n",
"Note: You can put any string for user_agent "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"client_id = \"\"\n",
"client_secret = \"\"\n",
"user_agent = \"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.tools.reddit_search.tool import RedditSearchRun\n",
"from langchain.utilities.reddit_search import RedditSearchAPIWrapper\n",
"\n",
"search = RedditSearchRun(\n",
" api_wrapper=RedditSearchAPIWrapper(\n",
" reddit_client_id=client_id,\n",
" reddit_client_secret=client_secret,\n",
" reddit_user_agent=user_agent,\n",
" )\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can then set your queries for example, what subreddit you want to query, how many posts you want to be returned, how you would like the result to be sorted etc."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.tools.reddit_search.tool import RedditSearchSchema\n",
"\n",
"search_params = RedditSearchSchema(\n",
" query=\"beginner\", sort=\"new\", time_filter=\"week\", subreddit=\"python\", limit=\"2\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally run the search and get your results"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"result = search.run(tool_input=search_params.dict())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(result)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here is an example of printing the result. \n",
"Note: You may get different output depending on the newest post in the subreddit but the formatting should be similar."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"> Searching r/python found 2 posts:\n",
"> Post Title: 'Setup Github Copilot in Visual Studio Code'\n",
"> User: Feisty-Recording-715\n",
"> Subreddit: r/Python:\n",
"> Text body: 🛠️ This tutorial is perfect for beginners looking to strengthen their understanding of version control or for experienced developers seeking a quick reference for GitHub setup in Visual Studio Code.\n",
">\n",
">🎓 By the end of this video, you'll be equipped with the skills to confidently manage your codebase, collaborate with others, and contribute to open-source projects on GitHub.\n",
">\n",
">\n",
">Video link: https://youtu.be/IdT1BhrSfdo?si=mV7xVpiyuhlD8Zrw\n",
">\n",
">Your feedback is welcome\n",
"> Post URL: https://www.reddit.com/r/Python/comments/1823wr7/setup_github_copilot_in_visual_studio_code/\n",
"> Post Category: N/A.\n",
"> Score: 0\n",
">\n",
">Post Title: 'A Chinese Checkers game made with pygame and PySide6, with custom bots support'\n",
">User: HenryChess\n",
">Subreddit: r/Python:\n",
"> Text body: GitHub link: https://github.com/henrychess/pygame-chinese-checkers\n",
">\n",
">I'm not sure if this counts as beginner or intermediate. I think I'm still in the beginner zone, so I flair it as beginner.\n",
">\n",
">This is a Chinese Checkers (aka Sternhalma) game for 2 to 3 players. The bots I wrote are easy to beat, as they're mainly for debugging the game logic part of the code. However, you can write up your own custom bots. There is a guide at the github page.\n",
"> Post URL: https://www.reddit.com/r/Python/comments/181xq0u/a_chinese_checkers_game_made_with_pygame_and/\n",
"> Post Category: N/A.\n",
" > Score: 1\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Using tool with an agent chain\n",
"\n",
"Reddit search functionality is also provided as a multi-input tool. In this example, we adapt [existing code from the docs](https://python.langchain.com/docs/modules/agents/how_to/sharedmemory_for_tools), and use ChatOpenAI to create an agent chain with memory. This agent chain is able to pull information from Reddit and use these posts to respond to subsequent input. \n",
"\n",
"To run the example, add your reddit API access information and also get an OpenAI key from the [OpenAI API](https://help.openai.com/en/articles/4936850-where-do-i-find-my-api-key)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Adapted code from https://python.langchain.com/docs/modules/agents/how_to/sharedmemory_for_tools\n",
"\n",
"from langchain.agents import AgentExecutor, StructuredChatAgent, Tool\n",
"from langchain.chains import LLMChain\n",
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.memory import ConversationBufferMemory, ReadOnlySharedMemory\n",
"from langchain.prompts import PromptTemplate\n",
"from langchain.tools.reddit_search.tool import RedditSearchRun\n",
"from langchain.utilities.reddit_search import RedditSearchAPIWrapper\n",
"\n",
"# Provide keys for Reddit\n",
"client_id = \"\"\n",
"client_secret = \"\"\n",
"user_agent = \"\"\n",
"# Provide key for OpenAI\n",
"openai_api_key = \"\"\n",
"\n",
"template = \"\"\"This is a conversation between a human and a bot:\n",
"\n",
"{chat_history}\n",
"\n",
"Write a summary of the conversation for {input}:\n",
"\"\"\"\n",
"\n",
"prompt = PromptTemplate(input_variables=[\"input\", \"chat_history\"], template=template)\n",
"memory = ConversationBufferMemory(memory_key=\"chat_history\")\n",
"\n",
"prefix = \"\"\"Have a conversation with a human, answering the following questions as best you can. You have access to the following tools:\"\"\"\n",
"suffix = \"\"\"Begin!\"\n",
"\n",
"{chat_history}\n",
"Question: {input}\n",
"{agent_scratchpad}\"\"\"\n",
"\n",
"tools = [\n",
" RedditSearchRun(\n",
" api_wrapper=RedditSearchAPIWrapper(\n",
" reddit_client_id=client_id,\n",
" reddit_client_secret=client_secret,\n",
" reddit_user_agent=user_agent,\n",
" )\n",
" )\n",
"]\n",
"\n",
"prompt = StructuredChatAgent.create_prompt(\n",
" prefix=prefix,\n",
" tools=tools,\n",
" suffix=suffix,\n",
" input_variables=[\"input\", \"chat_history\", \"agent_scratchpad\"],\n",
")\n",
"\n",
"llm = ChatOpenAI(temperature=0, openai_api_key=openai_api_key)\n",
"\n",
"llm_chain = LLMChain(llm=llm, prompt=prompt)\n",
"agent = StructuredChatAgent(llm_chain=llm_chain, verbose=True, tools=tools)\n",
"agent_chain = AgentExecutor.from_agent_and_tools(\n",
" agent=agent, verbose=True, memory=memory, tools=tools\n",
")\n",
"\n",
"# Answering the first prompt requires usage of the Reddit search tool.\n",
"agent_chain.run(input=\"What is the newest post on r/langchain for the week?\")\n",
"# Answering the subsequent prompt uses memory.\n",
"agent_chain.run(input=\"Who is the author of the post?\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.11.5 64-bit ('langchaindev')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
},
"vscode": {
"interpreter": {
"hash": "3929050b09828356c9f5ebaf862d05c053d8228eddbc70f990c168e54dd824ba"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -64,6 +64,7 @@ from langchain.tools.openweathermap.tool import OpenWeatherMapQueryRun
from langchain.tools.dataforseo_api_search import DataForSeoAPISearchRun
from langchain.tools.dataforseo_api_search import DataForSeoAPISearchResults
from langchain.tools.memorize.tool import Memorize
from langchain.tools.reddit_search.tool import RedditSearchRun
from langchain.utilities.arxiv import ArxivAPIWrapper
from langchain.utilities.golden_query import GoldenQueryAPIWrapper
from langchain.utilities.pubmed import PubMedAPIWrapper
@ -88,6 +89,7 @@ from langchain.utilities.wikipedia import WikipediaAPIWrapper
from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper
from langchain.utilities.openweathermap import OpenWeatherMapAPIWrapper
from langchain.utilities.dataforseo_api_search import DataForSeoAPIWrapper
from langchain.utilities.reddit_search import RedditSearchAPIWrapper
def _get_python_repl() -> BaseTool:
@ -374,6 +376,10 @@ def _get_google_cloud_texttospeech(**kwargs: Any) -> BaseTool:
return GoogleCloudTextToSpeechTool(**kwargs)
def _get_reddit_search(**kwargs: Any) -> BaseTool:
return RedditSearchRun(api_wrapper=RedditSearchAPIWrapper(**kwargs))
_EXTRA_LLM_TOOLS: Dict[
str,
Tuple[Callable[[Arg(BaseLanguageModel, "llm"), KwArg(Any)], BaseTool], List[str]],
@ -454,6 +460,10 @@ _EXTRA_OPTIONAL_TOOLS: Dict[str, Tuple[Callable[[KwArg(Any)], BaseTool], List[st
),
"eleven_labs_text2speech": (_get_eleven_labs_text2speech, ["eleven_api_key"]),
"google_cloud_texttospeech": (_get_google_cloud_texttospeech, []),
"reddit_search": (
_get_reddit_search,
["reddit_client_id", "reddit_client_secret", "reddit_user_agent"],
),
}

View File

@ -480,6 +480,12 @@ def _import_python_tool_PythonREPLTool() -> Any:
)
def _import_reddit_search_RedditSearchRun() -> Any:
from langchain.tools.reddit_search.tool import RedditSearchRun
return RedditSearchRun
def _import_render() -> Any:
from langchain.tools.render import format_tool_to_openai_function
@ -833,6 +839,8 @@ def __getattr__(name: str) -> Any:
return _import_python_tool_PythonAstREPLTool()
elif name == "PythonREPLTool":
return _import_python_tool_PythonREPLTool()
elif name == "RedditSearchRun":
return _import_reddit_search_RedditSearchRun()
elif name == "format_tool_to_openai_function":
return _import_render()
elif name == "BaseRequestsTool":
@ -983,6 +991,7 @@ __all__ = [
"OpenAPISpec",
"OpenWeatherMapQueryRun",
"PubmedQueryRun",
"RedditSearchRun",
"QueryCheckerTool",
"QueryPowerBITool",
"QuerySQLCheckerTool",

View File

@ -0,0 +1,63 @@
"""Tool for the Reddit search API."""
from typing import Optional, Type
from langchain.callbacks.manager import CallbackManagerForToolRun
from langchain.pydantic_v1 import BaseModel, Field
from langchain.tools.base import BaseTool
from langchain.utilities.reddit_search import RedditSearchAPIWrapper
class RedditSearchSchema(BaseModel):
"""Input for Reddit search."""
query: str = Field(
description="should be query string that post title should \
contain, or '*' if anything is allowed."
)
sort: str = Field(
description='should be sort method, which is one of: "relevance" \
, "hot", "top", "new", or "comments".'
)
time_filter: str = Field(
description='should be time period to filter by, which is \
one of "all", "day", "hour", "month", "week", or "year"'
)
subreddit: str = Field(
description='should be name of subreddit, like "all" for \
r/all'
)
limit: str = Field(
description="a positive integer indicating the maximum number \
of results to return"
)
class RedditSearchRun(BaseTool):
"""Tool that queries for posts on a subreddit."""
name: str = "reddit_search"
description: str = (
"A tool that searches for posts on Reddit."
"Useful when you need to know post information on a subreddit."
)
api_wrapper: RedditSearchAPIWrapper = Field(default_factory=RedditSearchAPIWrapper)
args_schema: Type[BaseModel] = RedditSearchSchema
def _run(
self,
query: str,
sort: str,
time_filter: str,
subreddit: str,
limit: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Use the tool."""
return self.api_wrapper.run(
query=query,
sort=sort,
time_filter=time_filter,
subreddit=subreddit,
limit=int(limit),
)

View File

@ -0,0 +1,121 @@
"""Wrapper for the Reddit API"""
from typing import Any, Dict, List, Optional
from langchain.pydantic_v1 import BaseModel, root_validator
from langchain.utils import get_from_dict_or_env
class RedditSearchAPIWrapper(BaseModel):
"""Wrapper for Reddit API
To use, set the environment variables ``REDDIT_CLIENT_ID``,
``REDDIT_CLIENT_SECRET``, ``REDDIT_USER_AGENT`` to set the client ID,
client secret, and user agent, respectively, as given by Reddit's API.
Alternatively, all three can be supplied as named parameters in the
constructor: ``reddit_client_id``, ``reddit_client_secret``, and
``reddit_user_agent``, respectively.
Example:
.. code-block:: python
from langchain.utilities import RedditSearchAPIWrapper
reddit_search = RedditSearchAPIWrapper()
"""
reddit_client: Any
# Values required to access Reddit API via praw
reddit_client_id: Optional[str]
reddit_client_secret: Optional[str]
reddit_user_agent: Optional[str]
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the API ID, secret and user agent exists in environment
and check that praw module is present.
"""
reddit_client_id = get_from_dict_or_env(
values, "reddit_client_id", "REDDIT_CLIENT_ID"
)
values["reddit_client_id"] = reddit_client_id
reddit_client_secret = get_from_dict_or_env(
values, "reddit_client_secret", "REDDIT_CLIENT_SECRET"
)
values["reddit_client_secret"] = reddit_client_secret
reddit_user_agent = get_from_dict_or_env(
values, "reddit_user_agent", "REDDIT_USER_AGENT"
)
values["reddit_user_agent"] = reddit_user_agent
try:
import praw
except ImportError:
raise ImportError(
"praw package not found, please install it with pip install praw"
)
reddit_client = praw.Reddit(
client_id=reddit_client_id,
client_secret=reddit_client_secret,
user_agent=reddit_user_agent,
)
values["reddit_client"] = reddit_client
return values
def run(
self, query: str, sort: str, time_filter: str, subreddit: str, limit: int
) -> str:
"""Search Reddit and return posts as a single string."""
results: List[Dict] = self.results(
query=query,
sort=sort,
time_filter=time_filter,
subreddit=subreddit,
limit=limit,
)
if len(results) > 0:
output: List[str] = [f"Searching r/{subreddit} found {len(results)} posts:"]
for r in results:
category = "N/A" if r["post_category"] is None else r["post_category"]
p = f"Post Title: '{r['post_title']}'\n\
User: {r['post_author']}\n\
Subreddit: {r['post_subreddit']}:\n\
Text body: {r['post_text']}\n\
Post URL: {r['post_url']}\n\
Post Category: {category}.\n\
Score: {r['post_score']}\n"
output.append(p)
return "\n".join(output)
else:
return f"Searching r/{subreddit} did not find any posts:"
def results(
self, query: str, sort: str, time_filter: str, subreddit: str, limit: int
) -> List[Dict]:
"""Use praw to search Reddit and return a list of dictionaries,
one for each post.
"""
subredditObject = self.reddit_client.subreddit(subreddit)
search_results = subredditObject.search(
query=query, sort=sort, time_filter=time_filter, limit=limit
)
search_results = [r for r in search_results]
results_object = []
for submission in search_results:
results_object.append(
{
"post_subreddit": submission.subreddit_name_prefixed,
"post_category": submission.category,
"post_title": submission.title,
"post_text": submission.selftext,
"post_score": submission.score,
"post_id": submission.id,
"post_url": submission.url,
"post_author": submission.author,
}
)
return results_object

View File

@ -6479,6 +6479,49 @@ files = [
[package.extras]
dill = ["dill (>=0.3.7)"]
[[package]]
name = "praw"
version = "7.7.1"
description = "PRAW, an acronym for \"Python Reddit API Wrapper\", is a Python package that allows for simple access to Reddit's API."
optional = true
python-versions = "~=3.7"
files = [
{file = "praw-7.7.1-py3-none-any.whl", hash = "sha256:9ec5dc943db00c175bc6a53f4e089ce625f3fdfb27305564b616747b767d38ef"},
{file = "praw-7.7.1.tar.gz", hash = "sha256:f1d7eef414cafe28080dda12ed09253a095a69933d5c8132eca11d4dc8a070bf"},
]
[package.dependencies]
prawcore = ">=2.1,<3"
update-checker = ">=0.18"
websocket-client = ">=0.54.0"
[package.extras]
ci = ["coveralls"]
dev = ["betamax (>=0.8,<0.9)", "betamax-matchers (>=0.3.0,<0.5)", "furo", "packaging", "pre-commit", "pytest (>=2.7.3)", "requests (>=2.20.1,<3)", "sphinx", "urllib3 (==1.26.*)"]
lint = ["furo", "pre-commit", "sphinx"]
readthedocs = ["furo", "sphinx"]
test = ["betamax (>=0.8,<0.9)", "betamax-matchers (>=0.3.0,<0.5)", "pytest (>=2.7.3)", "requests (>=2.20.1,<3)", "urllib3 (==1.26.*)"]
[[package]]
name = "prawcore"
version = "2.4.0"
description = "\"Low-level communication layer for PRAW 4+."
optional = true
python-versions = "~=3.8"
files = [
{file = "prawcore-2.4.0-py3-none-any.whl", hash = "sha256:29af5da58d85704b439ad3c820873ad541f4535e00bb98c66f0fbcc8c603065a"},
{file = "prawcore-2.4.0.tar.gz", hash = "sha256:b7b2b5a1d04406e086ab4e79988dc794df16059862f329f4c6a43ed09986c335"},
]
[package.dependencies]
requests = ">=2.6.0,<3.0"
[package.extras]
ci = ["coveralls"]
dev = ["packaging", "prawcore[lint]", "prawcore[test]"]
lint = ["pre-commit", "ruff (>=0.0.291)"]
test = ["betamax (>=0.8,<0.9)", "pytest (>=2.7.3)", "urllib3 (==1.26.*)"]
[[package]]
name = "prometheus-client"
version = "0.17.1"
@ -10407,6 +10450,25 @@ tzdata = {version = "*", markers = "platform_system == \"Windows\""}
[package.extras]
devenv = ["black", "check-manifest", "flake8", "pyroma", "pytest (>=4.3)", "pytest-cov", "pytest-mock (>=3.3)", "zest.releaser"]
[[package]]
name = "update-checker"
version = "0.18.0"
description = "A python module that will check for package updates."
optional = true
python-versions = "*"
files = [
{file = "update_checker-0.18.0-py3-none-any.whl", hash = "sha256:cbba64760a36fe2640d80d85306e8fe82b6816659190993b7bdabadee4d4bbfd"},
{file = "update_checker-0.18.0.tar.gz", hash = "sha256:6a2d45bb4ac585884a6b03f9eade9161cedd9e8111545141e9aa9058932acb13"},
]
[package.dependencies]
requests = ">=2.3.0"
[package.extras]
dev = ["black", "flake8", "pytest (>=2.7.3)"]
lint = ["black", "flake8"]
test = ["pytest (>=2.7.3)"]
[[package]]
name = "upstash-redis"
version = "0.15.0"
@ -11201,7 +11263,7 @@ cli = ["typer"]
cohere = ["cohere"]
docarray = ["docarray"]
embeddings = ["sentence-transformers"]
extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cohere", "dashvector", "databricks-vectorsearch", "dgml-utils", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "geopandas", "gitpython", "google-cloud-documentai", "gql", "html2text", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "openai", "openai", "openapi-pydantic", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "upstash-redis", "xata", "xmltodict"]
extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cohere", "dashvector", "databricks-vectorsearch", "dgml-utils", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "geopandas", "gitpython", "google-cloud-documentai", "gql", "html2text", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "openai", "openai", "openapi-pydantic", "pandas", "pdfminer-six", "pgvector", "praw", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "upstash-redis", "xata", "xmltodict"]
javascript = ["esprima"]
llms = ["clarifai", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "openlm", "torch", "transformers"]
openai = ["openai", "tiktoken"]
@ -11211,4 +11273,4 @@ text-helpers = ["chardet"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "c540c453e0b1221ed8c1bbb638430264dc9ec52df12d915357c961b31c69e132"
content-hash = "473f3af6bf8c6e224f9522df2ff78f7ab965ea044c60efe60e7fe0461d2d2014"

View File

@ -143,6 +143,7 @@ azure-ai-textanalytics = {version = "^5.3.0", optional = true}
google-cloud-documentai = {version = "^2.20.1", optional = true}
fireworks-ai = {version = "^0.6.0", optional = true, python = ">=3.9,<4.0"}
javelin-sdk = {version = "^0.1.8", optional = true}
praw = {version = "^7.7.1", optional = true}
msal = {version = "^1.25.0", optional = true}
databricks-vectorsearch = {version = "^0.21", optional = true}
dgml-utils = {version = "^0.3.0", optional = true}
@ -383,6 +384,7 @@ extended_testing = [
"rspace_client",
"fireworks-ai",
"javelin-sdk",
"praw",
"databricks-vectorsearch",
"dgml-utils",
"cohere",

View File

@ -0,0 +1,67 @@
import pytest
from langchain.utilities.reddit_search import RedditSearchAPIWrapper
@pytest.fixture
def api_client() -> RedditSearchAPIWrapper:
return RedditSearchAPIWrapper()
def assert_results_exists(results: list) -> None:
if len(results) > 0:
for result in results:
assert "post_title" in result
assert "post_author" in result
assert "post_subreddit" in result
assert "post_text" in result
assert "post_url" in result
assert "post_score" in result
assert "post_category" in result
assert "post_id" in result
else:
assert results == []
@pytest.mark.requires("praw")
def test_run_empty_query(api_client: RedditSearchAPIWrapper) -> None:
"""Test that run gives the correct answer with empty query."""
search = api_client.run(
query="", sort="relevance", time_filter="all", subreddit="all", limit=5
)
assert search == "Searching r/all did not find any posts:"
@pytest.mark.requires("praw")
def test_run_query(api_client: RedditSearchAPIWrapper) -> None:
"""Test that run gives the correct answer."""
search = api_client.run(
query="university",
sort="relevance",
time_filter="all",
subreddit="funny",
limit=5,
)
assert "University" in search
@pytest.mark.requires("praw")
def test_results_exists(api_client: RedditSearchAPIWrapper) -> None:
"""Test that results gives the correct output format."""
search = api_client.results(
query="What is the best programming language?",
sort="relevance",
time_filter="all",
subreddit="all",
limit=10,
)
assert_results_exists(search)
@pytest.mark.requires("praw")
def test_results_empty_query(api_client: RedditSearchAPIWrapper) -> None:
"""Test that results gives the correct output with empty query."""
search = api_client.results(
query="", sort="relevance", time_filter="all", subreddit="all", limit=10
)
assert search == []

View File

@ -78,6 +78,7 @@ EXPECTED_ALL = [
"OpenAPISpec",
"OpenWeatherMapQueryRun",
"PubmedQueryRun",
"RedditSearchRun",
"QueryCheckerTool",
"QueryPowerBITool",
"QuerySQLCheckerTool",

View File

@ -79,6 +79,7 @@ _EXPECTED = [
"OpenAPISpec",
"OpenWeatherMapQueryRun",
"PubmedQueryRun",
"RedditSearchRun",
"QueryCheckerTool",
"QueryPowerBITool",
"QuerySQLCheckerTool",