diff --git a/pilot/source_embedding/source_embedding.py b/pilot/source_embedding/source_embedding.py index 94e48e79e..acbf82a73 100644 --- a/pilot/source_embedding/source_embedding.py +++ b/pilot/source_embedding/source_embedding.py @@ -2,6 +2,9 @@ # -*- coding: utf-8 -*- from abc import ABC, abstractmethod from typing import Dict, List, Optional + +from langchain.embeddings import HuggingFaceEmbeddings + from pilot.configs.config import Config from pilot.vector_store.connector import VectorStoreConnector @@ -32,7 +35,9 @@ class SourceEmbedding(ABC): self.model_name = model_name self.vector_store_config = vector_store_config self.embedding_args = embedding_args - self.embeddings = vector_store_config["embeddings"] + self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name) + + vector_store_config["embeddings"] = self.embeddings self.vector_client = VectorStoreConnector( CFG.VECTOR_STORE_TYPE, vector_store_config ) diff --git a/pilot/vector_store/milvus_store.py b/pilot/vector_store/milvus_store.py index b87798d34..8c680b375 100644 --- a/pilot/vector_store/milvus_store.py +++ b/pilot/vector_store/milvus_store.py @@ -1,12 +1,11 @@ from typing import Any, Iterable, List, Optional, Tuple from langchain.docstore.document import Document -from pymilvus import Collection, DataType, connections, utility +from pymilvus import Collection, DataType, connections from pilot.configs.config import Config from pilot.vector_store.vector_store_base import VectorStoreBase - CFG = Config() @@ -30,8 +29,6 @@ class MilvusStore(VectorStoreBase): self.secure = ctx.get("secure", None) self.embedding = ctx.get("embeddings", None) self.fields = [] - self.alias = "default" - # use HNSW by default. self.index_params = { @@ -51,9 +48,7 @@ class MilvusStore(VectorStoreBase): "IVF_HNSW": {"params": {"nprobe": 10, "ef": 10}}, "ANNOY": {"params": {"search_k": 10}}, } - # default collection schema - self.primary_field = "pk_id" - self.vector_field = "vector" + self.text_field = "content" if (self.username is None) != (self.password is None): @@ -103,31 +98,42 @@ class MilvusStore(VectorStoreBase): texts = [d.page_content for d in documents] metadatas = [d.metadata for d in documents] embeddings = self.embedding.embed_query(texts[0]) - - if utility.has_collection(self.collection_name): - self.col = Collection( - self.collection_name, using=self.alias - ) - self.fields = [] - for x in self.col.schema.fields: - self.fields.append(x.name) - if x.auto_id: - self.fields.remove(x.name) - if x.is_primary: - self.primary_field = x.name - if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR: - self.vector_field = x.name - self._add_documents(texts, metadatas) - return self.collection_name - dim = len(embeddings) # Generate unique names - primary_field = self.primary_field - vector_field = self.vector_field - text_field = self.text_field - # self.text_field = text_field + primary_field = "pk_id" + vector_field = "vector" + text_field = "content" + self.text_field = text_field collection_name = vector_name fields = [] + # Determine metadata schema + # if metadatas: + # # Check if all metadata keys line up + # key = metadatas[0].keys() + # for x in metadatas: + # if key != x.keys(): + # raise ValueError( + # "Mismatched metadata. " + # "Make sure all metadata has the same keys and datatype." + # ) + # # Create FieldSchema for each entry in singular metadata. + # for key, value in metadatas[0].items(): + # # Infer the corresponding datatype of the metadata + # dtype = infer_dtype_bydata(value) + # if dtype == DataType.UNKNOWN: + # raise ValueError(f"Unrecognized datatype for {key}.") + # elif dtype == DataType.VARCHAR: + # # Find out max length text based metadata + # max_length = 0 + # for subvalues in metadatas: + # max_length = max(max_length, len(subvalues[key])) + # fields.append( + # FieldSchema(key, DataType.VARCHAR, max_length=max_length + 1) + # ) + # else: + # fields.append(FieldSchema(key, dtype)) + + # Find out max length of texts max_length = 0 for y in texts: max_length = max(max_length, len(y)) @@ -141,6 +147,7 @@ class MilvusStore(VectorStoreBase): ) # vector field fields.append(FieldSchema(vector_field, DataType.FLOAT_VECTOR, dim=dim)) + # milvus the schema for the collection schema = CollectionSchema(fields) # Create the collection collection = Collection(collection_name, schema) @@ -158,7 +165,7 @@ class MilvusStore(VectorStoreBase): self.primary_field = x.name if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR: self.vector_field = x.name - self._add_documents(texts, metadatas) + self._add_texts(texts, metadatas) return self.collection_name @@ -217,7 +224,7 @@ class MilvusStore(VectorStoreBase): # ) # return _text - def _add_documents( + def _add_texts( self, texts: Iterable[str], metadatas: Optional[List[dict]] = None, @@ -251,11 +258,6 @@ class MilvusStore(VectorStoreBase): def load_document(self, documents) -> None: """load document in vector database.""" self.init_schema_and_load(self.collection_name, documents) - # batch_size = 500 - # batched_list = [documents[i:i + batch_size] for i in range(0, len(documents), batch_size)] - # # docs = [] - # for doc_batch in batched_list: - # self.init_schema_and_load(self.collection_name, doc_batch) def similar_search(self, text, topk) -> None: """similar_search in vector database.""" @@ -318,6 +320,3 @@ class MilvusStore(VectorStoreBase): ) return data[0], ret - - def close(self): - connections.disconnect() \ No newline at end of file