diff --git a/pilot/vector_store/milvus_store.py b/pilot/vector_store/milvus_store.py index 8c680b375..dea33e61d 100644 --- a/pilot/vector_store/milvus_store.py +++ b/pilot/vector_store/milvus_store.py @@ -1,7 +1,7 @@ from typing import Any, Iterable, List, Optional, Tuple from langchain.docstore.document import Document -from pymilvus import Collection, DataType, connections +from pymilvus import Collection, DataType, connections, utility from pilot.configs.config import Config from pilot.vector_store.vector_store_base import VectorStoreBase @@ -29,6 +29,8 @@ class MilvusStore(VectorStoreBase): self.secure = ctx.get("secure", None) self.embedding = ctx.get("embeddings", None) self.fields = [] + self.alias = "default" + # use HNSW by default. self.index_params = { @@ -48,7 +50,9 @@ class MilvusStore(VectorStoreBase): "IVF_HNSW": {"params": {"nprobe": 10, "ef": 10}}, "ANNOY": {"params": {"search_k": 10}}, } - + # default collection schema + self.primary_field = "pk_id" + self.vector_field = "vector" self.text_field = "content" if (self.username is None) != (self.password is None): @@ -98,42 +102,30 @@ class MilvusStore(VectorStoreBase): texts = [d.page_content for d in documents] metadatas = [d.metadata for d in documents] embeddings = self.embedding.embed_query(texts[0]) + + if utility.has_collection(self.collection_name): + self.col = Collection( + self.collection_name, using=self.alias + ) + for x in self.col.schema.fields: + self.fields.append(x.name) + if x.auto_id: + self.fields.remove(x.name) + if x.is_primary: + self.primary_field = x.name + if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR: + self.vector_field = x.name + self._add_documents(texts, metadatas) + return self.collection_name + dim = len(embeddings) # Generate unique names - primary_field = "pk_id" - vector_field = "vector" - text_field = "content" - self.text_field = text_field + primary_field = self.primary_field + vector_field = self.vector_field + text_field = self.text_field + # self.text_field = text_field collection_name = vector_name fields = [] - # Determine metadata schema - # if metadatas: - # # Check if all metadata keys line up - # key = metadatas[0].keys() - # for x in metadatas: - # if key != x.keys(): - # raise ValueError( - # "Mismatched metadata. " - # "Make sure all metadata has the same keys and datatype." - # ) - # # Create FieldSchema for each entry in singular metadata. - # for key, value in metadatas[0].items(): - # # Infer the corresponding datatype of the metadata - # dtype = infer_dtype_bydata(value) - # if dtype == DataType.UNKNOWN: - # raise ValueError(f"Unrecognized datatype for {key}.") - # elif dtype == DataType.VARCHAR: - # # Find out max length text based metadata - # max_length = 0 - # for subvalues in metadatas: - # max_length = max(max_length, len(subvalues[key])) - # fields.append( - # FieldSchema(key, DataType.VARCHAR, max_length=max_length + 1) - # ) - # else: - # fields.append(FieldSchema(key, dtype)) - - # Find out max length of texts max_length = 0 for y in texts: max_length = max(max_length, len(y)) @@ -147,7 +139,6 @@ class MilvusStore(VectorStoreBase): ) # vector field fields.append(FieldSchema(vector_field, DataType.FLOAT_VECTOR, dim=dim)) - # milvus the schema for the collection schema = CollectionSchema(fields) # Create the collection collection = Collection(collection_name, schema) @@ -165,7 +156,7 @@ class MilvusStore(VectorStoreBase): self.primary_field = x.name if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR: self.vector_field = x.name - self._add_texts(texts, metadatas) + self._add_documents(texts, metadatas) return self.collection_name @@ -224,7 +215,7 @@ class MilvusStore(VectorStoreBase): # ) # return _text - def _add_texts( + def _add_documents( self, texts: Iterable[str], metadatas: Optional[List[dict]] = None, @@ -320,3 +311,6 @@ class MilvusStore(VectorStoreBase): ) return data[0], ret + + def close(self): + connections.disconnect() \ No newline at end of file