diff --git a/libs/core/langchain_core/documents/base.py b/libs/core/langchain_core/documents/base.py index f33d7c82b0f..c72bfa724fd 100644 --- a/libs/core/langchain_core/documents/base.py +++ b/libs/core/langchain_core/documents/base.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import List, Literal +from typing import Any, List, Literal from langchain_core.load.serializable import Serializable from langchain_core.pydantic_v1 import Field @@ -17,6 +17,10 @@ class Document(Serializable): """ type: Literal["Document"] = "Document" + def __init__(self, page_content: str, **kwargs: Any) -> None: + """Pass page_content in as positional or named arg.""" + super().__init__(page_content=page_content, **kwargs) + @classmethod def is_lc_serializable(cls) -> bool: """Return whether this class is serializable.""" diff --git a/libs/langchain/langchain/retrievers/document_compressors/chain_extract.py b/libs/langchain/langchain/retrievers/document_compressors/chain_extract.py index 0a04eb5322d..af8319f1ffc 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/chain_extract.py +++ b/libs/langchain/langchain/retrievers/document_compressors/chain_extract.py @@ -2,7 +2,7 @@ from __future__ import annotations import asyncio -from typing import Any, Callable, Dict, Optional, Sequence +from typing import Any, Callable, Dict, Optional, Sequence, cast from langchain_core.documents import Document from langchain_core.language_models import BaseLanguageModel @@ -67,7 +67,9 @@ class LLMChainExtractor(BaseDocumentCompressor): output = self.llm_chain.predict_and_parse(**_input, callbacks=callbacks) if len(output) == 0: continue - compressed_docs.append(Document(page_content=output, metadata=doc.metadata)) + compressed_docs.append( + Document(page_content=cast(str, output), metadata=doc.metadata) + ) return compressed_docs async def acompress_documents( diff --git a/libs/langchain/langchain/retrievers/ensemble.py b/libs/langchain/langchain/retrievers/ensemble.py index 050cb88fc4c..34838f23fed 100644 --- a/libs/langchain/langchain/retrievers/ensemble.py +++ b/libs/langchain/langchain/retrievers/ensemble.py @@ -3,7 +3,7 @@ Ensemble retriever that ensemble the results of multiple retrievers by using weighted Reciprocal Rank Fusion """ import asyncio -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, cast from langchain_core.callbacks import ( AsyncCallbackManagerForRetrieverRun, @@ -195,7 +195,7 @@ class EnsembleRetriever(BaseRetriever): # Enforce that retrieved docs are Documents for each list in retriever_docs for i in range(len(retriever_docs)): retriever_docs[i] = [ - Document(page_content=doc) if not isinstance(doc, Document) else doc + Document(page_content=cast(str, doc)) if isinstance(doc, str) else doc for doc in retriever_docs[i] ] diff --git a/libs/langchain/tests/unit_tests/indexes/test_hashed_document.py b/libs/langchain/tests/unit_tests/indexes/test_hashed_document.py index 2cf60e3a1a0..ff0121967ca 100644 --- a/libs/langchain/tests/unit_tests/indexes/test_hashed_document.py +++ b/libs/langchain/tests/unit_tests/indexes/test_hashed_document.py @@ -13,10 +13,10 @@ def test_hashed_document_hashing() -> None: def test_hashing_with_missing_content() -> None: """Check that ValueError is raised if page_content is missing.""" - with pytest.raises(ValueError): + with pytest.raises(TypeError): _HashedDocument( metadata={"key": "value"}, - ) + ) # type: ignore def test_uid_auto_assigned_to_hash() -> None: