mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 15:19:33 +00:00
langchain: Passthrough batch_size on index()/aindex() calls (#19443)
**Description:** This change passes through `batch_size` to `add_documents()`/`aadd_documents()` on calls to `index()` and `aindex()` such that the documents are processed in the expected batch size. **Issue:** #19415 **Dependencies:** N/A **Twitter handle:** N/A
This commit is contained in:
parent
82de8fd6c9
commit
e1a6341940
@ -330,7 +330,7 @@ def index(
|
||||
# Be pessimistic and assume that all vector store write will fail.
|
||||
# First write to vector store
|
||||
if docs_to_index:
|
||||
vector_store.add_documents(docs_to_index, ids=uids)
|
||||
vector_store.add_documents(docs_to_index, ids=uids, batch_size=batch_size)
|
||||
num_added += len(docs_to_index) - len(seen_docs)
|
||||
num_updated += len(seen_docs)
|
||||
|
||||
@ -544,7 +544,9 @@ async def aindex(
|
||||
# Be pessimistic and assume that all vector store write will fail.
|
||||
# First write to vector store
|
||||
if docs_to_index:
|
||||
await vector_store.aadd_documents(docs_to_index, ids=uids)
|
||||
await vector_store.aadd_documents(
|
||||
docs_to_index, ids=uids, batch_size=batch_size
|
||||
)
|
||||
num_added += len(docs_to_index) - len(seen_docs)
|
||||
num_updated += len(seen_docs)
|
||||
|
||||
|
@ -20,7 +20,7 @@ from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.vectorstores import VST, VectorStore
|
||||
|
||||
from langchain.indexes import aindex, index
|
||||
from langchain.indexes._api import _abatch
|
||||
from langchain.indexes._api import _abatch, _HashedDocument
|
||||
from langchain.indexes._sql_record_manager import SQLRecordManager
|
||||
|
||||
|
||||
@ -1304,3 +1304,44 @@ async def test_aindexing_force_update(
|
||||
"num_skipped": 0,
|
||||
"num_updated": 2,
|
||||
}
|
||||
|
||||
|
||||
def test_indexing_custom_batch_size(
|
||||
record_manager: SQLRecordManager, vector_store: InMemoryVectorStore
|
||||
) -> None:
|
||||
"""Test indexing with a custom batch size."""
|
||||
docs = [
|
||||
Document(
|
||||
page_content="This is a test document.",
|
||||
metadata={"source": "1"},
|
||||
),
|
||||
]
|
||||
ids = [_HashedDocument.from_document(doc).uid for doc in docs]
|
||||
|
||||
batch_size = 1
|
||||
with patch.object(vector_store, "add_documents") as mock_add_documents:
|
||||
index(docs, record_manager, vector_store, batch_size=batch_size)
|
||||
args, kwargs = mock_add_documents.call_args
|
||||
assert args == (docs,)
|
||||
assert kwargs == {"ids": ids, "batch_size": batch_size}
|
||||
|
||||
|
||||
@pytest.mark.requires("aiosqlite")
|
||||
async def test_aindexing_custom_batch_size(
|
||||
arecord_manager: SQLRecordManager, vector_store: InMemoryVectorStore
|
||||
) -> None:
|
||||
"""Test indexing with a custom batch size."""
|
||||
docs = [
|
||||
Document(
|
||||
page_content="This is a test document.",
|
||||
metadata={"source": "1"},
|
||||
),
|
||||
]
|
||||
ids = [_HashedDocument.from_document(doc).uid for doc in docs]
|
||||
|
||||
batch_size = 1
|
||||
with patch.object(vector_store, "aadd_documents") as mock_add_documents:
|
||||
await aindex(docs, arecord_manager, vector_store, batch_size=batch_size)
|
||||
args, kwargs = mock_add_documents.call_args
|
||||
assert args == (docs,)
|
||||
assert kwargs == {"ids": ids, "batch_size": batch_size}
|
||||
|
Loading…
Reference in New Issue
Block a user