fix(ChatKnowledge): add aload_document (#1548)

This commit is contained in:
Aries-ckt 2024-05-23 11:59:34 +08:00 committed by GitHub
parent 7f55aa4b6e
commit 83d7e9d82d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 180 additions and 238 deletions

View File

@ -27,6 +27,7 @@ from dbgpt.configs.model_config import (
EMBEDDING_MODEL_CONFIG, EMBEDDING_MODEL_CONFIG,
KNOWLEDGE_UPLOAD_ROOT_PATH, KNOWLEDGE_UPLOAD_ROOT_PATH,
) )
from dbgpt.rag import ChunkParameters
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
from dbgpt.rag.knowledge.base import ChunkStrategy from dbgpt.rag.knowledge.base import ChunkStrategy
from dbgpt.rag.knowledge.factory import KnowledgeFactory from dbgpt.rag.knowledge.factory import KnowledgeFactory
@ -235,13 +236,30 @@ async def document_upload(
@router.post("/knowledge/{space_name}/document/sync") @router.post("/knowledge/{space_name}/document/sync")
def document_sync(space_name: str, request: DocumentSyncRequest): async def document_sync(
space_name: str,
request: DocumentSyncRequest,
service: Service = Depends(get_rag_service),
):
logger.info(f"Received params: {space_name}, {request}") logger.info(f"Received params: {space_name}, {request}")
try: try:
knowledge_space_service.sync_knowledge_document( space = service.get({"name": space_name})
space_name=space_name, sync_request=request if space is None:
return Result.failed(code="E000X", msg=f"space {space_name} not exist")
if request.doc_ids is None or len(request.doc_ids) == 0:
return Result.failed(code="E000X", msg="doc_ids is None")
sync_request = KnowledgeSyncRequest(
doc_id=request.doc_ids[0],
space_id=str(space.id),
model_name=request.model_name,
) )
return Result.succ([]) sync_request.chunk_parameters = ChunkParameters(
chunk_strategy="Automatic",
chunk_size=request.chunk_size or 512,
chunk_overlap=request.chunk_overlap or 50,
)
doc_ids = await service.sync_document(requests=[sync_request])
return Result.succ(doc_ids)
except Exception as e: except Exception as e:
return Result.failed(code="E000X", msg=f"document sync error {e}") return Result.failed(code="E000X", msg=f"document sync error {e}")

View File

@ -1,7 +1,6 @@
import json import json
import logging import logging
from datetime import datetime from datetime import datetime
from typing import List
from dbgpt._private.config import Config from dbgpt._private.config import Config
from dbgpt.app.knowledge.chunk_db import DocumentChunkDao, DocumentChunkEntity from dbgpt.app.knowledge.chunk_db import DocumentChunkDao, DocumentChunkEntity
@ -32,13 +31,8 @@ from dbgpt.rag.assembler.embedding import EmbeddingAssembler
from dbgpt.rag.assembler.summary import SummaryAssembler from dbgpt.rag.assembler.summary import SummaryAssembler
from dbgpt.rag.chunk_manager import ChunkParameters from dbgpt.rag.chunk_manager import ChunkParameters
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
from dbgpt.rag.knowledge.base import ChunkStrategy, KnowledgeType from dbgpt.rag.knowledge.base import KnowledgeType
from dbgpt.rag.knowledge.factory import KnowledgeFactory from dbgpt.rag.knowledge.factory import KnowledgeFactory
from dbgpt.rag.text_splitter.text_splitter import (
RecursiveCharacterTextSplitter,
SpacyTextSplitter,
)
from dbgpt.serve.rag.api.schemas import KnowledgeSyncRequest
from dbgpt.serve.rag.models.models import KnowledgeSpaceDao, KnowledgeSpaceEntity from dbgpt.serve.rag.models.models import KnowledgeSpaceDao, KnowledgeSpaceEntity
from dbgpt.serve.rag.service.service import SyncStatus from dbgpt.serve.rag.service.service import SyncStatus
from dbgpt.storage.vector_store.base import VectorStoreConfig from dbgpt.storage.vector_store.base import VectorStoreConfig
@ -199,186 +193,6 @@ class KnowledgeService:
total = knowledge_document_dao.get_knowledge_documents_count(query) total = knowledge_document_dao.get_knowledge_documents_count(query)
return DocumentQueryResponse(data=data, total=total, page=page) return DocumentQueryResponse(data=data, total=total, page=page)
def batch_document_sync(
self,
space_name,
sync_requests: List[KnowledgeSyncRequest],
) -> List[int]:
"""batch sync knowledge document chunk into vector store
Args:
- space: Knowledge Space Name
- sync_requests: List[KnowledgeSyncRequest]
Returns:
- List[int]: document ids
"""
doc_ids = []
for sync_request in sync_requests:
docs = knowledge_document_dao.documents_by_ids([sync_request.doc_id])
if len(docs) == 0:
raise Exception(
f"there are document called, doc_id: {sync_request.doc_id}"
)
doc = docs[0]
if (
doc.status == SyncStatus.RUNNING.name
or doc.status == SyncStatus.FINISHED.name
):
raise Exception(
f" doc:{doc.doc_name} status is {doc.status}, can not sync"
)
chunk_parameters = sync_request.chunk_parameters
if chunk_parameters.chunk_strategy != ChunkStrategy.CHUNK_BY_SIZE.name:
space_context = self.get_space_context(space_name)
chunk_parameters.chunk_size = (
CFG.KNOWLEDGE_CHUNK_SIZE
if space_context is None
else int(space_context["embedding"]["chunk_size"])
)
chunk_parameters.chunk_overlap = (
CFG.KNOWLEDGE_CHUNK_OVERLAP
if space_context is None
else int(space_context["embedding"]["chunk_overlap"])
)
self._sync_knowledge_document(space_name, doc, chunk_parameters)
doc_ids.append(doc.id)
return doc_ids
def sync_knowledge_document(self, space_name, sync_request: DocumentSyncRequest):
"""sync knowledge document chunk into vector store
Args:
- space: Knowledge Space Name
- sync_request: DocumentSyncRequest
"""
from dbgpt.rag.text_splitter.pre_text_splitter import PreTextSplitter
doc_ids = sync_request.doc_ids
self.model_name = sync_request.model_name or CFG.LLM_MODEL
for doc_id in doc_ids:
query = KnowledgeDocumentEntity(id=doc_id)
docs = knowledge_document_dao.get_documents(query)
if len(docs) == 0:
raise Exception(
f"there are document called, doc_id: {sync_request.doc_id}"
)
doc = docs[0]
if (
doc.status == SyncStatus.RUNNING.name
or doc.status == SyncStatus.FINISHED.name
):
raise Exception(
f" doc:{doc.doc_name} status is {doc.status}, can not sync"
)
space_context = self.get_space_context(space_name)
chunk_size = (
CFG.KNOWLEDGE_CHUNK_SIZE
if space_context is None
else int(space_context["embedding"]["chunk_size"])
)
chunk_overlap = (
CFG.KNOWLEDGE_CHUNK_OVERLAP
if space_context is None
else int(space_context["embedding"]["chunk_overlap"])
)
if sync_request.chunk_size:
chunk_size = sync_request.chunk_size
if sync_request.chunk_overlap:
chunk_overlap = sync_request.chunk_overlap
separators = sync_request.separators or None
from dbgpt.rag.chunk_manager import ChunkParameters
chunk_parameters = ChunkParameters(
chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
if CFG.LANGUAGE == "en":
text_splitter = RecursiveCharacterTextSplitter(
separators=separators,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=len,
)
else:
if separators and len(separators) > 1:
raise ValueError(
"SpacyTextSplitter do not support multipsle separators"
)
try:
separator = "\n\n" if not separators else separators[0]
text_splitter = SpacyTextSplitter(
separator=separator,
pipeline="zh_core_web_sm",
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
except Exception:
text_splitter = RecursiveCharacterTextSplitter(
separators=separators,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
if sync_request.pre_separator:
logger.info(f"Use preseparator, {sync_request.pre_separator}")
text_splitter = PreTextSplitter(
pre_separator=sync_request.pre_separator,
text_splitter_impl=text_splitter,
)
chunk_parameters.text_splitter = text_splitter
self._sync_knowledge_document(space_name, doc, chunk_parameters)
return doc.id
def _sync_knowledge_document(
self,
space_name,
doc: KnowledgeDocumentEntity,
chunk_parameters: ChunkParameters,
) -> List[Chunk]:
"""sync knowledge document chunk into vector store"""
embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory
)
embedding_fn = embedding_factory.create(
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
)
spaces = self.get_knowledge_space(KnowledgeSpaceRequest(name=space_name))
if len(spaces) != 1:
raise Exception(f"invalid space name:{space_name}")
space = spaces[0]
from dbgpt.storage.vector_store.base import VectorStoreConfig
config = VectorStoreConfig(
name=space.name,
embedding_fn=embedding_fn,
max_chunks_once_load=CFG.KNOWLEDGE_MAX_CHUNKS_ONCE_LOAD,
llm_client=self.llm_client,
model_name=self.model_name,
)
vector_store_connector = VectorStoreConnector(
vector_store_type=space.vector_type, vector_store_config=config
)
knowledge = KnowledgeFactory.create(
datasource=doc.content,
knowledge_type=KnowledgeType.get_by_value(doc.doc_type),
)
assembler = EmbeddingAssembler.load_from_knowledge(
knowledge=knowledge,
chunk_parameters=chunk_parameters,
embeddings=embedding_fn,
vector_store_connector=vector_store_connector,
)
chunk_docs = assembler.get_chunks()
doc.status = SyncStatus.RUNNING.name
doc.chunk_size = len(chunk_docs)
doc.gmt_modified = datetime.now()
knowledge_document_dao.update_knowledge_document(doc)
executor = CFG.SYSTEM_APP.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create()
executor.submit(self.async_doc_embedding, assembler, chunk_docs, doc)
logger.info(f"begin save document chunks, doc:{doc.doc_name}")
return chunk_docs
async def document_summary(self, request: DocumentSummaryRequest): async def document_summary(self, request: DocumentSummaryRequest):
"""get document summary """get document summary
Args: Args:

View File

@ -1,9 +1,11 @@
"""Embedding Assembler.""" """Embedding Assembler."""
from concurrent.futures import ThreadPoolExecutor
from typing import Any, List, Optional from typing import Any, List, Optional
from dbgpt.core import Chunk, Embeddings from dbgpt.core import Chunk, Embeddings
from dbgpt.storage.vector_store.connector import VectorStoreConnector from dbgpt.storage.vector_store.connector import VectorStoreConnector
from ...util.executor_utils import blocking_func_to_async
from ..assembler.base import BaseAssembler from ..assembler.base import BaseAssembler
from ..chunk_manager import ChunkParameters from ..chunk_manager import ChunkParameters
from ..embedding.embedding_factory import DefaultEmbeddingFactory from ..embedding.embedding_factory import DefaultEmbeddingFactory
@ -98,6 +100,41 @@ class EmbeddingAssembler(BaseAssembler):
embeddings=embeddings, embeddings=embeddings,
) )
@classmethod
async def aload_from_knowledge(
cls,
knowledge: Knowledge,
vector_store_connector: VectorStoreConnector,
chunk_parameters: Optional[ChunkParameters] = None,
embedding_model: Optional[str] = None,
embeddings: Optional[Embeddings] = None,
executor: Optional[ThreadPoolExecutor] = None,
) -> "EmbeddingAssembler":
"""Load document embedding into vector store from path.
Args:
knowledge: (Knowledge) Knowledge datasource.
vector_store_connector: (VectorStoreConnector) VectorStoreConnector to use.
chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for
chunking.
embedding_model: (Optional[str]) Embedding model to use.
embeddings: (Optional[Embeddings]) Embeddings to use.
executor: (Optional[ThreadPoolExecutor) ThreadPoolExecutor to use.
Returns:
EmbeddingAssembler
"""
executor = executor or ThreadPoolExecutor()
return await blocking_func_to_async(
executor,
cls,
knowledge,
vector_store_connector,
chunk_parameters,
embedding_model,
embeddings,
)
def persist(self) -> List[str]: def persist(self) -> List[str]:
"""Persist chunks into vector store. """Persist chunks into vector store.

View File

@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_dict from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_dict
from dbgpt.core import Chunk, Embeddings from dbgpt.core import Chunk, Embeddings
from dbgpt.storage.vector_store.filters import MetadataFilters from dbgpt.storage.vector_store.filters import MetadataFilters
from dbgpt.util.executor_utils import blocking_func_to_async
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -46,6 +47,10 @@ class IndexStoreConfig(BaseModel):
class IndexStoreBase(ABC): class IndexStoreBase(ABC):
"""Index store base class.""" """Index store base class."""
def __init__(self, executor: Optional[ThreadPoolExecutor] = None):
"""Init index store."""
self._executor = executor or ThreadPoolExecutor()
@abstractmethod @abstractmethod
def load_document(self, chunks: List[Chunk]) -> List[str]: def load_document(self, chunks: List[Chunk]) -> List[str]:
"""Load document in index database. """Load document in index database.
@ -143,6 +148,27 @@ class IndexStoreBase(ABC):
) )
return ids return ids
async def aload_document_with_limit(
self, chunks: List[Chunk], max_chunks_once_load: int = 10, max_threads: int = 1
) -> List[str]:
"""Load document in index database with specified limit.
Args:
chunks(List[Chunk]): Document chunks.
max_chunks_once_load(int): Max number of chunks to load at once.
max_threads(int): Max number of threads to use.
Return:
List[str]: Chunk ids.
"""
return await blocking_func_to_async(
self._executor,
self.load_document_with_limit,
chunks,
max_chunks_once_load,
max_threads,
)
def similar_search( def similar_search(
self, text: str, topk: int, filters: Optional[MetadataFilters] = None self, text: str, topk: int, filters: Optional[MetadataFilters] = None
) -> List[Chunk]: ) -> List[Chunk]:

View File

@ -443,7 +443,7 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes
space_id, space_id,
doc_vo: DocumentVO, doc_vo: DocumentVO,
chunk_parameters: ChunkParameters, chunk_parameters: ChunkParameters,
) -> List[Chunk]: ) -> None:
"""sync knowledge document chunk into vector store""" """sync knowledge document chunk into vector store"""
embedding_factory = CFG.SYSTEM_APP.get_component( embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory "embedding_factory", EmbeddingFactory
@ -470,47 +470,45 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes
datasource=doc.content, datasource=doc.content,
knowledge_type=KnowledgeType.get_by_value(doc.doc_type), knowledge_type=KnowledgeType.get_by_value(doc.doc_type),
) )
assembler = EmbeddingAssembler.load_from_knowledge(
knowledge=knowledge,
chunk_parameters=chunk_parameters,
vector_store_connector=vector_store_connector,
)
chunk_docs = assembler.get_chunks()
doc.status = SyncStatus.RUNNING.name doc.status = SyncStatus.RUNNING.name
doc.chunk_size = len(chunk_docs)
doc.gmt_modified = datetime.now() doc.gmt_modified = datetime.now()
self._document_dao.update_knowledge_document(doc) self._document_dao.update_knowledge_document(doc)
# executor = CFG.SYSTEM_APP.get_component( # asyncio.create_task(self.async_doc_embedding(assembler, chunk_docs, doc))
# ComponentType.EXECUTOR_DEFAULT, ExecutorFactory asyncio.create_task(
# ).create() self.async_doc_embedding(
# executor.submit(self.async_doc_embedding, assembler, chunk_docs, doc) knowledge, chunk_parameters, vector_store_connector, doc
asyncio.create_task(self.async_doc_embedding(assembler, chunk_docs, doc)) )
)
logger.info(f"begin save document chunks, doc:{doc.doc_name}") logger.info(f"begin save document chunks, doc:{doc.doc_name}")
return chunk_docs # return chunk_docs
@trace("async_doc_embedding") @trace("async_doc_embedding")
async def async_doc_embedding(self, assembler, chunk_docs, doc): async def async_doc_embedding(
self, knowledge, chunk_parameters, vector_store_connector, doc
):
"""async document embedding into vector db """async document embedding into vector db
Args: Args:
- client: EmbeddingEngine Client - knowledge: Knowledge
- chunk_docs: List[Document] - chunk_parameters: ChunkParameters
- doc: KnowledgeDocumentEntity - vector_store_connector: vector_store_connector
- doc: doc
""" """
logger.info( logger.info(f"async doc embedding sync, doc:{doc.doc_name}")
f"async doc embedding sync, doc:{doc.doc_name}, chunks length is {len(chunk_docs)}"
)
try: try:
with root_tracer.start_span( with root_tracer.start_span(
"app.knowledge.assembler.persist", "app.knowledge.assembler.persist",
metadata={"doc": doc.doc_name, "chunks": len(chunk_docs)}, metadata={"doc": doc.doc_name},
): ):
# vector_ids = assembler.persist() assembler = await EmbeddingAssembler.aload_from_knowledge(
space = self.get({"name": doc.space}) knowledge=knowledge,
if space and space.vector_type == "KnowledgeGraph": chunk_parameters=chunk_parameters,
vector_ids = await assembler.apersist() vector_store_connector=vector_store_connector,
else: )
vector_ids = assembler.persist() chunk_docs = assembler.get_chunks()
doc.chunk_size = len(chunk_docs)
vector_ids = await assembler.apersist()
doc.status = SyncStatus.FINISHED.name doc.status = SyncStatus.FINISHED.name
doc.result = "document embedding success" doc.result = "document embedding success"
if vector_ids is not None: if vector_ids is not None:

View File

@ -37,7 +37,7 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase):
def __init__(self, config: BuiltinKnowledgeGraphConfig): def __init__(self, config: BuiltinKnowledgeGraphConfig):
"""Create builtin knowledge graph instance.""" """Create builtin knowledge graph instance."""
self._config = config self._config = config
super().__init__()
self._llm_client = config.llm_client self._llm_client = config.llm_client
if not self._llm_client: if not self._llm_client:
raise ValueError("No llm client provided.") raise ValueError("No llm client provided.")

View File

@ -2,6 +2,7 @@
import logging import logging
import math import math
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from typing import Any, List, Optional from typing import Any, List, Optional
from dbgpt._private.pydantic import ConfigDict, Field from dbgpt._private.pydantic import ConfigDict, Field
@ -9,6 +10,7 @@ from dbgpt.core import Chunk, Embeddings
from dbgpt.core.awel.flow import Parameter from dbgpt.core.awel.flow import Parameter
from dbgpt.rag.index.base import IndexStoreBase, IndexStoreConfig from dbgpt.rag.index.base import IndexStoreBase, IndexStoreConfig
from dbgpt.storage.vector_store.filters import MetadataFilters from dbgpt.storage.vector_store.filters import MetadataFilters
from dbgpt.util.executor_utils import blocking_func_to_async
from dbgpt.util.i18n_utils import _ from dbgpt.util.i18n_utils import _
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -102,6 +104,10 @@ class VectorStoreConfig(IndexStoreConfig):
class VectorStoreBase(IndexStoreBase, ABC): class VectorStoreBase(IndexStoreBase, ABC):
"""Vector store base class.""" """Vector store base class."""
def __init__(self, executor: Optional[ThreadPoolExecutor] = None):
"""Initialize vector store."""
super().__init__(executor)
def filter_by_score_threshold( def filter_by_score_threshold(
self, chunks: List[Chunk], score_threshold: float self, chunks: List[Chunk], score_threshold: float
) -> List[Chunk]: ) -> List[Chunk]:
@ -160,7 +166,7 @@ class VectorStoreBase(IndexStoreBase, ABC):
return 1.0 - distance / math.sqrt(2) return 1.0 - distance / math.sqrt(2)
async def aload_document(self, chunks: List[Chunk]) -> List[str]: # type: ignore async def aload_document(self, chunks: List[Chunk]) -> List[str]: # type: ignore
"""Load document in index database. """Async load document in index database.
Args: Args:
chunks(List[Chunk]): document chunks. chunks(List[Chunk]): document chunks.
@ -168,4 +174,4 @@ class VectorStoreBase(IndexStoreBase, ABC):
Return: Return:
List[str]: chunk ids. List[str]: chunk ids.
""" """
raise NotImplementedError return await blocking_func_to_async(self._executor, self.load_document, chunks)

View File

@ -62,6 +62,7 @@ class ChromaStore(VectorStoreBase):
Args: Args:
vector_store_config(ChromaVectorConfig): vector store config. vector_store_config(ChromaVectorConfig): vector store config.
""" """
super().__init__()
chroma_vector_config = vector_store_config.to_dict(exclude_none=True) chroma_vector_config = vector_store_config.to_dict(exclude_none=True)
chroma_path = chroma_vector_config.get( chroma_path = chroma_vector_config.get(
"persist_path", os.path.join(PILOT_PATH, "data") "persist_path", os.path.join(PILOT_PATH, "data")

View File

@ -170,14 +170,22 @@ class VectorStoreConnector:
) )
async def aload_document(self, chunks: List[Chunk]) -> List[str]: async def aload_document(self, chunks: List[Chunk]) -> List[str]:
"""Load document in vector database. """Async load document in vector database.
Args: Args:
- chunks: document chunks. - chunks: document chunks.
Return chunk ids. Return chunk ids.
""" """
return await self.client.aload_document( max_chunks_once_load = (
chunks, self._index_store_config.max_chunks_once_load
if self._index_store_config
else 10
)
max_threads = (
self._index_store_config.max_threads if self._index_store_config else 1
)
return await self.client.aload_document_with_limit(
chunks, max_chunks_once_load, max_threads
) )
def similar_search( def similar_search(

View File

@ -125,6 +125,7 @@ class ElasticStore(VectorStoreBase):
Args: Args:
vector_store_config (ElasticsearchVectorConfig): ElasticsearchStore config. vector_store_config (ElasticsearchVectorConfig): ElasticsearchStore config.
""" """
super().__init__()
connect_kwargs = {} connect_kwargs = {}
elasticsearch_vector_config = vector_store_config.dict() elasticsearch_vector_config = vector_store_config.dict()
self.uri = elasticsearch_vector_config.get("uri") or os.getenv( self.uri = elasticsearch_vector_config.get("uri") or os.getenv(

View File

@ -149,8 +149,14 @@ class MilvusStore(VectorStoreBase):
vector_store_config (MilvusVectorConfig): MilvusStore config. vector_store_config (MilvusVectorConfig): MilvusStore config.
refer to https://milvus.io/docs/v2.0.x/manage_connection.md refer to https://milvus.io/docs/v2.0.x/manage_connection.md
""" """
from pymilvus import connections super().__init__()
try:
from pymilvus import connections
except ImportError:
raise ValueError(
"Could not import pymilvus python package. "
"Please install it with `pip install pymilvus`."
)
connect_kwargs = {} connect_kwargs = {}
milvus_vector_config = vector_store_config.to_dict() milvus_vector_config = vector_store_config.to_dict()
self.uri = milvus_vector_config.get("uri") or os.getenv( self.uri = milvus_vector_config.get("uri") or os.getenv(
@ -373,8 +379,13 @@ class MilvusStore(VectorStoreBase):
self, text, topk, filters: Optional[MetadataFilters] = None self, text, topk, filters: Optional[MetadataFilters] = None
) -> List[Chunk]: ) -> List[Chunk]:
"""Perform a search on a query string and return results.""" """Perform a search on a query string and return results."""
from pymilvus import Collection, DataType try:
from pymilvus import Collection, DataType
except ImportError:
raise ValueError(
"Could not import pymilvus python package. "
"Please install it with `pip install pymilvus`."
)
"""similar_search in vector database.""" """similar_search in vector database."""
self.col = Collection(self.collection_name) self.col = Collection(self.collection_name)
schema = self.col.schema schema = self.col.schema
@ -419,7 +430,13 @@ class MilvusStore(VectorStoreBase):
Returns: Returns:
List[Tuple[Document, float]]: Result doc and score. List[Tuple[Document, float]]: Result doc and score.
""" """
from pymilvus import Collection try:
from pymilvus import Collection, DataType
except ImportError:
raise ValueError(
"Could not import pymilvus python package. "
"Please install it with `pip install pymilvus`."
)
self.col = Collection(self.collection_name) self.col = Collection(self.collection_name)
schema = self.col.schema schema = self.col.schema
@ -429,7 +446,6 @@ class MilvusStore(VectorStoreBase):
self.fields.remove(x.name) self.fields.remove(x.name)
if x.is_primary: if x.is_primary:
self.primary_field = x.name self.primary_field = x.name
from pymilvus import DataType
if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR: if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR:
self.vector_field = x.name self.vector_field = x.name
@ -526,15 +542,26 @@ class MilvusStore(VectorStoreBase):
def vector_name_exists(self): def vector_name_exists(self):
"""Whether vector name exists.""" """Whether vector name exists."""
from pymilvus import utility try:
from pymilvus import utility
except ImportError:
raise ValueError(
"Could not import pymilvus python package. "
"Please install it with `pip install pymilvus`."
)
"""is vector store name exist.""" """is vector store name exist."""
return utility.has_collection(self.collection_name) return utility.has_collection(self.collection_name)
def delete_vector_name(self, vector_name: str): def delete_vector_name(self, vector_name: str):
"""Delete vector name.""" """Delete vector name."""
from pymilvus import utility try:
from pymilvus import utility
except ImportError:
raise ValueError(
"Could not import pymilvus python package. "
"Please install it with `pip install pymilvus`."
)
"""milvus delete collection name""" """milvus delete collection name"""
logger.info(f"milvus vector_name:{vector_name} begin delete...") logger.info(f"milvus vector_name:{vector_name} begin delete...")
utility.drop_collection(self.collection_name) utility.drop_collection(self.collection_name)
@ -542,8 +569,13 @@ class MilvusStore(VectorStoreBase):
def delete_by_ids(self, ids): def delete_by_ids(self, ids):
"""Delete vector by ids.""" """Delete vector by ids."""
from pymilvus import Collection try:
from pymilvus import Collection
except ImportError:
raise ValueError(
"Could not import pymilvus python package. "
"Please install it with `pip install pymilvus`."
)
self.col = Collection(self.collection_name) self.col = Collection(self.collection_name)
# milvus delete vectors by ids # milvus delete vectors by ids
logger.info(f"begin delete milvus ids: {ids}") logger.info(f"begin delete milvus ids: {ids}")

View File

@ -717,7 +717,7 @@ class OceanBaseStore(VectorStoreBase):
"""Create a OceanBaseStore instance.""" """Create a OceanBaseStore instance."""
if vector_store_config.embedding_fn is None: if vector_store_config.embedding_fn is None:
raise ValueError("embedding_fn is required for OceanBaseStore") raise ValueError("embedding_fn is required for OceanBaseStore")
super().__init__()
self.embeddings = vector_store_config.embedding_fn self.embeddings = vector_store_config.embedding_fn
self.collection_name = vector_store_config.name self.collection_name = vector_store_config.name
vector_store_config = vector_store_config.dict() vector_store_config = vector_store_config.dict()

View File

@ -63,6 +63,7 @@ class PGVectorStore(VectorStoreBase):
raise ImportError( raise ImportError(
"Please install the `langchain` package to use the PGVector." "Please install the `langchain` package to use the PGVector."
) )
super().__init__()
self.connection_string = vector_store_config.connection_string self.connection_string = vector_store_config.connection_string
self.embeddings = vector_store_config.embedding_fn self.embeddings = vector_store_config.embedding_fn
self.collection_name = vector_store_config.name self.collection_name = vector_store_config.name

View File

@ -68,7 +68,7 @@ class WeaviateStore(VectorStoreBase):
"Could not import weaviate python package. " "Could not import weaviate python package. "
"Please install it with `pip install weaviate-client`." "Please install it with `pip install weaviate-client`."
) )
super().__init__()
self.weaviate_url = vector_store_config.weaviate_url self.weaviate_url = vector_store_config.weaviate_url
self.embedding = vector_store_config.embedding_fn self.embedding = vector_store_config.embedding_fn
self.vector_name = vector_store_config.name self.vector_name = vector_store_config.name