Add knn and query search field options to ElasticKnnSearch (#5641)

in the `ElasticKnnSearch` class added 2 arguments that were not exposed
properly

`knn_search` added:
- `vector_query_field: Optional[str] = 'vector'`
-- vector_query_field: Field name to use in knn search if not default
'vector'

`knn_hybrid_search` added:
- `vector_query_field: Optional[str] = 'vector'`
-- vector_query_field: Field name to use in knn search if not default
'vector'
- `query_field: Optional[str] = 'text'`
-- query_field: Field name to use in search if not default 'text'



Fixes # https://github.com/hwchase17/langchain/issues/5633


cc: @dev2049 @hwchase17

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
Jeff Vestal 2023-06-07 22:19:14 -05:00 committed by GitHub
parent cef79ca579
commit 3294774148
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -334,6 +334,8 @@ class ElasticKnnSearch(ElasticVectorSearch):
es_cloud_id: Optional[str] = None, es_cloud_id: Optional[str] = None,
es_user: Optional[str] = None, es_user: Optional[str] = None,
es_password: Optional[str] = None, es_password: Optional[str] = None,
vector_query_field: Optional[str] = "vector",
query_field: Optional[str] = "text",
): ):
""" """
Initializes an instance of the ElasticKnnSearch class and sets up the Initializes an instance of the ElasticKnnSearch class and sets up the
@ -362,6 +364,8 @@ class ElasticKnnSearch(ElasticVectorSearch):
self.embedding = embedding self.embedding = embedding
self.index_name = index_name self.index_name = index_name
self.query_field = query_field
self.vector_query_field = vector_query_field
# If a pre-existing Elasticsearch connection is provided, use it. # If a pre-existing Elasticsearch connection is provided, use it.
if es_connection is not None: if es_connection is not None:
@ -394,17 +398,16 @@ class ElasticKnnSearch(ElasticVectorSearch):
} }
} }
@staticmethod
def _default_knn_query( def _default_knn_query(
self,
query_vector: Optional[List[float]] = None, query_vector: Optional[List[float]] = None,
query: Optional[str] = None, query: Optional[str] = None,
model_id: Optional[str] = None, model_id: Optional[str] = None,
field: Optional[str] = "vector",
k: Optional[int] = 10, k: Optional[int] = 10,
num_candidates: Optional[int] = 10, num_candidates: Optional[int] = 10,
) -> Dict: ) -> Dict:
knn: Dict = { knn: Dict = {
"field": field, "field": self.vector_query_field,
"k": k, "k": k,
"num_candidates": num_candidates, "num_candidates": num_candidates,
} }
@ -462,6 +465,7 @@ class ElasticKnnSearch(ElasticVectorSearch):
source: Whether to include the source of each hit in the results. source: Whether to include the source of each hit in the results.
fields: The fields to include in the source of each hit. If None, all fields: The fields to include in the source of each hit. If None, all
fields are included. fields are included.
vector_query_field: Field name to use in knn search if not default 'vector'
Returns: Returns:
The search results. The search results.
@ -524,6 +528,8 @@ class ElasticKnnSearch(ElasticVectorSearch):
fields fields
The fields to include in the source of each hit. If None, all fields are The fields to include in the source of each hit. If None, all fields are
included. Defaults to None. included. Defaults to None.
vector_query_field: Field name to use in knn search if not default 'vector'
query_field: Field name to use in search if not default 'text'
Returns: Returns:
The search results. The search results.
@ -541,7 +547,9 @@ class ElasticKnnSearch(ElasticVectorSearch):
knn_query_body["boost"] = knn_boost knn_query_body["boost"] = knn_boost
# Generate the body of the standard Elasticsearch query # Generate the body of the standard Elasticsearch query
match_query_body = {"match": {"text": {"query": query, "boost": query_boost}}} match_query_body = {
"match": {self.query_field: {"query": query, "boost": query_boost}}
}
# Perform the hybrid search on the Elasticsearch index and return the results. # Perform the hybrid search on the Elasticsearch index and return the results.
res = self.client.search( res = self.client.search(