mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-07 12:00:46 +00:00
fix(ChatKnowledge): add aload_document (#1548)
This commit is contained in:
@@ -37,7 +37,7 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
|
||||
def __init__(self, config: BuiltinKnowledgeGraphConfig):
|
||||
"""Create builtin knowledge graph instance."""
|
||||
self._config = config
|
||||
|
||||
super().__init__()
|
||||
self._llm_client = config.llm_client
|
||||
if not self._llm_client:
|
||||
raise ValueError("No llm client provided.")
|
||||
|
@@ -2,6 +2,7 @@
|
||||
import logging
|
||||
import math
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from dbgpt._private.pydantic import ConfigDict, Field
|
||||
@@ -9,6 +10,7 @@ from dbgpt.core import Chunk, Embeddings
|
||||
from dbgpt.core.awel.flow import Parameter
|
||||
from dbgpt.rag.index.base import IndexStoreBase, IndexStoreConfig
|
||||
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||
from dbgpt.util.executor_utils import blocking_func_to_async
|
||||
from dbgpt.util.i18n_utils import _
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -102,6 +104,10 @@ class VectorStoreConfig(IndexStoreConfig):
|
||||
class VectorStoreBase(IndexStoreBase, ABC):
|
||||
"""Vector store base class."""
|
||||
|
||||
def __init__(self, executor: Optional[ThreadPoolExecutor] = None):
|
||||
"""Initialize vector store."""
|
||||
super().__init__(executor)
|
||||
|
||||
def filter_by_score_threshold(
|
||||
self, chunks: List[Chunk], score_threshold: float
|
||||
) -> List[Chunk]:
|
||||
@@ -160,7 +166,7 @@ class VectorStoreBase(IndexStoreBase, ABC):
|
||||
return 1.0 - distance / math.sqrt(2)
|
||||
|
||||
async def aload_document(self, chunks: List[Chunk]) -> List[str]: # type: ignore
|
||||
"""Load document in index database.
|
||||
"""Async load document in index database.
|
||||
|
||||
Args:
|
||||
chunks(List[Chunk]): document chunks.
|
||||
@@ -168,4 +174,4 @@ class VectorStoreBase(IndexStoreBase, ABC):
|
||||
Return:
|
||||
List[str]: chunk ids.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
return await blocking_func_to_async(self._executor, self.load_document, chunks)
|
||||
|
@@ -62,6 +62,7 @@ class ChromaStore(VectorStoreBase):
|
||||
Args:
|
||||
vector_store_config(ChromaVectorConfig): vector store config.
|
||||
"""
|
||||
super().__init__()
|
||||
chroma_vector_config = vector_store_config.to_dict(exclude_none=True)
|
||||
chroma_path = chroma_vector_config.get(
|
||||
"persist_path", os.path.join(PILOT_PATH, "data")
|
||||
|
@@ -170,14 +170,22 @@ class VectorStoreConnector:
|
||||
)
|
||||
|
||||
async def aload_document(self, chunks: List[Chunk]) -> List[str]:
|
||||
"""Load document in vector database.
|
||||
"""Async load document in vector database.
|
||||
|
||||
Args:
|
||||
- chunks: document chunks.
|
||||
Return chunk ids.
|
||||
"""
|
||||
return await self.client.aload_document(
|
||||
chunks,
|
||||
max_chunks_once_load = (
|
||||
self._index_store_config.max_chunks_once_load
|
||||
if self._index_store_config
|
||||
else 10
|
||||
)
|
||||
max_threads = (
|
||||
self._index_store_config.max_threads if self._index_store_config else 1
|
||||
)
|
||||
return await self.client.aload_document_with_limit(
|
||||
chunks, max_chunks_once_load, max_threads
|
||||
)
|
||||
|
||||
def similar_search(
|
||||
|
@@ -125,6 +125,7 @@ class ElasticStore(VectorStoreBase):
|
||||
Args:
|
||||
vector_store_config (ElasticsearchVectorConfig): ElasticsearchStore config.
|
||||
"""
|
||||
super().__init__()
|
||||
connect_kwargs = {}
|
||||
elasticsearch_vector_config = vector_store_config.dict()
|
||||
self.uri = elasticsearch_vector_config.get("uri") or os.getenv(
|
||||
|
@@ -149,8 +149,14 @@ class MilvusStore(VectorStoreBase):
|
||||
vector_store_config (MilvusVectorConfig): MilvusStore config.
|
||||
refer to https://milvus.io/docs/v2.0.x/manage_connection.md
|
||||
"""
|
||||
from pymilvus import connections
|
||||
|
||||
super().__init__()
|
||||
try:
|
||||
from pymilvus import connections
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import pymilvus python package. "
|
||||
"Please install it with `pip install pymilvus`."
|
||||
)
|
||||
connect_kwargs = {}
|
||||
milvus_vector_config = vector_store_config.to_dict()
|
||||
self.uri = milvus_vector_config.get("uri") or os.getenv(
|
||||
@@ -373,8 +379,13 @@ class MilvusStore(VectorStoreBase):
|
||||
self, text, topk, filters: Optional[MetadataFilters] = None
|
||||
) -> List[Chunk]:
|
||||
"""Perform a search on a query string and return results."""
|
||||
from pymilvus import Collection, DataType
|
||||
|
||||
try:
|
||||
from pymilvus import Collection, DataType
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import pymilvus python package. "
|
||||
"Please install it with `pip install pymilvus`."
|
||||
)
|
||||
"""similar_search in vector database."""
|
||||
self.col = Collection(self.collection_name)
|
||||
schema = self.col.schema
|
||||
@@ -419,7 +430,13 @@ class MilvusStore(VectorStoreBase):
|
||||
Returns:
|
||||
List[Tuple[Document, float]]: Result doc and score.
|
||||
"""
|
||||
from pymilvus import Collection
|
||||
try:
|
||||
from pymilvus import Collection, DataType
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import pymilvus python package. "
|
||||
"Please install it with `pip install pymilvus`."
|
||||
)
|
||||
|
||||
self.col = Collection(self.collection_name)
|
||||
schema = self.col.schema
|
||||
@@ -429,7 +446,6 @@ class MilvusStore(VectorStoreBase):
|
||||
self.fields.remove(x.name)
|
||||
if x.is_primary:
|
||||
self.primary_field = x.name
|
||||
from pymilvus import DataType
|
||||
|
||||
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
|
||||
self.vector_field = x.name
|
||||
@@ -526,15 +542,26 @@ class MilvusStore(VectorStoreBase):
|
||||
|
||||
def vector_name_exists(self):
|
||||
"""Whether vector name exists."""
|
||||
from pymilvus import utility
|
||||
try:
|
||||
from pymilvus import utility
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import pymilvus python package. "
|
||||
"Please install it with `pip install pymilvus`."
|
||||
)
|
||||
|
||||
"""is vector store name exist."""
|
||||
return utility.has_collection(self.collection_name)
|
||||
|
||||
def delete_vector_name(self, vector_name: str):
|
||||
"""Delete vector name."""
|
||||
from pymilvus import utility
|
||||
|
||||
try:
|
||||
from pymilvus import utility
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import pymilvus python package. "
|
||||
"Please install it with `pip install pymilvus`."
|
||||
)
|
||||
"""milvus delete collection name"""
|
||||
logger.info(f"milvus vector_name:{vector_name} begin delete...")
|
||||
utility.drop_collection(self.collection_name)
|
||||
@@ -542,8 +569,13 @@ class MilvusStore(VectorStoreBase):
|
||||
|
||||
def delete_by_ids(self, ids):
|
||||
"""Delete vector by ids."""
|
||||
from pymilvus import Collection
|
||||
|
||||
try:
|
||||
from pymilvus import Collection
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import pymilvus python package. "
|
||||
"Please install it with `pip install pymilvus`."
|
||||
)
|
||||
self.col = Collection(self.collection_name)
|
||||
# milvus delete vectors by ids
|
||||
logger.info(f"begin delete milvus ids: {ids}")
|
||||
|
@@ -717,7 +717,7 @@ class OceanBaseStore(VectorStoreBase):
|
||||
"""Create a OceanBaseStore instance."""
|
||||
if vector_store_config.embedding_fn is None:
|
||||
raise ValueError("embedding_fn is required for OceanBaseStore")
|
||||
|
||||
super().__init__()
|
||||
self.embeddings = vector_store_config.embedding_fn
|
||||
self.collection_name = vector_store_config.name
|
||||
vector_store_config = vector_store_config.dict()
|
||||
|
@@ -63,6 +63,7 @@ class PGVectorStore(VectorStoreBase):
|
||||
raise ImportError(
|
||||
"Please install the `langchain` package to use the PGVector."
|
||||
)
|
||||
super().__init__()
|
||||
self.connection_string = vector_store_config.connection_string
|
||||
self.embeddings = vector_store_config.embedding_fn
|
||||
self.collection_name = vector_store_config.name
|
||||
|
@@ -68,7 +68,7 @@ class WeaviateStore(VectorStoreBase):
|
||||
"Could not import weaviate python package. "
|
||||
"Please install it with `pip install weaviate-client`."
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
self.weaviate_url = vector_store_config.weaviate_url
|
||||
self.embedding = vector_store_config.embedding_fn
|
||||
self.vector_name = vector_store_config.name
|
||||
|
Reference in New Issue
Block a user