From 8eb266069f7d1968b919f03af0311651ce2427f3 Mon Sep 17 00:00:00 2001 From: aries-ckt <916701291@qq.com> Date: Thu, 25 May 2023 19:58:32 +0800 Subject: [PATCH] update:milvus text_field max length --- pilot/source_embedding/source_embedding.py | 7 +- pilot/vector_store/milvus_store.py | 79 +++++++++++----------- 2 files changed, 41 insertions(+), 45 deletions(-) diff --git a/pilot/source_embedding/source_embedding.py b/pilot/source_embedding/source_embedding.py index acbf82a73..94e48e79e 100644 --- a/pilot/source_embedding/source_embedding.py +++ b/pilot/source_embedding/source_embedding.py @@ -2,9 +2,6 @@ # -*- coding: utf-8 -*- from abc import ABC, abstractmethod from typing import Dict, List, Optional - -from langchain.embeddings import HuggingFaceEmbeddings - from pilot.configs.config import Config from pilot.vector_store.connector import VectorStoreConnector @@ -35,9 +32,7 @@ class SourceEmbedding(ABC): self.model_name = model_name self.vector_store_config = vector_store_config self.embedding_args = embedding_args - self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name) - - vector_store_config["embeddings"] = self.embeddings + self.embeddings = vector_store_config["embeddings"] self.vector_client = VectorStoreConnector( CFG.VECTOR_STORE_TYPE, vector_store_config ) diff --git a/pilot/vector_store/milvus_store.py b/pilot/vector_store/milvus_store.py index 8c680b375..78977cce8 100644 --- a/pilot/vector_store/milvus_store.py +++ b/pilot/vector_store/milvus_store.py @@ -1,11 +1,12 @@ from typing import Any, Iterable, List, Optional, Tuple from langchain.docstore.document import Document -from pymilvus import Collection, DataType, connections +from pymilvus import Collection, DataType, connections, utility from pilot.configs.config import Config from pilot.vector_store.vector_store_base import VectorStoreBase + CFG = Config() @@ -29,6 +30,8 @@ class MilvusStore(VectorStoreBase): self.secure = ctx.get("secure", None) self.embedding = ctx.get("embeddings", None) self.fields = [] + self.alias = "default" + # use HNSW by default. self.index_params = { @@ -48,7 +51,9 @@ class MilvusStore(VectorStoreBase): "IVF_HNSW": {"params": {"nprobe": 10, "ef": 10}}, "ANNOY": {"params": {"search_k": 10}}, } - + # default collection schema + self.primary_field = "pk_id" + self.vector_field = "vector" self.text_field = "content" if (self.username is None) != (self.password is None): @@ -98,48 +103,37 @@ class MilvusStore(VectorStoreBase): texts = [d.page_content for d in documents] metadatas = [d.metadata for d in documents] embeddings = self.embedding.embed_query(texts[0]) + + if utility.has_collection(self.collection_name): + self.col = Collection( + self.collection_name, using=self.alias + ) + self.fields = [] + for x in self.col.schema.fields: + self.fields.append(x.name) + if x.auto_id: + self.fields.remove(x.name) + if x.is_primary: + self.primary_field = x.name + if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR: + self.vector_field = x.name + self._add_documents(texts, metadatas) + return self.collection_name + dim = len(embeddings) # Generate unique names - primary_field = "pk_id" - vector_field = "vector" - text_field = "content" - self.text_field = text_field + primary_field = self.primary_field + vector_field = self.vector_field + text_field = self.text_field + # self.text_field = text_field collection_name = vector_name fields = [] - # Determine metadata schema - # if metadatas: - # # Check if all metadata keys line up - # key = metadatas[0].keys() - # for x in metadatas: - # if key != x.keys(): - # raise ValueError( - # "Mismatched metadata. " - # "Make sure all metadata has the same keys and datatype." - # ) - # # Create FieldSchema for each entry in singular metadata. - # for key, value in metadatas[0].items(): - # # Infer the corresponding datatype of the metadata - # dtype = infer_dtype_bydata(value) - # if dtype == DataType.UNKNOWN: - # raise ValueError(f"Unrecognized datatype for {key}.") - # elif dtype == DataType.VARCHAR: - # # Find out max length text based metadata - # max_length = 0 - # for subvalues in metadatas: - # max_length = max(max_length, len(subvalues[key])) - # fields.append( - # FieldSchema(key, DataType.VARCHAR, max_length=max_length + 1) - # ) - # else: - # fields.append(FieldSchema(key, dtype)) - - # Find out max length of texts max_length = 0 for y in texts: max_length = max(max_length, len(y)) # Create the text field fields.append( - FieldSchema(text_field, DataType.VARCHAR, max_length=max_length + 1) + FieldSchema(text_field, DataType.VARCHAR, max_length= 65535) ) # primary key field fields.append( @@ -147,7 +141,6 @@ class MilvusStore(VectorStoreBase): ) # vector field fields.append(FieldSchema(vector_field, DataType.FLOAT_VECTOR, dim=dim)) - # milvus the schema for the collection schema = CollectionSchema(fields) # Create the collection collection = Collection(collection_name, schema) @@ -165,7 +158,7 @@ class MilvusStore(VectorStoreBase): self.primary_field = x.name if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR: self.vector_field = x.name - self._add_texts(texts, metadatas) + self._add_documents(texts, metadatas) return self.collection_name @@ -224,7 +217,7 @@ class MilvusStore(VectorStoreBase): # ) # return _text - def _add_texts( + def _add_documents( self, texts: Iterable[str], metadatas: Optional[List[dict]] = None, @@ -257,7 +250,12 @@ class MilvusStore(VectorStoreBase): def load_document(self, documents) -> None: """load document in vector database.""" - self.init_schema_and_load(self.collection_name, documents) + # self.init_schema_and_load(self.collection_name, documents) + batch_size = 500 + batched_list = [documents[i:i + batch_size] for i in range(0, len(documents), batch_size)] + # docs = [] + for doc_batch in batched_list: + self.init_schema_and_load(self.collection_name, doc_batch) def similar_search(self, text, topk) -> None: """similar_search in vector database.""" @@ -320,3 +318,6 @@ class MilvusStore(VectorStoreBase): ) return data[0], ret + + def close(self): + connections.disconnect() \ No newline at end of file