fix:space resource error.

This commit is contained in:
aries_ckt
2024-08-14 20:40:27 +08:00
parent 1821f44c13
commit f1ca8a76ad
6 changed files with 347 additions and 21 deletions

View File

@@ -21,6 +21,7 @@ from dbgpt.core import (
) )
from dbgpt.rag.retriever.rerank import RerankEmbeddingsRanker from dbgpt.rag.retriever.rerank import RerankEmbeddingsRanker
from dbgpt.rag.retriever.rewrite import QueryRewrite from dbgpt.rag.retriever.rewrite import QueryRewrite
from dbgpt.serve.rag.retriever.knowledge_space import KnowledgeSpaceRetriever
from dbgpt.util.tracer import root_tracer, trace from dbgpt.util.tracer import root_tracer, trace
CFG = Config() CFG = Config()
@@ -77,7 +78,6 @@ class ChatKnowledge(BaseChat):
) )
from dbgpt.serve.rag.models.models import ( from dbgpt.serve.rag.models.models import (
KnowledgeSpaceDao, KnowledgeSpaceDao,
KnowledgeSpaceEntity,
) )
from dbgpt.storage.vector_store.base import VectorStoreConfig from dbgpt.storage.vector_store.base import VectorStoreConfig
@@ -113,12 +113,19 @@ class ChatKnowledge(BaseChat):
# We use reranker, so if the top_k is less than 20, # We use reranker, so if the top_k is less than 20,
# we need to set it to 20 # we need to set it to 20
retriever_top_k = max(CFG.RERANK_TOP_K, 20) retriever_top_k = max(CFG.RERANK_TOP_K, 20)
self.embedding_retriever = EmbeddingRetriever( # self.embedding_retriever = EmbeddingRetriever(
# top_k=retriever_top_k,
# index_store=vector_store_connector.index_client,
# query_rewrite=query_rewrite,
# rerank=reranker,
# )
self._space_retriever = KnowledgeSpaceRetriever(
space_id=self.knowledge_space,
top_k=retriever_top_k, top_k=retriever_top_k,
index_store=vector_store_connector.index_client,
query_rewrite=query_rewrite, query_rewrite=query_rewrite,
rerank=reranker, rerank=reranker,
) )
self.prompt_template.template_is_strict = False self.prompt_template.template_is_strict = False
self.relations = None self.relations = None
self.chunk_dao = DocumentChunkDao() self.chunk_dao = DocumentChunkDao()
@@ -275,6 +282,6 @@ class ChatKnowledge(BaseChat):
with root_tracer.start_span( with root_tracer.start_span(
"execute_similar_search", metadata={"query": query} "execute_similar_search", metadata={"query": query}
): ):
return await self.embedding_retriever.aretrieve_with_scores( return await self._space_retriever.aretrieve_with_scores(
query, self.recall_score query, self.recall_score
) )

View File

@@ -2,7 +2,7 @@
import json import json
import uuid import uuid
from typing import Any, Dict from typing import Any, Dict, Optional
from dbgpt._private.pydantic import BaseModel, Field, model_to_dict from dbgpt._private.pydantic import BaseModel, Field, model_to_dict
@@ -61,6 +61,7 @@ class Chunk(Document):
default="\n", default="\n",
description="Separator between metadata fields when converting to string.", description="Separator between metadata fields when converting to string.",
) )
retriever: Optional[str] = Field(default=None, description="retriever name")
def to_dict(self, **kwargs: Any) -> Dict[str, Any]: def to_dict(self, **kwargs: Any) -> Dict[str, Any]:
"""Convert Chunk to dict.""" """Convert Chunk to dict."""

View File

@@ -65,7 +65,7 @@ class KnowledgeSpaceRetrieverResource(RetrieverResource):
def __init__(self, name: str, space_name: str, context: Optional[dict] = None): def __init__(self, name: str, space_name: str, context: Optional[dict] = None):
retriever = KnowledgeSpaceRetriever( retriever = KnowledgeSpaceRetriever(
space_name=space_name, space_id=space_name,
top_k=context.get("top_k", None) if context else 4, top_k=context.get("top_k", None) if context else 4,
) )
super().__init__(name, retriever=retriever) super().__init__(name, retriever=retriever)

View File

@@ -5,8 +5,12 @@ from dbgpt.component import ComponentType
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
from dbgpt.core import Chunk from dbgpt.core import Chunk
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
from dbgpt.rag.retriever import EmbeddingRetriever, Ranker, QueryRewrite
from dbgpt.rag.retriever.base import BaseRetriever from dbgpt.rag.retriever.base import BaseRetriever
from dbgpt.serve.rag.connector import VectorStoreConnector from dbgpt.serve.rag.connector import VectorStoreConnector
from dbgpt.serve.rag.models.models import KnowledgeSpaceDao
from dbgpt.serve.rag.retriever.qa_retriever import QARetriever
from dbgpt.serve.rag.retriever.retriever_chain import RetrieverChain
from dbgpt.storage.vector_store.filters import MetadataFilters from dbgpt.storage.vector_store.filters import MetadataFilters
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
@@ -18,18 +22,24 @@ class KnowledgeSpaceRetriever(BaseRetriever):
def __init__( def __init__(
self, self,
space_name: str = None, space_id: str = None,
top_k: Optional[int] = 4, top_k: Optional[int] = 4,
query_rewrite: Optional[QueryRewrite] = None,
rerank: Optional[Ranker] = None,
): ):
""" """
Args: Args:
space_name (str): knowledge space name space_id (str): knowledge space name
top_k (Optional[int]): top k top_k (Optional[int]): top k
query_rewrite: (Optional[QueryRewrite]) query rewrite
rerank: (Optional[Ranker]) rerank
""" """
if space_name is None: if space_id is None:
raise ValueError("space_name is required") raise ValueError("space_id is required")
self._space_name = space_name self._space_id = space_id
self._top_k = top_k self._top_k = top_k
self._query_rewrite = query_rewrite
self._rerank = rerank
embedding_factory = CFG.SYSTEM_APP.get_component( embedding_factory = CFG.SYSTEM_APP.get_component(
"embedding_factory", EmbeddingFactory "embedding_factory", EmbeddingFactory
) )
@@ -37,8 +47,9 @@ class KnowledgeSpaceRetriever(BaseRetriever):
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL] model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
) )
from dbgpt.storage.vector_store.base import VectorStoreConfig from dbgpt.storage.vector_store.base import VectorStoreConfig
space_dao = KnowledgeSpaceDao()
config = VectorStoreConfig(name=self._space_name, embedding_fn=embedding_fn) space = space_dao.get_one({"id": space_id})
config = VectorStoreConfig(name=space.name, embedding_fn=embedding_fn)
self._vector_store_connector = VectorStoreConnector( self._vector_store_connector = VectorStoreConnector(
vector_store_type=CFG.VECTOR_STORE_TYPE, vector_store_type=CFG.VECTOR_STORE_TYPE,
vector_store_config=config, vector_store_config=config,
@@ -47,6 +58,20 @@ class KnowledgeSpaceRetriever(BaseRetriever):
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create() ).create()
self._retriever_chain = RetrieverChain(retrievers=[
QARetriever(space_id=space_id,
top_k=top_k,
embedding_fn=embedding_fn
),
EmbeddingRetriever(
index_store=self._vector_store_connector.index_client,
top_k=top_k,
query_rewrite=self._query_rewrite,
rerank=self._rerank
)
], executor=self._executor
)
def _retrieve( def _retrieve(
self, query: str, filters: Optional[MetadataFilters] = None self, query: str, filters: Optional[MetadataFilters] = None
) -> List[Chunk]: ) -> List[Chunk]:
@@ -59,8 +84,8 @@ class KnowledgeSpaceRetriever(BaseRetriever):
Return: Return:
List[Chunk]: list of chunks List[Chunk]: list of chunks
""" """
candidates = self._vector_store_connector.similar_search( candidates = self._retriever_chain.retrieve(
doc=query, topk=self._top_k, filters=filters query=query, filters=filters
) )
return candidates return candidates
@@ -80,13 +105,10 @@ class KnowledgeSpaceRetriever(BaseRetriever):
Return: Return:
List[Chunk]: list of chunks with score List[Chunk]: list of chunks with score
""" """
candidates_with_score = self._vector_store_connector.similar_search_with_scores( candidates_with_scores = self._retriever_chain.retrieve_with_scores(
doc=query, query, score_threshold, filters
topk=self._top_k,
score_threshold=score_threshold,
filters=filters,
) )
return candidates_with_score return candidates_with_scores
async def _aretrieve( async def _aretrieve(
self, query: str, filters: Optional[MetadataFilters] = None self, query: str, filters: Optional[MetadataFilters] = None

View File

@@ -0,0 +1,218 @@
import ast
import json
import logging
from typing import List, Optional, Any
from dbgpt._private.config import Config
from dbgpt.app.knowledge.chunk_db import DocumentChunkDao, DocumentChunkEntity
from dbgpt.app.knowledge.document_db import KnowledgeDocumentDao
from dbgpt.component import ComponentType
from dbgpt.core import Chunk
from dbgpt.rag.retriever.base import BaseRetriever
from dbgpt.serve.rag.models.models import KnowledgeSpaceDao, KnowledgeSpaceEntity
from dbgpt.storage.vector_store.filters import MetadataFilters
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
from dbgpt.util.similarity_util import calculate_cosine_similarity
from dbgpt.util.string_utils import remove_trailing_punctuation
CFG = Config()
CHUNK_PAGE_SIZE = 1000
logger = logging.getLogger(__name__)
class QARetriever(BaseRetriever):
"""Document QA retriever."""
def __init__(
self,
space_id: str = None,
top_k: Optional[int] = 4,
embedding_fn: Optional[Any] = 4,
lambda_value: Optional[float] = 1e-5,
):
"""
Args:
space_id (str): knowledge space name
top_k (Optional[int]): top k
"""
if space_id is None:
raise ValueError("space_id is required")
self._top_k = top_k
self._lambda_value = lambda_value
self._space_dao = KnowledgeSpaceDao()
self._document_dao = KnowledgeDocumentDao()
self._chunk_dao = DocumentChunkDao()
self._embedding_fn = embedding_fn
space = self._space_dao.get_one(
{"id": space_id}
)
if not space:
raise ValueError("space not found")
self.documents = self._document_dao.get_list({"space": space.name})
self._executor = CFG.SYSTEM_APP.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create()
def _retrieve(
self, query: str, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Retrieve knowledge chunks.
Args:
query (str): query text
filters: (Optional[MetadataFilters]) metadata filters.
Return:
List[Chunk]: list of chunks
"""
query = remove_trailing_punctuation(query)
candidate_results = []
for doc in self.documents:
if doc.questions:
questions = json.loads(doc.questions)
if query in questions:
chunks = self._chunk_dao.get_document_chunks(
DocumentChunkEntity(
document_id=doc.id
),
page_size=CHUNK_PAGE_SIZE
)
candidates = [
Chunk(content=chunk.content,
metadata=ast.literal_eval(chunk.meta_info),
retriever=self.name(),
score=0.0)
for chunk in chunks
]
candidate_results.extend(
self._cosine_similarity_rerank(candidates, query)
)
return candidate_results
def _retrieve_with_score(
self,
query: str,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
lambda_value: Optional[float] = None,
) -> List[Chunk]:
"""Retrieve knowledge chunks with score.
Args:
query (str): query text
score_threshold (float): score threshold
filters: (Optional[MetadataFilters]) metadata filters.
Return:
List[Chunk]: list of chunks with score
"""
query = remove_trailing_punctuation(query)
candidate_results = []
doc_ids = [doc.id for doc in self.documents]
query_param = DocumentChunkEntity()
chunks = self._chunk_dao.get_chunks_with_questions(
query=query_param,
document_ids=doc_ids
)
for chunk in chunks:
if chunk.questions:
questions = json.loads(chunk.questions)
if query in questions:
logger.info(f"qa chunk hit:{chunk}, question:{query}")
candidate_results.append(
Chunk(content=chunk.content,
chunk_id=str(chunk.id),
metadata={
"prop_field": ast.literal_eval(chunk.meta_info)
},
retriever=self.name(),
score=1.0
)
)
if len(candidate_results) > 0:
return self._cosine_similarity_rerank(candidate_results, query)
for doc in self.documents:
if doc.questions:
questions = json.loads(doc.questions)
if query in questions:
logger.info(f"qa document hit:{doc}, question:{query}")
chunks = self._chunk_dao.get_document_chunks(
DocumentChunkEntity(document_id=doc.id),
page_size=CHUNK_PAGE_SIZE
)
candidates_with_scores = [
Chunk(content=chunk.content,
chunk_id=str(chunk.id),
metadata={
"prop_field": ast.literal_eval(chunk.meta_info)
},
retriever=self.name(),
score=1.0)
for chunk in chunks
]
candidate_results.extend(
self._cosine_similarity_rerank(candidates_with_scores, query)
)
return candidate_results
async def _aretrieve(
self, query: str, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Retrieve knowledge chunks.
Args:
query (str): query text
filters: (Optional[MetadataFilters]) metadata filters.
Return:
List[Chunk]: list of chunks
"""
candidates = await blocking_func_to_async(
self._executor, self._retrieve, query, filters
)
return candidates
async def _aretrieve_with_score(
self,
query: str,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Retrieve knowledge chunks with score.
Args:
query (str): query text
score_threshold (float): score threshold
filters: (Optional[MetadataFilters]) metadata filters.
Return:
List[Chunk]: list of chunks with score
"""
candidates_with_score = await blocking_func_to_async(
self._executor, self._retrieve_with_score, query, score_threshold, filters
)
return candidates_with_score
def _cosine_similarity_rerank(self, candidates_with_scores: List[Chunk]
, query: str) -> List[Chunk]:
"""Rerank candidates using cosine similarity."""
if len(candidates_with_scores) > self._top_k:
for candidate in candidates_with_scores:
similarity = calculate_cosine_similarity(
embeddings=self._embedding_fn,
prediction=query,
contexts=[candidate.content]
)
score = float(similarity.mean())
candidate.score = score
candidates_with_scores.sort(key=lambda x: x.score, reverse=True)
candidates_with_scores = candidates_with_scores[: self._top_k]
candidates_with_scores = [
Chunk(content=candidate.content,
chunk_id=candidate.chunk_id,
metadata=candidate.metadata,
retriever=self.name(),
score=1.0)
for candidate in candidates_with_scores
]
return candidates_with_scores
@classmethod
def name(self):
"""Return retriever name."""
return "qa_retriever"

View File

@@ -0,0 +1,78 @@
from concurrent.futures import ThreadPoolExecutor, Executor
from typing import Optional, List
from dbgpt.core import Chunk
from dbgpt.rag.retriever.base import BaseRetriever
from dbgpt.storage.vector_store.filters import MetadataFilters
from dbgpt.util.executor_utils import blocking_func_to_async
class RetrieverChain(BaseRetriever):
"""Retriever chain class."""
def __init__(self, retrievers: Optional[List[BaseRetriever]] = None,
executor: Optional[Executor] = None):
"""Create retriever chain instance."""
self._retrievers = retrievers or []
self._executor = executor or ThreadPoolExecutor()
def _retrieve(self, query: str, filters: Optional[MetadataFilters] = None) -> List[
Chunk]:
"""Retrieve knowledge chunks.
Args:
query (str): query text
filters: (Optional[MetadataFilters]) metadata filters.
Return:
List[Chunk]: list of chunks
"""
for retriever in self._retrievers:
candidates = retriever.retrieve(
query, filters
)
if candidates:
return candidates
return []
async def _aretrieve(self, query: str, filters: Optional[MetadataFilters] = None) -> \
List[Chunk]:
"""Retrieve knowledge chunks.
Args:
query (str): query text
filters: (Optional[MetadataFilters]) metadata filters.
Return:
List[Chunk]: list of chunks
"""
candidates = await blocking_func_to_async(
self._executor, self._retrieve, query, filters
)
return candidates
def _retrieve_with_score(self, query: str, score_threshold: float, filters: Optional[MetadataFilters] = None) -> List[Chunk]:
"""Retrieve knowledge chunks.
Args:
query (str): query text
filters: (Optional[MetadataFilters]) metadata filters.
Return:
List[Chunk]: list of chunks
"""
for retriever in self._retrievers:
candidates_with_scores = retriever.retrieve_with_scores(
query=query, score_threshold=score_threshold, filters=filters
)
if candidates_with_scores:
return candidates_with_scores
return []
async def _aretrieve_with_score(self, query: str, score_threshold: float, filters: Optional[MetadataFilters] = None) -> List[Chunk]:
"""Retrieve knowledge chunks with score.
Args:
query (str): query text
score_threshold (float): score threshold
filters: (Optional[MetadataFilters]) metadata filters.
Return:
List[Chunk]: list of chunks with score
"""
candidates_with_score = await blocking_func_to_async(
self._executor, self._retrieve_with_score, query, score_threshold, filters
)
return candidates_with_score