diff --git a/packages/dbgpt-serve/src/dbgpt_serve/rag/storage_manager.py b/packages/dbgpt-serve/src/dbgpt_serve/rag/storage_manager.py index 5f9ed13df..d38df364d 100644 --- a/packages/dbgpt-serve/src/dbgpt_serve/rag/storage_manager.py +++ b/packages/dbgpt-serve/src/dbgpt_serve/rag/storage_manager.py @@ -1,5 +1,6 @@ """RAG STORAGE MANAGER manager.""" +import threading from typing import List, Optional, Type from dbgpt import BaseComponent @@ -22,6 +23,8 @@ class StorageManager(BaseComponent): def __init__(self, system_app: SystemApp): """Create a new ConnectorManager.""" self.system_app = system_app + self._store_cache = {} + self._cache_lock = threading.Lock() super().__init__(system_app) def init_app(self, system_app: SystemApp): @@ -62,17 +65,22 @@ class StorageManager(BaseComponent): """Create vector store.""" app_config = self.system_app.config.configs.get("app_config") storage_config = app_config.rag.storage - embedding_factory = self.system_app.get_component( - "embedding_factory", EmbeddingFactory - ) - embedding_fn = embedding_factory.create() - vector_store_config: VectorStoreConfig = storage_config.vector - return vector_store_config.create_store( - name=index_name, - embedding_fn=embedding_fn, - max_chunks_once_load=vector_store_config.max_chunks_once_load, - max_threads=vector_store_config.max_threads, - ) + if index_name in self._store_cache: + return self._store_cache[index_name] + with self._cache_lock: + embedding_factory = self.system_app.get_component( + "embedding_factory", EmbeddingFactory + ) + embedding_fn = embedding_factory.create() + vector_store_config: VectorStoreConfig = storage_config.vector + new_store = vector_store_config.create_store( + name=index_name, + embedding_fn=embedding_fn, + max_chunks_once_load=vector_store_config.max_chunks_once_load, + max_threads=vector_store_config.max_threads, + ) + self._store_cache[index_name] = new_store + return new_store def create_kg_store( self, index_name, llm_model: Optional[str] = None @@ -81,63 +89,72 @@ class StorageManager(BaseComponent): app_config = self.system_app.config.configs.get("app_config") rag_config = app_config.rag storage_config = app_config.rag.storage - worker_manager = self.system_app.get_component( - ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory - ).create() - llm_client = DefaultLLMClient(worker_manager=worker_manager) - embedding_factory = self.system_app.get_component( - "embedding_factory", EmbeddingFactory - ) - embedding_fn = embedding_factory.create() - if storage_config.graph: - graph_config = storage_config.graph - graph_config.llm_model = llm_model - if hasattr(graph_config, "enable_summary") and graph_config.enable_summary: - from dbgpt_ext.storage.knowledge_graph.community_summary import ( - CommunitySummaryKnowledgeGraph, - ) + if index_name in self._store_cache: + return self._store_cache[index_name] + with self._cache_lock: + worker_manager = self.system_app.get_component( + ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory + ).create() + llm_client = DefaultLLMClient(worker_manager=worker_manager) + embedding_factory = self.system_app.get_component( + "embedding_factory", EmbeddingFactory + ) + embedding_fn = embedding_factory.create() + if storage_config.graph: + graph_config = storage_config.graph + graph_config.llm_model = llm_model + if ( + hasattr(graph_config, "enable_summary") + and graph_config.enable_summary + ): + from dbgpt_ext.storage.knowledge_graph.community_summary import ( + CommunitySummaryKnowledgeGraph, + ) - return CommunitySummaryKnowledgeGraph( - config=storage_config.graph, - name=index_name, - llm_client=llm_client, - vector_store_config=storage_config.vector, - kg_extract_top_k=rag_config.kg_extract_top_k, - kg_extract_score_threshold=rag_config.kg_extract_score_threshold, - kg_community_top_k=rag_config.kg_community_top_k, - kg_community_score_threshold=rag_config.kg_community_score_threshold, - kg_triplet_graph_enabled=rag_config.kg_triplet_graph_enabled, - kg_document_graph_enabled=rag_config.kg_document_graph_enabled, - kg_chunk_search_top_k=rag_config.kg_chunk_search_top_k, - kg_extraction_batch_size=rag_config.kg_extraction_batch_size, - kg_community_summary_batch_size=rag_config.kg_community_summary_batch_size, - kg_embedding_batch_size=rag_config.kg_embedding_batch_size, - kg_similarity_top_k=rag_config.kg_similarity_top_k, - kg_similarity_score_threshold=rag_config.kg_similarity_score_threshold, - kg_enable_text_search=rag_config.kg_enable_text_search, - kg_text2gql_model_enabled=rag_config.kg_text2gql_model_enabled, - kg_text2gql_model_name=rag_config.kg_text2gql_model_name, - embedding_fn=embedding_fn, - kg_max_chunks_once_load=rag_config.max_chunks_once_load, - kg_max_threads=rag_config.max_threads, - ) - return BuiltinKnowledgeGraph( - config=storage_config.graph, - name=index_name, - llm_client=llm_client, - ) + return CommunitySummaryKnowledgeGraph( + config=storage_config.graph, + name=index_name, + llm_client=llm_client, + vector_store_config=storage_config.vector, + kg_extract_top_k=rag_config.kg_extract_top_k, + kg_extract_score_threshold=rag_config.kg_extract_score_threshold, + kg_community_top_k=rag_config.kg_community_top_k, + kg_community_score_threshold=rag_config.kg_community_score_threshold, + kg_triplet_graph_enabled=rag_config.kg_triplet_graph_enabled, + kg_document_graph_enabled=rag_config.kg_document_graph_enabled, + kg_chunk_search_top_k=rag_config.kg_chunk_search_top_k, + kg_extraction_batch_size=rag_config.kg_extraction_batch_size, + kg_community_summary_batch_size=rag_config.kg_community_summary_batch_size, + kg_embedding_batch_size=rag_config.kg_embedding_batch_size, + kg_similarity_top_k=rag_config.kg_similarity_top_k, + kg_similarity_score_threshold=rag_config.kg_similarity_score_threshold, + kg_enable_text_search=rag_config.kg_enable_text_search, + kg_text2gql_model_enabled=rag_config.kg_text2gql_model_enabled, + kg_text2gql_model_name=rag_config.kg_text2gql_model_name, + embedding_fn=embedding_fn, + kg_max_chunks_once_load=rag_config.max_chunks_once_load, + kg_max_threads=rag_config.max_threads, + ) + return BuiltinKnowledgeGraph( + config=storage_config.graph, + name=index_name, + llm_client=llm_client, + ) def create_full_text_store(self, index_name) -> FullTextStoreBase: """Create Full Text store.""" app_config = self.system_app.config.configs.get("app_config") rag_config = app_config.rag storage_config = app_config.rag.storage - return ElasticDocumentStore( - es_config=storage_config.full_text, - name=index_name, - k1=rag_config.bm25_k1, - b=rag_config.bm25_b, - ) + if index_name in self._store_cache: + return self._store_cache[index_name] + with self._cache_lock: + return ElasticDocumentStore( + es_config=storage_config.full_text, + name=index_name, + k1=rag_config.bm25_k1, + b=rag_config.bm25_b, + ) @property def get_vector_supported_types(self) -> List[str]: