feat: add support for arxiv identifier in ArxivAPIWrapper() (#9318)

- Description: this PR adds the support for arxiv identifier of the
ArxivAPIWrapper. I modified the `run()` and `load()` functions in
`arxiv.py`, using regex to recognize if the query is in the form of
arxiv identifier (see
[https://info.arxiv.org/help/find/index.html](https://info.arxiv.org/help/find/index.html)).
If so, it will directly search the paper corresponding to the arxiv
identifier. I also modified and added tests in `test_arxiv.py`.
  - Issue: #9047 
  - Dependencies: N/A
  - Tag maintainer: N/A

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
Mincoolee 2023-09-28 08:35:16 +08:00 committed by GitHub
parent d3c2ca5656
commit 05b75f3f13
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 754 additions and 71 deletions

View File

@ -1,6 +1,7 @@
"""Util that calls Arxiv.""" """Util that calls Arxiv."""
import logging import logging
import os import os
import re
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from langchain.pydantic_v1 import BaseModel, root_validator from langchain.pydantic_v1 import BaseModel, root_validator
@ -17,6 +18,9 @@ class ArxivAPIWrapper(BaseModel):
This wrapper will use the Arxiv API to conduct searches and This wrapper will use the Arxiv API to conduct searches and
fetch document summaries. By default, it will return the document summaries fetch document summaries. By default, it will return the document summaries
of the top-k results. of the top-k results.
If the query is in the form of arxiv identifier
(see https://info.arxiv.org/help/find/index.html), it will return the paper
corresponding to the arxiv identifier.
It limits the Document content by doc_content_chars_max. It limits the Document content by doc_content_chars_max.
Set doc_content_chars_max=None if you don't want to limit the content size. Set doc_content_chars_max=None if you don't want to limit the content size.
@ -54,6 +58,18 @@ class ArxivAPIWrapper(BaseModel):
load_all_available_meta: bool = False load_all_available_meta: bool = False
doc_content_chars_max: Optional[int] = 4000 doc_content_chars_max: Optional[int] = 4000
def is_arxiv_identifier(self, query: str) -> bool:
"""Check if a query is an arxiv identifier."""
arxiv_identifier_pattern = r"\d{2}(0[1-9]|1[0-2])\.\d{4,5}(v\d+|)|\d{7}.*"
for query_item in query[: self.ARXIV_MAX_QUERY_LENGTH].split():
match_result = re.match(arxiv_identifier_pattern, query_item)
if not match_result:
return False
assert match_result is not None
if not match_result.group(0) == query_item:
return False
return True
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the python package exists in environment.""" """Validate that the python package exists in environment."""
@ -88,9 +104,15 @@ class ArxivAPIWrapper(BaseModel):
query: a plaintext search query query: a plaintext search query
""" # noqa: E501 """ # noqa: E501
try: try:
results = self.arxiv_search( # type: ignore if self.is_arxiv_identifier(query):
query[: self.ARXIV_MAX_QUERY_LENGTH], max_results=self.top_k_results results = self.arxiv_search(
).results() id_list=query.split(),
max_results=self.top_k_results,
).results()
else:
results = self.arxiv_search( # type: ignore
query[: self.ARXIV_MAX_QUERY_LENGTH], max_results=self.top_k_results
).results()
except self.arxiv_exceptions as ex: except self.arxiv_exceptions as ex:
return f"Arxiv exception: {ex}" return f"Arxiv exception: {ex}"
docs = [ docs = [
@ -129,9 +151,15 @@ class ArxivAPIWrapper(BaseModel):
try: try:
# Remove the ":" and "-" from the query, as they can cause search problems # Remove the ":" and "-" from the query, as they can cause search problems
query = query.replace(":", "").replace("-", "") query = query.replace(":", "").replace("-", "")
results = self.arxiv_search( # type: ignore if self.is_arxiv_identifier(query):
query[: self.ARXIV_MAX_QUERY_LENGTH], max_results=self.load_max_docs results = self.arxiv_search(
).results() id_list=query[: self.ARXIV_MAX_QUERY_LENGTH].split(),
max_results=self.load_max_docs,
).results()
else:
results = self.arxiv_search( # type: ignore
query[: self.ARXIV_MAX_QUERY_LENGTH], max_results=self.load_max_docs
).results()
except self.arxiv_exceptions as ex: except self.arxiv_exceptions as ex:
logger.debug("Error on arxiv: %s", ex) logger.debug("Error on arxiv: %s", ex)
return [] return []

File diff suppressed because it is too large Load Diff

View File

@ -346,6 +346,7 @@ extended_testing = [
"faiss-cpu", "faiss-cpu",
"openapi-schema-pydantic", "openapi-schema-pydantic",
"markdownify", "markdownify",
"arxiv",
"dashvector", "dashvector",
"sqlite-vss", "sqlite-vss",
"timescale-vector", "timescale-vector",

View File

@ -15,13 +15,38 @@ def api_client() -> ArxivAPIWrapper:
return ArxivAPIWrapper() return ArxivAPIWrapper()
def test_run_success(api_client: ArxivAPIWrapper) -> None: def test_run_success_paper_name(api_client: ArxivAPIWrapper) -> None:
"""Test that returns the correct answer""" """Test a query of paper name that returns the correct answer"""
output = api_client.run("1605.08386") output = api_client.run("Heat-bath random walks with Markov bases")
assert "Probability distributions for Markov chains based quantum walks" in output
assert (
"Transformations of random walks on groups via Markov stopping times" in output
)
assert (
"Recurrence of Multidimensional Persistent Random Walks. Fourier and Series "
"Criteria" in output
)
def test_run_success_arxiv_identifier(api_client: ArxivAPIWrapper) -> None:
"""Test a query of an arxiv identifier returns the correct answer"""
output = api_client.run("1605.08386v1")
assert "Heat-bath random walks with Markov bases" in output assert "Heat-bath random walks with Markov bases" in output
def test_run_success_multiple_arxiv_identifiers(api_client: ArxivAPIWrapper) -> None:
"""Test a query of multiple arxiv identifiers that returns the correct answer"""
output = api_client.run("1605.08386v1 2212.00794v2 2308.07912")
assert "Heat-bath random walks with Markov bases" in output
assert "Scaling Language-Image Pre-training via Masking" in output
assert (
"Ultra-low mass PBHs in the early universe can explain the PTA signal" in output
)
def test_run_returns_several_docs(api_client: ArxivAPIWrapper) -> None: def test_run_returns_several_docs(api_client: ArxivAPIWrapper) -> None:
"""Test that returns several docs""" """Test that returns several docs"""
@ -43,14 +68,30 @@ def assert_docs(docs: List[Document]) -> None:
assert set(doc.metadata) == {"Published", "Title", "Authors", "Summary"} assert set(doc.metadata) == {"Published", "Title", "Authors", "Summary"}
def test_load_success(api_client: ArxivAPIWrapper) -> None: def test_load_success_paper_name(api_client: ArxivAPIWrapper) -> None:
"""Test that returns one document""" """Test a query of paper name that returns one document"""
docs = api_client.load("1605.08386") docs = api_client.load("Heat-bath random walks with Markov bases")
assert len(docs) == 3
assert_docs(docs)
def test_load_success_arxiv_identifier(api_client: ArxivAPIWrapper) -> None:
"""Test a query of an arxiv identifier that returns one document"""
docs = api_client.load("1605.08386v1")
assert len(docs) == 1 assert len(docs) == 1
assert_docs(docs) assert_docs(docs)
def test_load_success_multiple_arxiv_identifiers(api_client: ArxivAPIWrapper) -> None:
"""Test a query of arxiv identifiers that returns the correct answer"""
docs = api_client.load("1605.08386v1 2212.00794v2 2308.07912")
assert len(docs) == 3
assert_docs(docs)
def test_load_returns_no_result(api_client: ArxivAPIWrapper) -> None: def test_load_returns_no_result(api_client: ArxivAPIWrapper) -> None:
"""Test that returns no docs""" """Test that returns no docs"""

View File

@ -0,0 +1,17 @@
import pytest as pytest
from langchain.utilities import ArxivAPIWrapper
@pytest.mark.requires("arxiv")
def test_is_arxiv_identifier() -> None:
"""Test that is_arxiv_identifier returns True for valid arxiv identifiers"""
api_client = ArxivAPIWrapper()
assert api_client.is_arxiv_identifier("1605.08386v1")
assert api_client.is_arxiv_identifier("0705.0123")
assert api_client.is_arxiv_identifier("2308.07912")
assert api_client.is_arxiv_identifier("9603067 2308.07912 2308.07912")
assert not api_client.is_arxiv_identifier("12345")
assert not api_client.is_arxiv_identifier("0705.012")
assert not api_client.is_arxiv_identifier("0705.012300")
assert not api_client.is_arxiv_identifier("1605.08386w1")

191
poetry.lock generated

File diff suppressed because it is too large Load Diff