mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 07:35:18 +00:00
fix map reduce chain (#550)
This commit is contained in:
parent
ba0cbb4a41
commit
330a5b42d4
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra, root_validator
|
from pydantic import BaseModel, Extra, root_validator
|
||||||
|
|
||||||
@ -11,6 +11,13 @@ from langchain.chains.llm import LLMChain
|
|||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
|
|
||||||
|
|
||||||
|
class CombineDocsProtocol(Protocol):
|
||||||
|
"""Interface for the combine_docs method."""
|
||||||
|
|
||||||
|
def __call__(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
|
||||||
|
"""Interface for the combine_docs method."""
|
||||||
|
|
||||||
|
|
||||||
def _split_list_of_docs(
|
def _split_list_of_docs(
|
||||||
docs: List[Document], length_func: Callable, token_max: int, **kwargs: Any
|
docs: List[Document], length_func: Callable, token_max: int, **kwargs: Any
|
||||||
) -> List[List[Document]]:
|
) -> List[List[Document]]:
|
||||||
@ -38,10 +45,10 @@ def _split_list_of_docs(
|
|||||||
|
|
||||||
def _collapse_docs(
|
def _collapse_docs(
|
||||||
docs: List[Document],
|
docs: List[Document],
|
||||||
combine_document_func: Callable,
|
combine_document_func: CombineDocsProtocol,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Document:
|
) -> Document:
|
||||||
result = combine_document_func(docs, **kwargs)
|
result, _ = combine_document_func(docs, **kwargs)
|
||||||
combined_metadata = {k: str(v) for k, v in docs[0].metadata.items()}
|
combined_metadata = {k: str(v) for k, v in docs[0].metadata.items()}
|
||||||
for doc in docs[1:]:
|
for doc in docs[1:]:
|
||||||
for k, v in doc.metadata.items():
|
for k, v in doc.metadata.items():
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
"""Test functionality related to combining documents."""
|
"""Test functionality related to combining documents."""
|
||||||
|
|
||||||
from typing import List
|
from typing import Any, List, Tuple
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -12,11 +12,11 @@ from langchain.docstore.document import Document
|
|||||||
|
|
||||||
|
|
||||||
def _fake_docs_len_func(docs: List[Document]) -> int:
|
def _fake_docs_len_func(docs: List[Document]) -> int:
|
||||||
return len(_fake_combine_docs_func(docs))
|
return len(_fake_combine_docs_func(docs)[0])
|
||||||
|
|
||||||
|
|
||||||
def _fake_combine_docs_func(docs: List[Document]) -> str:
|
def _fake_combine_docs_func(docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
|
||||||
return "".join([d.page_content for d in docs])
|
return "".join([d.page_content for d in docs]), {}
|
||||||
|
|
||||||
|
|
||||||
def test__split_list_long_single_doc() -> None:
|
def test__split_list_long_single_doc() -> None:
|
||||||
|
Loading…
Reference in New Issue
Block a user