mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-04 12:18:24 +00:00
add search kwargs (#664)
This commit is contained in:
parent
65f3a341b0
commit
983b73f47c
@ -1,7 +1,7 @@
|
|||||||
"""Question-answering with sources over a vector database."""
|
"""Question-answering with sources over a vector database."""
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain
|
from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
@ -15,8 +15,8 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain, BaseModel):
|
|||||||
"""Vector Database to connect to."""
|
"""Vector Database to connect to."""
|
||||||
k: int = 4
|
k: int = 4
|
||||||
"""Number of results to return from store"""
|
"""Number of results to return from store"""
|
||||||
search_kwargs: Dict[str, Any] = {}
|
search_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
"""Extra search args"""
|
"""Extra search args."""
|
||||||
|
|
||||||
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
|
||||||
question = inputs[self.question_key]
|
question = inputs[self.question_key]
|
||||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra, root_validator
|
from pydantic import BaseModel, Extra, Field, root_validator
|
||||||
|
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||||
@ -39,6 +39,8 @@ class VectorDBQA(Chain, BaseModel):
|
|||||||
output_key: str = "result" #: :meta private:
|
output_key: str = "result" #: :meta private:
|
||||||
return_source_documents: bool = False
|
return_source_documents: bool = False
|
||||||
"""Return the source documents."""
|
"""Return the source documents."""
|
||||||
|
search_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
"""Extra search args."""
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
@ -127,7 +129,9 @@ class VectorDBQA(Chain, BaseModel):
|
|||||||
"""
|
"""
|
||||||
question = inputs[self.input_key]
|
question = inputs[self.input_key]
|
||||||
|
|
||||||
docs = self.vectorstore.similarity_search(question, k=self.k)
|
docs = self.vectorstore.similarity_search(
|
||||||
|
question, k=self.k, **self.search_kwargs
|
||||||
|
)
|
||||||
answer, _ = self.combine_documents_chain.combine_docs(docs, question=question)
|
answer, _ = self.combine_documents_chain.combine_docs(docs, question=question)
|
||||||
|
|
||||||
if self.return_source_documents:
|
if self.return_source_documents:
|
||||||
|
@ -26,7 +26,9 @@ class VectorStore(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def similarity_search(self, query: str, k: int = 4) -> List[Document]:
|
def similarity_search(
|
||||||
|
self, query: str, k: int = 4, **kwargs: Any
|
||||||
|
) -> List[Document]:
|
||||||
"""Return docs most similar to query."""
|
"""Return docs most similar to query."""
|
||||||
|
|
||||||
def max_marginal_relevance_search(
|
def max_marginal_relevance_search(
|
||||||
|
@ -106,7 +106,9 @@ class ElasticVectorSearch(VectorStore):
|
|||||||
self.client.indices.refresh(index=self.index_name)
|
self.client.indices.refresh(index=self.index_name)
|
||||||
return ids
|
return ids
|
||||||
|
|
||||||
def similarity_search(self, query: str, k: int = 4) -> List[Document]:
|
def similarity_search(
|
||||||
|
self, query: str, k: int = 4, **kwargs: Any
|
||||||
|
) -> List[Document]:
|
||||||
"""Return docs most similar to query.
|
"""Return docs most similar to query.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -103,7 +103,9 @@ class FAISS(VectorStore):
|
|||||||
docs.append((doc, scores[0][j]))
|
docs.append((doc, scores[0][j]))
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
def similarity_search(self, query: str, k: int = 4) -> List[Document]:
|
def similarity_search(
|
||||||
|
self, query: str, k: int = 4, **kwargs: Any
|
||||||
|
) -> List[Document]:
|
||||||
"""Return docs most similar to query.
|
"""Return docs most similar to query.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -120,6 +120,7 @@ class Pinecone(VectorStore):
|
|||||||
k: int = 5,
|
k: int = 5,
|
||||||
filter: Optional[dict] = None,
|
filter: Optional[dict] = None,
|
||||||
namespace: Optional[str] = None,
|
namespace: Optional[str] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
"""Return pinecone documents most similar to query.
|
"""Return pinecone documents most similar to query.
|
||||||
|
|
||||||
|
@ -71,7 +71,9 @@ class Weaviate(VectorStore):
|
|||||||
ids.append(_id)
|
ids.append(_id)
|
||||||
return ids
|
return ids
|
||||||
|
|
||||||
def similarity_search(self, query: str, k: int = 4) -> List[Document]:
|
def similarity_search(
|
||||||
|
self, query: str, k: int = 4, **kwargs: Any
|
||||||
|
) -> List[Document]:
|
||||||
"""Look up similar documents in weaviate."""
|
"""Look up similar documents in weaviate."""
|
||||||
content = {"concepts": [query]}
|
content = {"concepts": [query]}
|
||||||
query_obj = self._client.query.get(self._index_name, self._query_attrs)
|
query_obj = self._client.query.get(self._index_name, self._query_attrs)
|
||||||
|
Loading…
Reference in New Issue
Block a user