From f1ca8a76adbfe9c54926c624dc3c52977922a200 Mon Sep 17 00:00:00 2001 From: aries_ckt <916701291@qq.com> Date: Wed, 14 Aug 2024 20:40:27 +0800 Subject: [PATCH] fix:space resource error. --- dbgpt/app/scene/chat_knowledge/v1/chat.py | 15 +- dbgpt/core/interface/knowledge.py | 3 +- dbgpt/serve/agent/resource/knowledge.py | 2 +- dbgpt/serve/rag/retriever/knowledge_space.py | 52 +++-- dbgpt/serve/rag/retriever/qa_retriever.py | 218 +++++++++++++++++++ dbgpt/serve/rag/retriever/retriever_chain.py | 78 +++++++ 6 files changed, 347 insertions(+), 21 deletions(-) create mode 100644 dbgpt/serve/rag/retriever/qa_retriever.py create mode 100644 dbgpt/serve/rag/retriever/retriever_chain.py diff --git a/dbgpt/app/scene/chat_knowledge/v1/chat.py b/dbgpt/app/scene/chat_knowledge/v1/chat.py index c80cad423..59f9e5048 100644 --- a/dbgpt/app/scene/chat_knowledge/v1/chat.py +++ b/dbgpt/app/scene/chat_knowledge/v1/chat.py @@ -21,6 +21,7 @@ from dbgpt.core import ( ) from dbgpt.rag.retriever.rerank import RerankEmbeddingsRanker from dbgpt.rag.retriever.rewrite import QueryRewrite +from dbgpt.serve.rag.retriever.knowledge_space import KnowledgeSpaceRetriever from dbgpt.util.tracer import root_tracer, trace CFG = Config() @@ -77,7 +78,6 @@ class ChatKnowledge(BaseChat): ) from dbgpt.serve.rag.models.models import ( KnowledgeSpaceDao, - KnowledgeSpaceEntity, ) from dbgpt.storage.vector_store.base import VectorStoreConfig @@ -113,12 +113,19 @@ class ChatKnowledge(BaseChat): # We use reranker, so if the top_k is less than 20, # we need to set it to 20 retriever_top_k = max(CFG.RERANK_TOP_K, 20) - self.embedding_retriever = EmbeddingRetriever( + # self.embedding_retriever = EmbeddingRetriever( + # top_k=retriever_top_k, + # index_store=vector_store_connector.index_client, + # query_rewrite=query_rewrite, + # rerank=reranker, + # ) + self._space_retriever = KnowledgeSpaceRetriever( + space_id=self.knowledge_space, top_k=retriever_top_k, - index_store=vector_store_connector.index_client, query_rewrite=query_rewrite, rerank=reranker, ) + self.prompt_template.template_is_strict = False self.relations = None self.chunk_dao = DocumentChunkDao() @@ -275,6 +282,6 @@ class ChatKnowledge(BaseChat): with root_tracer.start_span( "execute_similar_search", metadata={"query": query} ): - return await self.embedding_retriever.aretrieve_with_scores( + return await self._space_retriever.aretrieve_with_scores( query, self.recall_score ) diff --git a/dbgpt/core/interface/knowledge.py b/dbgpt/core/interface/knowledge.py index e155e3169..6116439c1 100644 --- a/dbgpt/core/interface/knowledge.py +++ b/dbgpt/core/interface/knowledge.py @@ -2,7 +2,7 @@ import json import uuid -from typing import Any, Dict +from typing import Any, Dict, Optional from dbgpt._private.pydantic import BaseModel, Field, model_to_dict @@ -61,6 +61,7 @@ class Chunk(Document): default="\n", description="Separator between metadata fields when converting to string.", ) + retriever: Optional[str] = Field(default=None, description="retriever name") def to_dict(self, **kwargs: Any) -> Dict[str, Any]: """Convert Chunk to dict.""" diff --git a/dbgpt/serve/agent/resource/knowledge.py b/dbgpt/serve/agent/resource/knowledge.py index e7c11df21..90359be4c 100644 --- a/dbgpt/serve/agent/resource/knowledge.py +++ b/dbgpt/serve/agent/resource/knowledge.py @@ -65,7 +65,7 @@ class KnowledgeSpaceRetrieverResource(RetrieverResource): def __init__(self, name: str, space_name: str, context: Optional[dict] = None): retriever = KnowledgeSpaceRetriever( - space_name=space_name, + space_id=space_name, top_k=context.get("top_k", None) if context else 4, ) super().__init__(name, retriever=retriever) diff --git a/dbgpt/serve/rag/retriever/knowledge_space.py b/dbgpt/serve/rag/retriever/knowledge_space.py index 6711c36db..61f818318 100644 --- a/dbgpt/serve/rag/retriever/knowledge_space.py +++ b/dbgpt/serve/rag/retriever/knowledge_space.py @@ -5,8 +5,12 @@ from dbgpt.component import ComponentType from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG from dbgpt.core import Chunk from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory +from dbgpt.rag.retriever import EmbeddingRetriever, Ranker, QueryRewrite from dbgpt.rag.retriever.base import BaseRetriever from dbgpt.serve.rag.connector import VectorStoreConnector +from dbgpt.serve.rag.models.models import KnowledgeSpaceDao +from dbgpt.serve.rag.retriever.qa_retriever import QARetriever +from dbgpt.serve.rag.retriever.retriever_chain import RetrieverChain from dbgpt.storage.vector_store.filters import MetadataFilters from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async @@ -18,18 +22,24 @@ class KnowledgeSpaceRetriever(BaseRetriever): def __init__( self, - space_name: str = None, + space_id: str = None, top_k: Optional[int] = 4, + query_rewrite: Optional[QueryRewrite] = None, + rerank: Optional[Ranker] = None, ): """ Args: - space_name (str): knowledge space name + space_id (str): knowledge space name top_k (Optional[int]): top k + query_rewrite: (Optional[QueryRewrite]) query rewrite + rerank: (Optional[Ranker]) rerank """ - if space_name is None: - raise ValueError("space_name is required") - self._space_name = space_name + if space_id is None: + raise ValueError("space_id is required") + self._space_id = space_id self._top_k = top_k + self._query_rewrite = query_rewrite + self._rerank = rerank embedding_factory = CFG.SYSTEM_APP.get_component( "embedding_factory", EmbeddingFactory ) @@ -37,8 +47,9 @@ class KnowledgeSpaceRetriever(BaseRetriever): model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL] ) from dbgpt.storage.vector_store.base import VectorStoreConfig - - config = VectorStoreConfig(name=self._space_name, embedding_fn=embedding_fn) + space_dao = KnowledgeSpaceDao() + space = space_dao.get_one({"id": space_id}) + config = VectorStoreConfig(name=space.name, embedding_fn=embedding_fn) self._vector_store_connector = VectorStoreConnector( vector_store_type=CFG.VECTOR_STORE_TYPE, vector_store_config=config, @@ -47,6 +58,20 @@ class KnowledgeSpaceRetriever(BaseRetriever): ComponentType.EXECUTOR_DEFAULT, ExecutorFactory ).create() + self._retriever_chain = RetrieverChain(retrievers=[ + QARetriever(space_id=space_id, + top_k=top_k, + embedding_fn=embedding_fn + ), + EmbeddingRetriever( + index_store=self._vector_store_connector.index_client, + top_k=top_k, + query_rewrite=self._query_rewrite, + rerank=self._rerank + ) + ], executor=self._executor + ) + def _retrieve( self, query: str, filters: Optional[MetadataFilters] = None ) -> List[Chunk]: @@ -59,8 +84,8 @@ class KnowledgeSpaceRetriever(BaseRetriever): Return: List[Chunk]: list of chunks """ - candidates = self._vector_store_connector.similar_search( - doc=query, topk=self._top_k, filters=filters + candidates = self._retriever_chain.retrieve( + query=query, filters=filters ) return candidates @@ -80,13 +105,10 @@ class KnowledgeSpaceRetriever(BaseRetriever): Return: List[Chunk]: list of chunks with score """ - candidates_with_score = self._vector_store_connector.similar_search_with_scores( - doc=query, - topk=self._top_k, - score_threshold=score_threshold, - filters=filters, + candidates_with_scores = self._retriever_chain.retrieve_with_scores( + query, score_threshold, filters ) - return candidates_with_score + return candidates_with_scores async def _aretrieve( self, query: str, filters: Optional[MetadataFilters] = None diff --git a/dbgpt/serve/rag/retriever/qa_retriever.py b/dbgpt/serve/rag/retriever/qa_retriever.py new file mode 100644 index 000000000..21ba2f704 --- /dev/null +++ b/dbgpt/serve/rag/retriever/qa_retriever.py @@ -0,0 +1,218 @@ +import ast +import json +import logging +from typing import List, Optional, Any + +from dbgpt._private.config import Config +from dbgpt.app.knowledge.chunk_db import DocumentChunkDao, DocumentChunkEntity +from dbgpt.app.knowledge.document_db import KnowledgeDocumentDao + +from dbgpt.component import ComponentType +from dbgpt.core import Chunk +from dbgpt.rag.retriever.base import BaseRetriever +from dbgpt.serve.rag.models.models import KnowledgeSpaceDao, KnowledgeSpaceEntity +from dbgpt.storage.vector_store.filters import MetadataFilters +from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async +from dbgpt.util.similarity_util import calculate_cosine_similarity +from dbgpt.util.string_utils import remove_trailing_punctuation + +CFG = Config() +CHUNK_PAGE_SIZE = 1000 +logger = logging.getLogger(__name__) + + +class QARetriever(BaseRetriever): + """Document QA retriever.""" + + def __init__( + self, + space_id: str = None, + top_k: Optional[int] = 4, + embedding_fn: Optional[Any] = 4, + lambda_value: Optional[float] = 1e-5, + ): + """ + Args: + space_id (str): knowledge space name + top_k (Optional[int]): top k + """ + if space_id is None: + raise ValueError("space_id is required") + self._top_k = top_k + self._lambda_value = lambda_value + self._space_dao = KnowledgeSpaceDao() + self._document_dao = KnowledgeDocumentDao() + self._chunk_dao = DocumentChunkDao() + self._embedding_fn = embedding_fn + + space = self._space_dao.get_one( + {"id": space_id} + ) + if not space: + raise ValueError("space not found") + self.documents = self._document_dao.get_list({"space": space.name}) + self._executor = CFG.SYSTEM_APP.get_component( + ComponentType.EXECUTOR_DEFAULT, ExecutorFactory + ).create() + + def _retrieve( + self, query: str, filters: Optional[MetadataFilters] = None + ) -> List[Chunk]: + """Retrieve knowledge chunks. + Args: + query (str): query text + filters: (Optional[MetadataFilters]) metadata filters. + Return: + List[Chunk]: list of chunks + """ + query = remove_trailing_punctuation(query) + candidate_results = [] + for doc in self.documents: + if doc.questions: + questions = json.loads(doc.questions) + if query in questions: + chunks = self._chunk_dao.get_document_chunks( + DocumentChunkEntity( + document_id=doc.id + ), + page_size=CHUNK_PAGE_SIZE + ) + candidates = [ + Chunk(content=chunk.content, + metadata=ast.literal_eval(chunk.meta_info), + retriever=self.name(), + score=0.0) + for chunk in chunks + ] + candidate_results.extend( + self._cosine_similarity_rerank(candidates, query) + ) + return candidate_results + + def _retrieve_with_score( + self, + query: str, + score_threshold: float, + filters: Optional[MetadataFilters] = None, + lambda_value: Optional[float] = None, + ) -> List[Chunk]: + """Retrieve knowledge chunks with score. + Args: + query (str): query text + score_threshold (float): score threshold + filters: (Optional[MetadataFilters]) metadata filters. + Return: + List[Chunk]: list of chunks with score + """ + query = remove_trailing_punctuation(query) + candidate_results = [] + doc_ids = [doc.id for doc in self.documents] + query_param = DocumentChunkEntity() + chunks = self._chunk_dao.get_chunks_with_questions( + query=query_param, + document_ids=doc_ids + ) + for chunk in chunks: + if chunk.questions: + questions = json.loads(chunk.questions) + if query in questions: + logger.info(f"qa chunk hit:{chunk}, question:{query}") + candidate_results.append( + Chunk(content=chunk.content, + chunk_id=str(chunk.id), + metadata={ + "prop_field": ast.literal_eval(chunk.meta_info) + }, + retriever=self.name(), + score=1.0 + ) + ) + if len(candidate_results) > 0: + return self._cosine_similarity_rerank(candidate_results, query) + + for doc in self.documents: + if doc.questions: + questions = json.loads(doc.questions) + if query in questions: + logger.info(f"qa document hit:{doc}, question:{query}") + chunks = self._chunk_dao.get_document_chunks( + DocumentChunkEntity(document_id=doc.id), + page_size=CHUNK_PAGE_SIZE + ) + candidates_with_scores = [ + Chunk(content=chunk.content, + chunk_id=str(chunk.id), + metadata={ + "prop_field": ast.literal_eval(chunk.meta_info) + }, + retriever=self.name(), + score=1.0) + for chunk in chunks + ] + candidate_results.extend( + self._cosine_similarity_rerank(candidates_with_scores, query) + ) + return candidate_results + + async def _aretrieve( + self, query: str, filters: Optional[MetadataFilters] = None + ) -> List[Chunk]: + """Retrieve knowledge chunks. + Args: + query (str): query text + filters: (Optional[MetadataFilters]) metadata filters. + Return: + List[Chunk]: list of chunks + """ + candidates = await blocking_func_to_async( + self._executor, self._retrieve, query, filters + ) + return candidates + + async def _aretrieve_with_score( + self, + query: str, + score_threshold: float, + filters: Optional[MetadataFilters] = None, + ) -> List[Chunk]: + """Retrieve knowledge chunks with score. + Args: + query (str): query text + score_threshold (float): score threshold + filters: (Optional[MetadataFilters]) metadata filters. + Return: + List[Chunk]: list of chunks with score + """ + candidates_with_score = await blocking_func_to_async( + self._executor, self._retrieve_with_score, query, score_threshold, filters + ) + return candidates_with_score + + def _cosine_similarity_rerank(self, candidates_with_scores: List[Chunk] + , query: str) -> List[Chunk]: + """Rerank candidates using cosine similarity.""" + if len(candidates_with_scores) > self._top_k: + for candidate in candidates_with_scores: + similarity = calculate_cosine_similarity( + embeddings=self._embedding_fn, + prediction=query, + contexts=[candidate.content] + ) + score = float(similarity.mean()) + candidate.score = score + candidates_with_scores.sort(key=lambda x: x.score, reverse=True) + candidates_with_scores = candidates_with_scores[: self._top_k] + candidates_with_scores = [ + Chunk(content=candidate.content, + chunk_id=candidate.chunk_id, + metadata=candidate.metadata, + retriever=self.name(), + score=1.0) + for candidate in candidates_with_scores + ] + return candidates_with_scores + + @classmethod + def name(self): + """Return retriever name.""" + return "qa_retriever" diff --git a/dbgpt/serve/rag/retriever/retriever_chain.py b/dbgpt/serve/rag/retriever/retriever_chain.py new file mode 100644 index 000000000..18918198a --- /dev/null +++ b/dbgpt/serve/rag/retriever/retriever_chain.py @@ -0,0 +1,78 @@ +from concurrent.futures import ThreadPoolExecutor, Executor +from typing import Optional, List + +from dbgpt.core import Chunk +from dbgpt.rag.retriever.base import BaseRetriever +from dbgpt.storage.vector_store.filters import MetadataFilters +from dbgpt.util.executor_utils import blocking_func_to_async + + +class RetrieverChain(BaseRetriever): + """Retriever chain class.""" + + def __init__(self, retrievers: Optional[List[BaseRetriever]] = None, + executor: Optional[Executor] = None): + """Create retriever chain instance.""" + self._retrievers = retrievers or [] + self._executor = executor or ThreadPoolExecutor() + + def _retrieve(self, query: str, filters: Optional[MetadataFilters] = None) -> List[ + Chunk]: + """Retrieve knowledge chunks. + Args: + query (str): query text + filters: (Optional[MetadataFilters]) metadata filters. + Return: + List[Chunk]: list of chunks + """ + for retriever in self._retrievers: + candidates = retriever.retrieve( + query, filters + ) + if candidates: + return candidates + return [] + + async def _aretrieve(self, query: str, filters: Optional[MetadataFilters] = None) -> \ + List[Chunk]: + """Retrieve knowledge chunks. + Args: + query (str): query text + filters: (Optional[MetadataFilters]) metadata filters. + Return: + List[Chunk]: list of chunks + """ + candidates = await blocking_func_to_async( + self._executor, self._retrieve, query, filters + ) + return candidates + + def _retrieve_with_score(self, query: str, score_threshold: float, filters: Optional[MetadataFilters] = None) -> List[Chunk]: + """Retrieve knowledge chunks. + Args: + query (str): query text + filters: (Optional[MetadataFilters]) metadata filters. + Return: + List[Chunk]: list of chunks + """ + for retriever in self._retrievers: + candidates_with_scores = retriever.retrieve_with_scores( + query=query, score_threshold=score_threshold, filters=filters + ) + if candidates_with_scores: + return candidates_with_scores + return [] + + async def _aretrieve_with_score(self, query: str, score_threshold: float, filters: Optional[MetadataFilters] = None) -> List[Chunk]: + """Retrieve knowledge chunks with score. + Args: + query (str): query text + score_threshold (float): score threshold + filters: (Optional[MetadataFilters]) metadata filters. + Return: + List[Chunk]: list of chunks with score + """ + candidates_with_score = await blocking_func_to_async( + self._executor, self._retrieve_with_score, query, score_threshold, filters + ) + return candidates_with_score \ No newline at end of file