Repair Wikipedia document loader load_max_docs and improve test coverage. (#13769)

**Description:** 

Repair Wikipedia document loader `load_max_docs` and improve test
coverage.

**Issue:** 

The Wikipedia document loader was not respecting the `load_max_docs`
paramater (not reported) and would always return a maximum of 10
documents. This is because the API wrapper (in `utilities/wikipedia.py`)
wasn't passing `top_k_results` to the underlying [Wikipedia
library](https://wikipedia.readthedocs.io/en/latest/code.html#module-wikipedia).
By default this library returns 10 results.

The default number of results for the document loader has been reduced
from 100 to 25. This is because loading 100 results takes a very long
time and is an inconvenient default. It should possibly be 10.

In addition, the documentation for the loader reported that there was a
hard limit (300) on the number of documents returned. In actuality 300
is the maximum Wikipedia query character length set by the API wrapper.

Tests have been added for the document loader (previously missing) and
to test the correct numbers of documents are being returned by each
class, both by default, and when overridden. Also repaired is the
`assert_docs` test which has been updated to correctly test for the
default metadata (which includes `source` in recent releases).

**Dependencies:** 
nil

**Tag maintainer:**
@leo-gan

**Twitter handle:**
@queenvictoria
This commit is contained in:
Snow 2023-11-29 12:26:40 +11:00 committed by GitHub
parent 04c4878306
commit 1a33e5b500
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 88 additions and 7 deletions

View File

@ -9,7 +9,7 @@ from langchain.utilities.wikipedia import WikipediaAPIWrapper
class WikipediaLoader(BaseLoader):
"""Load from `Wikipedia`.
The hard limit on the number of downloaded Documents is 300 for now.
The hard limit on the length of the query is 300 for now.
Each wiki page represents one Document.
"""
@ -18,7 +18,7 @@ class WikipediaLoader(BaseLoader):
self,
query: str,
lang: str = "en",
load_max_docs: Optional[int] = 100,
load_max_docs: Optional[int] = 25,
load_all_available_meta: Optional[bool] = False,
doc_content_chars_max: Optional[int] = 4000,
):

View File

@ -43,7 +43,9 @@ class WikipediaAPIWrapper(BaseModel):
def run(self, query: str) -> str:
"""Run Wikipedia search and get page summaries."""
page_titles = self.wiki_client.search(query[:WIKIPEDIA_MAX_QUERY_LENGTH])
page_titles = self.wiki_client.search(
query[:WIKIPEDIA_MAX_QUERY_LENGTH], results=self.top_k_results
)
summaries = []
for page_title in page_titles[: self.top_k_results]:
if wiki_page := self._fetch_page(page_title):
@ -103,7 +105,9 @@ class WikipediaAPIWrapper(BaseModel):
Returns: a list of documents.
"""
page_titles = self.wiki_client.search(query[:WIKIPEDIA_MAX_QUERY_LENGTH])
page_titles = self.wiki_client.search(
query[:WIKIPEDIA_MAX_QUERY_LENGTH], results=self.top_k_results
)
docs = []
for page_title in page_titles[: self.top_k_results]:
if wiki_page := self._fetch_page(page_title):

View File

@ -0,0 +1,55 @@
"""Integration test for Wikipedia Document Loader."""
from typing import List
from langchain_core.documents import Document
from langchain.document_loaders import WikipediaLoader
def assert_docs(docs: List[Document], all_meta: bool = False) -> None:
for doc in docs:
assert doc.page_content
assert doc.metadata
main_meta = {"title", "summary", "source"}
assert set(doc.metadata).issuperset(main_meta)
if all_meta:
assert len(set(doc.metadata)) > len(main_meta)
else:
assert len(set(doc.metadata)) == len(main_meta)
def test_load_success() -> None:
loader = WikipediaLoader(query="HUNTER X HUNTER")
docs = loader.load()
assert len(docs) > 1
assert len(docs) <= 25
assert_docs(docs, all_meta=False)
def test_load_success_all_meta() -> None:
load_max_docs = 5
load_all_available_meta = True
loader = WikipediaLoader(
query="HUNTER X HUNTER",
load_max_docs=load_max_docs,
load_all_available_meta=load_all_available_meta,
)
docs = loader.load()
assert len(docs) == load_max_docs
assert_docs(docs, all_meta=load_all_available_meta)
def test_load_success_more() -> None:
load_max_docs = 10
loader = WikipediaLoader(query="HUNTER X HUNTER", load_max_docs=load_max_docs)
docs = loader.load()
assert len(docs) == load_max_docs
assert_docs(docs, all_meta=False)
def test_load_no_result() -> None:
loader = WikipediaLoader(
"NORESULTCALL_NORESULTCALL_NORESULTCALL_NORESULTCALL_NORESULTCALL_NORESULTCALL"
)
docs = loader.load()
assert not docs

View File

@ -1,4 +1,4 @@
"""Integration test for Wikipedia API Wrapper."""
"""Integration test for Wikipedia Retriever."""
from typing import List
import pytest
@ -16,7 +16,7 @@ def assert_docs(docs: List[Document], all_meta: bool = False) -> None:
for doc in docs:
assert doc.page_content
assert doc.metadata
main_meta = {"title", "summary"}
main_meta = {"title", "summary", "source"}
assert set(doc.metadata).issuperset(main_meta)
if all_meta:
assert len(set(doc.metadata)) > len(main_meta)
@ -27,6 +27,7 @@ def assert_docs(docs: List[Document], all_meta: bool = False) -> None:
def test_load_success(retriever: WikipediaRetriever) -> None:
docs = retriever.get_relevant_documents("HUNTER X HUNTER")
assert len(docs) > 1
assert len(docs) <= 3
assert_docs(docs, all_meta=False)
@ -34,6 +35,7 @@ def test_load_success_all_meta(retriever: WikipediaRetriever) -> None:
retriever.load_all_available_meta = True
docs = retriever.get_relevant_documents("HUNTER X HUNTER")
assert len(docs) > 1
assert len(docs) <= 3
assert_docs(docs, all_meta=True)
@ -46,6 +48,15 @@ def test_load_success_init_args() -> None:
assert_docs(docs, all_meta=True)
def test_load_success_init_args_more() -> None:
retriever = WikipediaRetriever(
lang="en", top_k_results=20, load_all_available_meta=False
)
docs = retriever.get_relevant_documents("HUNTER X HUNTER")
assert len(docs) == 20
assert_docs(docs, all_meta=False)
def test_load_no_result(retriever: WikipediaRetriever) -> None:
docs = retriever.get_relevant_documents(
"NORESULTCALL_NORESULTCALL_NORESULTCALL_NORESULTCALL_NORESULTCALL_NORESULTCALL"

View File

@ -28,7 +28,7 @@ def assert_docs(docs: List[Document], all_meta: bool = False) -> None:
for doc in docs:
assert doc.page_content
assert doc.metadata
main_meta = {"title", "summary"}
main_meta = {"title", "summary", "source"}
assert set(doc.metadata).issuperset(main_meta)
if all_meta:
assert len(set(doc.metadata)) > len(main_meta)
@ -39,6 +39,7 @@ def assert_docs(docs: List[Document], all_meta: bool = False) -> None:
def test_load_success(api_client: WikipediaAPIWrapper) -> None:
docs = api_client.load("HUNTER X HUNTER")
assert len(docs) > 1
assert len(docs) <= 3
assert_docs(docs, all_meta=False)
@ -46,9 +47,19 @@ def test_load_success_all_meta(api_client: WikipediaAPIWrapper) -> None:
api_client.load_all_available_meta = True
docs = api_client.load("HUNTER X HUNTER")
assert len(docs) > 1
assert len(docs) <= 3
assert_docs(docs, all_meta=True)
def test_load_more_docs_success(api_client: WikipediaAPIWrapper) -> None:
top_k_results = 20
api_client = WikipediaAPIWrapper(top_k_results=top_k_results)
docs = api_client.load("HUNTER X HUNTER")
assert len(docs) > 10
assert len(docs) <= top_k_results
assert_docs(docs, all_meta=False)
def test_load_no_result(api_client: WikipediaAPIWrapper) -> None:
docs = api_client.load(
"NORESULTCALL_NORESULTCALL_NORESULTCALL_NORESULTCALL_NORESULTCALL_NORESULTCALL"