mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-09 14:35:50 +00:00
feat: add support for arxiv identifier in ArxivAPIWrapper() (#9318)
- Description: this PR adds the support for arxiv identifier of the ArxivAPIWrapper. I modified the `run()` and `load()` functions in `arxiv.py`, using regex to recognize if the query is in the form of arxiv identifier (see [https://info.arxiv.org/help/find/index.html](https://info.arxiv.org/help/find/index.html)). If so, it will directly search the paper corresponding to the arxiv identifier. I also modified and added tests in `test_arxiv.py`. - Issue: #9047 - Dependencies: N/A - Tag maintainer: N/A --------- Co-authored-by: Bagatur <baskaryan@gmail.com> Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
d3c2ca5656
commit
05b75f3f13
@ -1,6 +1,7 @@
|
|||||||
"""Util that calls Arxiv."""
|
"""Util that calls Arxiv."""
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from langchain.pydantic_v1 import BaseModel, root_validator
|
from langchain.pydantic_v1 import BaseModel, root_validator
|
||||||
@ -17,6 +18,9 @@ class ArxivAPIWrapper(BaseModel):
|
|||||||
This wrapper will use the Arxiv API to conduct searches and
|
This wrapper will use the Arxiv API to conduct searches and
|
||||||
fetch document summaries. By default, it will return the document summaries
|
fetch document summaries. By default, it will return the document summaries
|
||||||
of the top-k results.
|
of the top-k results.
|
||||||
|
If the query is in the form of arxiv identifier
|
||||||
|
(see https://info.arxiv.org/help/find/index.html), it will return the paper
|
||||||
|
corresponding to the arxiv identifier.
|
||||||
It limits the Document content by doc_content_chars_max.
|
It limits the Document content by doc_content_chars_max.
|
||||||
Set doc_content_chars_max=None if you don't want to limit the content size.
|
Set doc_content_chars_max=None if you don't want to limit the content size.
|
||||||
|
|
||||||
@ -54,6 +58,18 @@ class ArxivAPIWrapper(BaseModel):
|
|||||||
load_all_available_meta: bool = False
|
load_all_available_meta: bool = False
|
||||||
doc_content_chars_max: Optional[int] = 4000
|
doc_content_chars_max: Optional[int] = 4000
|
||||||
|
|
||||||
|
def is_arxiv_identifier(self, query: str) -> bool:
|
||||||
|
"""Check if a query is an arxiv identifier."""
|
||||||
|
arxiv_identifier_pattern = r"\d{2}(0[1-9]|1[0-2])\.\d{4,5}(v\d+|)|\d{7}.*"
|
||||||
|
for query_item in query[: self.ARXIV_MAX_QUERY_LENGTH].split():
|
||||||
|
match_result = re.match(arxiv_identifier_pattern, query_item)
|
||||||
|
if not match_result:
|
||||||
|
return False
|
||||||
|
assert match_result is not None
|
||||||
|
if not match_result.group(0) == query_item:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that the python package exists in environment."""
|
"""Validate that the python package exists in environment."""
|
||||||
@ -88,9 +104,15 @@ class ArxivAPIWrapper(BaseModel):
|
|||||||
query: a plaintext search query
|
query: a plaintext search query
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
try:
|
try:
|
||||||
results = self.arxiv_search( # type: ignore
|
if self.is_arxiv_identifier(query):
|
||||||
query[: self.ARXIV_MAX_QUERY_LENGTH], max_results=self.top_k_results
|
results = self.arxiv_search(
|
||||||
).results()
|
id_list=query.split(),
|
||||||
|
max_results=self.top_k_results,
|
||||||
|
).results()
|
||||||
|
else:
|
||||||
|
results = self.arxiv_search( # type: ignore
|
||||||
|
query[: self.ARXIV_MAX_QUERY_LENGTH], max_results=self.top_k_results
|
||||||
|
).results()
|
||||||
except self.arxiv_exceptions as ex:
|
except self.arxiv_exceptions as ex:
|
||||||
return f"Arxiv exception: {ex}"
|
return f"Arxiv exception: {ex}"
|
||||||
docs = [
|
docs = [
|
||||||
@ -129,9 +151,15 @@ class ArxivAPIWrapper(BaseModel):
|
|||||||
try:
|
try:
|
||||||
# Remove the ":" and "-" from the query, as they can cause search problems
|
# Remove the ":" and "-" from the query, as they can cause search problems
|
||||||
query = query.replace(":", "").replace("-", "")
|
query = query.replace(":", "").replace("-", "")
|
||||||
results = self.arxiv_search( # type: ignore
|
if self.is_arxiv_identifier(query):
|
||||||
query[: self.ARXIV_MAX_QUERY_LENGTH], max_results=self.load_max_docs
|
results = self.arxiv_search(
|
||||||
).results()
|
id_list=query[: self.ARXIV_MAX_QUERY_LENGTH].split(),
|
||||||
|
max_results=self.load_max_docs,
|
||||||
|
).results()
|
||||||
|
else:
|
||||||
|
results = self.arxiv_search( # type: ignore
|
||||||
|
query[: self.ARXIV_MAX_QUERY_LENGTH], max_results=self.load_max_docs
|
||||||
|
).results()
|
||||||
except self.arxiv_exceptions as ex:
|
except self.arxiv_exceptions as ex:
|
||||||
logger.debug("Error on arxiv: %s", ex)
|
logger.debug("Error on arxiv: %s", ex)
|
||||||
return []
|
return []
|
||||||
|
523
libs/langchain/poetry.lock
generated
523
libs/langchain/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -346,6 +346,7 @@ extended_testing = [
|
|||||||
"faiss-cpu",
|
"faiss-cpu",
|
||||||
"openapi-schema-pydantic",
|
"openapi-schema-pydantic",
|
||||||
"markdownify",
|
"markdownify",
|
||||||
|
"arxiv",
|
||||||
"dashvector",
|
"dashvector",
|
||||||
"sqlite-vss",
|
"sqlite-vss",
|
||||||
"timescale-vector",
|
"timescale-vector",
|
||||||
|
@ -15,13 +15,38 @@ def api_client() -> ArxivAPIWrapper:
|
|||||||
return ArxivAPIWrapper()
|
return ArxivAPIWrapper()
|
||||||
|
|
||||||
|
|
||||||
def test_run_success(api_client: ArxivAPIWrapper) -> None:
|
def test_run_success_paper_name(api_client: ArxivAPIWrapper) -> None:
|
||||||
"""Test that returns the correct answer"""
|
"""Test a query of paper name that returns the correct answer"""
|
||||||
|
|
||||||
output = api_client.run("1605.08386")
|
output = api_client.run("Heat-bath random walks with Markov bases")
|
||||||
|
assert "Probability distributions for Markov chains based quantum walks" in output
|
||||||
|
assert (
|
||||||
|
"Transformations of random walks on groups via Markov stopping times" in output
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
"Recurrence of Multidimensional Persistent Random Walks. Fourier and Series "
|
||||||
|
"Criteria" in output
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_success_arxiv_identifier(api_client: ArxivAPIWrapper) -> None:
|
||||||
|
"""Test a query of an arxiv identifier returns the correct answer"""
|
||||||
|
|
||||||
|
output = api_client.run("1605.08386v1")
|
||||||
assert "Heat-bath random walks with Markov bases" in output
|
assert "Heat-bath random walks with Markov bases" in output
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_success_multiple_arxiv_identifiers(api_client: ArxivAPIWrapper) -> None:
|
||||||
|
"""Test a query of multiple arxiv identifiers that returns the correct answer"""
|
||||||
|
|
||||||
|
output = api_client.run("1605.08386v1 2212.00794v2 2308.07912")
|
||||||
|
assert "Heat-bath random walks with Markov bases" in output
|
||||||
|
assert "Scaling Language-Image Pre-training via Masking" in output
|
||||||
|
assert (
|
||||||
|
"Ultra-low mass PBHs in the early universe can explain the PTA signal" in output
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_run_returns_several_docs(api_client: ArxivAPIWrapper) -> None:
|
def test_run_returns_several_docs(api_client: ArxivAPIWrapper) -> None:
|
||||||
"""Test that returns several docs"""
|
"""Test that returns several docs"""
|
||||||
|
|
||||||
@ -43,14 +68,30 @@ def assert_docs(docs: List[Document]) -> None:
|
|||||||
assert set(doc.metadata) == {"Published", "Title", "Authors", "Summary"}
|
assert set(doc.metadata) == {"Published", "Title", "Authors", "Summary"}
|
||||||
|
|
||||||
|
|
||||||
def test_load_success(api_client: ArxivAPIWrapper) -> None:
|
def test_load_success_paper_name(api_client: ArxivAPIWrapper) -> None:
|
||||||
"""Test that returns one document"""
|
"""Test a query of paper name that returns one document"""
|
||||||
|
|
||||||
docs = api_client.load("1605.08386")
|
docs = api_client.load("Heat-bath random walks with Markov bases")
|
||||||
|
assert len(docs) == 3
|
||||||
|
assert_docs(docs)
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_success_arxiv_identifier(api_client: ArxivAPIWrapper) -> None:
|
||||||
|
"""Test a query of an arxiv identifier that returns one document"""
|
||||||
|
|
||||||
|
docs = api_client.load("1605.08386v1")
|
||||||
assert len(docs) == 1
|
assert len(docs) == 1
|
||||||
assert_docs(docs)
|
assert_docs(docs)
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_success_multiple_arxiv_identifiers(api_client: ArxivAPIWrapper) -> None:
|
||||||
|
"""Test a query of arxiv identifiers that returns the correct answer"""
|
||||||
|
|
||||||
|
docs = api_client.load("1605.08386v1 2212.00794v2 2308.07912")
|
||||||
|
assert len(docs) == 3
|
||||||
|
assert_docs(docs)
|
||||||
|
|
||||||
|
|
||||||
def test_load_returns_no_result(api_client: ArxivAPIWrapper) -> None:
|
def test_load_returns_no_result(api_client: ArxivAPIWrapper) -> None:
|
||||||
"""Test that returns no docs"""
|
"""Test that returns no docs"""
|
||||||
|
|
||||||
|
17
libs/langchain/tests/unit_tests/utilities/test_arxiv.py
Normal file
17
libs/langchain/tests/unit_tests/utilities/test_arxiv.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
import pytest as pytest
|
||||||
|
|
||||||
|
from langchain.utilities import ArxivAPIWrapper
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("arxiv")
|
||||||
|
def test_is_arxiv_identifier() -> None:
|
||||||
|
"""Test that is_arxiv_identifier returns True for valid arxiv identifiers"""
|
||||||
|
api_client = ArxivAPIWrapper()
|
||||||
|
assert api_client.is_arxiv_identifier("1605.08386v1")
|
||||||
|
assert api_client.is_arxiv_identifier("0705.0123")
|
||||||
|
assert api_client.is_arxiv_identifier("2308.07912")
|
||||||
|
assert api_client.is_arxiv_identifier("9603067 2308.07912 2308.07912")
|
||||||
|
assert not api_client.is_arxiv_identifier("12345")
|
||||||
|
assert not api_client.is_arxiv_identifier("0705.012")
|
||||||
|
assert not api_client.is_arxiv_identifier("0705.012300")
|
||||||
|
assert not api_client.is_arxiv_identifier("1605.08386w1")
|
191
poetry.lock
generated
191
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user