mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-17 02:03:44 +00:00
core[patch]: doc init positional args (#16854)
This commit is contained in:
parent
d80c612c92
commit
2a510c71a0
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Literal
|
||||
from typing import Any, List, Literal
|
||||
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
@ -17,6 +17,10 @@ class Document(Serializable):
|
||||
"""
|
||||
type: Literal["Document"] = "Document"
|
||||
|
||||
def __init__(self, page_content: str, **kwargs: Any) -> None:
|
||||
"""Pass page_content in as positional or named arg."""
|
||||
super().__init__(page_content=page_content, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether this class is serializable."""
|
||||
|
@ -2,7 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Callable, Dict, Optional, Sequence
|
||||
from typing import Any, Callable, Dict, Optional, Sequence, cast
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
@ -67,7 +67,9 @@ class LLMChainExtractor(BaseDocumentCompressor):
|
||||
output = self.llm_chain.predict_and_parse(**_input, callbacks=callbacks)
|
||||
if len(output) == 0:
|
||||
continue
|
||||
compressed_docs.append(Document(page_content=output, metadata=doc.metadata))
|
||||
compressed_docs.append(
|
||||
Document(page_content=cast(str, output), metadata=doc.metadata)
|
||||
)
|
||||
return compressed_docs
|
||||
|
||||
async def acompress_documents(
|
||||
|
@ -3,7 +3,7 @@ Ensemble retriever that ensemble the results of
|
||||
multiple retrievers by using weighted Reciprocal Rank Fusion
|
||||
"""
|
||||
import asyncio
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
@ -195,7 +195,7 @@ class EnsembleRetriever(BaseRetriever):
|
||||
# Enforce that retrieved docs are Documents for each list in retriever_docs
|
||||
for i in range(len(retriever_docs)):
|
||||
retriever_docs[i] = [
|
||||
Document(page_content=doc) if not isinstance(doc, Document) else doc
|
||||
Document(page_content=cast(str, doc)) if isinstance(doc, str) else doc
|
||||
for doc in retriever_docs[i]
|
||||
]
|
||||
|
||||
|
@ -13,10 +13,10 @@ def test_hashed_document_hashing() -> None:
|
||||
|
||||
def test_hashing_with_missing_content() -> None:
|
||||
"""Check that ValueError is raised if page_content is missing."""
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(TypeError):
|
||||
_HashedDocument(
|
||||
metadata={"key": "value"},
|
||||
)
|
||||
) # type: ignore
|
||||
|
||||
|
||||
def test_uid_auto_assigned_to_hash() -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user