mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-18 10:43:36 +00:00
core[patch]: doc init positional args (#16854)
This commit is contained in:
parent
d80c612c92
commit
2a510c71a0
@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import List, Literal
|
from typing import Any, List, Literal
|
||||||
|
|
||||||
from langchain_core.load.serializable import Serializable
|
from langchain_core.load.serializable import Serializable
|
||||||
from langchain_core.pydantic_v1 import Field
|
from langchain_core.pydantic_v1 import Field
|
||||||
@ -17,6 +17,10 @@ class Document(Serializable):
|
|||||||
"""
|
"""
|
||||||
type: Literal["Document"] = "Document"
|
type: Literal["Document"] = "Document"
|
||||||
|
|
||||||
|
def __init__(self, page_content: str, **kwargs: Any) -> None:
|
||||||
|
"""Pass page_content in as positional or named arg."""
|
||||||
|
super().__init__(page_content=page_content, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_lc_serializable(cls) -> bool:
|
def is_lc_serializable(cls) -> bool:
|
||||||
"""Return whether this class is serializable."""
|
"""Return whether this class is serializable."""
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Any, Callable, Dict, Optional, Sequence
|
from typing import Any, Callable, Dict, Optional, Sequence, cast
|
||||||
|
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
@ -67,7 +67,9 @@ class LLMChainExtractor(BaseDocumentCompressor):
|
|||||||
output = self.llm_chain.predict_and_parse(**_input, callbacks=callbacks)
|
output = self.llm_chain.predict_and_parse(**_input, callbacks=callbacks)
|
||||||
if len(output) == 0:
|
if len(output) == 0:
|
||||||
continue
|
continue
|
||||||
compressed_docs.append(Document(page_content=output, metadata=doc.metadata))
|
compressed_docs.append(
|
||||||
|
Document(page_content=cast(str, output), metadata=doc.metadata)
|
||||||
|
)
|
||||||
return compressed_docs
|
return compressed_docs
|
||||||
|
|
||||||
async def acompress_documents(
|
async def acompress_documents(
|
||||||
|
@ -3,7 +3,7 @@ Ensemble retriever that ensemble the results of
|
|||||||
multiple retrievers by using weighted Reciprocal Rank Fusion
|
multiple retrievers by using weighted Reciprocal Rank Fusion
|
||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional, cast
|
||||||
|
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
AsyncCallbackManagerForRetrieverRun,
|
||||||
@ -195,7 +195,7 @@ class EnsembleRetriever(BaseRetriever):
|
|||||||
# Enforce that retrieved docs are Documents for each list in retriever_docs
|
# Enforce that retrieved docs are Documents for each list in retriever_docs
|
||||||
for i in range(len(retriever_docs)):
|
for i in range(len(retriever_docs)):
|
||||||
retriever_docs[i] = [
|
retriever_docs[i] = [
|
||||||
Document(page_content=doc) if not isinstance(doc, Document) else doc
|
Document(page_content=cast(str, doc)) if isinstance(doc, str) else doc
|
||||||
for doc in retriever_docs[i]
|
for doc in retriever_docs[i]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -13,10 +13,10 @@ def test_hashed_document_hashing() -> None:
|
|||||||
|
|
||||||
def test_hashing_with_missing_content() -> None:
|
def test_hashing_with_missing_content() -> None:
|
||||||
"""Check that ValueError is raised if page_content is missing."""
|
"""Check that ValueError is raised if page_content is missing."""
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(TypeError):
|
||||||
_HashedDocument(
|
_HashedDocument(
|
||||||
metadata={"key": "value"},
|
metadata={"key": "value"},
|
||||||
)
|
) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def test_uid_auto_assigned_to_hash() -> None:
|
def test_uid_auto_assigned_to_hash() -> None:
|
||||||
|
Loading…
Reference in New Issue
Block a user