diff --git a/configs/dbgpt-bm25-rag.toml b/configs/dbgpt-bm25-rag.toml new file mode 100644 index 000000000..7ae1388f7 --- /dev/null +++ b/configs/dbgpt-bm25-rag.toml @@ -0,0 +1,50 @@ +[system] +# Load language from environment variable(It is set by the hook) +language = "${env:DBGPT_LANG:-zh}" +log_level = "INFO" +api_keys = [] +encrypt_key = "your_secret_key" + +# Server Configurations +[service.web] +host = "0.0.0.0" +port = 5670 + +[service.web.database] +type = "sqlite" +path = "pilot/meta_data/dbgpt.db" + + +[rag] +chunk_size=1000 +chunk_overlap=100 +similarity_top_k=5 +similarity_score_threshold=0.0 +max_chunks_once_load=10 +max_threads=1 +rerank_top_k=3 + +[rag.storage] +[rag.storage.vector] +type = "Chroma" +persist_path = "pilot/data" + +[rag.storage.full_text] +type = "ElasticSearch" +host="127.0.0.1" +port=9200 + + +# Model Configurations +[models] +[[models.llms]] +name = "${env:LLM_MODEL_NAME:-gpt-4o}" +provider = "${env:LLM_MODEL_PROVIDER:-proxy/openai}" +api_base = "${env:OPENAI_API_BASE:-https://api.openai.com/v1}" +api_key = "${env:OPENAI_API_KEY}" + +[[models.embeddings]] +name = "${env:EMBEDDING_MODEL_NAME:-text-embedding-3-small}" +provider = "${env:EMBEDDING_MODEL_PROVIDER:-proxy/openai}" +api_url = "${env:EMBEDDING_MODEL_API_URL:-https://api.openai.com/v1/embeddings}" +api_key = "${env:OPENAI_API_KEY}" diff --git a/docs/docs/installation/integrations/bm25_rag_install.md b/docs/docs/installation/integrations/bm25_rag_install.md new file mode 100644 index 000000000..68b945937 --- /dev/null +++ b/docs/docs/installation/integrations/bm25_rag_install.md @@ -0,0 +1,45 @@ +# BM25 RAG + +In this example, we will show how to use the Elasticsearch as in DB-GPT RAG Storage. Using a Elasticsearch database to implement RAG can, to some extent, alleviate the uncertainty and interpretability issues brought about by Elasticsearch database retrieval. + +### Install Dependencies + +First, you need to install the `dbgpt elasticsearch storage` library. + +```bash +uv sync --all-packages --frozen \ +--extra "base" \ +--extra "proxy_openai" \ +--extra "rag" \ +--extra "storage_elasticsearch" \ +--extra "dbgpts" +```` + +### Prepare Elasticsearch + +Prepare Elasticsearch database service, reference-[Elasticsearch Installation](https://www.elastic.co/guide/en/elasticsearch/reference/current/install-elasticsearch.html) . + +### Elasticsearch Configuration + + +Set rag storage variables below in `configs/dbgpt-bm25-rag.toml` file, let DB-GPT know how to connect to Elasticsearch. + +``` + +[rag.storage] +[rag.storage.full_text] +type = "ElasticSearch" +uri = "127.0.0.1" +port = "9200" +``` + +Then run the following command to start the webserver: +```bash +uv run python packages/dbgpt-app/src/dbgpt_app/dbgpt_server.py --config configs/dbgpt-bm25-rag.toml +``` + +Optionally +```bash +uv run python packages/dbgpt-app/src/dbgpt_app/dbgpt_server.py --config configs/dbgpt-bm25-rag.toml +``` + diff --git a/docs/sidebars.js b/docs/sidebars.js index 014598a58..bfd9ca228 100755 --- a/docs/sidebars.js +++ b/docs/sidebars.js @@ -107,6 +107,10 @@ const sidebars = { type: "doc", id: "installation/integrations/oceanbase_rag_install" }, + { + type: "doc", + id: "installation/integrations/bm25_rag_install" + }, { type: "doc", id: "installation/integrations/milvus_rag_install" diff --git a/examples/rag/bm25_retriever_example.py b/examples/rag/bm25_retriever_example.py index e321b62c8..8560c2b45 100644 --- a/examples/rag/bm25_retriever_example.py +++ b/examples/rag/bm25_retriever_example.py @@ -5,7 +5,7 @@ from dbgpt.configs.model_config import ROOT_PATH from dbgpt_ext.rag import ChunkParameters from dbgpt_ext.rag.assembler.bm25 import BM25Assembler from dbgpt_ext.rag.knowledge import KnowledgeFactory -from dbgpt_ext.storage.vector_store.elastic_store import ElasticsearchVectorConfig +from dbgpt_ext.storage.vector_store.elastic_store import ElasticsearchStoreConfig """Embedding rag example. pre-requirements: @@ -19,8 +19,7 @@ from dbgpt_ext.storage.vector_store.elastic_store import ElasticsearchVectorConf def _create_es_config(): """Create vector connector.""" - return ElasticsearchVectorConfig( - name="bm25_es_dbgpt", + return ElasticsearchStoreConfig( uri="localhost", port="9200", user="elastic", diff --git a/examples/rag/cross_encoder_rerank_example.py b/examples/rag/cross_encoder_rerank_example.py index bf0cb5464..a8e395b18 100644 --- a/examples/rag/cross_encoder_rerank_example.py +++ b/examples/rag/cross_encoder_rerank_example.py @@ -25,14 +25,16 @@ def _create_vector_connector(): """Create vector connector.""" config = ChromaVectorConfig( persist_path=PILOT_PATH, + ) + + return ChromaStore( + config, name="embedding_rag_test", embedding_fn=DefaultEmbeddingFactory( default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"), ).create(), ) - return ChromaStore(config) - async def main(): file_path = os.path.join(ROOT_PATH, "docs/docs/awel/awel.md") diff --git a/examples/rag/db_schema_rag_example.py b/examples/rag/db_schema_rag_example.py index a73f7f1fd..f8d27b834 100644 --- a/examples/rag/db_schema_rag_example.py +++ b/examples/rag/db_schema_rag_example.py @@ -4,8 +4,7 @@ from dbgpt.configs.model_config import MODEL_PATH, PILOT_PATH from dbgpt.rag.embedding import DefaultEmbeddingFactory from dbgpt_ext.datasource.rdbms.conn_sqlite import SQLiteTempConnector from dbgpt_ext.rag.assembler import DBSchemaAssembler -from dbgpt_ext.storage.vector_store.chroma_store import ChromaVectorConfig -from dbgpt_serve.rag.connector import VectorStoreConnector +from dbgpt_ext.storage.vector_store.chroma_store import ChromaStore, ChromaVectorConfig """DB struct rag example. pre-requirements: @@ -46,12 +45,12 @@ def _create_temporary_connection(): def _create_vector_connector(): """Create vector connector.""" - return VectorStoreConnector.from_default( - "Chroma", - vector_store_config=ChromaVectorConfig( - name="db_schema_vector_store_name", - persist_path=os.path.join(PILOT_PATH, "data"), - ), + config = ChromaVectorConfig( + persist_path=PILOT_PATH, + ) + return ChromaStore( + config, + name="embedding_rag_test", embedding_fn=DefaultEmbeddingFactory( default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"), ).create(), diff --git a/examples/rag/embedding_rag_example.py b/examples/rag/embedding_rag_example.py index 9e21898f0..8c775f034 100644 --- a/examples/rag/embedding_rag_example.py +++ b/examples/rag/embedding_rag_example.py @@ -25,14 +25,16 @@ def _create_vector_connector(): """Create vector connector.""" config = ChromaVectorConfig( persist_path=PILOT_PATH, + ) + + return ChromaStore( + config, name="embedding_rag_test", embedding_fn=DefaultEmbeddingFactory( default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"), ).create(), ) - return ChromaStore(config) - async def main(): file_path = os.path.join(ROOT_PATH, "docs/docs/awel/awel.md") diff --git a/examples/rag/graph_rag_example.py b/examples/rag/graph_rag_example.py index b84edd2a8..0a4dbd55b 100644 --- a/examples/rag/graph_rag_example.py +++ b/examples/rag/graph_rag_example.py @@ -10,13 +10,12 @@ from dbgpt.rag.retriever import RetrieverStrategy from dbgpt_ext.rag import ChunkParameters from dbgpt_ext.rag.assembler import EmbeddingAssembler from dbgpt_ext.rag.knowledge import KnowledgeFactory +from dbgpt_ext.storage.graph_store.tugraph_store import TuGraphStoreConfig from dbgpt_ext.storage.knowledge_graph.community_summary import ( CommunitySummaryKnowledgeGraph, - CommunitySummaryKnowledgeGraphConfig, ) from dbgpt_ext.storage.knowledge_graph.knowledge_graph import ( BuiltinKnowledgeGraph, - BuiltinKnowledgeGraphConfig, ) """GraphRAG example. @@ -61,26 +60,22 @@ async def test_community_graph_rag(): def __create_naive_kg_connector(): """Create knowledge graph connector.""" return BuiltinKnowledgeGraph( - config=BuiltinKnowledgeGraphConfig( - name="naive_graph_rag_test", - embedding_fn=None, - llm_client=llm_client, - model_name=model_name, - graph_store_type="MemoryGraph", - ), + config=TuGraphStoreConfig(), + name="naive_graph_rag_test", + embedding_fn=None, + llm_client=llm_client, + llm_model=model_name, ) def __create_community_kg_connector(): """Create community knowledge graph connector.""" return CommunitySummaryKnowledgeGraph( - config=CommunitySummaryKnowledgeGraphConfig( - name="community_graph_rag_test", - embedding_fn=DefaultEmbeddingFactory.openai(), - llm_client=llm_client, - model_name=model_name, - graph_store_type="TuGraphGraph", - ), + config=TuGraphStoreConfig(), + name="community_graph_rag_test", + embedding_fn=DefaultEmbeddingFactory.openai(), + llm_client=llm_client, + llm_model=model_name, ) diff --git a/examples/rag/keyword_rag_example.py b/examples/rag/keyword_rag_example.py index 25a236e85..8aaf6a257 100644 --- a/examples/rag/keyword_rag_example.py +++ b/examples/rag/keyword_rag_example.py @@ -6,8 +6,8 @@ from dbgpt_ext.rag import ChunkParameters from dbgpt_ext.rag.assembler import EmbeddingAssembler from dbgpt_ext.rag.knowledge import KnowledgeFactory from dbgpt_ext.storage.full_text.elasticsearch import ( - ElasticDocumentConfig, ElasticDocumentStore, + ElasticsearchStoreConfig, ) """Keyword rag example. @@ -22,15 +22,14 @@ from dbgpt_ext.storage.full_text.elasticsearch import ( def _create_es_connector(): """Create es connector.""" - config = ElasticDocumentConfig( - name="keyword_rag_test", + config = ElasticsearchStoreConfig( uri="localhost", port="9200", user="elastic", password="dbgpt", ) - return ElasticDocumentStore(config) + return ElasticDocumentStore(config, name="keyword_rag_test") async def main(): diff --git a/examples/rag/metadata_filter_example.py b/examples/rag/metadata_filter_example.py index f98e46e93..a63c884ef 100644 --- a/examples/rag/metadata_filter_example.py +++ b/examples/rag/metadata_filter_example.py @@ -23,14 +23,16 @@ def _create_vector_connector(): """Create vector connector.""" config = ChromaVectorConfig( persist_path=PILOT_PATH, - name="metadata_rag_test", + ) + + return ChromaStore( + config, + name="embedding_rag_test", embedding_fn=DefaultEmbeddingFactory( default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"), ).create(), ) - return ChromaStore(config) - async def main(): file_path = os.path.join(ROOT_PATH, "docs/docs/awel/awel.md") diff --git a/examples/rag/rag_embedding_api_example.py b/examples/rag/rag_embedding_api_example.py index a07e065db..7d9a48c71 100644 --- a/examples/rag/rag_embedding_api_example.py +++ b/examples/rag/rag_embedding_api_example.py @@ -56,11 +56,13 @@ def _create_vector_connector(): """Create vector connector.""" config = ChromaVectorConfig( persist_path=PILOT_PATH, - name="embedding_api_rag_test", - embedding_fn=_create_embeddings(), ) - return ChromaStore(config) + return ChromaStore( + config, + name="embedding_rag_test", + embedding_fn=_create_embeddings(), + ) async def main(): diff --git a/examples/rag/retriever_evaluation_example.py b/examples/rag/retriever_evaluation_example.py index a21c7df61..ee8e7ddf3 100644 --- a/examples/rag/retriever_evaluation_example.py +++ b/examples/rag/retriever_evaluation_example.py @@ -27,15 +27,17 @@ def _create_embeddings( ).create() -def _create_vector_connector(embeddings: Embeddings): +def _create_vector_connector(): """Create vector connector.""" config = ChromaVectorConfig( persist_path=PILOT_PATH, - name="embedding_rag_test", - embedding_fn=embeddings, ) - return ChromaStore(config) + return ChromaStore( + config, + name="embedding_rag_test", + embedding_fn=_create_embeddings(), + ) async def main(): diff --git a/examples/rag/simple_dbschema_retriever_example.py b/examples/rag/simple_dbschema_retriever_example.py index 1a2329a50..a8b08b83a 100644 --- a/examples/rag/simple_dbschema_retriever_example.py +++ b/examples/rag/simple_dbschema_retriever_example.py @@ -39,15 +39,17 @@ from dbgpt_ext.storage.vector_store.chroma_store import ChromaStore, ChromaVecto def _create_vector_connector(): """Create vector connector.""" config = ChromaVectorConfig( - persist_path=os.path.join(PILOT_PATH, "data"), - name="vector_name", + persist_path=PILOT_PATH, + ) + + return ChromaStore( + config, + name="embedding_rag_test", embedding_fn=DefaultEmbeddingFactory( default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"), ).create(), ) - return ChromaStore(config) - def _create_temporary_connection(): """Create a temporary database connection for testing.""" diff --git a/examples/rag/simple_rag_embedding_example.py b/examples/rag/simple_rag_embedding_example.py index 27d6b0f75..d764fc259 100644 --- a/examples/rag/simple_rag_embedding_example.py +++ b/examples/rag/simple_rag_embedding_example.py @@ -27,14 +27,16 @@ def _create_vector_connector(): """Create vector connector.""" config = ChromaVectorConfig( persist_path=PILOT_PATH, + ) + + return ChromaStore( + config, name="embedding_rag_test", embedding_fn=DefaultEmbeddingFactory( default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"), ).create(), ) - return ChromaStore(config) - class TriggerReqBody(BaseModel): url: str = Field(..., description="url") diff --git a/examples/rag/simple_rag_retriever_example.py b/examples/rag/simple_rag_retriever_example.py index a59d95c53..f6930f875 100644 --- a/examples/rag/simple_rag_retriever_example.py +++ b/examples/rag/simple_rag_retriever_example.py @@ -76,14 +76,16 @@ def _create_vector_connector(): """Create vector connector.""" config = ChromaVectorConfig( persist_path=PILOT_PATH, + ) + + return ChromaStore( + config, name="embedding_rag_test", embedding_fn=DefaultEmbeddingFactory( default_model_name=os.path.join(MODEL_PATH, "text2vec-large-chinese"), ).create(), ) - return ChromaStore(config) - with DAG("simple_sdk_rag_retriever_example") as dag: vector_store = _create_vector_connector() diff --git a/packages/dbgpt-app/src/dbgpt_app/component_configs.py b/packages/dbgpt-app/src/dbgpt_app/component_configs.py index d5e6c525f..b6acbf06c 100644 --- a/packages/dbgpt-app/src/dbgpt_app/component_configs.py +++ b/packages/dbgpt-app/src/dbgpt_app/component_configs.py @@ -5,6 +5,7 @@ from dbgpt.component import SystemApp from dbgpt.configs.model_config import MODEL_DISK_CACHE_DIR, resolve_root_path from dbgpt.util.executor_utils import DefaultExecutorFactory from dbgpt_app.config import ApplicationConfig, ServiceWebParameters +from dbgpt_serve.rag.storage_manager import StorageManager logger = logging.getLogger(__name__) @@ -34,6 +35,7 @@ def initialize_components( system_app.register(DefaultScheduler) system_app.register_instance(controller) system_app.register(ConnectorManager) + system_app.register(StorageManager) from dbgpt_serve.agent.hub.controller import module_plugin diff --git a/packages/dbgpt-app/src/dbgpt_app/config.py b/packages/dbgpt-app/src/dbgpt_app/config.py index 9a8a78395..d13b4b304 100644 --- a/packages/dbgpt-app/src/dbgpt_app/config.py +++ b/packages/dbgpt-app/src/dbgpt_app/config.py @@ -7,16 +7,15 @@ from dbgpt.model.parameter import ( ModelServiceConfig, ) from dbgpt.storage.cache.manager import ModelCacheParameters -from dbgpt.storage.vector_store.base import VectorStoreConfig from dbgpt.util.configure import HookConfig from dbgpt.util.i18n_utils import _ from dbgpt.util.parameter_utils import BaseParameters from dbgpt.util.tracer import TracerParameters from dbgpt.util.utils import LoggingParameters from dbgpt_ext.datasource.rdbms.conn_sqlite import SQLiteConnectorParameters -from dbgpt_ext.storage.knowledge_graph.knowledge_graph import ( - BuiltinKnowledgeGraphConfig, -) +from dbgpt_ext.storage.graph_store.tugraph_store import TuGraphStoreConfig +from dbgpt_ext.storage.vector_store.chroma_store import ChromaVectorConfig +from dbgpt_ext.storage.vector_store.elastic_store import ElasticsearchStoreConfig from dbgpt_serve.core import BaseServeConfig @@ -52,22 +51,22 @@ class SystemParameters: @dataclass class StorageConfig(BaseParameters): - vector: VectorStoreConfig = field( - default_factory=VectorStoreConfig, + vector: Optional[ChromaVectorConfig] = field( + default_factory=lambda: ChromaVectorConfig(), metadata={ "help": _("default vector type"), }, ) - graph: BuiltinKnowledgeGraphConfig = field( - default_factory=BuiltinKnowledgeGraphConfig, + graph: Optional[TuGraphStoreConfig] = field( + default=None, metadata={ "help": _("default graph type"), }, ) - full_text: BuiltinKnowledgeGraphConfig = field( - default_factory=BuiltinKnowledgeGraphConfig, + full_text: Optional[ElasticsearchStoreConfig] = field( + default=None, metadata={ - "help": _("default graph type"), + "help": _("default full text type"), }, ) @@ -93,7 +92,7 @@ class RagParameters(BaseParameters): default=10, metadata={"help": _("knowledge search top k")}, ) - similarity_score_threshold: Optional[int] = field( + similarity_score_threshold: Optional[float] = field( default=0.0, metadata={"help": _("knowledge search top similarity score")}, ) @@ -117,14 +116,86 @@ class RagParameters(BaseParameters): default_factory=lambda: StorageConfig(), metadata={"help": _("Storage configuration")}, ) - graph_search_top_k: Optional[int] = field( - default=3, + knowledge_graph_chunk_search_top_k: Optional[int] = field( + default=5, metadata={"help": _("knowledge graph search top k")}, ) - graph_community_summary_enabled: Optional[bool] = field( + kg_enable_summary: Optional[bool] = field( default=False, metadata={"help": _("graph community summary enabled")}, ) + llm_model: Optional[str] = field( + default=None, + metadata={"help": _("kg extract llm model")}, + ) + kg_extract_top_k: Optional[int] = field( + default=5, + metadata={"help": _("kg extract top k")}, + ) + kg_extract_score_threshold: Optional[float] = field( + default=0.3, + metadata={"help": _("kg extract score threshold")}, + ) + kg_community_top_k: Optional[int] = field( + default=50, + metadata={"help": _("kg community top k")}, + ) + kg_community_score_threshold: Optional[float] = field( + default=0.3, + metadata={"help": _("kg_community_score_threshold")}, + ) + kg_triplet_graph_enabled: Optional[bool] = field( + default=True, + metadata={"help": _("kg_triplet_graph_enabled")}, + ) + kg_document_graph_enabled: Optional[bool] = field( + default=True, + metadata={"help": _("kg_document_graph_enabled")}, + ) + kg_chunk_search_top_k: Optional[int] = field( + default=5, + metadata={"help": _("kg_chunk_search_top_k")}, + ) + kg_extraction_batch_size: Optional[int] = field( + default=3, + metadata={"help": _("kg_extraction_batch_size")}, + ) + kg_community_summary_batch_size: Optional[int] = field( + default=20, + metadata={"help": _("kg_community_summary_batch_size")}, + ) + kg_embedding_batch_size: Optional[int] = field( + default=20, + metadata={"help": _("kg_embedding_batch_size")}, + ) + kg_similarity_top_k: Optional[int] = field( + default=5, + metadata={"help": _("kg_similarity_top_k")}, + ) + kg_similarity_score_threshold: Optional[float] = field( + default=0.7, + metadata={"help": _("kg_similarity_score_threshold")}, + ) + kg_enable_text_search: Optional[bool] = field( + default=False, + metadata={"help": _("kg_enable_text_search")}, + ) + kg_text2gql_model_enabled: Optional[bool] = field( + default=False, + metadata={"help": _("kg_text2gql_model_enabled")}, + ) + kg_text2gql_model_name: Optional[str] = field( + default=None, + metadata={"help": _("text2gql_model_name")}, + ) + bm25_k1: Optional[float] = field( + default=2.0, + metadata={"help": _("bm25_k1")}, + ) + bm25_b: Optional[float] = field( + default=0.75, + metadata={"help": _("bm25_b")}, + ) @dataclass diff --git a/packages/dbgpt-app/src/dbgpt_app/dbgpt_server.py b/packages/dbgpt-app/src/dbgpt_app/dbgpt_server.py index d9ea04dd9..908a9a6c4 100644 --- a/packages/dbgpt-app/src/dbgpt_app/dbgpt_server.py +++ b/packages/dbgpt-app/src/dbgpt_app/dbgpt_server.py @@ -257,6 +257,7 @@ def run_webserver(config_file: str): def scan_configs(): from dbgpt.model import scan_model_providers from dbgpt_app.initialization.serve_initialization import scan_serve_configs + from dbgpt_ext.storage import scan_storage_configs from dbgpt_serve.datasource.manages.connector_manager import ConnectorManager cm = ConnectorManager(system_app) @@ -266,6 +267,8 @@ def scan_configs(): scan_model_providers() # Register all serve configs scan_serve_configs() + # Register all storage configs + scan_storage_configs() def load_config(config_file: str = None) -> ApplicationConfig: diff --git a/packages/dbgpt-app/src/dbgpt_app/knowledge/api.py b/packages/dbgpt-app/src/dbgpt_app/knowledge/api.py index a630aa422..775f14156 100644 --- a/packages/dbgpt-app/src/dbgpt_app/knowledge/api.py +++ b/packages/dbgpt-app/src/dbgpt_app/knowledge/api.py @@ -12,10 +12,8 @@ from dbgpt.configs.model_config import ( KNOWLEDGE_UPLOAD_ROOT_PATH, ) from dbgpt.core.awel.dag.dag_manager import DAGManager -from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory from dbgpt.rag.retriever import BaseRetriever from dbgpt.rag.retriever.embedding import EmbeddingRetriever -from dbgpt.storage.vector_store.base import VectorStoreConfig from dbgpt.util.i18n_utils import _ from dbgpt.util.tracer import SpanType, root_tracer from dbgpt_app.knowledge.request.request import ( @@ -49,8 +47,10 @@ from dbgpt_serve.rag.api.schemas import ( KnowledgeStorageType, KnowledgeSyncRequest, ) -from dbgpt_serve.rag.connector import VectorStoreConnector + +# from dbgpt_serve.rag.connector import VectorStoreConnector from dbgpt_serve.rag.service.service import Service +from dbgpt_serve.rag.storage_manager import StorageManager logger = logging.getLogger(__name__) @@ -505,22 +505,12 @@ def chunk_edit( @router.post("/knowledge/{vector_name}/query") -def similar_query(space_name: str, query_request: KnowledgeQueryRequest): +def similarity_query(space_name: str, query_request: KnowledgeQueryRequest): print(f"Received params: {space_name}, {query_request}") - embedding_factory = CFG.SYSTEM_APP.get_component( - "embedding_factory", EmbeddingFactory - ) - config = VectorStoreConfig( - name=space_name, - embedding_fn=embedding_factory.create(), - ) - vector_store_connector = VectorStoreConnector( - vector_store_type=CFG.VECTOR_STORE_TYPE, - vector_store_config=config, - system_app=CFG.SYSTEM_APP, - ) + storage_manager = StorageManager.get_instance(CFG.SYSTEM_APP) + vector_store_connector = storage_manager.create_vector_store(index_name=space_name) retriever = EmbeddingRetriever( - top_k=query_request.top_k, index_store=vector_store_connector.index_client + top_k=query_request.top_k, index_store=vector_store_connector ) chunks = retriever.retrieve(query_request.query) res = [ diff --git a/packages/dbgpt-app/src/dbgpt_app/knowledge/service.py b/packages/dbgpt-app/src/dbgpt_app/knowledge/service.py index 829327984..2f8eaec78 100644 --- a/packages/dbgpt-app/src/dbgpt_app/knowledge/service.py +++ b/packages/dbgpt-app/src/dbgpt_app/knowledge/service.py @@ -12,10 +12,8 @@ from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG from dbgpt.core import LLMClient from dbgpt.model import DefaultLLMClient from dbgpt.model.cluster import WorkerManagerFactory -from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory from dbgpt.rag.knowledge.base import KnowledgeType from dbgpt.rag.retriever.rerank import RerankEmbeddingsRanker -from dbgpt.storage.vector_store.base import VectorStoreConfig from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async from dbgpt.util.tracer import root_tracer, trace from dbgpt_app.knowledge.request.request import ( @@ -36,7 +34,6 @@ from dbgpt_app.knowledge.request.response import ( from dbgpt_ext.rag.assembler.summary import SummaryAssembler from dbgpt_ext.rag.chunk_manager import ChunkParameters from dbgpt_ext.rag.knowledge.factory import KnowledgeFactory -from dbgpt_serve.rag.connector import VectorStoreConnector from dbgpt_serve.rag.models.chunk_db import DocumentChunkDao, DocumentChunkEntity from dbgpt_serve.rag.models.document_db import ( KnowledgeDocumentDao, @@ -45,6 +42,7 @@ from dbgpt_serve.rag.models.document_db import ( from dbgpt_serve.rag.models.models import KnowledgeSpaceDao, KnowledgeSpaceEntity from dbgpt_serve.rag.retriever.knowledge_space import KnowledgeSpaceRetriever from dbgpt_serve.rag.service.service import SyncStatus +from dbgpt_serve.rag.storage_manager import StorageManager knowledge_space_dao = KnowledgeSpaceDao() knowledge_document_dao = KnowledgeDocumentDao() @@ -82,6 +80,14 @@ class KnowledgeService: rag_config = CFG.SYSTEM_APP.config.configs.get("app_config").rag return rag_config + @property + def storage_manager(self): + return StorageManager.get_instance(CFG.SYSTEM_APP) + + @property + def system_app(self): + return CFG.SYSTEM_APP + def create_knowledge_space(self, request: KnowledgeSpaceRequest): """create knowledge space Args: @@ -91,7 +97,7 @@ class KnowledgeService: name=request.name, ) if request.vector_type == "VectorStore": - request.vector_type = self.rag_config.storage.vector.get("type") + request.vector_type = self.rag_config.storage.vector.get_type_value() if request.vector_type == "KnowledgeGraph": knowledge_space_name_pattern = r"^[a-zA-Z0-9\u4e00-\u9fa5]+$" if not re.match(knowledge_space_name_pattern, request.name): @@ -412,28 +418,15 @@ class KnowledgeService: if len(spaces) != 1: raise Exception(f"invalid space name:{space_name}") space = spaces[0] - - embedding_factory = CFG.SYSTEM_APP.get_component( - "embedding_factory", EmbeddingFactory - ) - embedding_fn = embedding_factory.create() - config = VectorStoreConfig( - name=space.name, - embedding_fn=embedding_fn, - llm_client=self.llm_client, - model_name=None, - ) if space.domain_type == DOMAIN_TYPE_FINANCIAL_REPORT: conn_manager = CFG.local_db_manager conn_manager.delete_db(f"{space.name}_fin_report") - vector_store_connector = VectorStoreConnector( - vector_store_type=space.vector_type, - vector_store_config=config, - system_app=CFG.SYSTEM_APP, + storage_connector = self.storage_manager.get_storage_connector( + index_name=space_name, storage_type=space.vector_type ) # delete vectors - vector_store_connector.delete_vector_name(space.name) + storage_connector.delete_vector_name(space.name) document_query = KnowledgeDocumentEntity(space=space.name) # delete chunks documents = knowledge_document_dao.get_documents(document_query) @@ -462,23 +455,11 @@ class KnowledgeService: vector_ids = documents[0].vector_ids if vector_ids is not None: - embedding_factory = CFG.SYSTEM_APP.get_component( - "embedding_factory", EmbeddingFactory - ) - embedding_fn = embedding_factory.create() - config = VectorStoreConfig( - name=space.name, - embedding_fn=embedding_fn, - llm_client=self.llm_client, - model_name=None, - ) - vector_store_connector = VectorStoreConnector( - vector_store_type=space.vector_type, - vector_store_config=config, - system_app=CFG.SYSTEM_APP, + storage_connector = self.storage_manager.get_storage_connector( + index_name=space_name, storage_type=space.vector_type ) # delete vector by ids - vector_store_connector.delete_by_ids(vector_ids) + storage_connector.delete_by_ids(vector_ids) # delete chunks document_chunk_dao.raw_delete(documents[0].id) # delete document @@ -628,29 +609,12 @@ class KnowledgeService: return chat def query_graph(self, space_name, limit): - embedding_factory = CFG.SYSTEM_APP.get_component( - "embedding_factory", EmbeddingFactory - ) - embedding_fn = embedding_factory.create() spaces = self.get_knowledge_space(KnowledgeSpaceRequest(name=space_name)) if len(spaces) != 1: raise Exception(f"invalid space name:{space_name}") - space = spaces[0] - print(CFG.LLM_MODEL) - config = VectorStoreConfig( - name=space.name, - embedding_fn=embedding_fn, - max_chunks_once_load=CFG.KNOWLEDGE_MAX_CHUNKS_ONCE_LOAD, - llm_client=self.llm_client, - model_name=None, - ) - - vector_store_connector = VectorStoreConnector( - vector_store_type=space.vector_type, - vector_store_config=config, - system_app=CFG.SYSTEM_APP, - ) - graph = vector_store_connector.client.query_graph(limit=limit) + # space = spaces[0] + graph_store = self.storage_manager.create_kg_store(index_name=space_name) + graph = graph_store.query_graph(limit=limit) res = {"nodes": [], "edges": []} for node in graph.vertices(): res["nodes"].append( diff --git a/packages/dbgpt-app/src/dbgpt_app/scene/chat_knowledge/v1/chat.py b/packages/dbgpt-app/src/dbgpt_app/scene/chat_knowledge/v1/chat.py index 114a24d11..be453003c 100644 --- a/packages/dbgpt-app/src/dbgpt_app/scene/chat_knowledge/v1/chat.py +++ b/packages/dbgpt-app/src/dbgpt_app/scene/chat_knowledge/v1/chat.py @@ -234,7 +234,7 @@ class ChatKnowledge(BaseChat): from dbgpt_ext.storage import __knowledge_graph__ as graph_storages if spaces[0].vector_type in graph_storages: - return self.rag_config.graph_search_top_k + return self.rag_config.kg_chunk_search_top_k return self.rag_config.similarity_top_k diff --git a/packages/dbgpt-core/src/dbgpt/component.py b/packages/dbgpt-core/src/dbgpt/component.py index 253f40d11..938ec59e3 100644 --- a/packages/dbgpt-core/src/dbgpt/component.py +++ b/packages/dbgpt-core/src/dbgpt/component.py @@ -96,6 +96,7 @@ class ComponentType(str, Enum): AWEL_DAG_MANAGER = "dbgpt_awel_dag_manager" UNIFIED_METADATA_DB_MANAGER_FACTORY = "dbgpt_unified_metadata_db_manager_factory" CONNECTOR_MANAGER = "dbgpt_connector_manager" + RAG_STORAGE_MANAGER = "dbgpt_rag_storage_manager" AGENT_MANAGER = "dbgpt_agent_manager" RESOURCE_MANAGER = "dbgpt_resource_manager" VARIABLES_PROVIDER = "dbgpt_variables_provider" diff --git a/packages/dbgpt-core/src/dbgpt/storage/base.py b/packages/dbgpt-core/src/dbgpt/storage/base.py index cbf1992bf..7017fc3d0 100644 --- a/packages/dbgpt-core/src/dbgpt/storage/base.py +++ b/packages/dbgpt-core/src/dbgpt/storage/base.py @@ -4,50 +4,24 @@ import logging import time from abc import ABC, abstractmethod from concurrent.futures import Executor, ThreadPoolExecutor -from typing import Any, Dict, List, Optional +from dataclasses import dataclass +from typing import List, Optional -from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_dict -from dbgpt.core import Chunk, Embeddings +from dbgpt.core import Chunk from dbgpt.storage.vector_store.filters import MetadataFilters +from dbgpt.util import BaseParameters from dbgpt.util.executor_utils import blocking_func_to_async_no_executor logger = logging.getLogger(__name__) -class IndexStoreConfig(BaseModel): +@dataclass +class IndexStoreConfig(BaseParameters): """Index store config.""" - model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow") - - type: Optional[str] = Field( - default=None, - description="storage type", - ) - - name: str = Field( - default="dbgpt_collection", - description="The name of index store, if not set, will use the default name.", - ) - embedding_fn: Optional[Embeddings] = Field( - default=None, - description="The embedding function of vector store, if not set, will use the " - "default embedding function.", - ) - max_chunks_once_load: int = Field( - default=10, - description="The max number of chunks to load at once. If your document is " - "large, you can set this value to a larger number to speed up the loading " - "process. Default is 10.", - ) - max_threads: int = Field( - default=1, - description="The max number of threads to use. Default is 1. If you set this " - "bigger than 1, please make sure your vector store is thread-safe.", - ) - - def to_dict(self, **kwargs) -> Dict[str, Any]: - """Convert to dict.""" - return model_to_dict(self, **kwargs) + def create_store(self, **kwargs) -> "IndexStoreBase": + """Create a new index store from the config.""" + raise NotImplementedError("Current index store does not support create_store") class IndexStoreBase(ABC): diff --git a/packages/dbgpt-core/src/dbgpt/storage/graph_store/base.py b/packages/dbgpt-core/src/dbgpt/storage/graph_store/base.py index b9b706bc9..682da0b82 100644 --- a/packages/dbgpt-core/src/dbgpt/storage/graph_store/base.py +++ b/packages/dbgpt-core/src/dbgpt/storage/graph_store/base.py @@ -2,35 +2,33 @@ import logging from abc import ABC, abstractmethod -from typing import Optional +from dataclasses import dataclass -from dbgpt._private.pydantic import BaseModel, ConfigDict, Field -from dbgpt.core import Embeddings +from dbgpt.util import BaseParameters, RegisterParameters logger = logging.getLogger(__name__) -class GraphStoreConfig(BaseModel): +@dataclass +class GraphStoreConfig(BaseParameters, RegisterParameters): """Graph store config.""" - model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow") - - name: str = Field( - default="dbgpt_collection", - description="The name of graph store, inherit from index store.", - ) - embedding_fn: Optional[Embeddings] = Field( - default=None, - description="The embedding function of graph store, optional.", - ) - enable_summary: bool = Field( - default=False, - description="Enable graph community summary or not.", - ) - enable_similarity_search: bool = Field( - default=False, - description="Enable similarity search or not.", - ) + # name: str = Field( + # default="dbgpt_collection", + # description="The name of graph store, inherit from index store.", + # ) + # embedding_fn: Optional[Embeddings] = Field( + # default=None, + # description="The embedding function of graph store, optional.", + # ) + # enable_summary: bool = Field( + # default=False, + # description="Enable graph community summary or not.", + # ) + # enable_similarity_search: bool = Field( + # default=False, + # description="Enable similarity search or not.", + # ) class GraphStoreBase(ABC): diff --git a/packages/dbgpt-core/src/dbgpt/storage/knowledge_graph/base.py b/packages/dbgpt-core/src/dbgpt/storage/knowledge_graph/base.py index 60f35fc3d..7dae4a7f6 100644 --- a/packages/dbgpt-core/src/dbgpt/storage/knowledge_graph/base.py +++ b/packages/dbgpt-core/src/dbgpt/storage/knowledge_graph/base.py @@ -2,23 +2,23 @@ import logging from abc import ABC, abstractmethod +from dataclasses import dataclass from typing import List, Optional from pydantic import Field -from dbgpt._private.pydantic import ConfigDict from dbgpt.core import Chunk from dbgpt.storage.base import IndexStoreBase, IndexStoreConfig from dbgpt.storage.graph_store.graph import Graph +from dbgpt.util import RegisterParameters logger = logging.getLogger(__name__) -class KnowledgeGraphConfig(IndexStoreConfig): +@dataclass +class KnowledgeGraphConfig(IndexStoreConfig, RegisterParameters): """Knowledge graph config.""" - model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow") - class KnowledgeGraphBase(IndexStoreBase, ABC): """Knowledge graph base class.""" diff --git a/packages/dbgpt-core/src/dbgpt/storage/vector_store/base.py b/packages/dbgpt-core/src/dbgpt/storage/vector_store/base.py index 5e9ac8dc3..ebfc3cd16 100644 --- a/packages/dbgpt-core/src/dbgpt/storage/vector_store/base.py +++ b/packages/dbgpt-core/src/dbgpt/storage/vector_store/base.py @@ -4,13 +4,14 @@ import logging import math from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass, field from typing import Any, List, Optional -from dbgpt._private.pydantic import ConfigDict, Field from dbgpt.core import Chunk, Embeddings from dbgpt.core.awel.flow import Parameter from dbgpt.storage.base import IndexStoreBase, IndexStoreConfig from dbgpt.storage.vector_store.filters import MetadataFilters +from dbgpt.util import RegisterParameters from dbgpt.util.executor_utils import blocking_func_to_async from dbgpt.util.i18n_utils import _ @@ -84,33 +85,31 @@ _COMMON_PARAMETERS = [ ] -class VectorStoreConfig(IndexStoreConfig): +@dataclass +class VectorStoreConfig(IndexStoreConfig, RegisterParameters): """Vector store config.""" - model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow") + user: Optional[str] = field( + default=None, + metadata={ + "help": _( + "The user of vector store, if not set, will use the default user." + ), + }, + ) + password: Optional[str] = field( + default=None, + metadata={ + "help": _( + "The password of vector store, if not set, " + "will use the default password." + ), + }, + ) - user: Optional[str] = Field( - default=None, - description="The user of vector store, if not set, will use the default user.", - ) - password: Optional[str] = Field( - default=None, - description=( - "The password of vector store, if not set, will use the default password." - ), - ) - topk: int = Field( - default=5, - description="Topk of vector search", - ) - score_threshold: float = Field( - default=0.3, - description="Recall score of vector search", - ) - type: Optional[str] = Field( - default=None, - description="vector storage type", - ) + def create_store(self, **kwargs) -> "VectorStoreBase": + """Create a new index store from the config.""" + raise NotImplementedError("Current vector store does not support create_store") class VectorStoreBase(IndexStoreBase, ABC): diff --git a/packages/dbgpt-ext/src/dbgpt_ext/rag/assembler/bm25.py b/packages/dbgpt-ext/src/dbgpt_ext/rag/assembler/bm25.py index e4a37f74c..d7282199d 100644 --- a/packages/dbgpt-ext/src/dbgpt_ext/rag/assembler/bm25.py +++ b/packages/dbgpt-ext/src/dbgpt_ext/rag/assembler/bm25.py @@ -10,7 +10,7 @@ from dbgpt.util.executor_utils import blocking_func_to_async from dbgpt_ext.rag.assembler.base import BaseAssembler from dbgpt_ext.rag.chunk_manager import ChunkParameters from dbgpt_ext.rag.retriever.bm25 import BM25Retriever -from dbgpt_ext.storage.vector_store.elastic_store import ElasticsearchVectorConfig +from dbgpt_ext.storage.vector_store.elastic_store import ElasticsearchStoreConfig class BM25Assembler(BaseAssembler): @@ -44,7 +44,8 @@ class BM25Assembler(BaseAssembler): def __init__( self, knowledge: Knowledge, - es_config: ElasticsearchVectorConfig, + es_config: ElasticsearchStoreConfig, + name: Optional[str] = "dbgpt", k1: Optional[float] = 2.0, b: Optional[float] = 0.75, chunk_parameters: Optional[ChunkParameters] = None, @@ -55,7 +56,7 @@ class BM25Assembler(BaseAssembler): Args: knowledge: (Knowledge) Knowledge datasource. - es_config: (ElasticsearchVectorConfig) Elasticsearch config. + es_config: (ElasticsearchStoreConfig) Elasticsearch config. k1 (Optional[float]): Controls non-linear term frequency normalization (saturation). The default value is 2.0. b (Optional[float]): Controls to what degree document length normalizes @@ -70,7 +71,7 @@ class BM25Assembler(BaseAssembler): self._es_port = es_config.port self._es_username = es_config.user self._es_password = es_config.password - self._index_name = es_config.name + self._index_name = name self._k1 = k1 self._b = b if self._es_username and self._es_password: @@ -123,7 +124,8 @@ class BM25Assembler(BaseAssembler): def load_from_knowledge( cls, knowledge: Knowledge, - es_config: ElasticsearchVectorConfig, + es_config: ElasticsearchStoreConfig, + name: Optional[str] = "dbgpt", k1: Optional[float] = 2.0, b: Optional[float] = 0.75, chunk_parameters: Optional[ChunkParameters] = None, @@ -132,7 +134,8 @@ class BM25Assembler(BaseAssembler): Args: knowledge: (Knowledge) Knowledge datasource. - es_config: (ElasticsearchVectorConfig) Elasticsearch config. + es_config: (ElasticsearchStoreConfig) Elasticsearch config. + name: (Optional[str]) BM25 name. k1: (Optional[float]) BM25 parameter k1. b: (Optional[float]) BM25 parameter b. chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for @@ -144,6 +147,7 @@ class BM25Assembler(BaseAssembler): return cls( knowledge=knowledge, es_config=es_config, + name=name, k1=k1, b=b, chunk_parameters=chunk_parameters, @@ -153,7 +157,8 @@ class BM25Assembler(BaseAssembler): async def aload_from_knowledge( cls, knowledge: Knowledge, - es_config: ElasticsearchVectorConfig, + es_config: ElasticsearchStoreConfig, + name: Optional[str] = "dbgpt", k1: Optional[float] = 2.0, b: Optional[float] = 0.75, chunk_parameters: Optional[ChunkParameters] = None, @@ -163,7 +168,7 @@ class BM25Assembler(BaseAssembler): Args: knowledge: (Knowledge) Knowledge datasource. - es_config: (ElasticsearchVectorConfig) Elasticsearch config. + es_config: (ElasticsearchStoreConfig) Elasticsearch config. k1: (Optional[float]) BM25 parameter k1. b: (Optional[float]) BM25 parameter b. chunk_parameters: (Optional[ChunkParameters]) ChunkManager to use for @@ -178,6 +183,7 @@ class BM25Assembler(BaseAssembler): cls, knowledge, es_config=es_config, + name=name, k1=k1, b=b, chunk_parameters=chunk_parameters, diff --git a/packages/dbgpt-ext/src/dbgpt_ext/rag/assembler/db_schema.py b/packages/dbgpt-ext/src/dbgpt_ext/rag/assembler/db_schema.py index 42b31ee9f..276960e01 100644 --- a/packages/dbgpt-ext/src/dbgpt_ext/rag/assembler/db_schema.py +++ b/packages/dbgpt-ext/src/dbgpt_ext/rag/assembler/db_schema.py @@ -21,7 +21,7 @@ class DBSchemaAssembler(BaseAssembler): from dbgpt.datasource.rdbms.conn_sqlite import SQLiteTempConnector from dbgpt_serve.rag.assembler.db_struct import DBSchemaAssembler - from dbgpt.storage.vector_store.connector import VectorStoreConnector + from dbgpt.storage.vector_store.connector import VectorStoreBase from dbgpt.storage.vector_store.chroma_store import ChromaVectorConfig connection = SQLiteTempConnector.create_temporary_db() diff --git a/packages/dbgpt-ext/src/dbgpt_ext/rag/retriever/graph_retriever/graph_retriever.py b/packages/dbgpt-ext/src/dbgpt_ext/rag/retriever/graph_retriever/graph_retriever.py index bd9cedc50..bb0c56c85 100644 --- a/packages/dbgpt-ext/src/dbgpt_ext/rag/retriever/graph_retriever/graph_retriever.py +++ b/packages/dbgpt-ext/src/dbgpt_ext/rag/retriever/graph_retriever/graph_retriever.py @@ -2,8 +2,9 @@ import logging import os -from typing import List, Tuple, Union +from typing import List, Optional, Tuple, Union +from dbgpt.core import Embeddings, LLMClient from dbgpt.rag.transformer.keyword_extractor import KeywordExtractor from dbgpt.rag.transformer.simple_intent_translator import SimpleIntentTranslator from dbgpt.storage.graph_store.graph import Graph, MemoryGraph @@ -33,79 +34,86 @@ class GraphRetriever(GraphRetrieverBase): def __init__( self, - config, graph_store_adapter, + llm_client: Optional[LLMClient] = None, + llm_model: Optional[str] = None, + triplet_graph_enabled: Optional[bool] = True, + document_graph_enabled: Optional[bool] = True, + extract_top_k: Optional[int] = 5, + kg_chunk_search_top_k: Optional[int] = 5, + similarity_top_k: Optional[int] = 5, + similarity_score_threshold: Optional[float] = 0.7, + embedding_fn: Optional[Embeddings] = None, + embedding_batch_size: Optional[int] = 20, + enable_text_search: Optional[bool] = False, + text2gql_model_enabled: Optional[bool] = False, + text2gql_model_name: Optional[str] = None, ): """Initialize Graph Retriever.""" - self._triplet_graph_enabled = config.triplet_graph_enabled or ( + self._triplet_graph_enabled = triplet_graph_enabled or ( os.getenv("TRIPLET_GRAPH_ENABLED", "").lower() == "true" ) - self._document_graph_enabled = config.document_graph_enabled or ( + self._document_graph_enabled = document_graph_enabled or ( os.getenv("DOCUMENT_GRAPH_ENABLED", "").lower() == "true" ) triplet_topk = int( - config.extract_topk or os.getenv("KNOWLEDGE_GRAPH_EXTRACT_SEARCH_TOP_SIZE") + extract_top_k or os.getenv("KNOWLEDGE_GRAPH_EXTRACT_SEARCH_TOP_SIZE") ) document_topk = int( - config.knowledge_graph_chunk_search_top_size - or os.getenv("KNOWLEDGE_GRAPH_CHUNK_SEARCH_TOP_SIZE") + kg_chunk_search_top_k or os.getenv("KNOWLEDGE_GRAPH_CHUNK_SEARCH_TOP_SIZE") ) - llm_client = config.llm_client - model_name = config.model_name + llm_client = llm_client + model_name = llm_model self._enable_similarity_search = ( graph_store_adapter.graph_store.enable_similarity_search ) self._embedding_batch_size = int( - config.knowledge_graph_embedding_batch_size - or os.getenv("KNOWLEDGE_GRAPH_EMBEDDING_BATCH_SIZE") + embedding_batch_size or os.getenv("KNOWLEDGE_GRAPH_EMBEDDING_BATCH_SIZE") ) similarity_search_topk = int( - config.similarity_search_topk - or os.getenv("KNOWLEDGE_GRAPH_SIMILARITY_SEARCH_TOP_SIZE") + similarity_top_k or os.getenv("KNOWLEDGE_GRAPH_SIMILARITY_SEARCH_TOP_SIZE") ) similarity_search_score_threshold = float( - config.extract_score_threshold + similarity_score_threshold or os.getenv("KNOWLEDGE_GRAPH_EXTRACT_SEARCH_RECALL_SCORE") ) - self._enable_text_search = config.enable_text_search or ( + self._enable_text_search = enable_text_search or ( os.getenv("TEXT_SEARCH_ENABLED", "").lower() == "true" ) - text2gql_model_enabled = config.text2gql_model_enabled or ( + text2gql_model_enabled = text2gql_model_enabled or ( os.getenv("TEXT2GQL_MODEL_ENABLED", "").lower() == "true" ) - text2gql_model_name = config.text2gql_model_name or os.getenv( - "TEXT2GQL_MODEL_NAME" + text2gql_model_name = text2gql_model_name or os.getenv("TEXT2GQL_MODEL_NAME") + text2gql_model_enabled = ( + os.environ["TEXT2GQL_MODEL_ENABLED"].lower() == "true" + if "TEXT2GQL_MODEL_ENABLED" in os.environ + else text2gql_model_enabled + ) + text2gql_model_name = os.getenv( + "TEXT2GQL_MODEL_NAME", + text2gql_model_name, ) text2gql_model_enabled = ( os.environ["TEXT2GQL_MODEL_ENABLED"].lower() == "true" if "TEXT2GQL_MODEL_ENABLED" in os.environ - else config.text2gql_model_enabled + else text2gql_model_enabled ) text2gql_model_name = os.getenv( "TEXT2GQL_MODEL_NAME", - config.text2gql_model_name, + text2gql_model_name, ) text2gql_model_enabled = ( os.environ["TEXT2GQL_MODEL_ENABLED"].lower() == "true" if "TEXT2GQL_MODEL_ENABLED" in os.environ - else config.text2gql_model_enabled + else text2gql_model_enabled ) text2gql_model_name = os.getenv( "TEXT2GQL_MODEL_NAME", - config.text2gql_model_name, - ) - text2gql_model_enabled = ( - os.environ["TEXT2GQL_MODEL_ENABLED"].lower() == "true" - if "TEXT2GQL_MODEL_ENABLED" in os.environ - else config.text2gql_model_enabled - ) - text2gql_model_name = os.getenv( - "TEXT2GQL_MODEL_NAME", - config.text2gql_model_name, + text2gql_model_name, ) self._keyword_extractor = KeywordExtractor(llm_client, model_name) - self._text_embedder = TextEmbedder(config.embedding_fn) + self._text_embedder = TextEmbedder(embedding_fn) intent_interpreter = SimpleIntentTranslator(llm_client, model_name) if text2gql_model_enabled: diff --git a/packages/dbgpt-ext/src/dbgpt_ext/rag/transformer/graph_extractor.py b/packages/dbgpt-ext/src/dbgpt_ext/rag/transformer/graph_extractor.py index c8b2794f0..eb2be1c9e 100644 --- a/packages/dbgpt-ext/src/dbgpt_ext/rag/transformer/graph_extractor.py +++ b/packages/dbgpt-ext/src/dbgpt_ext/rag/transformer/graph_extractor.py @@ -17,19 +17,27 @@ class GraphExtractor(LLMExtractor): """GraphExtractor class.""" def __init__( - self, llm_client: LLMClient, model_name: str, chunk_history: VectorStoreBase + self, + llm_client: LLMClient, + model_name: str, + chunk_history: VectorStoreBase, + index_name: str, + max_chunks_once_load: Optional[int] = 10, + max_threads: Optional[int] = 1, + top_k: Optional[int] = 5, + score_threshold: Optional[float] = 0.7, ): """Initialize the GraphExtractor.""" super().__init__(llm_client, model_name, GRAPH_EXTRACT_PT_CN) self._chunk_history = chunk_history - config = self._chunk_history.get_config() + # config = self._chunk_history.get_config() - self._vector_space = config.name - self._max_chunks_once_load = config.max_chunks_once_load - self._max_threads = config.max_threads - self._topk = config.topk - self._score_threshold = config.score_threshold + self._vector_space = index_name + self._max_chunks_once_load = max_chunks_once_load + self._max_threads = max_threads + self._topk = top_k + self._score_threshold = score_threshold async def aload_chunk_context(self, texts: List[str]) -> Dict[str, str]: """Load chunk context.""" diff --git a/packages/dbgpt-ext/src/dbgpt_ext/storage/__init__.py b/packages/dbgpt-ext/src/dbgpt_ext/storage/__init__.py index 0525602a6..75a42774f 100644 --- a/packages/dbgpt-ext/src/dbgpt_ext/storage/__init__.py +++ b/packages/dbgpt-ext/src/dbgpt_ext/storage/__init__.py @@ -50,11 +50,11 @@ def _import_oceanbase() -> Tuple[Type, Type]: def _import_elastic() -> Tuple[Type, Type]: from dbgpt_ext.storage.vector_store.elastic_store import ( - ElasticsearchVectorConfig, + ElasticsearchStoreConfig, ElasticStore, ) - return ElasticStore, ElasticsearchVectorConfig + return ElasticStore, ElasticsearchStoreConfig def _import_builtin_knowledge_graph() -> Tuple[Type, Type]: @@ -115,6 +115,35 @@ def _select_rag_storage(name: str) -> Tuple[Type, Type]: raise AttributeError(f"Could not find: {name}") +_HAS_SCAN = False + + +def scan_storage_configs(): + """Scan storage configs.""" + from dbgpt.storage.base import IndexStoreConfig + from dbgpt.util.module_utils import ModelScanner, ScannerConfig + + global _HAS_SCAN + + if _HAS_SCAN: + return + modules = [ + "dbgpt_ext.storage.vector_store", + "dbgpt_ext.storage.knowledge_graph", + "dbgpt_ext.storage.graph_store", + ] + + scanner = ModelScanner[IndexStoreConfig]() + for module in modules: + config = ScannerConfig( + module_path=module, + base_class=IndexStoreConfig, + ) + scanner.scan_and_register(config) + _HAS_SCAN = True + return scanner.get_registered_items() + + __vector_store__ = [ "Chroma", "Milvus", diff --git a/packages/dbgpt-ext/src/dbgpt_ext/storage/full_text/elasticsearch.py b/packages/dbgpt-ext/src/dbgpt_ext/storage/full_text/elasticsearch.py index 2f6ff6d00..b2c26a625 100644 --- a/packages/dbgpt-ext/src/dbgpt_ext/storage/full_text/elasticsearch.py +++ b/packages/dbgpt-ext/src/dbgpt_ext/storage/full_text/elasticsearch.py @@ -11,21 +11,19 @@ from dbgpt.storage.full_text.base import FullTextStoreBase from dbgpt.storage.vector_store.filters import MetadataFilters from dbgpt.util import string_utils from dbgpt.util.executor_utils import blocking_func_to_async -from dbgpt_ext.storage.vector_store.elastic_store import ElasticsearchVectorConfig - - -class ElasticDocumentConfig(ElasticsearchVectorConfig): - """Elasticsearch document store config.""" - - k1: Optional[float] = 2.0 - b: Optional[float] = 0.75 +from dbgpt_ext.storage.vector_store.elastic_store import ElasticsearchStoreConfig class ElasticDocumentStore(FullTextStoreBase): """Elasticsearch index store.""" def __init__( - self, es_config: ElasticDocumentConfig, executor: Optional[Executor] = None + self, + es_config: ElasticsearchStoreConfig, + name: Optional[str] = "dbgpt", + k1: Optional[float] = 2.0, + b: Optional[float] = 0.75, + executor: Optional[Executor] = None, ): """Init elasticsearch index store. @@ -46,17 +44,17 @@ class ElasticDocumentStore(FullTextStoreBase): self._es_password = es_config.password or os.getenv( "ELASTICSEARCH_PASSWORD", "dbgpt" ) - self._index_name = es_config.name.lower() - if string_utils.contains_chinese(es_config.name): - bytes_str = es_config.name.encode("utf-8") + self._index_name = name.lower() + if string_utils.contains_chinese(name): + bytes_str = name.encode("utf-8") hex_str = bytes_str.hex() self._index_name = "dbgpt_" + hex_str # k1 (Optional[float]): Controls non-linear term frequency normalization # (saturation). The default value is 2.0. - self._k1 = es_config.k1 or 2.0 + self._k1 = k1 or 2.0 # b (Optional[float]): Controls to what degree document length normalizes # tf values. The default value is 0.75. - self._b = es_config.b or 0.75 + self._b = b or 0.75 if self._es_username and self._es_password: self._es_client = Elasticsearch( hosts=[f"http://{self._es_url}:{self._es_port}"], diff --git a/packages/dbgpt-ext/src/dbgpt_ext/storage/graph_store/tugraph_store.py b/packages/dbgpt-ext/src/dbgpt_ext/storage/graph_store/tugraph_store.py index 78b68d7ab..777a58000 100644 --- a/packages/dbgpt-ext/src/dbgpt_ext/storage/graph_store/tugraph_store.py +++ b/packages/dbgpt-ext/src/dbgpt_ext/storage/graph_store/tugraph_store.py @@ -4,9 +4,9 @@ import base64 import json import logging import os +from dataclasses import dataclass, field from typing import List -from dbgpt._private.pydantic import ConfigDict, Field from dbgpt.storage.graph_store.base import GraphStoreBase, GraphStoreConfig from dbgpt.storage.graph_store.graph import GraphElemType from dbgpt_ext.datasource.conn_tugraph import TuGraphConnector @@ -14,66 +14,89 @@ from dbgpt_ext.datasource.conn_tugraph import TuGraphConnector logger = logging.getLogger(__name__) +@dataclass class TuGraphStoreConfig(GraphStoreConfig): """TuGraph store config.""" - model_config = ConfigDict(arbitrary_types_allowed=True) + __type__ = "TuGraph" - host: str = Field( + host: str = field( default="127.0.0.1", - description="TuGraph host", + metadata={ + "description": "TuGraph host", + }, ) - port: int = Field( + port: int = field( default=7687, - description="TuGraph port", + metadata={ + "description": "TuGraph port", + }, ) - username: str = Field( + username: str = field( default="admin", - description="login username", + metadata={ + "description": "login username", + }, ) - password: str = Field( + password: str = field( default="73@TuGraph", - description="login password", + metadata={ + "description": "login password", + }, ) - vertex_type: str = Field( + vertex_type: str = field( default=GraphElemType.ENTITY.value, - description="The type of entity vertex, `entity` by default.", + metadata={ + "description": "The type of vertex, `entity` by default.", + }, ) - document_type: str = Field( + document_type: str = field( default=GraphElemType.DOCUMENT.value, - description="The type of document vertex, `document` by default.", + metadata={ + "description": "The type of document vertex, `document` by default.", + }, ) - chunk_type: str = Field( + chunk_type: str = field( default=GraphElemType.CHUNK.value, - description="The type of chunk vertex, `relation` by default.", + metadata={ + "description": "The type of chunk vertex, `relation` by default.", + }, ) - edge_type: str = Field( + edge_type: str = field( default=GraphElemType.RELATION.value, - description="The type of relation edge, `relation` by default.", + metadata={ + "description": "The type of relation edge, `relation` by default.", + }, ) - include_type: str = Field( + include_type: str = field( default=GraphElemType.INCLUDE.value, - description="The type of include edge, `include` by default.", + metadata={ + "description": "The type of include edge, `include` by default.", + }, ) - next_type: str = Field( + next_type: str = field( default=GraphElemType.NEXT.value, - description="The type of next edge, `next` by default.", + metadata={ + "description": "The type of next edge, `next` by default.", + }, ) - plugin_names: List[str] = Field( - default=["leiden"], - description=( - "Plugins need to be loaded when initialize TuGraph, " - "code: https://github.com/TuGraph-family" - "/dbgpt-tugraph-plugins/tree/master/cpp" - ), + plugin_names: List[str] = field( + default_factory=lambda: ["leiden"], + metadata={ + "description": "The list of plugin names to be uploaded to the database.", + }, ) - enable_summary: bool = Field( + enable_summary: bool = field( default=True, - description="Enable graph community summary or not.", + metadata={ + "description": "Enable graph community summary or not.", + }, ) - enable_similarity_search: bool = Field( + enable_similarity_search: bool = field( default=False, - description="Enable the similarity search or not", + metadata={ + "description": "Enable the similarity search or not", + }, ) diff --git a/packages/dbgpt-ext/src/dbgpt_ext/storage/knowledge_graph/community/community_metastore.py b/packages/dbgpt-ext/src/dbgpt_ext/storage/knowledge_graph/community/community_metastore.py index ad6acbe76..5792ff154 100644 --- a/packages/dbgpt-ext/src/dbgpt_ext/storage/knowledge_graph/community/community_metastore.py +++ b/packages/dbgpt-ext/src/dbgpt_ext/storage/knowledge_graph/community/community_metastore.py @@ -18,18 +18,24 @@ class BuiltinCommunityMetastore(CommunityMetastore): """Builtin Community metastore.""" def __init__( - self, vector_store: VectorStoreBase, rdb_store: Optional[RDBMSConnector] = None + self, + vector_store: VectorStoreBase, + rdb_store: Optional[RDBMSConnector] = None, + index_name: Optional[str] = None, + max_chunks_once_load: Optional[int] = 10, + max_threads: Optional[int] = 1, + top_k: Optional[int] = 5, + score_threshold: Optional[float] = 0.7, ): """Initialize Community metastore.""" self._vector_store = vector_store self._rdb_store = rdb_store - config = self._vector_store.get_config() - self._vector_space = config.name - self._max_chunks_once_load = config.max_chunks_once_load - self._max_threads = config.max_threads - self._topk = config.topk - self._score_threshold = config.score_threshold + self._vector_space = index_name + self._max_chunks_once_load = max_chunks_once_load + self._max_threads = max_threads + self._topk = top_k + self._score_threshold = score_threshold def get(self, community_id: str) -> Community: """Get community.""" diff --git a/packages/dbgpt-ext/src/dbgpt_ext/storage/knowledge_graph/community/community_store.py b/packages/dbgpt-ext/src/dbgpt_ext/storage/knowledge_graph/community/community_store.py index 6aa2bda0d..cce0dca16 100644 --- a/packages/dbgpt-ext/src/dbgpt_ext/storage/knowledge_graph/community/community_store.py +++ b/packages/dbgpt-ext/src/dbgpt_ext/storage/knowledge_graph/community/community_store.py @@ -25,11 +25,23 @@ class CommunityStore: graph_store_adapter: GraphStoreAdapter, community_summarizer: CommunitySummarizer, vector_store: VectorStoreBase, + index_name: Optional[str] = None, + max_chunks_once_load: Optional[int] = 10, + max_threads: Optional[int] = 1, + top_k: Optional[int] = 5, + score_threshold: Optional[float] = 0.7, ): """Initialize the CommunityStore class.""" self._graph_store_adapter = graph_store_adapter self._community_summarizer = community_summarizer - self._meta_store = BuiltinCommunityMetastore(vector_store) + self._meta_store = BuiltinCommunityMetastore( + vector_store=vector_store, + index_name=index_name, + max_chunks_once_load=max_chunks_once_load, + max_threads=max_threads, + top_k=top_k, + score_threshold=score_threshold, + ) async def build_communities(self, batch_size: int = 1): """Discover communities.""" diff --git a/packages/dbgpt-ext/src/dbgpt_ext/storage/knowledge_graph/community_summary.py b/packages/dbgpt-ext/src/dbgpt_ext/storage/knowledge_graph/community_summary.py index 9de29bc91..e66d5fc7b 100644 --- a/packages/dbgpt-ext/src/dbgpt_ext/storage/knowledge_graph/community_summary.py +++ b/packages/dbgpt-ext/src/dbgpt_ext/storage/knowledge_graph/community_summary.py @@ -5,9 +5,9 @@ import os import uuid from typing import List, Optional, Tuple -from dbgpt._private.pydantic import ConfigDict, Field -from dbgpt.core import Chunk, LLMClient +from dbgpt.core import Chunk, Embeddings, LLMClient from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource +from dbgpt.storage.graph_store.base import GraphStoreConfig from dbgpt.storage.knowledge_graph.base import ParagraphChunk from dbgpt.storage.vector_store.base import VectorStoreConfig from dbgpt.storage.vector_store.filters import MetadataFilters @@ -17,13 +17,13 @@ from dbgpt_ext.rag.transformer.community_summarizer import CommunitySummarizer from dbgpt_ext.rag.transformer.graph_embedder import GraphEmbedder from dbgpt_ext.rag.transformer.graph_extractor import GraphExtractor from dbgpt_ext.rag.transformer.text_embedder import TextEmbedder +from dbgpt_ext.storage.graph_store.tugraph_store import TuGraphStoreConfig from dbgpt_ext.storage.knowledge_graph.community.community_store import CommunityStore from dbgpt_ext.storage.knowledge_graph.knowledge_graph import ( GRAPH_PARAMETERS, BuiltinKnowledgeGraph, BuiltinKnowledgeGraphConfig, ) -from dbgpt_ext.storage.vector_store.factory import VectorStoreFactory logger = logging.getLogger(__name__) @@ -139,87 +139,6 @@ logger = logging.getLogger(__name__) ), ], ) -class CommunitySummaryKnowledgeGraphConfig(BuiltinKnowledgeGraphConfig): - """Community summary knowledge graph config.""" - - model_config = ConfigDict(arbitrary_types_allowed=True) - - vector_store_type: str = Field( - default="Chroma", - description="The type of vector store.", - ) - user: Optional[str] = Field( - default=None, - description="The user of vector store, if not set, will use the default user.", - ) - password: Optional[str] = Field( - default=None, - description=( - "The password of vector store, if not set, will use the default password." - ), - ) - extract_topk: int = Field( - default=5, - description="Topk of knowledge graph extract", - ) - extract_score_threshold: float = Field( - default=0.3, - description="Recall score of knowledge graph extract", - ) - community_topk: int = Field( - default=50, - description="Topk of community search in knowledge graph", - ) - community_score_threshold: float = Field( - default=0.3, - description="Recall score of community search in knowledge graph", - ) - triplet_graph_enabled: bool = Field( - default=True, - description="Enable the graph search for triplets", - ) - document_graph_enabled: bool = Field( - default=True, - description="Enable the graph search for documents and chunks", - ) - knowledge_graph_chunk_search_top_size: int = Field( - default=5, - description="Top size of knowledge graph chunk search", - ) - knowledge_graph_extraction_batch_size: int = Field( - default=20, - description="Batch size of triplets extraction from the text", - ) - community_summary_batch_size: int = Field( - default=20, - description="Batch size of parallel community building process", - ) - knowledge_graph_embedding_batch_size: int = Field( - default=20, - description="Batch size of triplets embedding from the text", - ) - similarity_search_topk: int = Field( - default=5, - description="Topk of similarity search", - ) - similarity_search_score_threshold: float = Field( - default=0.7, - description="Recall score of similarity search", - ) - enable_text_search: bool = Field( - default=False, - description="Enable text2gql search or not.", - ) - text2gql_model_enabled: bool = Field( - default=False, - description="Enable fine-tuned text2gql model for text2gql translation.", - ) - text2gql_model_name: str = Field( - default=None, - description="LLM Model for text2gql translation.", - ) - - @register_resource( _("Community Summary Knowledge Graph"), "community_summary_knowledge_graph", @@ -239,104 +158,137 @@ class CommunitySummaryKnowledgeGraphConfig(BuiltinKnowledgeGraphConfig): class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph): """Community summary knowledge graph class.""" - def __init__(self, config: CommunitySummaryKnowledgeGraphConfig): + def __init__( + self, + config: GraphStoreConfig, + name: Optional[str] = "dbgpt", + llm_client: Optional[LLMClient] = None, + llm_model: Optional[str] = None, + kg_extract_top_k: Optional[int] = 5, + kg_extract_score_threshold: Optional[float] = 0.3, + kg_community_top_k: Optional[int] = 50, + kg_community_score_threshold: Optional[float] = 0.3, + kg_triplet_graph_enabled: Optional[bool] = True, + kg_document_graph_enabled: Optional[bool] = True, + kg_chunk_search_top_k: Optional[int] = 5, + kg_extraction_batch_size: Optional[int] = 3, + kg_community_summary_batch_size: Optional[int] = 20, + kg_embedding_batch_size: Optional[int] = 20, + kg_similarity_top_k: Optional[int] = 5, + kg_similarity_score_threshold: Optional[float] = 0.7, + kg_enable_text_search: Optional[float] = False, + kg_text2gql_model_enabled: Optional[bool] = False, + kg_text2gql_model_name: Optional[str] = None, + embedding_fn: Optional[Embeddings] = None, + vector_store_config: Optional["VectorStoreConfig"] = None, + kg_max_chunks_once_load: Optional[int] = 10, + kg_max_threads: Optional[int] = 1, + ): """Initialize community summary knowledge graph class.""" - super().__init__(config) + super().__init__( + config=config, name=name, llm_client=llm_client, llm_model=llm_model + ) self._config = config - self._vector_store_type = config.vector_store_type or os.getenv( + self._vector_store_type = config.get_type_value() or os.getenv( "VECTOR_STORE_TYPE" ) self._extract_topk = int( - config.extract_topk or os.getenv("KNOWLEDGE_GRAPH_EXTRACT_SEARCH_TOP_SIZE") + kg_extract_top_k or os.getenv("KNOWLEDGE_GRAPH_EXTRACT_SEARCH_TOP_SIZE") ) self._extract_score_threshold = float( - config.extract_score_threshold + kg_extract_score_threshold or os.getenv("KNOWLEDGE_GRAPH_EXTRACT_SEARCH_RECALL_SCORE") ) self._community_topk = int( - config.community_topk - or os.getenv("KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_TOP_SIZE") + kg_community_top_k or os.getenv("KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_TOP_SIZE") ) self._community_score_threshold = float( - config.community_score_threshold + kg_community_score_threshold or os.getenv("KNOWLEDGE_GRAPH_COMMUNITY_SEARCH_RECALL_SCORE") ) - self._document_graph_enabled = config.document_graph_enabled or ( + self._document_graph_enabled = kg_document_graph_enabled or ( os.getenv("DOCUMENT_GRAPH_ENABLED", "").lower() == "true" ) - self._triplet_graph_enabled = config.triplet_graph_enabled or ( + self._triplet_graph_enabled = kg_triplet_graph_enabled or ( os.getenv("TRIPLET_GRAPH_ENABLED", "").lower() == "true" ) self._triplet_extraction_batch_size = int( - config.knowledge_graph_extraction_batch_size + kg_extraction_batch_size or os.getenv("KNOWLEDGE_GRAPH_EXTRACTION_BATCH_SIZE") ) self._triplet_embedding_batch_size = int( - config.knowledge_graph_embedding_batch_size - or os.getenv("KNOWLEDGE_GRAPH_EMBEDDING_BATCH_SIZE") + kg_embedding_batch_size or os.getenv("KNOWLEDGE_GRAPH_EMBEDDING_BATCH_SIZE") ) self._community_summary_batch_size = int( - config.community_summary_batch_size - or os.getenv("COMMUNITY_SUMMARY_BATCH_SIZE") + kg_community_summary_batch_size or os.getenv("COMMUNITY_SUMMARY_BATCH_SIZE") ) - - def extractor_configure(name: str, cfg: VectorStoreConfig): - cfg.name = name - cfg.embedding_fn = config.embedding_fn - cfg.max_chunks_once_load = config.max_chunks_once_load - cfg.max_threads = config.max_threads - cfg.user = config.user - cfg.password = config.password - cfg.topk = self._extract_topk - cfg.score_threshold = self._extract_score_threshold + self._embedding_fn = embedding_fn + self._vector_store_config = vector_store_config self._graph_extractor = GraphExtractor( self._llm_client, self._model_name, - VectorStoreFactory.create( - self._vector_store_type, - config.name + "_CHUNK_HISTORY", - extractor_configure, + vector_store_config.create_store( + name=name + "_CHUNK_HISTORY", embedding_fn=embedding_fn ), + index_name=name, + max_chunks_once_load=kg_max_chunks_once_load, + max_threads=kg_max_threads, + top_k=kg_extract_top_k, + score_threshold=kg_extract_score_threshold, ) - self._graph_embedder = GraphEmbedder(self._config.embedding_fn) - self._text_embedder = TextEmbedder(self._config.embedding_fn) + self._graph_embedder = GraphEmbedder(embedding_fn) + self._text_embedder = TextEmbedder(embedding_fn) - def community_store_configure(name: str, cfg: VectorStoreConfig): - cfg.name = name - cfg.embedding_fn = config.embedding_fn - cfg.max_chunks_once_load = config.max_chunks_once_load - cfg.max_threads = config.max_threads - cfg.user = config.user - cfg.password = config.password - cfg.topk = self._community_topk - cfg.score_threshold = self._community_score_threshold + # def community_store_configure(name: str, cfg: VectorStoreConfig): + # cfg.name = name + # cfg.embedding_fn = self._embedding_fn + # cfg.max_chunks_once_load = max_chunks_once_load + # cfg.max_threads = max_threads + # cfg.user = config.user + # cfg.password = config.password + # cfg.topk = self._community_topk + # cfg.score_threshold = self._community_score_threshold self._community_store = CommunityStore( self._graph_store_adapter, CommunitySummarizer(self._llm_client, self._model_name), - VectorStoreFactory.create( - self._vector_store_type, - config.name + "_COMMUNITY_SUMMARY", - community_store_configure, + vector_store_config.create_store( + name=name + "_COMMUNITY_SUMMARY", embedding_fn=embedding_fn ), + index_name=name, + max_chunks_once_load=kg_max_chunks_once_load, + max_threads=kg_max_threads, + top_k=kg_community_top_k, + score_threshold=kg_extract_score_threshold, ) self._graph_retriever = GraphRetriever( - config, self._graph_store_adapter, + llm_client=llm_client, + llm_model=llm_model, + triplet_graph_enabled=kg_triplet_graph_enabled, + document_graph_enabled=kg_document_graph_enabled, + extract_top_k=kg_extract_top_k, + kg_chunk_search_top_k=kg_chunk_search_top_k, + similarity_top_k=kg_similarity_top_k, + similarity_score_threshold=kg_similarity_score_threshold, + embedding_fn=embedding_fn, + embedding_batch_size=kg_embedding_batch_size, + text2gql_model_enabled=kg_text2gql_model_enabled, + text2gql_model_name=kg_text2gql_model_name, ) - def get_config(self) -> BuiltinKnowledgeGraphConfig: + def get_config(self) -> TuGraphStoreConfig: """Get the knowledge graph config.""" return self._config async def aload_document(self, chunks: List[Chunk]) -> List[str]: """Extract and persist graph from the document file.""" if not self.vector_name_exists(): - self._graph_store_adapter.create_graph(self.get_config().name) + self._graph_store_adapter.create_graph(self._graph_name) await self._aload_document_graph(chunks) await self._aload_triplet_graph(chunks) await self._community_store.build_communities( @@ -447,7 +399,8 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph): chunk_name=doc_name, ) - # chunk.metadata = {"Header0": "title", "Header1": "title", ..., "source": "source_path"} # noqa: E501 + # chunk.metadata = {"Header0": "title", + # "Header1": "title", ..., "source": "source_path"} # noqa: E501 for chunk_index, chunk in enumerate(chunks): parent = None directory_keys = list(chunk.metadata.keys())[ @@ -542,7 +495,7 @@ class CommunitySummaryKnowledgeGraph(BuiltinKnowledgeGraph): self._graph_embedder.truncate() logger.info("Truncate text embedder") self._text_embedder.truncate() - return [self._config.name] + return [self._graph_name] def delete_vector_name(self, index_name: str): """Delete knowledge graph.""" diff --git a/packages/dbgpt-ext/src/dbgpt_ext/storage/knowledge_graph/knowledge_graph.py b/packages/dbgpt-ext/src/dbgpt_ext/storage/knowledge_graph/knowledge_graph.py index 636a7097a..a49410546 100644 --- a/packages/dbgpt-ext/src/dbgpt_ext/storage/knowledge_graph/knowledge_graph.py +++ b/packages/dbgpt-ext/src/dbgpt_ext/storage/knowledge_graph/knowledge_graph.py @@ -3,9 +3,9 @@ import asyncio import logging import os +from dataclasses import dataclass, field from typing import List, Optional -from dbgpt._private.pydantic import ConfigDict, Field from dbgpt.core import Chunk, Embeddings, LLMClient from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource from dbgpt.rag.transformer.keyword_extractor import KeywordExtractor @@ -16,6 +16,7 @@ from dbgpt.storage.vector_store.filters import MetadataFilters from dbgpt.util.i18n_utils import _ from dbgpt_ext.rag.transformer.triplet_extractor import TripletExtractor from dbgpt_ext.storage.graph_store.factory import GraphStoreFactory +from dbgpt_ext.storage.graph_store.tugraph_store import TuGraphStoreConfig from dbgpt_ext.storage.knowledge_graph.community.base import GraphStoreAdapter from dbgpt_ext.storage.knowledge_graph.community.factory import GraphStoreAdapterFactory @@ -98,16 +99,19 @@ GRAPH_PARAMETERS = [ ), ], ) +@dataclass class BuiltinKnowledgeGraphConfig(KnowledgeGraphConfig): """Builtin knowledge graph config.""" - model_config = ConfigDict(arbitrary_types_allowed=True) + __type__ = "TuGraph" - llm_client: LLMClient = Field(default=None, description="The default llm client.") + llm_model: Optional[str] = field( + default=None, metadata={"description": "llm model name."} + ) - model_name: str = Field(default=None, description="The name of llm model.") - - type: str = Field(default="TuGraph", description="The type of graph store.") + graph_type: Optional[str] = field( + default="TuGraph", metadata={"description": "graph store type."} + ) @register_resource( @@ -129,33 +133,38 @@ class BuiltinKnowledgeGraphConfig(KnowledgeGraphConfig): class BuiltinKnowledgeGraph(KnowledgeGraphBase): """Builtin knowledge graph class.""" - def __init__(self, config: BuiltinKnowledgeGraphConfig): + def __init__( + self, + config: GraphStoreConfig = None, + name: Optional[str] = "dbgpt", + llm_client: Optional[LLMClient] = None, + llm_model: Optional[str] = None, + ): """Create builtin knowledge graph instance.""" super().__init__() self._config = config - - self._llm_client = config.llm_client + self._llm_client = llm_client + self._graph_name = name if not self._llm_client: raise ValueError("No llm client provided.") - self._model_name = config.model_name + self._model_name = llm_model self._triplet_extractor = TripletExtractor(self._llm_client, self._model_name) self._keyword_extractor = KeywordExtractor(self._llm_client, self._model_name) self._graph_store: GraphStoreBase = self.__init_graph_store(config) self._graph_store_adapter: GraphStoreAdapter = self.__init_graph_store_adapter() - def __init_graph_store(self, config: BuiltinKnowledgeGraphConfig) -> GraphStoreBase: + def __init_graph_store(self, config: GraphStoreConfig) -> GraphStoreBase: def configure(cfg: GraphStoreConfig): - cfg.name = config.name - cfg.embedding_fn = config.embedding_fn + cfg.name = self._graph_name - graph_store_type = config.type or os.getenv("GRAPH_STORE_TYPE") - return GraphStoreFactory.create(graph_store_type, configure, config.dict()) + graph_store_type = config.get_type_value() or os.getenv("GRAPH_STORE_TYPE") + return GraphStoreFactory.create(graph_store_type, configure, config.to_dict()) def __init_graph_store_adapter(self): return GraphStoreAdapterFactory.create(self._graph_store) - def get_config(self) -> BuiltinKnowledgeGraphConfig: + def get_config(self) -> TuGraphStoreConfig: """Get the knowledge graph config.""" return self._config @@ -171,7 +180,7 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase): # wait async tasks completed if not self.vector_name_exists(): - self._graph_store_adapter.create_graph(self.get_config().name) + self._graph_store_adapter.create_graph(self._graph_name) tasks = [process_chunk(chunk) for chunk in chunks] loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) @@ -188,7 +197,7 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase): List[str]: chunk ids. """ if not self.vector_name_exists(): - self._graph_store_adapter.create_graph(self.get_config().name) + self._graph_store_adapter.create_graph(self._graph_name) for chunk in chunks: triplets = await self._triplet_extractor.extract(chunk.content) for triplet in triplets: @@ -257,7 +266,7 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase): def truncate(self) -> List[str]: """Truncate knowledge graph.""" - logger.info(f"Truncate graph {self._config.name}") + logger.info(f"Truncate graph {self._graph_name}") self._graph_store_adapter.truncate() logger.info("Truncate keyword extractor") @@ -266,7 +275,7 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase): logger.info("Truncate triplet extractor") self._triplet_extractor.truncate() - return [self._config.name] + return [self._graph_name] def delete_vector_name(self, index_name: str): """Delete vector name.""" @@ -286,4 +295,4 @@ class BuiltinKnowledgeGraph(KnowledgeGraphBase): def vector_name_exists(self) -> bool: """Whether name exists.""" - return self._graph_store_adapter.graph_store.is_exist(self._config.name) + return self._graph_store_adapter.graph_store.is_exist(self._graph_name) diff --git a/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/chroma_store.py b/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/chroma_store.py index 012e3693d..11b308e2d 100644 --- a/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/chroma_store.py +++ b/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/chroma_store.py @@ -2,11 +2,11 @@ import logging import os +from dataclasses import dataclass, field from typing import Any, Dict, Iterable, List, Mapping, Optional, Union -from dbgpt._private.pydantic import ConfigDict, Field from dbgpt.configs.model_config import PILOT_PATH, resolve_root_path -from dbgpt.core import Chunk +from dbgpt.core import Chunk, Embeddings from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource from dbgpt.storage.vector_store.base import ( _COMMON_PARAMETERS, @@ -18,8 +18,6 @@ from dbgpt.util.i18n_utils import _ logger = logging.getLogger(__name__) -CHROMA_COLLECTION_NAME = "langchain" - @register_resource( _("Chroma Config"), @@ -38,21 +36,29 @@ CHROMA_COLLECTION_NAME = "langchain" ), ], ) +@dataclass class ChromaVectorConfig(VectorStoreConfig): """Chroma vector store config.""" - model_config = ConfigDict(arbitrary_types_allowed=True) + __type__ = "Chroma" - persist_path: Optional[str] = Field( + persist_path: Optional[str] = field( default=os.getenv("CHROMA_PERSIST_PATH", None), - description="the persist path of vector store.", + metadata={ + "help": _("The persist path of vector store."), + }, ) - collection_metadata: Optional[dict] = Field( + collection_metadata: Optional[dict] = field( default=None, - description="the index metadata of vector store, if not set, will use the " - "default metadata.", + metadata={ + "help": _("The metadata of collection."), + }, ) + def create_store(self, **kwargs) -> "ChromaStore": + """Create index store.""" + return ChromaStore(vector_store_config=self, **kwargs) + @register_resource( _("Chroma Vector Store"), @@ -73,11 +79,22 @@ class ChromaVectorConfig(VectorStoreConfig): class ChromaStore(VectorStoreBase): """Chroma vector store.""" - def __init__(self, vector_store_config: ChromaVectorConfig) -> None: + def __init__( + self, + vector_store_config: ChromaVectorConfig, + name: Optional[str], + embedding_fn: Optional[Embeddings] = None, + chroma_client: Optional["PersistentClient"] = None, # type: ignore # noqa + collection_metadata: Optional[dict] = None, + ) -> None: """Create a ChromaStore instance. Args: vector_store_config(ChromaVectorConfig): vector store config. + name(str): collection name. + embedding_fn(Embeddings): embedding function. + chroma_client(PersistentClient): chroma client. + collection_metadata(dict): collection metadata. """ super().__init__() self._vector_store_config = vector_store_config @@ -85,28 +102,27 @@ class ChromaStore(VectorStoreBase): from chromadb import PersistentClient, Settings except ImportError: raise ImportError("Please install chroma package first.") - chroma_vector_config = vector_store_config.to_dict(exclude_none=True) + chroma_vector_config = vector_store_config.to_dict() chroma_path = chroma_vector_config.get( "persist_path", os.path.join(PILOT_PATH, "data") ) - self.persist_dir = os.path.join( - resolve_root_path(chroma_path), vector_store_config.name + ".vectordb" - ) - self.embeddings = vector_store_config.embedding_fn + self.persist_dir = os.path.join(resolve_root_path(chroma_path) + "/chromadb") + self.embeddings = embedding_fn + if not self.embeddings: + raise ValueError("Embeddings is None") chroma_settings = Settings( # chroma_db_impl="duckdb+parquet", => deprecated configuration of Chroma persist_directory=self.persist_dir, anonymized_telemetry=False, ) - self._chroma_client = PersistentClient( - path=self.persist_dir, settings=chroma_settings - ) - - collection_metadata = chroma_vector_config.get("collection_metadata") or { - "hnsw:space": "cosine" - } + self._chroma_client = chroma_client + if not self._chroma_client: + self._chroma_client = PersistentClient( + path=self.persist_dir, settings=chroma_settings + ) + collection_metadata = collection_metadata or {"hnsw:space": "cosine"} self._collection = self._chroma_client.get_or_create_collection( - name=CHROMA_COLLECTION_NAME, + name=name, embedding_function=None, metadata=collection_metadata, ) diff --git a/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/elastic_store.py b/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/elastic_store.py index f46ccf88a..6e19ceb08 100644 --- a/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/elastic_store.py +++ b/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/elastic_store.py @@ -4,9 +4,9 @@ from __future__ import annotations import logging import os +from dataclasses import dataclass, field from typing import List, Optional -from dbgpt._private.pydantic import Field from dbgpt.core import Chunk, Embeddings from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource from dbgpt.storage.vector_store.base import ( @@ -73,47 +73,63 @@ logger = logging.getLogger(__name__) ], description=_("Elasticsearch vector config."), ) -class ElasticsearchVectorConfig(VectorStoreConfig): +@dataclass +class ElasticsearchStoreConfig(VectorStoreConfig): """Elasticsearch vector store config.""" - class Config: - """Config for BaseModel.""" + __type__ = "ElasticSearch" - arbitrary_types_allowed = True - - uri: str = Field( + uri: str = field( default="localhost", - description="The uri of elasticsearch store, if not set, will use the default " - "uri.", + metadata={ + "description": "The uri of elasticsearch store, if not set, " + "will use the default uri." + }, ) - port: str = Field( + port: str = field( default="9200", - description="The port of elasticsearch store, if not set, will use the default " - "port.", + metadata={ + "description": "The port of elasticsearch store, if not set, will use the " + "default port." + }, ) - alias: str = Field( + alias: str = field( default="default", - description="The alias of elasticsearch store, if not set, will use the " - "default " - "alias.", + metadata={ + "description": "The alias of elasticsearch store, if not set, will use the " + "default alias." + }, ) - index_name: str = Field( + index_name: str = field( default="index_name_test", - description="The index name of elasticsearch store, if not set, will use the " - "default index name.", + metadata={ + "description": "The index name of elasticsearch store, if not set, will" + " use " + "the default index name." + }, ) - metadata_field: str = Field( + metadata_field: str = field( default="metadata", - description="The metadata field of elasticsearch store, if not set, will use " - "the default metadata field.", + metadata={ + "description": "The metadata field of elasticsearch store, " + "if not set, will " + "use the default metadata field." + }, ) - secure: str = Field( + secure: str = field( default="", - description="The secure of elasticsearch store, if not set, will use the " - "default secure.", + metadata={ + "description": "The secure of elasticsearch store, i" + "f not set, will use the " + "default secure." + }, ) + def create_store(self, **kwargs) -> "ElasticStore": + """Create Elastic store.""" + return ElasticStore(vector_store_config=self, **kwargs) + @register_resource( _("Elastic Vector Store"), @@ -124,7 +140,7 @@ class ElasticsearchVectorConfig(VectorStoreConfig): Parameter.build_from( _("Elastic Config"), "vector_store_config", - ElasticsearchVectorConfig, + ElasticsearchStoreConfig, description=_("the elastic config of vector store."), optional=True, default=None, @@ -134,11 +150,16 @@ class ElasticsearchVectorConfig(VectorStoreConfig): class ElasticStore(VectorStoreBase): """Elasticsearch vector store.""" - def __init__(self, vector_store_config: ElasticsearchVectorConfig) -> None: + def __init__( + self, + vector_store_config: ElasticsearchStoreConfig, + name: Optional[str], + embedding_fn: Optional[Embeddings] = None, + ) -> None: """Create a ElasticsearchStore instance. Args: - vector_store_config (ElasticsearchVectorConfig): ElasticsearchStore config. + vector_store_config (ElasticsearchStoreConfig): ElasticsearchStore config. """ super().__init__() self._vector_store_config = vector_store_config @@ -158,21 +179,19 @@ class ElasticStore(VectorStoreBase): "ELASTICSEARCH_PASSWORD" ) or elasticsearch_vector_config.get("password") - self.collection_name = ( - elasticsearch_vector_config.get("name") or vector_store_config.name - ) + self.collection_name = name # name to hex if string_utils.contains_chinese(self.collection_name): bytes_str = self.collection_name.encode("utf-8") hex_str = bytes_str.hex() self.collection_name = hex_str - if vector_store_config.embedding_fn is None: + if embedding_fn is None: # Perform runtime checks on self.embedding to # ensure it has been correctly set and loaded raise ValueError("embedding_fn is required for ElasticSearchStore") # to lower case self.index_name = self.collection_name.lower() - self.embedding: Embeddings = vector_store_config.embedding_fn + self.embedding: Embeddings = embedding_fn self.fields: List = [] if (self.username is None) != (self.password is None): @@ -251,7 +270,7 @@ class ElasticStore(VectorStoreBase): except Exception as e: logger.error(f"ElasticSearch connection failed: {e}") - def get_config(self) -> ElasticsearchVectorConfig: + def get_config(self) -> ElasticsearchStoreConfig: """Get the vector store config.""" return self._vector_store_config diff --git a/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/factory.py b/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/factory.py deleted file mode 100644 index fac690884..000000000 --- a/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/factory.py +++ /dev/null @@ -1,46 +0,0 @@ -"""Vector store factory.""" - -import logging -from typing import Tuple, Type - -from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig -from dbgpt_ext.storage import __vector_store__ as vector_store_list -from dbgpt_ext.storage import _select_rag_storage - -logger = logging.getLogger(__name__) - - -class VectorStoreFactory: - """Factory for vector store.""" - - @staticmethod - def create( - vector_store_type: str, vector_space_name: str, vector_store_configure=None - ) -> VectorStoreBase: - """Create a VectorStore instance. - - Args: - - vector_store_type: vector store type Chroma, Milvus, etc. - - vector_store_config: vector store config - """ - store_cls, cfg_cls = VectorStoreFactory.__find_type(vector_store_type) - - try: - config = cfg_cls() - if vector_store_configure: - vector_store_configure(vector_space_name, config) - return store_cls(config) - except Exception as e: - logger.error("create vector store failed: %s", e) - raise e - - @staticmethod - def __find_type(vector_store_type: str) -> Tuple[Type, Type]: - for t in vector_store_list: - if t.lower() == vector_store_type.lower(): - store_cls, cfg_cls = _select_rag_storage(t) - if issubclass(store_cls, VectorStoreBase) and issubclass( - cfg_cls, VectorStoreConfig - ): - return store_cls, cfg_cls - raise Exception(f"Vector store {vector_store_type} not supported") diff --git a/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/milvus_store.py b/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/milvus_store.py index 81e54ed19..841720113 100644 --- a/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/milvus_store.py +++ b/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/milvus_store.py @@ -5,9 +5,9 @@ from __future__ import annotations import json import logging import os +from dataclasses import dataclass, field from typing import Any, Iterable, List, Optional -from dbgpt._private.pydantic import ConfigDict, Field from dbgpt.core import Chunk, Embeddings from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource from dbgpt.storage.vector_store.base import ( @@ -94,51 +94,82 @@ logger = logging.getLogger(__name__) ], description=_("Milvus vector config."), ) +@dataclass class MilvusVectorConfig(VectorStoreConfig): """Milvus vector store config.""" - model_config = ConfigDict(arbitrary_types_allowed=True) + __type__ = "Milvus" - uri: Optional[str] = Field( + uri: str = field( default=None, - description="The uri of milvus store, if not set, will use the default uri.", + metadata={ + "help": _("The uri of milvus store, if not set, will use the default uri.") + }, ) - port: str = Field( + port: str = field( default="19530", - description="The port of milvus store, if not set, will use the default port.", + metadata={ + "help": _( + "The port of milvus store, if not set, will use the default port." + ) + }, ) - alias: str = Field( + alias: str = field( default="default", - description="The alias of milvus store, if not set, will use the default " - "alias.", + metadata={ + "help": _( + "The alias of milvus store, if not set, will use the default alias." + ) + }, ) - primary_field: str = Field( + primary_field: str = field( default="pk_id", - description="The primary field of milvus store, if not set, will use the " - "default primary field.", + metadata={ + "help": _( + "The primary field of milvus store, i" + "f not set, will use the default primary field." + ) + }, ) - text_field: str = Field( + text_field: str = field( default="content", - description="The text field of milvus store, if not set, will use the default " - "text field.", + metadata={ + "help": _( + "The text field of milvus store, if not set, will use the " + "default text field." + ) + }, ) - embedding_field: str = Field( + embedding_field: str = field( default="vector", - description="The embedding field of milvus store, if not set, will use the " - "default embedding field.", + metadata={ + "help": _( + "The embedding field of milvus store, if not set, will use the " + "default embedding field." + ) + }, ) - metadata_field: str = Field( + metadata_field: str = field( default="metadata", - description="The metadata field of milvus store, if not set, will use the " - "default metadata field.", + metadata={ + "help": _( + "The metadata field of milvus store, if not set, will use the " + "default metadata field." + ) + }, ) - secure: str = Field( + secure: str = field( default="", - description="The secure of milvus store, if not set, will use the default " - "secure.", + metadata={ + "help": _("The secure of milvus store, if not set, will use the default ") + }, ) + def create_store(self, **kwargs) -> "MilvusStore": + """Create Milvus Store.""" + return MilvusStore(vector_store_config=self, **kwargs) + @register_resource( _("Milvus Vector Store"), @@ -159,7 +190,12 @@ class MilvusVectorConfig(VectorStoreConfig): class MilvusStore(VectorStoreBase): """Milvus vector store.""" - def __init__(self, vector_store_config: MilvusVectorConfig) -> None: + def __init__( + self, + vector_store_config: MilvusVectorConfig, + name: Optional[str], + embedding_fn: Optional[Embeddings] = None, + ) -> None: """Create a MilvusStore instance. Args: @@ -190,18 +226,16 @@ class MilvusStore(VectorStoreBase): ) self.secure = milvus_vector_config.get("secure") or os.getenv("MILVUS_SECURE") - self.collection_name = ( - milvus_vector_config.get("name") or vector_store_config.name - ) + self.collection_name = name if string_utils.contains_chinese(self.collection_name): bytes_str = self.collection_name.encode("utf-8") hex_str = bytes_str.hex() self.collection_name = hex_str - if vector_store_config.embedding_fn is None: + if embedding_fn is None: # Perform runtime checks on self.embedding to # ensure it has been correctly set and loaded raise ValueError("embedding_fn is required for MilvusStore") - self.embedding: Embeddings = vector_store_config.embedding_fn + self.embedding: Embeddings = embedding_fn self.fields: List = [] self.alias = milvus_vector_config.get("alias") or "default" diff --git a/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/oceanbase_store.py b/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/oceanbase_store.py index ddc8ff820..7848598f6 100644 --- a/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/oceanbase_store.py +++ b/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/oceanbase_store.py @@ -5,14 +5,14 @@ import logging import math import os import uuid +from dataclasses import dataclass, field from typing import Any, List, Optional, Tuple import numpy as np -from pydantic import Field from sqlalchemy import JSON, Column, String, Table, func, text from sqlalchemy.dialects.mysql import LONGTEXT -from dbgpt.core import Chunk +from dbgpt.core import Chunk, Embeddings from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource from dbgpt.storage.vector_store.base import ( _COMMON_PARAMETERS, @@ -122,35 +122,48 @@ def _normalize(vector: List[float]) -> List[float]: ], description="OceanBase vector store config.", ) +@dataclass class OceanBaseConfig(VectorStoreConfig): """OceanBase vector store config.""" - class Config: - """Config for BaseModel.""" + __type__ = "OceanBase" - arbitrary_types_allowed = True + ob_host: Optional[str] = field( + default=None, + metadata={ + "help": "The host of oceanbase, if not set, will use the default host." + }, + ) + ob_port: Optional[int] = field( + default=None, + metadata={ + "help": "The port of oceanbase, if not set, will use the default port." + }, + ) + ob_user: Optional[str] = field( + default=None, + metadata={ + "help": "The user of oceanbase, if not set, will use the default user." + }, + ) + ob_password: Optional[str] = field( + default=None, + metadata={ + "help": "The password of oceanbase, if not set, " + "will use the default password" + }, + ) + ob_database: Optional[str] = field( + default=None, + metadata={ + "help": "The database for vector tables, if not set, " + "will use the default database." + }, + ) - """OceanBase config""" - ob_host: Optional[str] = Field( - default=None, - description="oceanbase host", - ) - ob_port: Optional[int] = Field( - default=None, - description="oceanbase port", - ) - ob_user: Optional[str] = Field( - default=None, - description="user to login", - ) - ob_password: Optional[str] = Field( - default=None, - description="password to login", - ) - ob_database: Optional[str] = Field( - default=None, - description="database for vector tables", - ) + def create_store(self, **kwargs) -> "OceanBaseStore": + """Create OceanBase store.""" + return OceanBaseStore(vector_store_config=self, **kwargs) @register_resource( @@ -172,7 +185,12 @@ class OceanBaseConfig(VectorStoreConfig): class OceanBaseStore(VectorStoreBase): """OceanBase vector store.""" - def __init__(self, vector_store_config: OceanBaseConfig) -> None: + def __init__( + self, + vector_store_config: OceanBaseConfig, + name: Optional[str], + embedding_fn: Optional[Embeddings] = None, + ) -> None: """Create a OceanBaseStore instance.""" try: from pyobvector import ObVecClient # type: ignore @@ -188,8 +206,8 @@ class OceanBaseStore(VectorStoreBase): super().__init__() self._vector_store_config = vector_store_config - self.embedding_function = vector_store_config.embedding_fn - self.table_name = vector_store_config.name + self.embedding_function = embedding_fn + self.table_name = name vector_store_config_map = vector_store_config.to_dict() OB_HOST = str( diff --git a/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/pgvector_store.py b/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/pgvector_store.py index 950469301..87c63dd22 100644 --- a/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/pgvector_store.py +++ b/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/pgvector_store.py @@ -1,10 +1,10 @@ """Postgres vector store.""" import logging +from dataclasses import dataclass, field from typing import List, Optional -from dbgpt._private.pydantic import ConfigDict, Field -from dbgpt.core import Chunk +from dbgpt.core import Chunk, Embeddings from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource from dbgpt.storage.vector_store.base import ( _COMMON_PARAMETERS, @@ -37,17 +37,22 @@ logger = logging.getLogger(__name__) ], description="PG vector config.", ) +@dataclass class PGVectorConfig(VectorStoreConfig): """PG vector store config.""" - model_config = ConfigDict(arbitrary_types_allowed=True) - - connection_string: str = Field( + connection_string: str = field( default=None, - description="the connection string of vector store, if not set, will use the " - "default connection string.", + metadata={ + "description": "the connection string of vector store, " + "if not set, will use the default connection string." + }, ) + def create_store(self, **kwargs) -> "PGVectorStore": + """Create Milvus Store.""" + return PGVectorStore(vector_store_config=self, **kwargs) + @register_resource( _("PG Vector Store"), @@ -71,7 +76,12 @@ class PGVectorStore(VectorStoreBase): To use this, you should have the ``pgvector`` python package installed. """ - def __init__(self, vector_store_config: PGVectorConfig) -> None: + def __init__( + self, + vector_store_config: PGVectorConfig, + name: Optional[str], + embedding_fn: Optional[Embeddings] = None, + ) -> None: """Create a PGVectorStore instance.""" try: from langchain.vectorstores import PGVector # mypy: ignore @@ -83,8 +93,8 @@ class PGVectorStore(VectorStoreBase): self._vector_store_config = vector_store_config self.connection_string = vector_store_config.connection_string - self.embeddings = vector_store_config.embedding_fn - self.collection_name = vector_store_config.name + self.embeddings = embedding_fn + self.collection_name = name self.vector_store_client = PGVector( embedding_function=self.embeddings, # type: ignore diff --git a/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/weaviate_store.py b/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/weaviate_store.py index f916424b6..1e410c4fb 100644 --- a/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/weaviate_store.py +++ b/packages/dbgpt-ext/src/dbgpt_ext/storage/vector_store/weaviate_store.py @@ -2,10 +2,10 @@ import logging import os +from dataclasses import dataclass, field from typing import List, Optional -from dbgpt._private.pydantic import ConfigDict, Field -from dbgpt.core import Chunk +from dbgpt.core import Chunk, Embeddings from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource from dbgpt.storage.vector_store.base import ( _COMMON_PARAMETERS, @@ -45,20 +45,28 @@ logger = logging.getLogger(__name__) ), ], ) +@dataclass class WeaviateVectorConfig(VectorStoreConfig): """Weaviate vector store config.""" - model_config = ConfigDict(arbitrary_types_allowed=True) - - weaviate_url: str = Field( + weaviate_url: str = field( default=os.getenv("WEAVIATE_URL", None), - description="weaviate url address, if not set, will use the default url.", + metadata={ + "description": "weaviate url address, if not set, " + "will use the default url.", + }, ) - persist_path: str = Field( + persist_path: str = field( default=os.getenv("WEAVIATE_PERSIST_PATH", None), - description="weaviate persist path.", + metadata={ + "description": "weaviate persist path.", + }, ) + def create_store(self, **kwargs) -> "WeaviateStore": + """Create Weaviate Store store.""" + return WeaviateStore(vector_store_config=self, **kwargs) + @register_resource( _("Weaviate Vector Store"), @@ -79,7 +87,12 @@ class WeaviateVectorConfig(VectorStoreConfig): class WeaviateStore(VectorStoreBase): """Weaviate database.""" - def __init__(self, vector_store_config: WeaviateVectorConfig) -> None: + def __init__( + self, + vector_store_config: WeaviateVectorConfig, + name: Optional[str], + embedding_fn: Optional[Embeddings] = None, + ) -> None: """Initialize with Weaviate client.""" try: import weaviate @@ -92,10 +105,10 @@ class WeaviateStore(VectorStoreBase): self._vector_store_config = vector_store_config self.weaviate_url = vector_store_config.weaviate_url - self.embedding = vector_store_config.embedding_fn - self.vector_name = vector_store_config.name + self.embedding = embedding_fn + self.vector_name = name self.persist_dir = os.path.join( - vector_store_config.persist_path, vector_store_config.name + ".vectordb" + vector_store_config.persist_path, name + ".vectordb" ) self.vector_store_client = weaviate.Client(self.weaviate_url) diff --git a/packages/dbgpt-serve/src/dbgpt_serve/datasource/service/db_summary_client.py b/packages/dbgpt-serve/src/dbgpt_serve/datasource/service/db_summary_client.py index 4a9395b49..5c0f1d835 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/datasource/service/db_summary_client.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/datasource/service/db_summary_client.py @@ -8,11 +8,12 @@ from dbgpt.component import SystemApp from dbgpt.core import Embeddings from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory from dbgpt.rag.text_splitter.text_splitter import RDBTextSplitter +from dbgpt.storage.vector_store.base import VectorStoreBase from dbgpt_ext.rag import ChunkParameters from dbgpt_ext.rag.summary.gdbms_db_summary import GdbmsSummary from dbgpt_ext.rag.summary.rdbms_db_summary import RdbmsSummary from dbgpt_serve.datasource.manages import ConnectorManager -from dbgpt_serve.rag.connector import VectorStoreConnector +from dbgpt_serve.rag.storage_manager import StorageManager logger = logging.getLogger(__name__) @@ -68,8 +69,8 @@ class DBSummaryClient: ) retriever = DBSchemaRetriever( top_k=topk, - table_vector_store_connector=table_vector_connector.index_client, - field_vector_store_connector=field_vector_connector.index_client, + table_vector_store_connector=table_vector_connector, + field_vector_store_connector=field_vector_connector, separator="--table-field-separator--", ) @@ -116,8 +117,8 @@ class DBSummaryClient: ) db_assembler = DBSchemaAssembler.load_from_connection( connector=db_summary_client.db, - table_vector_store_connector=table_vector_connector.index_client, - field_vector_store_connector=field_vector_connector.index_client, + table_vector_store_connector=table_vector_connector, + field_vector_store_connector=field_vector_connector, chunk_parameters=chunk_parameters, max_seq_length=self.app_config.service.web.embedding_model_max_seq_len, ) @@ -157,23 +158,14 @@ class DBSummaryClient: def _get_vector_connector_by_db( self, dbname - ) -> Tuple[VectorStoreConnector, VectorStoreConnector]: - from dbgpt.storage.vector_store.base import VectorStoreConfig - + ) -> Tuple[VectorStoreBase, VectorStoreBase]: vector_store_name = dbname + "_profile" - table_vector_store_config = VectorStoreConfig(name=vector_store_name) - table_vector_connector = VectorStoreConnector.from_default( - self.storage_config.vector.get("type"), - self.embeddings, - vector_store_config=table_vector_store_config, - system_app=self.system_app, + storage_manager = StorageManager.get_instance(self.system_app) + table_vector_store = storage_manager.create_vector_store( + index_name=vector_store_name ) field_vector_store_name = dbname + "_profile_field" - field_vector_store_config = VectorStoreConfig(name=field_vector_store_name) - field_vector_connector = VectorStoreConnector.from_default( - self.storage_config.vector.get("type"), - self.embeddings, - vector_store_config=field_vector_store_config, - system_app=self.system_app, + field_vector_store = storage_manager.create_vector_store( + index_name=field_vector_store_name ) - return table_vector_connector, field_vector_connector + return table_vector_store, field_vector_store diff --git a/packages/dbgpt-serve/src/dbgpt_serve/datasource/service/service.py b/packages/dbgpt-serve/src/dbgpt_serve/datasource/service/service.py index fd5d92180..1ec6b44e7 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/datasource/service/service.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/datasource/service/service.py @@ -10,7 +10,6 @@ from dbgpt.component import ComponentType, SystemApp from dbgpt.core.awel.dag.dag_manager import DAGManager from dbgpt.datasource.parameter import BaseDatasourceParameters from dbgpt.storage.metadata import BaseDao -from dbgpt.storage.vector_store.base import VectorStoreConfig from dbgpt.util.executor_utils import ExecutorFactory from dbgpt_ext.datasource.schema import DBType from dbgpt_serve.core import BaseService, ResourceTypes @@ -19,8 +18,8 @@ from dbgpt_serve.datasource.manages.connect_config_db import ( ConnectConfigDao, ConnectConfigEntity, ) -from dbgpt_serve.rag.connector import VectorStoreConnector +from ...rag.storage_manager import StorageManager from ..api.schemas import ( DatasourceCreateRequest, DatasourceQueryResponse, @@ -93,6 +92,12 @@ class Service( raise ValueError("SYSTEM_APP is not set") return ConnectorManager.get_instance(self._system_app) + @property + def storage_manager(self) -> StorageManager: + if not self._system_app: + raise ValueError("SYSTEM_APP is not set") + return StorageManager.get_instance(self._system_app) + def create( self, request: Union[DatasourceCreateRequest, DatasourceServeRequest] ) -> DatasourceQueryResponse: @@ -229,13 +234,10 @@ class Service( """ db_config = self._dao.get_one({"id": datasource_id}) vector_name = db_config.db_name + "_profile" - vector_store_config = VectorStoreConfig(name=vector_name) - _vector_connector = VectorStoreConnector( - vector_store_type=CFG.VECTOR_STORE_TYPE, - vector_store_config=vector_store_config, - system_app=self._system_app, + vector_connector = self.storage_manager.create_vector_store( + index_name=vector_name ) - _vector_connector.delete_vector_name(vector_name) + vector_connector.delete_vector_name(vector_name) if db_config: self._dao.delete({"id": datasource_id}) return db_config @@ -300,13 +302,10 @@ class Service( """ db_config = self._dao.get_one({"id": datasource_id}) vector_name = db_config.db_name + "_profile" - vector_store_config = VectorStoreConfig(name=vector_name) - _vector_connector = VectorStoreConnector( - vector_store_type=CFG.VECTOR_STORE_TYPE, - vector_store_config=vector_store_config, - system_app=self._system_app, + vector_connector = self.storage_manager.create_vector_store( + index_name=vector_name ) - _vector_connector.delete_vector_name(vector_name) + vector_connector.delete_vector_name(vector_name) self._db_summary_client.db_summary_embedding( db_config.db_name, db_config.db_type ) diff --git a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/service.py b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/service.py index 9d2262ff4..3a1c2a13e 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/service.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/evaluate/service/service.py @@ -15,15 +15,16 @@ from dbgpt.rag.evaluation import RetrieverEvaluator from dbgpt.rag.evaluation.answer import AnswerRelevancyMetric from dbgpt.rag.evaluation.retriever import RetrieverSimilarityMetric from dbgpt.storage.metadata import BaseDao -from dbgpt.storage.vector_store.base import VectorStoreConfig from dbgpt_serve.rag.operators.knowledge_space import SpaceRetrieverOperator from ...agent.agents.controller import multi_agents from ...agent.evaluation.evaluation import AgentEvaluator, AgentOutputOperator from ...core import BaseService from ...prompt.service.service import Service as PromptService -from ...rag.connector import VectorStoreConnector + +# from ...rag.connector import VectorStoreConnector from ...rag.service.service import Service as RagService +from ...rag.storage_manager import StorageManager from ..api.schemas import EvaluateServeRequest, EvaluateServeResponse, EvaluationScene from ..config import SERVE_SERVICE_COMPONENT_NAME, ServeConfig from ..models.models import ServeDao, ServeEntity @@ -64,6 +65,10 @@ class Service(BaseService[ServeEntity, EvaluateServeRequest, EvaluateServeRespon """ self._system_app = system_app + @property + def storage_manager(self): + return StorageManager.get_instance(self._system_app) + @property def dao(self) -> BaseDao[ServeEntity, EvaluateServeRequest, EvaluateServeResponse]: """Returns the internal DAO.""" @@ -104,17 +109,13 @@ class Service(BaseService[ServeEntity, EvaluateServeRequest, EvaluateServeRespon ) embeddings = embedding_factory.create() - config = VectorStoreConfig( - name=scene_value, - embedding_fn=embeddings, - ) space = self.rag_service.get({"space_id": str(scene_value)}) if not space: raise ValueError(f"Space {scene_value} not found") - vector_store_connector = VectorStoreConnector( - vector_store_type=space.vector_type, - vector_store_config=config, - system_app=self._system_app, + storage_connector = self.storage_manager.get_storage_connector( + index_name=space.name, + storage_type=space.vector_type, + llm_model=context.get("llm_model"), ) evaluator = RetrieverEvaluator( operator_cls=SpaceRetrieverOperator, @@ -122,7 +123,7 @@ class Service(BaseService[ServeEntity, EvaluateServeRequest, EvaluateServeRespon operator_kwargs={ "space_id": str(scene_value), "top_k": self._serve_config.similarity_top_k, - "vector_store_connector": vector_store_connector, + "vector_store_connector": storage_connector, }, ) metrics = [] diff --git a/packages/dbgpt-serve/src/dbgpt_serve/rag/connector.py b/packages/dbgpt-serve/src/dbgpt_serve/rag/connector.py index 3f4b46711..679d1642d 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/rag/connector.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/rag/connector.py @@ -11,6 +11,7 @@ from dbgpt.core import Chunk, Embeddings from dbgpt.storage.base import IndexStoreConfig from dbgpt.storage.vector_store.base import VectorStoreConfig from dbgpt.storage.vector_store.filters import MetadataFilters +from dbgpt_ext.storage import __document_store__ as supported_full_tet_list from dbgpt_ext.storage import __knowledge_graph__ as supported_kg_store_list from dbgpt_ext.storage import __vector_store__ as supported_vector_store_list @@ -70,19 +71,14 @@ class VectorStoreConnector: self._vector_store_type = vector_store_type self._embeddings = vector_store_config.embedding_fn - config_dict = {} - storage_config = self.app_config.rag.storage - if vector_store_type in supported_vector_store_list: - config_dict = storage_config.vector - elif vector_store_type in supported_kg_store_list: - config_dict = storage_config.graph + config_dict = self._adapt_storage_config(vector_store_type).to_dict() for key in vector_store_config.to_dict().keys(): value = getattr(vector_store_config, key) if value is not None: config_dict[key] = value - for key, value in vector_store_config.model_extra.items(): - if value is not None: - config_dict[key] = value + # for key, value in vector_store_config.model_extra.items(): + # if value is not None: + # config_dict[key] = value config = self.config_class(**config_dict) try: if vector_store_type in pools and config.name in pools[vector_store_type]: @@ -97,10 +93,10 @@ class VectorStoreConnector: def _rewrite_index_store_type(self, index_store_type): # Rewrite Knowledge Graph Type if self.app_config.rag.storage.graph: - graph_dict = self.app_config.rag.storage.graph + graph_config = self.app_config.rag.storage.graph if ( - isinstance(graph_dict, dict) - and graph_dict.get("enable_summary", "false").lower() == "true" + hasattr(graph_config, "enable_summary") + and graph_config.enable_summary.lower() == "true" ): if index_store_type == "KnowledgeGraph": return "CommunitySummaryKnowledgeGraph" @@ -286,3 +282,14 @@ class VectorStoreConnector: for cls_name in rag_storages: store_cls, config_cls = _select_rag_storage(cls_name) connector[cls_name] = (store_cls, config_cls) + + def _adapt_storage_config(self, vector_store_type): + """Adapt storage config.""" + storage_config = self.app_config.rag.storage + if vector_store_type in supported_vector_store_list: + return storage_config.vector + elif vector_store_type in supported_kg_store_list: + return storage_config.graph + elif vector_store_type in supported_full_tet_list: + return storage_config.full_text + raise ValueError(f"storage type {vector_store_type} not supported") diff --git a/packages/dbgpt-serve/src/dbgpt_serve/rag/operators/knowledge_space.py b/packages/dbgpt-serve/src/dbgpt_serve/rag/operators/knowledge_space.py index 649dd7100..f7e72f65f 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/rag/operators/knowledge_space.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/rag/operators/knowledge_space.py @@ -100,6 +100,7 @@ class SpaceRetrieverOperator(RetrieverOperator[IN, OUT]): space_retriever = KnowledgeSpaceRetriever( space_id=self._space_id, top_k=self._top_k, + system_app=self._service.system_app, ) if isinstance(query, str): candidates = space_retriever.retrieve_with_scores(query, self._recall_score) diff --git a/packages/dbgpt-serve/src/dbgpt_serve/rag/retriever/knowledge_space.py b/packages/dbgpt-serve/src/dbgpt_serve/rag/retriever/knowledge_space.py index 828c0db6f..07374fab9 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/rag/retriever/knowledge_space.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/rag/retriever/knowledge_space.py @@ -2,17 +2,15 @@ from typing import List, Optional from dbgpt.component import ComponentType, SystemApp from dbgpt.core import Chunk -from dbgpt.model import DefaultLLMClient -from dbgpt.model.cluster import WorkerManagerFactory from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory from dbgpt.rag.retriever import EmbeddingRetriever, QueryRewrite, Ranker from dbgpt.rag.retriever.base import BaseRetriever from dbgpt.storage.vector_store.filters import MetadataFilters from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async -from dbgpt_serve.rag.connector import VectorStoreConnector from dbgpt_serve.rag.models.models import KnowledgeSpaceDao from dbgpt_serve.rag.retriever.qa_retriever import QARetriever from dbgpt_serve.rag.retriever.retriever_chain import RetrieverChain +from dbgpt_serve.rag.storage_manager import StorageManager class KnowledgeSpaceRetriever(BaseRetriever): @@ -49,7 +47,6 @@ class KnowledgeSpaceRetriever(BaseRetriever): "embedding_factory", EmbeddingFactory ) embedding_fn = embedding_factory.create() - from dbgpt.storage.vector_store.base import VectorStoreConfig space_dao = KnowledgeSpaceDao() space = space_dao.get_one({"id": space_id}) @@ -57,21 +54,10 @@ class KnowledgeSpaceRetriever(BaseRetriever): space = space_dao.get_one({"name": space_id}) if space is None: raise ValueError(f"Knowledge space {space_id} not found") - worker_manager = self._system_app.get_component( - ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory - ).create() - llm_client = DefaultLLMClient(worker_manager=worker_manager) - config = VectorStoreConfig( - name=space.name, - embedding_fn=embedding_fn, - llm_client=llm_client, - llm_model=self._llm_model, - ) - - self._vector_store_connector = VectorStoreConnector( - vector_store_type=space.vector_type, - vector_store_config=config, - system_app=self._system_app, + storage_connector = self.storage_manager.get_storage_connector( + space.name, + space.vector_type, + self._llm_model, ) self._executor = self._system_app.get_component( ComponentType.EXECUTOR_DEFAULT, ExecutorFactory @@ -86,7 +72,7 @@ class KnowledgeSpaceRetriever(BaseRetriever): system_app=system_app, ), EmbeddingRetriever( - index_store=self._vector_store_connector.index_client, + index_store=storage_connector, top_k=self._top_k, query_rewrite=self._query_rewrite, rerank=self._rerank, @@ -95,6 +81,10 @@ class KnowledgeSpaceRetriever(BaseRetriever): executor=self._executor, ) + @property + def storage_manager(self): + return StorageManager.get_instance(self._system_app) + def _retrieve( self, query: str, filters: Optional[MetadataFilters] = None ) -> List[Chunk]: diff --git a/packages/dbgpt-serve/src/dbgpt_serve/rag/service/service.py b/packages/dbgpt-serve/src/dbgpt_serve/rag/service/service.py index b8e38b8dc..8cd1475ba 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/rag/service/service.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/rag/service/service.py @@ -18,13 +18,11 @@ from dbgpt.configs.model_config import ( from dbgpt.core import Chunk, LLMClient from dbgpt.model import DefaultLLMClient from dbgpt.model.cluster import WorkerManagerFactory -from dbgpt.rag.embedding import EmbeddingFactory from dbgpt.rag.embedding.embedding_factory import RerankEmbeddingFactory from dbgpt.rag.knowledge import ChunkStrategy, KnowledgeType from dbgpt.rag.retriever.rerank import RerankEmbeddingsRanker from dbgpt.storage.metadata import BaseDao from dbgpt.storage.metadata._base_dao import QUERY_SPEC -from dbgpt.storage.vector_store.base import VectorStoreConfig from dbgpt.util.pagination_utils import PaginationResult from dbgpt.util.string_utils import remove_trailing_punctuation from dbgpt.util.tracer import root_tracer, trace @@ -33,7 +31,6 @@ from dbgpt_ext.rag.assembler import EmbeddingAssembler from dbgpt_ext.rag.chunk_manager import ChunkParameters from dbgpt_ext.rag.knowledge import KnowledgeFactory from dbgpt_serve.core import BaseService -from dbgpt_serve.rag.connector import VectorStoreConnector from ..api.schemas import ( ChunkServeRequest, @@ -52,6 +49,7 @@ from ..models.document_db import ( ) from ..models.models import KnowledgeSpaceDao, KnowledgeSpaceEntity from ..retriever.knowledge_space import KnowledgeSpaceRetriever +from ..storage_manager import StorageManager logger = logging.getLogger(__name__) @@ -96,6 +94,10 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes self._chunk_dao = self._chunk_dao or DocumentChunkDao() self._system_app = system_app + @property + def storage_manager(self): + return StorageManager.get_instance(self._system_app) + @property def dao( self, @@ -286,14 +288,7 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes space = self.get(query_request) if space is None: raise HTTPException(status_code=400, detail=f"Space {space_id} not found") - config = VectorStoreConfig( - name=space.name, llm_client=self.llm_client, model_name=None - ) - vector_store_connector = VectorStoreConnector( - vector_store_type=space.vector_type, - vector_store_config=config, - system_app=self._system_app, - ) + vector_store_connector = self.create_vector_store(space.name) # delete vectors vector_store_connector.delete_vector_name(space.name) document_query = KnowledgeDocumentEntity(space=space.name) @@ -360,14 +355,7 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes vector_ids = docuemnt.vector_ids if vector_ids is not None: - config = VectorStoreConfig( - name=space.name, llm_client=self.llm_client, model_name=None - ) - vector_store_connector = VectorStoreConnector( - vector_store_type=space.vector_type, - vector_store_config=config, - system_app=self._system_app, - ) + vector_store_connector = self.create_vector_store(space.name) # delete vector by ids vector_store_connector.delete_by_ids(vector_ids) # delete chunks @@ -498,25 +486,9 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes chunk_parameters: ChunkParameters, ) -> None: """sync knowledge document chunk into vector store""" - embedding_factory = self._system_app.get_component( - "embedding_factory", EmbeddingFactory - ) - embedding_fn = embedding_factory.create() - from dbgpt.storage.vector_store.base import VectorStoreConfig - space = self.get({"id": space_id}) - config = VectorStoreConfig( - name=space.name, - embedding_fn=embedding_fn, - max_chunks_once_load=self._serve_config.max_chunks_once_load, - max_threads=self._serve_config.max_threads, - llm_client=self.llm_client, - model_name=None, - ) - vector_store_connector = VectorStoreConnector( - vector_store_type=space.vector_type, - vector_store_config=config, - system_app=self._system_app, + storage_connector = self.storage_manager.get_storage_connector( + space.name, space.vector_type ) knowledge = None if not space.domain_type or ( @@ -531,17 +503,17 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes doc.gmt_modified = datetime.now() self._document_dao.update_knowledge_document(doc) asyncio.create_task( - self.async_doc_embedding( - knowledge, chunk_parameters, vector_store_connector, doc, space + self.async_doc_process( + knowledge, chunk_parameters, storage_connector, doc, space ) ) logger.info(f"begin save document chunks, doc:{doc.doc_name}") - @trace("async_doc_embedding") - async def async_doc_embedding( - self, knowledge, chunk_parameters, vector_store_connector, doc, space + @trace("async_doc_process") + async def async_doc_process( + self, knowledge, chunk_parameters, storage_connector, doc, space ): - """async document embedding into vector db + """async document process into storage Args: - knowledge: Knowledge - chunk_parameters: ChunkParameters @@ -572,13 +544,11 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes doc.chunk_size = len(chunk_docs) vector_ids = [chunk.chunk_id for chunk in chunk_docs] else: - max_chunks_once_load = ( - vector_store_connector._index_store_config.max_chunks_once_load - ) - max_threads = vector_store_connector._index_store_config.max_threads + max_chunks_once_load = self.config.max_chunks_once_load + max_threads = self.config.max_threads assembler = await EmbeddingAssembler.aload_from_knowledge( knowledge=knowledge, - index_store=vector_store_connector.index_client, + index_store=storage_connector, chunk_parameters=chunk_parameters, ) diff --git a/packages/dbgpt-serve/src/dbgpt_serve/rag/storage_manager.py b/packages/dbgpt-serve/src/dbgpt_serve/rag/storage_manager.py new file mode 100644 index 000000000..ffe122676 --- /dev/null +++ b/packages/dbgpt-serve/src/dbgpt_serve/rag/storage_manager.py @@ -0,0 +1,152 @@ +"""RAG STORAGE MANAGER manager.""" + +from typing import List, Optional, Type + +from dbgpt import BaseComponent +from dbgpt.component import ComponentType, SystemApp +from dbgpt.model import DefaultLLMClient +from dbgpt.model.cluster import WorkerManagerFactory +from dbgpt.rag.embedding import EmbeddingFactory +from dbgpt.storage.base import IndexStoreBase +from dbgpt.storage.full_text.base import FullTextStoreBase +from dbgpt.storage.vector_store.base import VectorStoreBase, VectorStoreConfig +from dbgpt_ext.storage.full_text.elasticsearch import ElasticDocumentStore +from dbgpt_ext.storage.knowledge_graph.knowledge_graph import BuiltinKnowledgeGraph + + +class StorageManager(BaseComponent): + """RAG STORAGE MANAGER manager.""" + + name = ComponentType.RAG_STORAGE_MANAGER + + def __init__(self, system_app: SystemApp): + """Create a new ConnectorManager.""" + self.system_app = system_app + super().__init__(system_app) + + def init_app(self, system_app: SystemApp): + """Init component.""" + self.system_app = system_app + + def storage_config(self): + """Storage config.""" + app_config = self.system_app.config.configs.get("app_config") + return app_config.rag.storage + + def get_storage_connector( + self, index_name: str, storage_type: str, llm_model: Optional[str] = None + ) -> IndexStoreBase: + """Get storage connector.""" + supported_vector_types = self.get_vector_supported_types + storage_config = self.storage_config() + if storage_type in supported_vector_types: + return self.create_vector_store(index_name) + elif storage_type == "KnowledgeGraph": + if not storage_config.graph: + raise ValueError( + "Graph storage is not configured.please check your config." + "reference configs/dbgpt-graphrag.toml" + ) + return self.create_kg_store(index_name, llm_model) + elif storage_type == "FullText": + if not storage_config.full_text: + raise ValueError( + "FullText storage is not configured.please check your config." + "reference configs/dbgpt-bm25-rag.toml" + ) + return self.create_full_text_store(index_name) + else: + raise ValueError(f"Does not support storage type {storage_type}") + + def create_vector_store(self, index_name) -> VectorStoreBase: + """Create vector store.""" + app_config = self.system_app.config.configs.get("app_config") + storage_config = app_config.rag.storage + embedding_factory = self.system_app.get_component( + "embedding_factory", EmbeddingFactory + ) + embedding_fn = embedding_factory.create() + vector_store_config: VectorStoreConfig = storage_config.vector + return vector_store_config.create_store( + name=index_name, embedding_fn=embedding_fn + ) + + def create_kg_store( + self, index_name, llm_model: Optional[str] = None + ) -> BuiltinKnowledgeGraph: + """Create knowledge graph store.""" + app_config = self.system_app.config.configs.get("app_config") + rag_config = app_config.rag + storage_config = app_config.rag.storage + worker_manager = self.system_app.get_component( + ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory + ).create() + llm_client = DefaultLLMClient(worker_manager=worker_manager) + embedding_factory = self.system_app.get_component( + "embedding_factory", EmbeddingFactory + ) + embedding_fn = embedding_factory.create() + if storage_config.graph: + graph_config = storage_config.graph + graph_config.llm_model = llm_model + if hasattr(graph_config, "enable_summary") and graph_config.enable_summary: + from dbgpt_ext.storage.knowledge_graph.community_summary import ( + CommunitySummaryKnowledgeGraph, + ) + + return CommunitySummaryKnowledgeGraph( + config=storage_config.graph, + name=index_name, + llm_client=llm_client, + vector_store_config=storage_config.vector, + kg_extract_top_k=rag_config.kg_extract_top_k, + kg_extract_score_threshold=rag_config.kg_extract_score_threshold, + kg_community_top_k=rag_config.kg_community_top_k, + kg_community_score_threshold=rag_config.kg_community_score_threshold, + kg_triplet_graph_enabled=rag_config.kg_triplet_graph_enabled, + kg_document_graph_enabled=rag_config.kg_document_graph_enabled, + kg_chunk_search_top_k=rag_config.kg_chunk_search_top_k, + kg_extraction_batch_size=rag_config.kg_extraction_batch_size, + kg_community_summary_batch_size=rag_config.kg_community_summary_batch_size, + kg_embedding_batch_size=rag_config.kg_embedding_batch_size, + kg_similarity_top_k=rag_config.kg_similarity_top_k, + kg_similarity_score_threshold=rag_config.kg_similarity_score_threshold, + kg_enable_text_search=rag_config.kg_enable_text_search, + kg_text2gql_model_enabled=rag_config.kg_text2gql_model_enabled, + kg_text2gql_model_name=rag_config.kg_text2gql_model_name, + embedding_fn=embedding_fn, + kg_max_chunks_once_load=rag_config.max_chunks_once_load, + kg_max_threads=rag_config.max_threads, + ) + return BuiltinKnowledgeGraph( + config=storage_config.graph, + name=index_name, + llm_client=llm_client, + ) + + def create_full_text_store(self, index_name) -> FullTextStoreBase: + """Create Full Text store.""" + app_config = self.system_app.config.configs.get("app_config") + rag_config = app_config.rag + storage_config = app_config.rag.storage + return ElasticDocumentStore( + es_config=storage_config.full_text, + name=index_name, + k1=rag_config.bm25_k1, + b=rag_config.bm25_b, + ) + + @property + def get_vector_supported_types(self) -> List[str]: + """Get all supported types.""" + support_types = [] + vector_store_classes = _get_all_subclasses() + for vector_cls in vector_store_classes: + support_types.append(vector_cls.__type__) + return support_types + + +def _get_all_subclasses() -> List[Type[VectorStoreConfig]]: + """Get all subclasses of cls.""" + + return VectorStoreConfig.__subclasses__()