mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-16 01:37:59 +00:00
partners/milvus: allow creating a vectorstore with sparse embeddings (#25284)
# Description Milvus (and `pymilvus`) recently added the option to use [sparse vectors](https://milvus.io/docs/sparse_vector.md#Sparse-Vector) with appropriate search methods (e.g., `SPARSE_INVERTED_INDEX`) and embeddings (e.g., `BM25`, `SPLADE`). This PR allow creating a vector store using langchain's `Milvus` class, setting the matching vector field type to `DataType.SPARSE_FLOAT_VECTOR` and the default index type to `SPARSE_INVERTED_INDEX`. It is only extending functionality, and backward compatible. ## Note I also interested in extending the Milvus class further to support multi vector search (aka hybrid search). Will be happy to discuss that. See [here](https://github.com/langchain-ai/langchain/discussions/19955), [here](https://github.com/langchain-ai/langchain/pull/20375), and [here](https://github.com/langchain-ai/langchain/discussions/22886) similar needs. --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
09b04c7e3b
commit
b5d670498f
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Iterable, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import numpy as np
|
||||
@ -9,6 +9,8 @@ from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
from langchain_milvus.utils.sparse import BaseSparseEmbedding
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_MILVUS_CONNECTION = {
|
||||
@ -110,7 +112,7 @@ class Milvus(VectorStore):
|
||||
Name of the collection.
|
||||
collection_description: str
|
||||
Description of the collection.
|
||||
embedding_function: Embeddings
|
||||
embedding_function: Union[Embeddings, BaseSparseEmbedding]
|
||||
Embedding function to use.
|
||||
|
||||
Key init args — client params:
|
||||
@ -219,7 +221,7 @@ class Milvus(VectorStore):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_function: Embeddings,
|
||||
embedding_function: Union[Embeddings, BaseSparseEmbedding], # type: ignore
|
||||
collection_name: str = "LangChainCollection",
|
||||
collection_description: str = "",
|
||||
collection_properties: Optional[dict[str, Any]] = None,
|
||||
@ -276,6 +278,11 @@ class Milvus(VectorStore):
|
||||
},
|
||||
"GPU_IVF_FLAT": {"metric_type": "L2", "params": {"nprobe": 10}},
|
||||
"GPU_IVF_PQ": {"metric_type": "L2", "params": {"nprobe": 10}},
|
||||
"SPARSE_INVERTED_INDEX": {
|
||||
"metric_type": "IP",
|
||||
"params": {"drop_ratio_build": 0.2},
|
||||
},
|
||||
"SPARSE_WAND": {"metric_type": "IP", "params": {"drop_ratio_build": 0.2}},
|
||||
}
|
||||
|
||||
self.embedding_func = embedding_function
|
||||
@ -340,7 +347,7 @@ class Milvus(VectorStore):
|
||||
)
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Embeddings:
|
||||
def embeddings(self) -> Union[Embeddings, BaseSparseEmbedding]: # type: ignore
|
||||
return self.embedding_func
|
||||
|
||||
def _create_connection_alias(self, connection_args: dict) -> str:
|
||||
@ -402,6 +409,10 @@ class Milvus(VectorStore):
|
||||
logger.error("Failed to create new connection using: %s", alias)
|
||||
raise e
|
||||
|
||||
@property
|
||||
def _is_sparse_embedding(self) -> bool:
|
||||
return isinstance(self.embedding_func, BaseSparseEmbedding)
|
||||
|
||||
def _init(
|
||||
self,
|
||||
embeddings: Optional[list] = None,
|
||||
@ -539,9 +550,14 @@ class Milvus(VectorStore):
|
||||
)
|
||||
)
|
||||
# Create the vector field, supports binary or float vectors
|
||||
fields.append(
|
||||
FieldSchema(self._vector_field, infer_dtype_bydata(embeddings[0]), dim=dim)
|
||||
)
|
||||
if self._is_sparse_embedding:
|
||||
fields.append(FieldSchema(self._vector_field, DataType.SPARSE_FLOAT_VECTOR))
|
||||
else:
|
||||
fields.append(
|
||||
FieldSchema(
|
||||
self._vector_field, infer_dtype_bydata(embeddings[0]), dim=dim
|
||||
)
|
||||
)
|
||||
|
||||
# Create the schema for the collection
|
||||
schema = CollectionSchema(
|
||||
@ -606,11 +622,18 @@ class Milvus(VectorStore):
|
||||
try:
|
||||
# If no index params, use a default HNSW based one
|
||||
if self.index_params is None:
|
||||
self.index_params = {
|
||||
"metric_type": "L2",
|
||||
"index_type": "HNSW",
|
||||
"params": {"M": 8, "efConstruction": 64},
|
||||
}
|
||||
if self._is_sparse_embedding:
|
||||
self.index_params = {
|
||||
"metric_type": "IP",
|
||||
"index_type": "SPARSE_INVERTED_INDEX",
|
||||
"params": {"drop_ratio_build": 0.2},
|
||||
}
|
||||
else:
|
||||
self.index_params = {
|
||||
"metric_type": "L2",
|
||||
"index_type": "HNSW",
|
||||
"params": {"M": 8, "efConstruction": 64},
|
||||
}
|
||||
|
||||
try:
|
||||
self.col.create_index(
|
||||
@ -740,7 +763,7 @@ class Milvus(VectorStore):
|
||||
)
|
||||
|
||||
try:
|
||||
embeddings = self.embedding_func.embed_documents(texts)
|
||||
embeddings: list = self.embedding_func.embed_documents(texts)
|
||||
except NotImplementedError:
|
||||
embeddings = [self.embedding_func.embed_query(x) for x in texts]
|
||||
|
||||
@ -815,7 +838,7 @@ class Milvus(VectorStore):
|
||||
|
||||
def _collection_search(
|
||||
self,
|
||||
embedding: List[float],
|
||||
embedding: List[float] | Dict[int, float],
|
||||
k: int = 4,
|
||||
param: Optional[dict] = None,
|
||||
expr: Optional[str] = None,
|
||||
@ -829,7 +852,8 @@ class Milvus(VectorStore):
|
||||
https://milvus.io/api-reference/pymilvus/v2.4.x/ORM/Collection/search.md
|
||||
|
||||
Args:
|
||||
embedding (List[float]): The embedding vector being searched.
|
||||
embedding (List[float] | Dict[int, float]): The embedding vector being
|
||||
searched.
|
||||
k (int, optional): The amount of results to return. Defaults to 4.
|
||||
param (dict): The search params for the specified index.
|
||||
Defaults to None.
|
||||
@ -976,7 +1000,7 @@ class Milvus(VectorStore):
|
||||
|
||||
def similarity_search_with_score_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
embedding: List[float] | Dict[int, float],
|
||||
k: int = 4,
|
||||
param: Optional[dict] = None,
|
||||
expr: Optional[str] = None,
|
||||
@ -990,7 +1014,8 @@ class Milvus(VectorStore):
|
||||
https://milvus.io/api-reference/pymilvus/v2.4.x/ORM/Collection/search.md
|
||||
|
||||
Args:
|
||||
embedding (List[float]): The embedding vector being searched.
|
||||
embedding (List[float] | Dict[int, float]): The embedding vector being
|
||||
searched.
|
||||
k (int, optional): The amount of results to return. Defaults to 4.
|
||||
param (dict): The search params for the specified index.
|
||||
Defaults to None.
|
||||
@ -1068,7 +1093,7 @@ class Milvus(VectorStore):
|
||||
|
||||
def max_marginal_relevance_search_by_vector(
|
||||
self,
|
||||
embedding: list[float],
|
||||
embedding: list[float] | dict[int, float],
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
@ -1080,7 +1105,8 @@ class Milvus(VectorStore):
|
||||
"""Perform a search and return results that are reordered by MMR.
|
||||
|
||||
Args:
|
||||
embedding (str): The embedding vector being searched.
|
||||
embedding (list[float] | dict[int, float]): The embedding vector being
|
||||
searched.
|
||||
k (int, optional): How many results to give. Defaults to 4.
|
||||
fetch_k (int, optional): Total results to select k from.
|
||||
Defaults to 20.
|
||||
@ -1171,7 +1197,7 @@ class Milvus(VectorStore):
|
||||
def from_texts(
|
||||
cls,
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
embedding: Union[Embeddings, BaseSparseEmbedding], # type: ignore
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
collection_name: str = "LangChainCollection",
|
||||
connection_args: dict[str, Any] = DEFAULT_MILVUS_CONNECTION,
|
||||
@ -1187,7 +1213,7 @@ class Milvus(VectorStore):
|
||||
|
||||
Args:
|
||||
texts (List[str]): Text data.
|
||||
embedding (Embeddings): Embedding function.
|
||||
embedding (Union[Embeddings, BaseSparseEmbedding]): Embedding function.
|
||||
metadatas (Optional[List[dict]]): Metadata for each text if it exists.
|
||||
Defaults to None.
|
||||
collection_name (str, optional): Collection name to use. Defaults to
|
||||
|
@ -1,10 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from langchain_milvus.utils.sparse import BaseSparseEmbedding
|
||||
from langchain_milvus.vectorstores.milvus import Milvus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -141,7 +142,7 @@ class Zilliz(Milvus):
|
||||
def from_texts(
|
||||
cls,
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
embedding: Union[Embeddings, BaseSparseEmbedding],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
collection_name: str = "LangChainCollection",
|
||||
connection_args: Optional[Dict[str, Any]] = None,
|
||||
|
@ -5,6 +5,7 @@ from typing import Any, List, Optional
|
||||
import pytest
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain_milvus.utils.sparse import BM25SparseEmbedding
|
||||
from langchain_milvus.vectorstores import Milvus
|
||||
from tests.integration_tests.utils import (
|
||||
FakeEmbeddings,
|
||||
@ -304,6 +305,31 @@ def test_milvus_enable_dynamic_field_with_partition_key() -> None:
|
||||
}
|
||||
|
||||
|
||||
def test_milvus_sparse_embeddings() -> None:
|
||||
texts = [
|
||||
"In 'The Clockwork Kingdom' by Augusta Wynter, a brilliant inventor discovers "
|
||||
"a hidden world of clockwork machines and ancient magic, where a rebellion is "
|
||||
"brewing against the tyrannical ruler of the land.",
|
||||
"In 'The Phantom Pilgrim' by Rowan Welles, a charismatic smuggler is hired by "
|
||||
"a mysterious organization to transport a valuable artifact across a war-torn "
|
||||
"continent, but soon finds themselves pursued by assassins and rival factions.",
|
||||
"In 'The Dreamwalker's Journey' by Lyra Snow, a young dreamwalker discovers "
|
||||
"she has the ability to enter people's dreams, but soon finds herself trapped "
|
||||
"in a surreal world of nightmares and illusions, where the boundaries between "
|
||||
"reality and fantasy blur.",
|
||||
]
|
||||
sparse_embedding_func = BM25SparseEmbedding(corpus=texts)
|
||||
docsearch = Milvus.from_texts(
|
||||
embedding=sparse_embedding_func,
|
||||
texts=texts,
|
||||
connection_args={"uri": "./milvus_demo.db"},
|
||||
drop_old=True,
|
||||
)
|
||||
|
||||
output = docsearch.similarity_search("Pilgrim", k=1)
|
||||
assert "Pilgrim" in output[0].page_content
|
||||
|
||||
|
||||
def test_milvus_array_field() -> None:
|
||||
"""Manually specify metadata schema, including an array_field.
|
||||
For more information about array data type and filtering, please refer to
|
||||
@ -365,4 +391,6 @@ def test_milvus_array_field() -> None:
|
||||
# test_milvus_enable_dynamic_field()
|
||||
# test_milvus_disable_dynamic_field()
|
||||
# test_milvus_metadata_field()
|
||||
# test_milvus_enable_dynamic_field_with_partition_key()
|
||||
# test_milvus_sparse_embeddings()
|
||||
# test_milvus_array_field()
|
||||
|
Loading…
Reference in New Issue
Block a user