mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-08 14:05:16 +00:00
feat: add support for arxiv identifier in ArxivAPIWrapper() (#9318)
- Description: this PR adds the support for arxiv identifier of the ArxivAPIWrapper. I modified the `run()` and `load()` functions in `arxiv.py`, using regex to recognize if the query is in the form of arxiv identifier (see [https://info.arxiv.org/help/find/index.html](https://info.arxiv.org/help/find/index.html)). If so, it will directly search the paper corresponding to the arxiv identifier. I also modified and added tests in `test_arxiv.py`. - Issue: #9047 - Dependencies: N/A - Tag maintainer: N/A --------- Co-authored-by: Bagatur <baskaryan@gmail.com> Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
d3c2ca5656
commit
05b75f3f13
@ -1,6 +1,7 @@
|
||||
"""Util that calls Arxiv."""
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.pydantic_v1 import BaseModel, root_validator
|
||||
@ -17,6 +18,9 @@ class ArxivAPIWrapper(BaseModel):
|
||||
This wrapper will use the Arxiv API to conduct searches and
|
||||
fetch document summaries. By default, it will return the document summaries
|
||||
of the top-k results.
|
||||
If the query is in the form of arxiv identifier
|
||||
(see https://info.arxiv.org/help/find/index.html), it will return the paper
|
||||
corresponding to the arxiv identifier.
|
||||
It limits the Document content by doc_content_chars_max.
|
||||
Set doc_content_chars_max=None if you don't want to limit the content size.
|
||||
|
||||
@ -54,6 +58,18 @@ class ArxivAPIWrapper(BaseModel):
|
||||
load_all_available_meta: bool = False
|
||||
doc_content_chars_max: Optional[int] = 4000
|
||||
|
||||
def is_arxiv_identifier(self, query: str) -> bool:
|
||||
"""Check if a query is an arxiv identifier."""
|
||||
arxiv_identifier_pattern = r"\d{2}(0[1-9]|1[0-2])\.\d{4,5}(v\d+|)|\d{7}.*"
|
||||
for query_item in query[: self.ARXIV_MAX_QUERY_LENGTH].split():
|
||||
match_result = re.match(arxiv_identifier_pattern, query_item)
|
||||
if not match_result:
|
||||
return False
|
||||
assert match_result is not None
|
||||
if not match_result.group(0) == query_item:
|
||||
return False
|
||||
return True
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that the python package exists in environment."""
|
||||
@ -88,9 +104,15 @@ class ArxivAPIWrapper(BaseModel):
|
||||
query: a plaintext search query
|
||||
""" # noqa: E501
|
||||
try:
|
||||
results = self.arxiv_search( # type: ignore
|
||||
query[: self.ARXIV_MAX_QUERY_LENGTH], max_results=self.top_k_results
|
||||
).results()
|
||||
if self.is_arxiv_identifier(query):
|
||||
results = self.arxiv_search(
|
||||
id_list=query.split(),
|
||||
max_results=self.top_k_results,
|
||||
).results()
|
||||
else:
|
||||
results = self.arxiv_search( # type: ignore
|
||||
query[: self.ARXIV_MAX_QUERY_LENGTH], max_results=self.top_k_results
|
||||
).results()
|
||||
except self.arxiv_exceptions as ex:
|
||||
return f"Arxiv exception: {ex}"
|
||||
docs = [
|
||||
@ -129,9 +151,15 @@ class ArxivAPIWrapper(BaseModel):
|
||||
try:
|
||||
# Remove the ":" and "-" from the query, as they can cause search problems
|
||||
query = query.replace(":", "").replace("-", "")
|
||||
results = self.arxiv_search( # type: ignore
|
||||
query[: self.ARXIV_MAX_QUERY_LENGTH], max_results=self.load_max_docs
|
||||
).results()
|
||||
if self.is_arxiv_identifier(query):
|
||||
results = self.arxiv_search(
|
||||
id_list=query[: self.ARXIV_MAX_QUERY_LENGTH].split(),
|
||||
max_results=self.load_max_docs,
|
||||
).results()
|
||||
else:
|
||||
results = self.arxiv_search( # type: ignore
|
||||
query[: self.ARXIV_MAX_QUERY_LENGTH], max_results=self.load_max_docs
|
||||
).results()
|
||||
except self.arxiv_exceptions as ex:
|
||||
logger.debug("Error on arxiv: %s", ex)
|
||||
return []
|
||||
|
523
libs/langchain/poetry.lock
generated
523
libs/langchain/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -346,6 +346,7 @@ extended_testing = [
|
||||
"faiss-cpu",
|
||||
"openapi-schema-pydantic",
|
||||
"markdownify",
|
||||
"arxiv",
|
||||
"dashvector",
|
||||
"sqlite-vss",
|
||||
"timescale-vector",
|
||||
|
@ -15,13 +15,38 @@ def api_client() -> ArxivAPIWrapper:
|
||||
return ArxivAPIWrapper()
|
||||
|
||||
|
||||
def test_run_success(api_client: ArxivAPIWrapper) -> None:
|
||||
"""Test that returns the correct answer"""
|
||||
def test_run_success_paper_name(api_client: ArxivAPIWrapper) -> None:
|
||||
"""Test a query of paper name that returns the correct answer"""
|
||||
|
||||
output = api_client.run("1605.08386")
|
||||
output = api_client.run("Heat-bath random walks with Markov bases")
|
||||
assert "Probability distributions for Markov chains based quantum walks" in output
|
||||
assert (
|
||||
"Transformations of random walks on groups via Markov stopping times" in output
|
||||
)
|
||||
assert (
|
||||
"Recurrence of Multidimensional Persistent Random Walks. Fourier and Series "
|
||||
"Criteria" in output
|
||||
)
|
||||
|
||||
|
||||
def test_run_success_arxiv_identifier(api_client: ArxivAPIWrapper) -> None:
|
||||
"""Test a query of an arxiv identifier returns the correct answer"""
|
||||
|
||||
output = api_client.run("1605.08386v1")
|
||||
assert "Heat-bath random walks with Markov bases" in output
|
||||
|
||||
|
||||
def test_run_success_multiple_arxiv_identifiers(api_client: ArxivAPIWrapper) -> None:
|
||||
"""Test a query of multiple arxiv identifiers that returns the correct answer"""
|
||||
|
||||
output = api_client.run("1605.08386v1 2212.00794v2 2308.07912")
|
||||
assert "Heat-bath random walks with Markov bases" in output
|
||||
assert "Scaling Language-Image Pre-training via Masking" in output
|
||||
assert (
|
||||
"Ultra-low mass PBHs in the early universe can explain the PTA signal" in output
|
||||
)
|
||||
|
||||
|
||||
def test_run_returns_several_docs(api_client: ArxivAPIWrapper) -> None:
|
||||
"""Test that returns several docs"""
|
||||
|
||||
@ -43,14 +68,30 @@ def assert_docs(docs: List[Document]) -> None:
|
||||
assert set(doc.metadata) == {"Published", "Title", "Authors", "Summary"}
|
||||
|
||||
|
||||
def test_load_success(api_client: ArxivAPIWrapper) -> None:
|
||||
"""Test that returns one document"""
|
||||
def test_load_success_paper_name(api_client: ArxivAPIWrapper) -> None:
|
||||
"""Test a query of paper name that returns one document"""
|
||||
|
||||
docs = api_client.load("1605.08386")
|
||||
docs = api_client.load("Heat-bath random walks with Markov bases")
|
||||
assert len(docs) == 3
|
||||
assert_docs(docs)
|
||||
|
||||
|
||||
def test_load_success_arxiv_identifier(api_client: ArxivAPIWrapper) -> None:
|
||||
"""Test a query of an arxiv identifier that returns one document"""
|
||||
|
||||
docs = api_client.load("1605.08386v1")
|
||||
assert len(docs) == 1
|
||||
assert_docs(docs)
|
||||
|
||||
|
||||
def test_load_success_multiple_arxiv_identifiers(api_client: ArxivAPIWrapper) -> None:
|
||||
"""Test a query of arxiv identifiers that returns the correct answer"""
|
||||
|
||||
docs = api_client.load("1605.08386v1 2212.00794v2 2308.07912")
|
||||
assert len(docs) == 3
|
||||
assert_docs(docs)
|
||||
|
||||
|
||||
def test_load_returns_no_result(api_client: ArxivAPIWrapper) -> None:
|
||||
"""Test that returns no docs"""
|
||||
|
||||
|
17
libs/langchain/tests/unit_tests/utilities/test_arxiv.py
Normal file
17
libs/langchain/tests/unit_tests/utilities/test_arxiv.py
Normal file
@ -0,0 +1,17 @@
|
||||
import pytest as pytest
|
||||
|
||||
from langchain.utilities import ArxivAPIWrapper
|
||||
|
||||
|
||||
@pytest.mark.requires("arxiv")
|
||||
def test_is_arxiv_identifier() -> None:
|
||||
"""Test that is_arxiv_identifier returns True for valid arxiv identifiers"""
|
||||
api_client = ArxivAPIWrapper()
|
||||
assert api_client.is_arxiv_identifier("1605.08386v1")
|
||||
assert api_client.is_arxiv_identifier("0705.0123")
|
||||
assert api_client.is_arxiv_identifier("2308.07912")
|
||||
assert api_client.is_arxiv_identifier("9603067 2308.07912 2308.07912")
|
||||
assert not api_client.is_arxiv_identifier("12345")
|
||||
assert not api_client.is_arxiv_identifier("0705.012")
|
||||
assert not api_client.is_arxiv_identifier("0705.012300")
|
||||
assert not api_client.is_arxiv_identifier("1605.08386w1")
|
191
poetry.lock
generated
191
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user