update:milvus text_field max length

This commit is contained in:
aries-ckt
2023-05-25 19:58:32 +08:00
parent c0d62f3620
commit 8eb266069f
2 changed files with 41 additions and 45 deletions

View File

@@ -2,9 +2,6 @@
# -*- coding: utf-8 -*-
from abc import ABC, abstractmethod
from typing import Dict, List, Optional
from langchain.embeddings import HuggingFaceEmbeddings
from pilot.configs.config import Config
from pilot.vector_store.connector import VectorStoreConnector
@@ -35,9 +32,7 @@ class SourceEmbedding(ABC):
self.model_name = model_name
self.vector_store_config = vector_store_config
self.embedding_args = embedding_args
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name)
vector_store_config["embeddings"] = self.embeddings
self.embeddings = vector_store_config["embeddings"]
self.vector_client = VectorStoreConnector(
CFG.VECTOR_STORE_TYPE, vector_store_config
)

View File

@@ -1,11 +1,12 @@
from typing import Any, Iterable, List, Optional, Tuple
from langchain.docstore.document import Document
from pymilvus import Collection, DataType, connections
from pymilvus import Collection, DataType, connections, utility
from pilot.configs.config import Config
from pilot.vector_store.vector_store_base import VectorStoreBase
CFG = Config()
@@ -29,6 +30,8 @@ class MilvusStore(VectorStoreBase):
self.secure = ctx.get("secure", None)
self.embedding = ctx.get("embeddings", None)
self.fields = []
self.alias = "default"
# use HNSW by default.
self.index_params = {
@@ -48,7 +51,9 @@ class MilvusStore(VectorStoreBase):
"IVF_HNSW": {"params": {"nprobe": 10, "ef": 10}},
"ANNOY": {"params": {"search_k": 10}},
}
# default collection schema
self.primary_field = "pk_id"
self.vector_field = "vector"
self.text_field = "content"
if (self.username is None) != (self.password is None):
@@ -98,48 +103,37 @@ class MilvusStore(VectorStoreBase):
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
embeddings = self.embedding.embed_query(texts[0])
if utility.has_collection(self.collection_name):
self.col = Collection(
self.collection_name, using=self.alias
)
self.fields = []
for x in self.col.schema.fields:
self.fields.append(x.name)
if x.auto_id:
self.fields.remove(x.name)
if x.is_primary:
self.primary_field = x.name
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
self.vector_field = x.name
self._add_documents(texts, metadatas)
return self.collection_name
dim = len(embeddings)
# Generate unique names
primary_field = "pk_id"
vector_field = "vector"
text_field = "content"
self.text_field = text_field
primary_field = self.primary_field
vector_field = self.vector_field
text_field = self.text_field
# self.text_field = text_field
collection_name = vector_name
fields = []
# Determine metadata schema
# if metadatas:
# # Check if all metadata keys line up
# key = metadatas[0].keys()
# for x in metadatas:
# if key != x.keys():
# raise ValueError(
# "Mismatched metadata. "
# "Make sure all metadata has the same keys and datatype."
# )
# # Create FieldSchema for each entry in singular metadata.
# for key, value in metadatas[0].items():
# # Infer the corresponding datatype of the metadata
# dtype = infer_dtype_bydata(value)
# if dtype == DataType.UNKNOWN:
# raise ValueError(f"Unrecognized datatype for {key}.")
# elif dtype == DataType.VARCHAR:
# # Find out max length text based metadata
# max_length = 0
# for subvalues in metadatas:
# max_length = max(max_length, len(subvalues[key]))
# fields.append(
# FieldSchema(key, DataType.VARCHAR, max_length=max_length + 1)
# )
# else:
# fields.append(FieldSchema(key, dtype))
# Find out max length of texts
max_length = 0
for y in texts:
max_length = max(max_length, len(y))
# Create the text field
fields.append(
FieldSchema(text_field, DataType.VARCHAR, max_length=max_length + 1)
FieldSchema(text_field, DataType.VARCHAR, max_length= 65535)
)
# primary key field
fields.append(
@@ -147,7 +141,6 @@ class MilvusStore(VectorStoreBase):
)
# vector field
fields.append(FieldSchema(vector_field, DataType.FLOAT_VECTOR, dim=dim))
# milvus the schema for the collection
schema = CollectionSchema(fields)
# Create the collection
collection = Collection(collection_name, schema)
@@ -165,7 +158,7 @@ class MilvusStore(VectorStoreBase):
self.primary_field = x.name
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
self.vector_field = x.name
self._add_texts(texts, metadatas)
self._add_documents(texts, metadatas)
return self.collection_name
@@ -224,7 +217,7 @@ class MilvusStore(VectorStoreBase):
# )
# return _text
def _add_texts(
def _add_documents(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
@@ -257,7 +250,12 @@ class MilvusStore(VectorStoreBase):
def load_document(self, documents) -> None:
"""load document in vector database."""
self.init_schema_and_load(self.collection_name, documents)
# self.init_schema_and_load(self.collection_name, documents)
batch_size = 500
batched_list = [documents[i:i + batch_size] for i in range(0, len(documents), batch_size)]
# docs = []
for doc_batch in batched_list:
self.init_schema_and_load(self.collection_name, doc_batch)
def similar_search(self, text, topk) -> None:
"""similar_search in vector database."""
@@ -320,3 +318,6 @@ class MilvusStore(VectorStoreBase):
)
return data[0], ret
def close(self):
connections.disconnect()