mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 00:23:25 +00:00
Add params to reduce K dynamically to reduce it below token limit (#739)
Referring to #687, I implemented the functionality to reduce K if it exceeds the token limit. Edit: I should have ran make lint locally. Also, this only applies to `StuffDocumentChain`
This commit is contained in:
parent
d2f882158f
commit
28efbb05bf
@ -1,8 +1,10 @@
|
|||||||
"""Question-answering with sources over a vector database."""
|
"""Question-answering with sources over a vector database."""
|
||||||
|
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||||
from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain
|
from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.vectorstores.base import VectorStore
|
from langchain.vectorstores.base import VectorStore
|
||||||
@ -15,11 +17,36 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain, BaseModel):
|
|||||||
"""Vector Database to connect to."""
|
"""Vector Database to connect to."""
|
||||||
k: int = 4
|
k: int = 4
|
||||||
"""Number of results to return from store"""
|
"""Number of results to return from store"""
|
||||||
|
reduce_k_below_max_tokens: bool = False
|
||||||
|
"""Reduce the number of results to return from store based on tokens limit"""
|
||||||
|
max_tokens_limit: int = 3375
|
||||||
|
"""Restrict the docs to return from store based on tokens,
|
||||||
|
enforced only for StuffDocumentChain and if reduce_k_below_max_tokens is to true"""
|
||||||
search_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
search_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
"""Extra search args."""
|
"""Extra search args."""
|
||||||
|
|
||||||
|
def _reduce_tokens_below_limit(self, docs: List[Document]) -> List[Document]:
|
||||||
|
num_docs = len(docs)
|
||||||
|
|
||||||
|
if self.reduce_k_below_max_tokens and isinstance(
|
||||||
|
self.combine_documents_chain, StuffDocumentsChain
|
||||||
|
):
|
||||||
|
tokens = [
|
||||||
|
self.combine_documents_chain.llm_chain.llm.get_num_tokens(
|
||||||
|
doc.page_content
|
||||||
|
)
|
||||||
|
for doc in docs
|
||||||
|
]
|
||||||
|
token_count = sum(tokens[:num_docs])
|
||||||
|
while token_count > self.max_tokens_limit:
|
||||||
|
num_docs -= 1
|
||||||
|
token_count -= tokens[num_docs]
|
||||||
|
|
||||||
|
return docs[:num_docs]
|
||||||
|
|
||||||
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||||
question = inputs[self.question_key]
|
question = inputs[self.question_key]
|
||||||
return self.vectorstore.similarity_search(
|
docs = self.vectorstore.similarity_search(
|
||||||
question, k=self.k, **self.search_kwargs
|
question, k=self.k, **self.search_kwargs
|
||||||
)
|
)
|
||||||
|
return self._reduce_tokens_below_limit(docs)
|
||||||
|
Loading…
Reference in New Issue
Block a user