fix:storage add collection

This commit is contained in:
aries_ckt 2025-04-12 14:18:12 +08:00
parent ae4112cf35
commit 5d67f166d9
4 changed files with 76 additions and 55 deletions

View File

@ -176,3 +176,7 @@ class VectorStoreBase(IndexStoreBase, ABC):
def truncate(self) -> List[str]: def truncate(self) -> List[str]:
"""Truncate the collection.""" """Truncate the collection."""
raise NotImplementedError raise NotImplementedError
def create_collection(self, collection_name: str, **kwargs) -> Any:
"""Create the collection."""
raise NotImplementedError

View File

@ -133,16 +133,22 @@ class ChromaStore(VectorStoreBase):
) )
collection_metadata = collection_metadata or {"hnsw:space": "cosine"} collection_metadata = collection_metadata or {"hnsw:space": "cosine"}
self._collection = self._chroma_client.get_or_create_collection( self._collection = self.create_collection(
name=self._collection_name, collection_name=self._collection_name,
embedding_function=None, collection_metadata=collection_metadata,
metadata=collection_metadata,
) )
def get_config(self) -> ChromaVectorConfig: def get_config(self) -> ChromaVectorConfig:
"""Get the vector store config.""" """Get the vector store config."""
return self._vector_store_config return self._vector_store_config
def create_collection(self, collection_name: str, **kwargs) -> Any:
return self._chroma_client.get_or_create_collection(
name=collection_name,
embedding_function=None,
metadata=kwargs.get("collection_metadata"),
)
def similar_search( def similar_search(
self, text, topk, filters: Optional[MetadataFilters] = None self, text, topk, filters: Optional[MetadataFilters] = None
) -> List[Chunk]: ) -> List[Chunk]:

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import logging import logging
import os import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List, Optional from typing import List, Optional, Any
from dbgpt.core import Chunk, Embeddings from dbgpt.core import Chunk, Embeddings
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
@ -230,17 +230,11 @@ class ElasticStore(VectorStoreBase):
basic_auth=(self.username, self.password), basic_auth=(self.username, self.password),
) )
# create es index # create es index
if not self.vector_name_exists(): self.create_collection(collection_name=self.index_name)
self.es_client_python.indices.create(
index=self.index_name, body=self.index_settings
)
else: else:
logger.warning("ElasticSearch not set username and password") logger.warning("ElasticSearch not set username and password")
self.es_client_python = Elasticsearch(f"http://{self.uri}:{self.port}") self.es_client_python = Elasticsearch(f"http://{self.uri}:{self.port}")
if not self.vector_name_exists(): self.create_collection(collection_name=self.index_name)
self.es_client_python.indices.create(
index=self.index_name, body=self.index_settings
)
except ConnectionError: except ConnectionError:
logger.error("ElasticSearch connection failed") logger.error("ElasticSearch connection failed")
except Exception as e: except Exception as e:
@ -276,6 +270,13 @@ class ElasticStore(VectorStoreBase):
"""Get the vector store config.""" """Get the vector store config."""
return self._vector_store_config return self._vector_store_config
def create_collection(self, collection_name: str, **kwargs) -> Any:
if not self.vector_name_exists():
self.es_client_python.indices.create(
index=collection_name, body=self.index_settings
)
return True
def load_document( def load_document(
self, self,
chunks: List[Chunk], chunks: List[Chunk],

View File

@ -283,17 +283,16 @@ class MilvusStore(VectorStoreBase):
password=self.password, password=self.password,
alias="default", alias="default",
) )
self.col = self.create_collection(collection_name=self.collection_name)
def init_schema_and_load(self, vector_name, documents) -> List[str]: def create_collection(self, collection_name: str, **kwargs) -> Any:
"""Create a Milvus collection. """Create a Milvus collection.
Create a Milvus collection, indexes it with HNSW, load document. Create a Milvus collection, indexes it with HNSW, load document
Args:
Args: collection_name (str): your collection name.
vector_name (Embeddings): your collection name. Returns:
documents (List[str]): Text to insert. List[str]: document ids.
Returns:
List[str]: document ids.
""" """
try: try:
from pymilvus import ( from pymilvus import (
@ -317,25 +316,10 @@ class MilvusStore(VectorStoreBase):
alias="default", alias="default",
# secure=self.secure, # secure=self.secure,
) )
texts = [d.content for d in documents] embeddings = self.embedding.embed_query(collection_name)
metadatas = [d.metadata for d in documents]
embeddings = self.embedding.embed_query(texts[0])
if utility.has_collection(self.collection_name): if utility.has_collection(collection_name):
self.col = Collection(self.collection_name, using=self.alias) return Collection(self.collection_name, using=self.alias)
self.fields = []
for x in self.col.schema.fields:
self.fields.append(x.name)
if x.auto_id:
self.fields.remove(x.name)
if x.is_primary:
self.primary_field = x.name
if (
x.dtype == DataType.FLOAT_VECTOR
or x.dtype == DataType.BINARY_VECTOR
):
self.vector_field = x.name
return self._add_documents(texts, metadatas)
# return self.collection_name # return self.collection_name
dim = len(embeddings) dim = len(embeddings)
@ -345,12 +329,8 @@ class MilvusStore(VectorStoreBase):
text_field = self.text_field text_field = self.text_field
metadata_field = self.metadata_field metadata_field = self.metadata_field
props_field = self.props_field props_field = self.props_field
# self.text_field = text_field
collection_name = vector_name
fields = [] fields = []
max_length = 0 # max_length = 0
for y in texts:
max_length = max(max_length, len(y))
# Create the text field # Create the text field
fields.append(FieldSchema(text_field, DataType.VARCHAR, max_length=65535)) fields.append(FieldSchema(text_field, DataType.VARCHAR, max_length=65535))
# primary key field # primary key field
@ -371,18 +351,48 @@ class MilvusStore(VectorStoreBase):
# milvus index # milvus index
collection.create_index(vector_field, index) collection.create_index(vector_field, index)
collection.load() collection.load()
schema = collection.schema return collection
for x in schema.fields:
def _load_documents(self, documents) -> List[str]:
"""Load documents into Milvus.
Load documents.
Args:
documents (List[str]): Text to insert.
Returns:
List[str]: document ids.
"""
try:
from pymilvus import (
Collection,
CollectionSchema,
DataType,
FieldSchema,
connections,
utility,
)
from pymilvus.orm.types import infer_dtype_bydata # noqa: F401
except ImportError:
raise ValueError(
"Could not import pymilvus python package. "
"Please install it with `pip install pymilvus`."
)
texts = [d.content for d in documents]
metadatas = [d.metadata for d in documents]
self.fields = []
for x in self.col.schema.fields:
self.fields.append(x.name) self.fields.append(x.name)
if x.auto_id: if x.auto_id:
self.fields.remove(x.name) self.fields.remove(x.name)
if x.is_primary: if x.is_primary:
self.primary_field = x.name self.primary_field = x.name
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR: if (
x.dtype == DataType.FLOAT_VECTOR
or x.dtype == DataType.BINARY_VECTOR
):
self.vector_field = x.name self.vector_field = x.name
ids = self._add_documents(texts, metadatas) return self._add_documents(texts, metadatas)
return ids
def _add_documents( def _add_documents(
self, self,
@ -430,7 +440,7 @@ class MilvusStore(VectorStoreBase):
] ]
doc_ids = [] doc_ids = []
for doc_batch in batched_list: for doc_batch in batched_list:
doc_ids.extend(self.init_schema_and_load(self.collection_name, doc_batch)) doc_ids.extend(self._load_documents(doc_batch))
doc_ids = [str(doc_id) for doc_id in doc_ids] doc_ids = [str(doc_id) for doc_id in doc_ids]
return doc_ids return doc_ids
@ -655,23 +665,23 @@ class MilvusStore(VectorStoreBase):
if isinstance(metadata_filter.value, str): if isinstance(metadata_filter.value, str):
expr = ( expr = (
f"{self.props_field}['{metadata_filter.key}'] " f"{self.props_field}['{metadata_filter.key}'] "
f"{FilterOperator.EQ} '{metadata_filter.value}'" f"{FilterOperator.EQ.value} '{metadata_filter.value}'"
) )
metadata_filters.append(expr) metadata_filters.append(expr)
elif isinstance(metadata_filter.value, List): elif isinstance(metadata_filter.value, List):
expr = ( expr = (
f"{self.props_field}['{metadata_filter.key}'] " f"{self.props_field}['{metadata_filter.key}'] "
f"{FilterOperator.IN} {metadata_filter.value}" f"{FilterOperator.IN.value} {metadata_filter.value}"
) )
metadata_filters.append(expr) metadata_filters.append(expr)
else: else:
expr = ( expr = (
f"{self.props_field}['{metadata_filter.key}'] " f"{self.props_field}['{metadata_filter.key}'] "
f"{FilterOperator.EQ} {str(metadata_filter.value)}" f"{FilterOperator.EQ.value} {str(metadata_filter.value)}"
) )
metadata_filters.append(expr) metadata_filters.append(expr)
if len(metadata_filters) > 1: if len(metadata_filters) > 1:
metadata_filter_expr = f" {filters.condition} ".join(metadata_filters) metadata_filter_expr = f" {filters.condition.value} ".join(metadata_filters)
else: else:
metadata_filter_expr = metadata_filters[0] metadata_filter_expr = metadata_filters[0]
return metadata_filter_expr return metadata_filter_expr