fix(ChatKnowledge): add aload_document (#1548)

This commit is contained in:
Aries-ckt
2024-05-23 11:59:34 +08:00
committed by GitHub
parent 7f55aa4b6e
commit 83d7e9d82d
14 changed files with 180 additions and 238 deletions

View File

@@ -37,7 +37,7 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
def __init__(self, config: BuiltinKnowledgeGraphConfig):
"""Create builtin knowledge graph instance."""
self._config = config
super().__init__()
self._llm_client = config.llm_client
if not self._llm_client:
raise ValueError("No llm client provided.")

View File

@@ -2,6 +2,7 @@
import logging
import math
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from typing import Any, List, Optional
from dbgpt._private.pydantic import ConfigDict, Field
@@ -9,6 +10,7 @@ from dbgpt.core import Chunk, Embeddings
from dbgpt.core.awel.flow import Parameter
from dbgpt.rag.index.base import IndexStoreBase, IndexStoreConfig
from dbgpt.storage.vector_store.filters import MetadataFilters
from dbgpt.util.executor_utils import blocking_func_to_async
from dbgpt.util.i18n_utils import _
logger = logging.getLogger(__name__)
@@ -102,6 +104,10 @@ class VectorStoreConfig(IndexStoreConfig):
class VectorStoreBase(IndexStoreBase, ABC):
"""Vector store base class."""
def __init__(self, executor: Optional[ThreadPoolExecutor] = None):
"""Initialize vector store."""
super().__init__(executor)
def filter_by_score_threshold(
self, chunks: List[Chunk], score_threshold: float
) -> List[Chunk]:
@@ -160,7 +166,7 @@ class VectorStoreBase(IndexStoreBase, ABC):
return 1.0 - distance / math.sqrt(2)
async def aload_document(self, chunks: List[Chunk]) -> List[str]: # type: ignore
"""Load document in index database.
"""Async load document in index database.
Args:
chunks(List[Chunk]): document chunks.
@@ -168,4 +174,4 @@ class VectorStoreBase(IndexStoreBase, ABC):
Return:
List[str]: chunk ids.
"""
raise NotImplementedError
return await blocking_func_to_async(self._executor, self.load_document, chunks)

View File

@@ -62,6 +62,7 @@ class ChromaStore(VectorStoreBase):
Args:
vector_store_config(ChromaVectorConfig): vector store config.
"""
super().__init__()
chroma_vector_config = vector_store_config.to_dict(exclude_none=True)
chroma_path = chroma_vector_config.get(
"persist_path", os.path.join(PILOT_PATH, "data")

View File

@@ -170,14 +170,22 @@ class VectorStoreConnector:
)
async def aload_document(self, chunks: List[Chunk]) -> List[str]:
"""Load document in vector database.
"""Async load document in vector database.
Args:
- chunks: document chunks.
Return chunk ids.
"""
return await self.client.aload_document(
chunks,
max_chunks_once_load = (
self._index_store_config.max_chunks_once_load
if self._index_store_config
else 10
)
max_threads = (
self._index_store_config.max_threads if self._index_store_config else 1
)
return await self.client.aload_document_with_limit(
chunks, max_chunks_once_load, max_threads
)
def similar_search(

View File

@@ -125,6 +125,7 @@ class ElasticStore(VectorStoreBase):
Args:
vector_store_config (ElasticsearchVectorConfig): ElasticsearchStore config.
"""
super().__init__()
connect_kwargs = {}
elasticsearch_vector_config = vector_store_config.dict()
self.uri = elasticsearch_vector_config.get("uri") or os.getenv(

View File

@@ -149,8 +149,14 @@ class MilvusStore(VectorStoreBase):
vector_store_config (MilvusVectorConfig): MilvusStore config.
refer to https://milvus.io/docs/v2.0.x/manage_connection.md
"""
from pymilvus import connections
super().__init__()
try:
from pymilvus import connections
except ImportError:
raise ValueError(
"Could not import pymilvus python package. "
"Please install it with `pip install pymilvus`."
)
connect_kwargs = {}
milvus_vector_config = vector_store_config.to_dict()
self.uri = milvus_vector_config.get("uri") or os.getenv(
@@ -373,8 +379,13 @@ class MilvusStore(VectorStoreBase):
self, text, topk, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Perform a search on a query string and return results."""
from pymilvus import Collection, DataType
try:
from pymilvus import Collection, DataType
except ImportError:
raise ValueError(
"Could not import pymilvus python package. "
"Please install it with `pip install pymilvus`."
)
"""similar_search in vector database."""
self.col = Collection(self.collection_name)
schema = self.col.schema
@@ -419,7 +430,13 @@ class MilvusStore(VectorStoreBase):
Returns:
List[Tuple[Document, float]]: Result doc and score.
"""
from pymilvus import Collection
try:
from pymilvus import Collection, DataType
except ImportError:
raise ValueError(
"Could not import pymilvus python package. "
"Please install it with `pip install pymilvus`."
)
self.col = Collection(self.collection_name)
schema = self.col.schema
@@ -429,7 +446,6 @@ class MilvusStore(VectorStoreBase):
self.fields.remove(x.name)
if x.is_primary:
self.primary_field = x.name
from pymilvus import DataType
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
self.vector_field = x.name
@@ -526,15 +542,26 @@ class MilvusStore(VectorStoreBase):
def vector_name_exists(self):
"""Whether vector name exists."""
from pymilvus import utility
try:
from pymilvus import utility
except ImportError:
raise ValueError(
"Could not import pymilvus python package. "
"Please install it with `pip install pymilvus`."
)
"""is vector store name exist."""
return utility.has_collection(self.collection_name)
def delete_vector_name(self, vector_name: str):
"""Delete vector name."""
from pymilvus import utility
try:
from pymilvus import utility
except ImportError:
raise ValueError(
"Could not import pymilvus python package. "
"Please install it with `pip install pymilvus`."
)
"""milvus delete collection name"""
logger.info(f"milvus vector_name:{vector_name} begin delete...")
utility.drop_collection(self.collection_name)
@@ -542,8 +569,13 @@ class MilvusStore(VectorStoreBase):
def delete_by_ids(self, ids):
"""Delete vector by ids."""
from pymilvus import Collection
try:
from pymilvus import Collection
except ImportError:
raise ValueError(
"Could not import pymilvus python package. "
"Please install it with `pip install pymilvus`."
)
self.col = Collection(self.collection_name)
# milvus delete vectors by ids
logger.info(f"begin delete milvus ids: {ids}")

View File

@@ -717,7 +717,7 @@ class OceanBaseStore(VectorStoreBase):
"""Create a OceanBaseStore instance."""
if vector_store_config.embedding_fn is None:
raise ValueError("embedding_fn is required for OceanBaseStore")
super().__init__()
self.embeddings = vector_store_config.embedding_fn
self.collection_name = vector_store_config.name
vector_store_config = vector_store_config.dict()

View File

@@ -63,6 +63,7 @@ class PGVectorStore(VectorStoreBase):
raise ImportError(
"Please install the `langchain` package to use the PGVector."
)
super().__init__()
self.connection_string = vector_store_config.connection_string
self.embeddings = vector_store_config.embedding_fn
self.collection_name = vector_store_config.name

View File

@@ -68,7 +68,7 @@ class WeaviateStore(VectorStoreBase):
"Could not import weaviate python package. "
"Please install it with `pip install weaviate-client`."
)
super().__init__()
self.weaviate_url = vector_store_config.weaviate_url
self.embedding = vector_store_config.embedding_fn
self.vector_name = vector_store_config.name