mirror of
https://github.com/hwchase17/langchain.git
synced 2025-10-21 17:25:34 +00:00
milvus: New langchain_milvus package and new milvus features (#21077)
New features: - New langchain_milvus package in partner - Milvus collection hybrid search retriever - Zilliz cloud pipeline retriever - Milvus Local guid - Rag-milvus template --------- Signed-off-by: ChengZi <chen.zhang@zilliz.com> Signed-off-by: Jael Gu <mengjia.gu@zilliz.com> Co-authored-by: Jael Gu <mengjia.gu@zilliz.com> Co-authored-by: Jackson <jacksonxie612@gmail.com> Co-authored-by: Erick Friis <erick@langchain.dev> Co-authored-by: Erick Friis <erickfriis@gmail.com>
This commit is contained in:
@@ -0,0 +1,160 @@
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from pymilvus import AnnSearchRequest, Collection
|
||||
from pymilvus.client.abstract import BaseRanker, SearchResult # type: ignore
|
||||
|
||||
from langchain_milvus.utils.sparse import BaseSparseEmbedding
|
||||
|
||||
|
||||
class MilvusCollectionHybridSearchRetriever(BaseRetriever):
|
||||
"""This is a hybrid search retriever
|
||||
that uses Milvus Collection to retrieve documents based on multiple fields.
|
||||
For more information, please refer to:
|
||||
https://milvus.io/docs/release_notes.md#Multi-Embedding---Hybrid-Search
|
||||
"""
|
||||
|
||||
collection: Collection
|
||||
"""Milvus Collection object."""
|
||||
rerank: BaseRanker
|
||||
"""Milvus ranker object. Such as WeightedRanker or RRFRanker."""
|
||||
anns_fields: List[str]
|
||||
"""The names of vector fields that are used for ANNS search."""
|
||||
field_embeddings: List[Union[Embeddings, BaseSparseEmbedding]]
|
||||
"""The embedding functions of each vector fields,
|
||||
which can be either Embeddings or BaseSparseEmbedding."""
|
||||
field_search_params: Optional[List[Dict]] = None
|
||||
"""The search parameters of each vector fields.
|
||||
If not specified, the default search parameters will be used."""
|
||||
field_limits: Optional[List[int]] = None
|
||||
"""Limit number of results for each ANNS field.
|
||||
If not specified, the default top_k will be used."""
|
||||
field_exprs: Optional[List[Optional[str]]] = None
|
||||
"""The boolean expression for filtering the search results."""
|
||||
top_k: int = 4
|
||||
"""Final top-K number of documents to retrieve."""
|
||||
text_field: str = "text"
|
||||
"""The text field name,
|
||||
which will be used as the `page_content` of a `Document` object."""
|
||||
output_fields: Optional[List[str]] = None
|
||||
"""Final output fields of the documents.
|
||||
If not specified, all fields except the vector fields will be used as output fields,
|
||||
which will be the `metadata` of a `Document` object."""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# If some parameters are not specified, set default values
|
||||
if self.field_search_params is None:
|
||||
default_search_params = {
|
||||
"metric_type": "L2",
|
||||
"params": {"nprobe": 10},
|
||||
}
|
||||
self.field_search_params = [default_search_params] * len(self.anns_fields)
|
||||
if self.field_limits is None:
|
||||
self.field_limits = [self.top_k] * len(self.anns_fields)
|
||||
if self.field_exprs is None:
|
||||
self.field_exprs = [None] * len(self.anns_fields)
|
||||
|
||||
# Check the fields
|
||||
self._validate_fields_num()
|
||||
self.output_fields = self._get_output_fields()
|
||||
self._validate_fields_name()
|
||||
|
||||
# Load collection
|
||||
self.collection.load()
|
||||
|
||||
def _validate_fields_num(self) -> None:
|
||||
assert (
|
||||
len(self.anns_fields) >= 2
|
||||
), "At least two fields are required for hybrid search."
|
||||
lengths = [len(self.anns_fields)]
|
||||
if self.field_limits is not None:
|
||||
lengths.append(len(self.field_limits))
|
||||
if self.field_exprs is not None:
|
||||
lengths.append(len(self.field_exprs))
|
||||
|
||||
if not all(length == lengths[0] for length in lengths):
|
||||
raise ValueError("All field-related lists must have the same length.")
|
||||
|
||||
if len(self.field_search_params) != len(self.anns_fields): # type: ignore[arg-type]
|
||||
raise ValueError(
|
||||
"field_search_params must have the same length as anns_fields."
|
||||
)
|
||||
|
||||
def _validate_fields_name(self) -> None:
|
||||
collection_fields = [x.name for x in self.collection.schema.fields]
|
||||
for field in self.anns_fields:
|
||||
assert (
|
||||
field in collection_fields
|
||||
), f"{field} is not a valid field in the collection."
|
||||
assert (
|
||||
self.text_field in collection_fields
|
||||
), f"{self.text_field} is not a valid field in the collection."
|
||||
for field in self.output_fields: # type: ignore[union-attr]
|
||||
assert (
|
||||
field in collection_fields
|
||||
), f"{field} is not a valid field in the collection."
|
||||
|
||||
def _get_output_fields(self) -> List[str]:
|
||||
if self.output_fields:
|
||||
return self.output_fields
|
||||
output_fields = [x.name for x in self.collection.schema.fields]
|
||||
for field in self.anns_fields:
|
||||
if field in output_fields:
|
||||
output_fields.remove(field)
|
||||
if self.text_field not in output_fields:
|
||||
output_fields.append(self.text_field)
|
||||
return output_fields
|
||||
|
||||
def _build_ann_search_requests(self, query: str) -> List[AnnSearchRequest]:
|
||||
search_requests = []
|
||||
for ann_field, embedding, param, limit, expr in zip(
|
||||
self.anns_fields,
|
||||
self.field_embeddings,
|
||||
self.field_search_params, # type: ignore[arg-type]
|
||||
self.field_limits, # type: ignore[arg-type]
|
||||
self.field_exprs, # type: ignore[arg-type]
|
||||
):
|
||||
request = AnnSearchRequest(
|
||||
data=[embedding.embed_query(query)],
|
||||
anns_field=ann_field,
|
||||
param=param,
|
||||
limit=limit,
|
||||
expr=expr,
|
||||
)
|
||||
search_requests.append(request)
|
||||
return search_requests
|
||||
|
||||
def _parse_document(self, data: dict) -> Document:
|
||||
return Document(
|
||||
page_content=data.pop(self.text_field),
|
||||
metadata=data,
|
||||
)
|
||||
|
||||
def _process_search_result(
|
||||
self, search_results: List[SearchResult]
|
||||
) -> List[Document]:
|
||||
documents = []
|
||||
for result in search_results[0]:
|
||||
data = {x: result.entity.get(x) for x in self.output_fields} # type: ignore[union-attr]
|
||||
doc = self._parse_document(data)
|
||||
documents.append(doc)
|
||||
return documents
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
requests = self._build_ann_search_requests(query)
|
||||
search_result = self.collection.hybrid_search(
|
||||
requests, self.rerank, limit=self.top_k, output_fields=self.output_fields
|
||||
)
|
||||
documents = self._process_search_result(search_result)
|
||||
return documents
|
Reference in New Issue
Block a user