fix:storage add collection

This commit is contained in:
aries_ckt 2025-04-12 14:18:12 +08:00
parent ae4112cf35
commit 5d67f166d9
4 changed files with 76 additions and 55 deletions

View File

@ -176,3 +176,7 @@ class VectorStoreBase(IndexStoreBase, ABC):
def truncate(self) -> List[str]:
"""Truncate the collection."""
raise NotImplementedError
def create_collection(self, collection_name: str, **kwargs) -> Any:
"""Create the collection."""
raise NotImplementedError

View File

@ -133,16 +133,22 @@ class ChromaStore(VectorStoreBase):
)
collection_metadata = collection_metadata or {"hnsw:space": "cosine"}
self._collection = self._chroma_client.get_or_create_collection(
name=self._collection_name,
embedding_function=None,
metadata=collection_metadata,
self._collection = self.create_collection(
collection_name=self._collection_name,
collection_metadata=collection_metadata,
)
def get_config(self) -> ChromaVectorConfig:
"""Get the vector store config."""
return self._vector_store_config
def create_collection(self, collection_name: str, **kwargs) -> Any:
return self._chroma_client.get_or_create_collection(
name=collection_name,
embedding_function=None,
metadata=kwargs.get("collection_metadata"),
)
def similar_search(
self, text, topk, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import logging
import os
from dataclasses import dataclass, field
from typing import List, Optional
from typing import List, Optional, Any
from dbgpt.core import Chunk, Embeddings
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
@ -230,17 +230,11 @@ class ElasticStore(VectorStoreBase):
basic_auth=(self.username, self.password),
)
# create es index
if not self.vector_name_exists():
self.es_client_python.indices.create(
index=self.index_name, body=self.index_settings
)
self.create_collection(collection_name=self.index_name)
else:
logger.warning("ElasticSearch not set username and password")
self.es_client_python = Elasticsearch(f"http://{self.uri}:{self.port}")
if not self.vector_name_exists():
self.es_client_python.indices.create(
index=self.index_name, body=self.index_settings
)
self.create_collection(collection_name=self.index_name)
except ConnectionError:
logger.error("ElasticSearch connection failed")
except Exception as e:
@ -276,6 +270,13 @@ class ElasticStore(VectorStoreBase):
"""Get the vector store config."""
return self._vector_store_config
def create_collection(self, collection_name: str, **kwargs) -> Any:
if not self.vector_name_exists():
self.es_client_python.indices.create(
index=collection_name, body=self.index_settings
)
return True
def load_document(
self,
chunks: List[Chunk],

View File

@ -283,17 +283,16 @@ class MilvusStore(VectorStoreBase):
password=self.password,
alias="default",
)
self.col = self.create_collection(collection_name=self.collection_name)
def init_schema_and_load(self, vector_name, documents) -> List[str]:
def create_collection(self, collection_name: str, **kwargs) -> Any:
"""Create a Milvus collection.
Create a Milvus collection, indexes it with HNSW, load document.
Args:
vector_name (Embeddings): your collection name.
documents (List[str]): Text to insert.
Returns:
List[str]: document ids.
Create a Milvus collection, indexes it with HNSW, load document
Args:
collection_name (str): your collection name.
Returns:
List[str]: document ids.
"""
try:
from pymilvus import (
@ -317,25 +316,10 @@ class MilvusStore(VectorStoreBase):
alias="default",
# secure=self.secure,
)
texts = [d.content for d in documents]
metadatas = [d.metadata for d in documents]
embeddings = self.embedding.embed_query(texts[0])
embeddings = self.embedding.embed_query(collection_name)
if utility.has_collection(self.collection_name):
self.col = Collection(self.collection_name, using=self.alias)
self.fields = []
for x in self.col.schema.fields:
self.fields.append(x.name)
if x.auto_id:
self.fields.remove(x.name)
if x.is_primary:
self.primary_field = x.name
if (
x.dtype == DataType.FLOAT_VECTOR
or x.dtype == DataType.BINARY_VECTOR
):
self.vector_field = x.name
return self._add_documents(texts, metadatas)
if utility.has_collection(collection_name):
return Collection(self.collection_name, using=self.alias)
# return self.collection_name
dim = len(embeddings)
@ -345,12 +329,8 @@ class MilvusStore(VectorStoreBase):
text_field = self.text_field
metadata_field = self.metadata_field
props_field = self.props_field
# self.text_field = text_field
collection_name = vector_name
fields = []
max_length = 0
for y in texts:
max_length = max(max_length, len(y))
# max_length = 0
# Create the text field
fields.append(FieldSchema(text_field, DataType.VARCHAR, max_length=65535))
# primary key field
@ -371,18 +351,48 @@ class MilvusStore(VectorStoreBase):
# milvus index
collection.create_index(vector_field, index)
collection.load()
schema = collection.schema
for x in schema.fields:
return collection
def _load_documents(self, documents) -> List[str]:
"""Load documents into Milvus.
Load documents.
Args:
documents (List[str]): Text to insert.
Returns:
List[str]: document ids.
"""
try:
from pymilvus import (
Collection,
CollectionSchema,
DataType,
FieldSchema,
connections,
utility,
)
from pymilvus.orm.types import infer_dtype_bydata # noqa: F401
except ImportError:
raise ValueError(
"Could not import pymilvus python package. "
"Please install it with `pip install pymilvus`."
)
texts = [d.content for d in documents]
metadatas = [d.metadata for d in documents]
self.fields = []
for x in self.col.schema.fields:
self.fields.append(x.name)
if x.auto_id:
self.fields.remove(x.name)
if x.is_primary:
self.primary_field = x.name
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
if (
x.dtype == DataType.FLOAT_VECTOR
or x.dtype == DataType.BINARY_VECTOR
):
self.vector_field = x.name
ids = self._add_documents(texts, metadatas)
return ids
return self._add_documents(texts, metadatas)
def _add_documents(
self,
@ -430,7 +440,7 @@ class MilvusStore(VectorStoreBase):
]
doc_ids = []
for doc_batch in batched_list:
doc_ids.extend(self.init_schema_and_load(self.collection_name, doc_batch))
doc_ids.extend(self._load_documents(doc_batch))
doc_ids = [str(doc_id) for doc_id in doc_ids]
return doc_ids
@ -655,23 +665,23 @@ class MilvusStore(VectorStoreBase):
if isinstance(metadata_filter.value, str):
expr = (
f"{self.props_field}['{metadata_filter.key}'] "
f"{FilterOperator.EQ} '{metadata_filter.value}'"
f"{FilterOperator.EQ.value} '{metadata_filter.value}'"
)
metadata_filters.append(expr)
elif isinstance(metadata_filter.value, List):
expr = (
f"{self.props_field}['{metadata_filter.key}'] "
f"{FilterOperator.IN} {metadata_filter.value}"
f"{FilterOperator.IN.value} {metadata_filter.value}"
)
metadata_filters.append(expr)
else:
expr = (
f"{self.props_field}['{metadata_filter.key}'] "
f"{FilterOperator.EQ} {str(metadata_filter.value)}"
f"{FilterOperator.EQ.value} {str(metadata_filter.value)}"
)
metadata_filters.append(expr)
if len(metadata_filters) > 1:
metadata_filter_expr = f" {filters.condition} ".join(metadata_filters)
metadata_filter_expr = f" {filters.condition.value} ".join(metadata_filters)
else:
metadata_filter_expr = metadata_filters[0]
return metadata_filter_expr