mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 07:09:31 +00:00
add token max parameter (#7204)
This commit is contained in:
parent
7b585c7585
commit
8410c6a747
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Dict, List, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from pydantic import Extra, root_validator
|
from pydantic import Extra, root_validator
|
||||||
|
|
||||||
@ -198,7 +198,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
def combine_docs(
|
def combine_docs(
|
||||||
self,
|
self,
|
||||||
docs: List[Document],
|
docs: List[Document],
|
||||||
token_max: int = 3000,
|
token_max: Optional[int] = None,
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Tuple[str, dict]:
|
) -> Tuple[str, dict]:
|
||||||
@ -229,7 +229,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
async def acombine_docs(
|
async def acombine_docs(
|
||||||
self,
|
self,
|
||||||
docs: List[Document],
|
docs: List[Document],
|
||||||
token_max: int = 3000,
|
token_max: Optional[int] = None,
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Tuple[str, dict]:
|
) -> Tuple[str, dict]:
|
||||||
|
@ -152,6 +152,10 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
"""Chain to use to collapse documents if needed until they can all fit.
|
"""Chain to use to collapse documents if needed until they can all fit.
|
||||||
If None, will use the combine_documents_chain.
|
If None, will use the combine_documents_chain.
|
||||||
This is typically a StuffDocumentsChain."""
|
This is typically a StuffDocumentsChain."""
|
||||||
|
token_max: int = 3000
|
||||||
|
"""The maximum number of tokens to group documents into. For example, if
|
||||||
|
set to 3000 then documents will be grouped into chunks of no greater than
|
||||||
|
3000 tokens before trying to combine them into a smaller chunk."""
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
@ -169,7 +173,7 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
def combine_docs(
|
def combine_docs(
|
||||||
self,
|
self,
|
||||||
docs: List[Document],
|
docs: List[Document],
|
||||||
token_max: int = 3000,
|
token_max: Optional[int] = None,
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Tuple[str, dict]:
|
) -> Tuple[str, dict]:
|
||||||
@ -198,7 +202,7 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
async def acombine_docs(
|
async def acombine_docs(
|
||||||
self,
|
self,
|
||||||
docs: List[Document],
|
docs: List[Document],
|
||||||
token_max: int = 3000,
|
token_max: Optional[int] = None,
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Tuple[str, dict]:
|
) -> Tuple[str, dict]:
|
||||||
@ -227,7 +231,7 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
def _collapse(
|
def _collapse(
|
||||||
self,
|
self,
|
||||||
docs: List[Document],
|
docs: List[Document],
|
||||||
token_max: int = 3000,
|
token_max: Optional[int] = None,
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Tuple[List[Document], dict]:
|
) -> Tuple[List[Document], dict]:
|
||||||
@ -240,9 +244,10 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
input_documents=docs, callbacks=callbacks, **kwargs
|
input_documents=docs, callbacks=callbacks, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
while num_tokens is not None and num_tokens > token_max:
|
_token_max = token_max or self.token_max
|
||||||
|
while num_tokens is not None and num_tokens > _token_max:
|
||||||
new_result_doc_list = _split_list_of_docs(
|
new_result_doc_list = _split_list_of_docs(
|
||||||
result_docs, length_func, token_max, **kwargs
|
result_docs, length_func, _token_max, **kwargs
|
||||||
)
|
)
|
||||||
result_docs = []
|
result_docs = []
|
||||||
for docs in new_result_doc_list:
|
for docs in new_result_doc_list:
|
||||||
@ -254,7 +259,7 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
async def _acollapse(
|
async def _acollapse(
|
||||||
self,
|
self,
|
||||||
docs: List[Document],
|
docs: List[Document],
|
||||||
token_max: int = 3000,
|
token_max: Optional[int] = None,
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Tuple[List[Document], dict]:
|
) -> Tuple[List[Document], dict]:
|
||||||
@ -267,9 +272,10 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
input_documents=docs, callbacks=callbacks, **kwargs
|
input_documents=docs, callbacks=callbacks, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
while num_tokens is not None and num_tokens > token_max:
|
_token_max = token_max or self.token_max
|
||||||
|
while num_tokens is not None and num_tokens > _token_max:
|
||||||
new_result_doc_list = _split_list_of_docs(
|
new_result_doc_list = _split_list_of_docs(
|
||||||
result_docs, length_func, token_max, **kwargs
|
result_docs, length_func, _token_max, **kwargs
|
||||||
)
|
)
|
||||||
result_docs = []
|
result_docs = []
|
||||||
for docs in new_result_doc_list:
|
for docs in new_result_doc_list:
|
||||||
|
@ -79,6 +79,7 @@ def _load_map_reduce_chain(
|
|||||||
reduce_llm: Optional[BaseLanguageModel] = None,
|
reduce_llm: Optional[BaseLanguageModel] = None,
|
||||||
collapse_llm: Optional[BaseLanguageModel] = None,
|
collapse_llm: Optional[BaseLanguageModel] = None,
|
||||||
verbose: Optional[bool] = None,
|
verbose: Optional[bool] = None,
|
||||||
|
token_max: int = 3000,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> MapReduceDocumentsChain:
|
) -> MapReduceDocumentsChain:
|
||||||
map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
|
map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
|
||||||
@ -111,6 +112,8 @@ def _load_map_reduce_chain(
|
|||||||
reduce_documents_chain = ReduceDocumentsChain(
|
reduce_documents_chain = ReduceDocumentsChain(
|
||||||
combine_documents_chain=combine_documents_chain,
|
combine_documents_chain=combine_documents_chain,
|
||||||
collapse_documents_chain=collapse_chain,
|
collapse_documents_chain=collapse_chain,
|
||||||
|
token_max=token_max,
|
||||||
|
verbose=verbose,
|
||||||
)
|
)
|
||||||
return MapReduceDocumentsChain(
|
return MapReduceDocumentsChain(
|
||||||
llm_chain=map_chain,
|
llm_chain=map_chain,
|
||||||
|
@ -99,6 +99,7 @@ def _load_map_reduce_chain(
|
|||||||
verbose: Optional[bool] = None,
|
verbose: Optional[bool] = None,
|
||||||
callback_manager: Optional[BaseCallbackManager] = None,
|
callback_manager: Optional[BaseCallbackManager] = None,
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
|
token_max: int = 3000,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> MapReduceDocumentsChain:
|
) -> MapReduceDocumentsChain:
|
||||||
_question_prompt = (
|
_question_prompt = (
|
||||||
@ -154,6 +155,8 @@ def _load_map_reduce_chain(
|
|||||||
reduce_documents_chain = ReduceDocumentsChain(
|
reduce_documents_chain = ReduceDocumentsChain(
|
||||||
combine_documents_chain=combine_documents_chain,
|
combine_documents_chain=combine_documents_chain,
|
||||||
collapse_documents_chain=collapse_chain,
|
collapse_documents_chain=collapse_chain,
|
||||||
|
token_max=token_max,
|
||||||
|
verbose=verbose,
|
||||||
)
|
)
|
||||||
return MapReduceDocumentsChain(
|
return MapReduceDocumentsChain(
|
||||||
llm_chain=map_chain,
|
llm_chain=map_chain,
|
||||||
|
@ -48,6 +48,7 @@ def _load_map_reduce_chain(
|
|||||||
reduce_llm: Optional[BaseLanguageModel] = None,
|
reduce_llm: Optional[BaseLanguageModel] = None,
|
||||||
collapse_llm: Optional[BaseLanguageModel] = None,
|
collapse_llm: Optional[BaseLanguageModel] = None,
|
||||||
verbose: Optional[bool] = None,
|
verbose: Optional[bool] = None,
|
||||||
|
token_max: int = 3000,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> MapReduceDocumentsChain:
|
) -> MapReduceDocumentsChain:
|
||||||
map_chain = LLMChain(llm=llm, prompt=map_prompt, verbose=verbose)
|
map_chain = LLMChain(llm=llm, prompt=map_prompt, verbose=verbose)
|
||||||
@ -79,6 +80,8 @@ def _load_map_reduce_chain(
|
|||||||
reduce_documents_chain = ReduceDocumentsChain(
|
reduce_documents_chain = ReduceDocumentsChain(
|
||||||
combine_documents_chain=combine_documents_chain,
|
combine_documents_chain=combine_documents_chain,
|
||||||
collapse_documents_chain=collapse_chain,
|
collapse_documents_chain=collapse_chain,
|
||||||
|
token_max=token_max,
|
||||||
|
verbose=verbose,
|
||||||
)
|
)
|
||||||
return MapReduceDocumentsChain(
|
return MapReduceDocumentsChain(
|
||||||
llm_chain=map_chain,
|
llm_chain=map_chain,
|
||||||
|
Loading…
Reference in New Issue
Block a user