update:milvus batch

This commit is contained in:
aries-ckt 2023-05-25 17:17:02 +08:00
parent 85906a3c45
commit adb8a5a316
2 changed files with 9 additions and 8 deletions

View File

@ -2,9 +2,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List, Optional from typing import Dict, List, Optional
from langchain.embeddings import HuggingFaceEmbeddings
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.vector_store.connector import VectorStoreConnector from pilot.vector_store.connector import VectorStoreConnector
@ -35,9 +32,7 @@ class SourceEmbedding(ABC):
self.model_name = model_name self.model_name = model_name
self.vector_store_config = vector_store_config self.vector_store_config = vector_store_config
self.embedding_args = embedding_args self.embedding_args = embedding_args
self.embeddings = HuggingFaceEmbeddings(model_name=self.model_name) self.embeddings = vector_store_config["embeddings"]
vector_store_config["embeddings"] = self.embeddings
self.vector_client = VectorStoreConnector( self.vector_client = VectorStoreConnector(
CFG.VECTOR_STORE_TYPE, vector_store_config CFG.VECTOR_STORE_TYPE, vector_store_config
) )

View File

@ -6,6 +6,7 @@ from pymilvus import Collection, DataType, connections, utility
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.vector_store.vector_store_base import VectorStoreBase from pilot.vector_store.vector_store_base import VectorStoreBase
CFG = Config() CFG = Config()
@ -107,6 +108,7 @@ class MilvusStore(VectorStoreBase):
self.col = Collection( self.col = Collection(
self.collection_name, using=self.alias self.collection_name, using=self.alias
) )
self.fields = []
for x in self.col.schema.fields: for x in self.col.schema.fields:
self.fields.append(x.name) self.fields.append(x.name)
if x.auto_id: if x.auto_id:
@ -131,7 +133,7 @@ class MilvusStore(VectorStoreBase):
max_length = max(max_length, len(y)) max_length = max(max_length, len(y))
# Create the text field # Create the text field
fields.append( fields.append(
FieldSchema(text_field, DataType.VARCHAR, max_length=max_length + 1) FieldSchema(text_field, DataType.VARCHAR, max_length=max_length + 100)
) )
# primary key field # primary key field
fields.append( fields.append(
@ -248,7 +250,11 @@ class MilvusStore(VectorStoreBase):
def load_document(self, documents) -> None: def load_document(self, documents) -> None:
"""load document in vector database.""" """load document in vector database."""
self.init_schema_and_load(self.collection_name, documents) batch_size = 500
batched_list = [documents[i:i + batch_size] for i in range(0, len(documents), batch_size)]
# docs = []
for doc_batch in batched_list:
self.init_schema_and_load(self.collection_name, doc_batch)
def similar_search(self, text, topk) -> None: def similar_search(self, text, topk) -> None:
"""similar_search in vector database.""" """similar_search in vector database."""