From 1a33e5b500d2cc06a9b60e7f193bf2602362850c Mon Sep 17 00:00:00 2001 From: Snow Date: Wed, 29 Nov 2023 12:26:40 +1100 Subject: [PATCH] Repair Wikipedia document loader `load_max_docs` and improve test coverage. (#13769) **Description:** Repair Wikipedia document loader `load_max_docs` and improve test coverage. **Issue:** The Wikipedia document loader was not respecting the `load_max_docs` paramater (not reported) and would always return a maximum of 10 documents. This is because the API wrapper (in `utilities/wikipedia.py`) wasn't passing `top_k_results` to the underlying [Wikipedia library](https://wikipedia.readthedocs.io/en/latest/code.html#module-wikipedia). By default this library returns 10 results. The default number of results for the document loader has been reduced from 100 to 25. This is because loading 100 results takes a very long time and is an inconvenient default. It should possibly be 10. In addition, the documentation for the loader reported that there was a hard limit (300) on the number of documents returned. In actuality 300 is the maximum Wikipedia query character length set by the API wrapper. Tests have been added for the document loader (previously missing) and to test the correct numbers of documents are being returned by each class, both by default, and when overridden. Also repaired is the `assert_docs` test which has been updated to correctly test for the default metadata (which includes `source` in recent releases). **Dependencies:** nil **Tag maintainer:** @leo-gan **Twitter handle:** @queenvictoria --- .../langchain/document_loaders/wikipedia.py | 4 +- .../langchain/utilities/wikipedia.py | 8 ++- .../document_loaders/test_wikipedia.py | 55 +++++++++++++++++++ .../retrievers/test_wikipedia.py | 15 ++++- .../utilities/test_wikipedia_api.py | 13 ++++- 5 files changed, 88 insertions(+), 7 deletions(-) create mode 100644 libs/langchain/tests/integration_tests/document_loaders/test_wikipedia.py diff --git a/libs/langchain/langchain/document_loaders/wikipedia.py b/libs/langchain/langchain/document_loaders/wikipedia.py index dee0df04e1d..5099959fa5c 100644 --- a/libs/langchain/langchain/document_loaders/wikipedia.py +++ b/libs/langchain/langchain/document_loaders/wikipedia.py @@ -9,7 +9,7 @@ from langchain.utilities.wikipedia import WikipediaAPIWrapper class WikipediaLoader(BaseLoader): """Load from `Wikipedia`. - The hard limit on the number of downloaded Documents is 300 for now. + The hard limit on the length of the query is 300 for now. Each wiki page represents one Document. """ @@ -18,7 +18,7 @@ class WikipediaLoader(BaseLoader): self, query: str, lang: str = "en", - load_max_docs: Optional[int] = 100, + load_max_docs: Optional[int] = 25, load_all_available_meta: Optional[bool] = False, doc_content_chars_max: Optional[int] = 4000, ): diff --git a/libs/langchain/langchain/utilities/wikipedia.py b/libs/langchain/langchain/utilities/wikipedia.py index c1f53e7a3d7..37dc064ffb4 100644 --- a/libs/langchain/langchain/utilities/wikipedia.py +++ b/libs/langchain/langchain/utilities/wikipedia.py @@ -43,7 +43,9 @@ class WikipediaAPIWrapper(BaseModel): def run(self, query: str) -> str: """Run Wikipedia search and get page summaries.""" - page_titles = self.wiki_client.search(query[:WIKIPEDIA_MAX_QUERY_LENGTH]) + page_titles = self.wiki_client.search( + query[:WIKIPEDIA_MAX_QUERY_LENGTH], results=self.top_k_results + ) summaries = [] for page_title in page_titles[: self.top_k_results]: if wiki_page := self._fetch_page(page_title): @@ -103,7 +105,9 @@ class WikipediaAPIWrapper(BaseModel): Returns: a list of documents. """ - page_titles = self.wiki_client.search(query[:WIKIPEDIA_MAX_QUERY_LENGTH]) + page_titles = self.wiki_client.search( + query[:WIKIPEDIA_MAX_QUERY_LENGTH], results=self.top_k_results + ) docs = [] for page_title in page_titles[: self.top_k_results]: if wiki_page := self._fetch_page(page_title): diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_wikipedia.py b/libs/langchain/tests/integration_tests/document_loaders/test_wikipedia.py new file mode 100644 index 00000000000..929b68f402a --- /dev/null +++ b/libs/langchain/tests/integration_tests/document_loaders/test_wikipedia.py @@ -0,0 +1,55 @@ +"""Integration test for Wikipedia Document Loader.""" +from typing import List + +from langchain_core.documents import Document + +from langchain.document_loaders import WikipediaLoader + + +def assert_docs(docs: List[Document], all_meta: bool = False) -> None: + for doc in docs: + assert doc.page_content + assert doc.metadata + main_meta = {"title", "summary", "source"} + assert set(doc.metadata).issuperset(main_meta) + if all_meta: + assert len(set(doc.metadata)) > len(main_meta) + else: + assert len(set(doc.metadata)) == len(main_meta) + + +def test_load_success() -> None: + loader = WikipediaLoader(query="HUNTER X HUNTER") + docs = loader.load() + assert len(docs) > 1 + assert len(docs) <= 25 + assert_docs(docs, all_meta=False) + + +def test_load_success_all_meta() -> None: + load_max_docs = 5 + load_all_available_meta = True + loader = WikipediaLoader( + query="HUNTER X HUNTER", + load_max_docs=load_max_docs, + load_all_available_meta=load_all_available_meta, + ) + docs = loader.load() + assert len(docs) == load_max_docs + assert_docs(docs, all_meta=load_all_available_meta) + + +def test_load_success_more() -> None: + load_max_docs = 10 + loader = WikipediaLoader(query="HUNTER X HUNTER", load_max_docs=load_max_docs) + docs = loader.load() + assert len(docs) == load_max_docs + assert_docs(docs, all_meta=False) + + +def test_load_no_result() -> None: + loader = WikipediaLoader( + "NORESULTCALL_NORESULTCALL_NORESULTCALL_NORESULTCALL_NORESULTCALL_NORESULTCALL" + ) + docs = loader.load() + assert not docs diff --git a/libs/langchain/tests/integration_tests/retrievers/test_wikipedia.py b/libs/langchain/tests/integration_tests/retrievers/test_wikipedia.py index f911c219c51..e740bbd5733 100644 --- a/libs/langchain/tests/integration_tests/retrievers/test_wikipedia.py +++ b/libs/langchain/tests/integration_tests/retrievers/test_wikipedia.py @@ -1,4 +1,4 @@ -"""Integration test for Wikipedia API Wrapper.""" +"""Integration test for Wikipedia Retriever.""" from typing import List import pytest @@ -16,7 +16,7 @@ def assert_docs(docs: List[Document], all_meta: bool = False) -> None: for doc in docs: assert doc.page_content assert doc.metadata - main_meta = {"title", "summary"} + main_meta = {"title", "summary", "source"} assert set(doc.metadata).issuperset(main_meta) if all_meta: assert len(set(doc.metadata)) > len(main_meta) @@ -27,6 +27,7 @@ def assert_docs(docs: List[Document], all_meta: bool = False) -> None: def test_load_success(retriever: WikipediaRetriever) -> None: docs = retriever.get_relevant_documents("HUNTER X HUNTER") assert len(docs) > 1 + assert len(docs) <= 3 assert_docs(docs, all_meta=False) @@ -34,6 +35,7 @@ def test_load_success_all_meta(retriever: WikipediaRetriever) -> None: retriever.load_all_available_meta = True docs = retriever.get_relevant_documents("HUNTER X HUNTER") assert len(docs) > 1 + assert len(docs) <= 3 assert_docs(docs, all_meta=True) @@ -46,6 +48,15 @@ def test_load_success_init_args() -> None: assert_docs(docs, all_meta=True) +def test_load_success_init_args_more() -> None: + retriever = WikipediaRetriever( + lang="en", top_k_results=20, load_all_available_meta=False + ) + docs = retriever.get_relevant_documents("HUNTER X HUNTER") + assert len(docs) == 20 + assert_docs(docs, all_meta=False) + + def test_load_no_result(retriever: WikipediaRetriever) -> None: docs = retriever.get_relevant_documents( "NORESULTCALL_NORESULTCALL_NORESULTCALL_NORESULTCALL_NORESULTCALL_NORESULTCALL" diff --git a/libs/langchain/tests/integration_tests/utilities/test_wikipedia_api.py b/libs/langchain/tests/integration_tests/utilities/test_wikipedia_api.py index 1041fdb5a7b..08517d11de5 100644 --- a/libs/langchain/tests/integration_tests/utilities/test_wikipedia_api.py +++ b/libs/langchain/tests/integration_tests/utilities/test_wikipedia_api.py @@ -28,7 +28,7 @@ def assert_docs(docs: List[Document], all_meta: bool = False) -> None: for doc in docs: assert doc.page_content assert doc.metadata - main_meta = {"title", "summary"} + main_meta = {"title", "summary", "source"} assert set(doc.metadata).issuperset(main_meta) if all_meta: assert len(set(doc.metadata)) > len(main_meta) @@ -39,6 +39,7 @@ def assert_docs(docs: List[Document], all_meta: bool = False) -> None: def test_load_success(api_client: WikipediaAPIWrapper) -> None: docs = api_client.load("HUNTER X HUNTER") assert len(docs) > 1 + assert len(docs) <= 3 assert_docs(docs, all_meta=False) @@ -46,9 +47,19 @@ def test_load_success_all_meta(api_client: WikipediaAPIWrapper) -> None: api_client.load_all_available_meta = True docs = api_client.load("HUNTER X HUNTER") assert len(docs) > 1 + assert len(docs) <= 3 assert_docs(docs, all_meta=True) +def test_load_more_docs_success(api_client: WikipediaAPIWrapper) -> None: + top_k_results = 20 + api_client = WikipediaAPIWrapper(top_k_results=top_k_results) + docs = api_client.load("HUNTER X HUNTER") + assert len(docs) > 10 + assert len(docs) <= top_k_results + assert_docs(docs, all_meta=False) + + def test_load_no_result(api_client: WikipediaAPIWrapper) -> None: docs = api_client.load( "NORESULTCALL_NORESULTCALL_NORESULTCALL_NORESULTCALL_NORESULTCALL_NORESULTCALL"