diff --git a/packages/dbgpt-core/src/dbgpt/storage/vector_store/base.py b/packages/dbgpt-core/src/dbgpt/storage/vector_store/base.py index dac1f2c6e..8435901e1 100644 --- a/packages/dbgpt-core/src/dbgpt/storage/vector_store/base.py +++ b/packages/dbgpt-core/src/dbgpt/storage/vector_store/base.py @@ -201,3 +201,7 @@ class VectorStoreBase(IndexStoreBase, ABC): def truncate(self) -> List[str]: """Truncate the collection.""" raise NotImplementedError + + def create_collection(self, collection_name: str, **kwargs) -> Any: + """Create the collection.""" + raise NotImplementedError diff --git a/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/chroma_store.py b/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/chroma_store.py index c28ebaa2b..3520cc6d5 100644 --- a/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/chroma_store.py +++ b/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/chroma_store.py @@ -139,16 +139,22 @@ class ChromaStore(VectorStoreBase): ) collection_metadata = collection_metadata or {"hnsw:space": "cosine"} - self._collection = self._chroma_client.get_or_create_collection( - name=self._collection_name, - embedding_function=None, - metadata=collection_metadata, + self._collection = self.create_collection( + collection_name=self._collection_name, + collection_metadata=collection_metadata, ) def get_config(self) -> ChromaVectorConfig: """Get the vector store config.""" return self._vector_store_config + def create_collection(self, collection_name: str, **kwargs) -> Any: + return self._chroma_client.get_or_create_collection( + name=collection_name, + embedding_function=None, + metadata=kwargs.get("collection_metadata"), + ) + def similar_search( self, text, topk, filters: Optional[MetadataFilters] = None ) -> List[Chunk]: diff --git a/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/elastic_store.py b/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/elastic_store.py index 7c3dc091d..f0e3ebc22 100644 --- a/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/elastic_store.py +++ b/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/elastic_store.py @@ -5,7 +5,7 @@ from __future__ import annotations import logging import os from dataclasses import dataclass, field -from typing import List, Optional +from typing import Any, List, Optional from dbgpt.core import Chunk, Embeddings from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource @@ -237,17 +237,11 @@ class ElasticStore(VectorStoreBase): basic_auth=(self.username, self.password), ) # create es index - if not self.vector_name_exists(): - self.es_client_python.indices.create( - index=self.index_name, body=self.index_settings - ) + self.create_collection(collection_name=self.index_name) else: logger.warning("ElasticSearch not set username and password") self.es_client_python = Elasticsearch(f"http://{self.uri}:{self.port}") - if not self.vector_name_exists(): - self.es_client_python.indices.create( - index=self.index_name, body=self.index_settings - ) + self.create_collection(collection_name=self.index_name) except ConnectionError: logger.error("ElasticSearch connection failed") except Exception as e: @@ -283,6 +277,13 @@ class ElasticStore(VectorStoreBase): """Get the vector store config.""" return self._vector_store_config + def create_collection(self, collection_name: str, **kwargs) -> Any: + if not self.vector_name_exists(): + self.es_client_python.indices.create( + index=collection_name, body=self.index_settings + ) + return True + def load_document( self, chunks: List[Chunk], diff --git a/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/milvus_store.py b/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/milvus_store.py index 8569b2c99..b98de08f5 100644 --- a/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/milvus_store.py +++ b/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/milvus_store.py @@ -287,15 +287,14 @@ class MilvusStore(VectorStoreBase): password=self.password, alias="default", ) + self.col = self.create_collection(collection_name=self.collection_name) - def init_schema_and_load(self, vector_name, documents) -> List[str]: + def create_collection(self, collection_name: str, **kwargs) -> Any: """Create a Milvus collection. - Create a Milvus collection, indexes it with HNSW, load document. - + Create a Milvus collection, indexes it with HNSW, load document Args: - vector_name (Embeddings): your collection name. - documents (List[str]): Text to insert. + collection_name (str): your collection name. Returns: List[str]: document ids. """ @@ -321,25 +320,10 @@ class MilvusStore(VectorStoreBase): alias="default", # secure=self.secure, ) - texts = [d.content for d in documents] - metadatas = [d.metadata for d in documents] - embeddings = self.embedding.embed_query(texts[0]) + embeddings = self.embedding.embed_query(collection_name) - if utility.has_collection(self.collection_name): - self.col = Collection(self.collection_name, using=self.alias) - self.fields = [] - for x in self.col.schema.fields: - self.fields.append(x.name) - if x.auto_id: - self.fields.remove(x.name) - if x.is_primary: - self.primary_field = x.name - if ( - x.dtype == DataType.FLOAT_VECTOR - or x.dtype == DataType.BINARY_VECTOR - ): - self.vector_field = x.name - return self._add_documents(texts, metadatas) + if utility.has_collection(collection_name): + return Collection(self.collection_name, using=self.alias) # return self.collection_name dim = len(embeddings) @@ -349,12 +333,8 @@ class MilvusStore(VectorStoreBase): text_field = self.text_field metadata_field = self.metadata_field props_field = self.props_field - # self.text_field = text_field - collection_name = vector_name fields = [] - max_length = 0 - for y in texts: - max_length = max(max_length, len(y)) + # max_length = 0 # Create the text field fields.append(FieldSchema(text_field, DataType.VARCHAR, max_length=65535)) # primary key field @@ -375,8 +355,32 @@ class MilvusStore(VectorStoreBase): # milvus index collection.create_index(vector_field, index) collection.load() - schema = collection.schema - for x in schema.fields: + return collection + + def _load_documents(self, documents) -> List[str]: + """Load documents into Milvus. + + Load documents. + + Args: + documents (List[str]): Text to insert. + Returns: + List[str]: document ids. + """ + try: + from pymilvus import ( + DataType, + ) + from pymilvus.orm.types import infer_dtype_bydata # noqa: F401 + except ImportError: + raise ValueError( + "Could not import pymilvus python package. " + "Please install it with `pip install pymilvus`." + ) + texts = [d.content for d in documents] + metadatas = [d.metadata for d in documents] + self.fields = [] + for x in self.col.schema.fields: self.fields.append(x.name) if x.auto_id: self.fields.remove(x.name) @@ -384,9 +388,7 @@ class MilvusStore(VectorStoreBase): self.primary_field = x.name if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR: self.vector_field = x.name - ids = self._add_documents(texts, metadatas) - - return ids + return self._add_documents(texts, metadatas) def _add_documents( self, @@ -434,7 +436,7 @@ class MilvusStore(VectorStoreBase): ] doc_ids = [] for doc_batch in batched_list: - doc_ids.extend(self.init_schema_and_load(self.collection_name, doc_batch)) + doc_ids.extend(self._load_documents(doc_batch)) doc_ids = [str(doc_id) for doc_id in doc_ids] return doc_ids @@ -659,23 +661,23 @@ class MilvusStore(VectorStoreBase): if isinstance(metadata_filter.value, str): expr = ( f"{self.props_field}['{metadata_filter.key}'] " - f"{FilterOperator.EQ} '{metadata_filter.value}'" + f"{FilterOperator.EQ.value} '{metadata_filter.value}'" ) metadata_filters.append(expr) elif isinstance(metadata_filter.value, List): expr = ( f"{self.props_field}['{metadata_filter.key}'] " - f"{FilterOperator.IN} {metadata_filter.value}" + f"{FilterOperator.IN.value} {metadata_filter.value}" ) metadata_filters.append(expr) else: expr = ( f"{self.props_field}['{metadata_filter.key}'] " - f"{FilterOperator.EQ} {str(metadata_filter.value)}" + f"{FilterOperator.EQ.value} {str(metadata_filter.value)}" ) metadata_filters.append(expr) if len(metadata_filters) > 1: - metadata_filter_expr = f" {filters.condition} ".join(metadata_filters) + metadata_filter_expr = f" {filters.condition.value} ".join(metadata_filters) else: metadata_filter_expr = metadata_filters[0] return metadata_filter_expr diff --git a/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/oceanbase_store.py b/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/oceanbase_store.py index 3638ebb48..261655972 100644 --- a/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/oceanbase_store.py +++ b/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/oceanbase_store.py @@ -312,6 +312,11 @@ class OceanBaseStore(VectorStoreBase): vidxs=vidx_params, ) + def create_collection(self, collection_name: str, **kwargs) -> Any: + """Create the collection.""" + embeddings = self.embedding_function.embed_documents([collection_name]) + return self._create_table_with_index(embeddings) + def load_document(self, chunks: List[Chunk]) -> List[str]: """Load document in vector database.""" batch_size = 100 diff --git a/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/pgvector_store.py b/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/pgvector_store.py index 3d0a774c4..d1855c6e1 100644 --- a/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/pgvector_store.py +++ b/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/pgvector_store.py @@ -2,7 +2,7 @@ import logging from dataclasses import dataclass, field -from typing import List, Optional +from typing import Any, List, Optional from dbgpt.core import Chunk, Embeddings from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource @@ -114,6 +114,10 @@ class PGVectorStore(VectorStoreBase): """Get the vector store config.""" return self._vector_store_config + def create_collection(self, collection_name: str, **kwargs) -> Any: + """Create the collection.""" + return self.vector_store_client.create_collection() + def similar_search( self, text: str, topk: int, filters: Optional[MetadataFilters] = None ) -> List[Chunk]: diff --git a/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/weaviate_store.py b/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/weaviate_store.py index b8be7b98f..43abfd686 100644 --- a/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/weaviate_store.py +++ b/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/weaviate_store.py @@ -3,7 +3,7 @@ import logging import os from dataclasses import dataclass, field -from typing import List, Optional +from typing import Any, List, Optional from dbgpt.core import Chunk, Embeddings from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource @@ -125,6 +125,10 @@ class WeaviateStore(VectorStoreBase): """Get the vector store config.""" return self._vector_store_config + def create_collection(self, collection_name: str, **kwargs) -> Any: + """Create the collection.""" + return self._default_schema() + def similar_search( self, text: str, topk: int, filters: Optional[MetadataFilters] = None ) -> List[Chunk]: diff --git a/packages/dbgpt-serve/src/dbgpt_serve/agent/agents/controller.py b/packages/dbgpt-serve/src/dbgpt_serve/agent/agents/controller.py index b60e06730..02c97fb85 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/agent/agents/controller.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/agent/agents/controller.py @@ -142,7 +142,10 @@ class MultiAgents(BaseComponent, ABC): ).create() storage_manager = StorageManager.get_instance(self.system_app) - vector_store = storage_manager.create_vector_store(index_name="_agent_memory_") + index_name = "_agent_memory_" + vector_store = storage_manager.create_vector_store(index_name=index_name) + if not vector_store.vector_name_exists(): + vector_store.create_collection(collection_name=index_name) embeddings = EmbeddingFactory.get_instance(self.system_app).create() short_term_memory = EnhancedShortTermMemory( embeddings, executor=executor, buffer_size=10