update:milvus text_field max length

This commit is contained in:
aries-ckt
2023-05-25 19:58:32 +08:00
parent c0d62f3620
commit 8eb266069f
2 changed files with 41 additions and 45 deletions

View File

@@ -2,9 +2,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List, Optional from typing import Dict, List, Optional
from langchain.embeddings import HuggingFaceEmbeddings
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.vector_store.connector import VectorStoreConnector from pilot.vector_store.connector import VectorStoreConnector
@@ -35,9 +32,7 @@ class SourceEmbedding(ABC):
self.model_name = model_name self.model_name = model_name
self.vector_store_config = vector_store_config self.vector_store_config = vector_store_config
self.embedding_args = embedding_args self.embedding_args = embedding_args
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name) self.embeddings = vector_store_config["embeddings"]
vector_store_config["embeddings"] = self.embeddings
self.vector_client = VectorStoreConnector( self.vector_client = VectorStoreConnector(
CFG.VECTOR_STORE_TYPE, vector_store_config CFG.VECTOR_STORE_TYPE, vector_store_config
) )

View File

@@ -1,11 +1,12 @@
from typing import Any, Iterable, List, Optional, Tuple from typing import Any, Iterable, List, Optional, Tuple
from langchain.docstore.document import Document from langchain.docstore.document import Document
from pymilvus import Collection, DataType, connections from pymilvus import Collection, DataType, connections, utility
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.vector_store.vector_store_base import VectorStoreBase from pilot.vector_store.vector_store_base import VectorStoreBase
CFG = Config() CFG = Config()
@@ -29,6 +30,8 @@ class MilvusStore(VectorStoreBase):
self.secure = ctx.get("secure", None) self.secure = ctx.get("secure", None)
self.embedding = ctx.get("embeddings", None) self.embedding = ctx.get("embeddings", None)
self.fields = [] self.fields = []
self.alias = "default"
# use HNSW by default. # use HNSW by default.
self.index_params = { self.index_params = {
@@ -48,7 +51,9 @@ class MilvusStore(VectorStoreBase):
"IVF_HNSW": {"params": {"nprobe": 10, "ef": 10}}, "IVF_HNSW": {"params": {"nprobe": 10, "ef": 10}},
"ANNOY": {"params": {"search_k": 10}}, "ANNOY": {"params": {"search_k": 10}},
} }
# default collection schema
self.primary_field = "pk_id"
self.vector_field = "vector"
self.text_field = "content" self.text_field = "content"
if (self.username is None) != (self.password is None): if (self.username is None) != (self.password is None):
@@ -98,48 +103,37 @@ class MilvusStore(VectorStoreBase):
texts = [d.page_content for d in documents] texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents] metadatas = [d.metadata for d in documents]
embeddings = self.embedding.embed_query(texts[0]) embeddings = self.embedding.embed_query(texts[0])
if utility.has_collection(self.collection_name):
self.col = Collection(
self.collection_name, using=self.alias
)
self.fields = []
for x in self.col.schema.fields:
self.fields.append(x.name)
if x.auto_id:
self.fields.remove(x.name)
if x.is_primary:
self.primary_field = x.name
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
self.vector_field = x.name
self._add_documents(texts, metadatas)
return self.collection_name
dim = len(embeddings) dim = len(embeddings)
# Generate unique names # Generate unique names
primary_field = "pk_id" primary_field = self.primary_field
vector_field = "vector" vector_field = self.vector_field
text_field = "content" text_field = self.text_field
self.text_field = text_field # self.text_field = text_field
collection_name = vector_name collection_name = vector_name
fields = [] fields = []
# Determine metadata schema
# if metadatas:
# # Check if all metadata keys line up
# key = metadatas[0].keys()
# for x in metadatas:
# if key != x.keys():
# raise ValueError(
# "Mismatched metadata. "
# "Make sure all metadata has the same keys and datatype."
# )
# # Create FieldSchema for each entry in singular metadata.
# for key, value in metadatas[0].items():
# # Infer the corresponding datatype of the metadata
# dtype = infer_dtype_bydata(value)
# if dtype == DataType.UNKNOWN:
# raise ValueError(f"Unrecognized datatype for {key}.")
# elif dtype == DataType.VARCHAR:
# # Find out max length text based metadata
# max_length = 0
# for subvalues in metadatas:
# max_length = max(max_length, len(subvalues[key]))
# fields.append(
# FieldSchema(key, DataType.VARCHAR, max_length=max_length + 1)
# )
# else:
# fields.append(FieldSchema(key, dtype))
# Find out max length of texts
max_length = 0 max_length = 0
for y in texts: for y in texts:
max_length = max(max_length, len(y)) max_length = max(max_length, len(y))
# Create the text field # Create the text field
fields.append( fields.append(
FieldSchema(text_field, DataType.VARCHAR, max_length=max_length + 1) FieldSchema(text_field, DataType.VARCHAR, max_length= 65535)
) )
# primary key field # primary key field
fields.append( fields.append(
@@ -147,7 +141,6 @@ class MilvusStore(VectorStoreBase):
) )
# vector field # vector field
fields.append(FieldSchema(vector_field, DataType.FLOAT_VECTOR, dim=dim)) fields.append(FieldSchema(vector_field, DataType.FLOAT_VECTOR, dim=dim))
# milvus the schema for the collection
schema = CollectionSchema(fields) schema = CollectionSchema(fields)
# Create the collection # Create the collection
collection = Collection(collection_name, schema) collection = Collection(collection_name, schema)
@@ -165,7 +158,7 @@ class MilvusStore(VectorStoreBase):
self.primary_field = x.name self.primary_field = x.name
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR: if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
self.vector_field = x.name self.vector_field = x.name
self._add_texts(texts, metadatas) self._add_documents(texts, metadatas)
return self.collection_name return self.collection_name
@@ -224,7 +217,7 @@ class MilvusStore(VectorStoreBase):
# ) # )
# return _text # return _text
def _add_texts( def _add_documents(
self, self,
texts: Iterable[str], texts: Iterable[str],
metadatas: Optional[List[dict]] = None, metadatas: Optional[List[dict]] = None,
@@ -257,7 +250,12 @@ class MilvusStore(VectorStoreBase):
def load_document(self, documents) -> None: def load_document(self, documents) -> None:
"""load document in vector database.""" """load document in vector database."""
self.init_schema_and_load(self.collection_name, documents) # self.init_schema_and_load(self.collection_name, documents)
batch_size = 500
batched_list = [documents[i:i + batch_size] for i in range(0, len(documents), batch_size)]
# docs = []
for doc_batch in batched_list:
self.init_schema_and_load(self.collection_name, doc_batch)
def similar_search(self, text, topk) -> None: def similar_search(self, text, topk) -> None:
"""similar_search in vector database.""" """similar_search in vector database."""
@@ -320,3 +318,6 @@ class MilvusStore(VectorStoreBase):
) )
return data[0], ret return data[0], ret
def close(self):
connections.disconnect()