Compare commits

...

7 Commits

Author SHA1 Message Date
Eugene Yurtsev
99e8a085d0 qq 2023-05-15 11:51:53 -04:00
Eugene Yurtsev
7679fd9825 q 2023-05-15 11:24:31 -04:00
Eugene Yurtsev
80b7e78437 Merge branch 'master' into base_document_loader_to_retriever 2023-05-15 11:19:44 -04:00
leo-gan
b0e81d5a51 fixed notebook 2023-05-13 20:23:51 -07:00
leo-gan
0d4e3b2766 removed retrievers/arxiv.py and its references in __init__.py 2023-05-13 20:21:04 -07:00
leo-gan
c724703c07 changed notebook example 2023-05-13 20:18:50 -07:00
leo-gan
30d34879bf refactored to class BaseLoader(BaseRetriever). integr tests are OK 2023-05-13 19:33:15 -07:00
6 changed files with 72 additions and 38 deletions

View File

@@ -45,7 +45,7 @@
"id": "6c15470b-a16b-4e0d-bc6a-6998bafbb5a4",
"metadata": {},
"source": [
"`ArxivRetriever` has these arguments:\n",
"`ArxivLoader` has these arguments:\n",
"- optional `load_max_docs`: default=100. Use it to limit number of downloaded documents. It takes time to download all 100 documents, so use a small number for experiments. There is a hard limit of 300 for now.\n",
"- optional `load_all_available_meta`: default=False. By default only the most important fields downloaded: `Published` (date when document was published/last updated), `Title`, `Authors`, `Summary`. If True, other fields also downloaded.\n",
"\n",
@@ -77,24 +77,26 @@
},
"outputs": [],
"source": [
"from langchain.retrievers import ArxivRetriever"
"from langchain.document_loaders.arxiv import ArxivLoader"
]
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 4,
"id": "f381f642",
"metadata": {},
"outputs": [],
"source": [
"retriever = ArxivRetriever(load_max_docs=2)"
"retriever = ArxivLoader(load_max_docs=2)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "20ae1a74",
"metadata": {},
"execution_count": 5,
"id": "6ddc7d22-4a19-44cc-b620-4a5e3f0d7cff",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"docs = retriever.get_relevant_documents(query='1605.08386')"
@@ -102,7 +104,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 7,
"id": "1d5a5088",
"metadata": {},
"outputs": [
@@ -115,7 +117,7 @@
" 'Summary': 'Graphs on lattice points are studied whose edges come from a finite set of\\nallowed moves of arbitrary length. We show that the diameter of these graphs on\\nfibers of a fixed integer matrix can be bounded from above by a constant. We\\nthen study the mixing behaviour of heat-bath random walks on these graphs. We\\nalso state explicit conditions on the set of moves so that the heat-bath random\\nwalk, a generalization of the Glauber dynamics, is an expander in fixed\\ndimension.'}"
]
},
"execution_count": 9,
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
@@ -126,7 +128,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 8,
"id": "c0ccd0c7-f6a6-43e7-b842-5f57afb94224",
"metadata": {},
"outputs": [
@@ -136,7 +138,7 @@
"'arXiv:1605.08386v1 [math.CO] 26 May 2016\\nHEAT-BATH RANDOM WALKS WITH MARKOV BASES\\nCAPRICE STANLEY AND TOBIAS WINDISCH\\nAbstract. Graphs on lattice points are studied whose edges come from a finite set of\\nallowed moves of arbitrary length. We show that the diameter of these graphs on fibers of a\\nfixed integer matrix can be bounded from above by a constant. We then study the mixing\\nbehaviour of heat-b'"
]
},
"execution_count": 10,
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
@@ -155,7 +157,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 9,
"id": "bb3601df-53ea-4826-bdbe-554387bc3ad4",
"metadata": {
"tags": []
@@ -179,7 +181,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 10,
"id": "e9c1a114-0410-4804-be30-05f34a9760f9",
"metadata": {
"tags": []
@@ -193,7 +195,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 14,
"id": "51a33cc9-ec42-4afc-8a2d-3bfff476aa59",
"metadata": {
"tags": []
@@ -204,7 +206,7 @@
"from langchain.chains import ConversationalRetrievalChain\n",
"\n",
"model = ChatOpenAI(model_name='gpt-3.5-turbo') # switch to 'gpt-4'\n",
"qa = ConversationalRetrievalChain.from_llm(model,retriever=retriever)"
"qa = ConversationalRetrievalChain.from_llm(model,retriever=retriever, max_tokens_limit=4000)"
]
},
{
@@ -296,7 +298,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "09794ab5-759c-4b56-95d4-2454d4d86da1",
"id": "d23e9baf-e86d-457e-9342-82ffaf820aad",
"metadata": {},
"outputs": [],
"source": []

View File

@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import List, Optional, Iterator
from langchain.docstore.document import Document
from langchain.document_loaders.base import BaseLoader
@@ -14,18 +14,25 @@ class ArxivLoader(BaseLoader):
def __init__(
self,
query: str,
*,
query: Optional[str] = None,
load_max_docs: Optional[int] = 100,
load_all_available_meta: Optional[bool] = False,
load_all_available_meta: bool = False,
):
"""loader with a query and the maximum number of documents to load."""
self.query = query
self.load_max_docs = load_max_docs
self.load_all_available_meta = load_all_available_meta
def load(self) -> List[Document]:
"""Loads a query result from arxiv.org into a list of Documents."""
return list(self.lazy_load())
def lazy_load(self) -> Iterator[Document]:
"""Loads a query result from arxiv.org into a list of Documents."""
arxiv_client = ArxivAPIWrapper(
load_max_docs=self.load_max_docs,
load_all_available_meta=self.load_all_available_meta,
)
docs = arxiv_client.load(self.query)
docs = arxiv_client.lazy_load()
return docs

View File

@@ -1,18 +1,44 @@
from typing import List
from pydantic import BaseModel
from typing import List, Iterator, Type, Mapping, Any
from langchain.document_loaders.base import BaseLoader
from langchain.schema import BaseRetriever, Document
from langchain.utilities.arxiv import ArxivAPIWrapper
class ArxivRetriever(BaseRetriever, ArxivAPIWrapper):
"""
It is effectively a wrapper for ArxivAPIWrapper.
It wraps load() to get_relevant_documents().
It uses all ArxivAPIWrapper arguments without any change.
class QuerySupportingLoader(BaseLoader, BaseModel):
"""A parameterized loader."""
query: str
class ArxivLoader(QuerySupportingLoader):
"""Load documents from Arxiv.
SHOULD LIVE WITH DOCUMENT LOADERS
"""
arxiv_api_wrapper: ArxivAPIWrapper
def load(self) -> List[Document]:
"""We should stop implementing this and instead implement lazy_load()"""
return list(self.lazy_load())
def lazy_load(
self,
) -> Iterator[Document]:
"""A lazy loader for document content."""
yield from self.arxiv_api_wrapper.load(self.query)
class DocumentLoaderRetriever(BaseRetriever):
loader_cls: Type[QuerySupportingLoader]
additional_kwargs: Mapping[str, Any]
def get_relevant_documents(self, query: str) -> List[Document]:
return self.load(query=query)
"""Get relevant documents for a query."""
loader = self.loader_cls(query=query)
return loader.load()
async def aget_relevant_documents(self, query: str) -> List[Document]:
def aget_relevant_documents(self, query: str) -> List[Document]:
raise NotImplementedError

View File

@@ -1,8 +1,7 @@
"""Util that calls Arxiv."""
import logging
from typing import Any, Dict, List
from pydantic import BaseModel, Extra, root_validator
from typing import Any, Dict, List
from langchain.schema import Document

View File

@@ -1,6 +1,6 @@
from typing import List
from langchain.document_loaders.arxiv import ArxivLoader
from langchain.document_loaders import ArxivLoader
from langchain.schema import Document

View File

@@ -3,13 +3,13 @@ from typing import List
import pytest
from langchain.retrievers import ArxivRetriever
from langchain.document_loaders import ArxivLoader
from langchain.schema import Document
@pytest.fixture
def retriever() -> ArxivRetriever:
return ArxivRetriever()
def retriever() -> ArxivLoader:
return ArxivLoader(query=None)
def assert_docs(docs: List[Document], all_meta: bool = False) -> None:
@@ -24,13 +24,13 @@ def assert_docs(docs: List[Document], all_meta: bool = False) -> None:
assert len(set(doc.metadata)) == len(main_meta)
def test_load_success(retriever: ArxivRetriever) -> None:
def test_load_success(retriever: ArxivLoader) -> None:
docs = retriever.get_relevant_documents(query="1605.08386")
assert len(docs) == 1
assert_docs(docs, all_meta=False)
def test_load_success_all_meta(retriever: ArxivRetriever) -> None:
def test_load_success_all_meta(retriever: ArxivLoader) -> None:
retriever.load_all_available_meta = True
retriever.load_max_docs = 2
docs = retriever.get_relevant_documents(query="ChatGPT")
@@ -39,12 +39,12 @@ def test_load_success_all_meta(retriever: ArxivRetriever) -> None:
def test_load_success_init_args() -> None:
retriever = ArxivRetriever(load_max_docs=1, load_all_available_meta=True)
retriever = ArxivLoader(load_max_docs=1, load_all_available_meta=True)
docs = retriever.get_relevant_documents(query="ChatGPT")
assert len(docs) == 1
assert_docs(docs, all_meta=True)
def test_load_no_result(retriever: ArxivRetriever) -> None:
def test_load_no_result(retriever: ArxivLoader) -> None:
docs = retriever.get_relevant_documents("1605.08386WWW")
assert not docs