mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-11 22:09:44 +00:00
fix:space resource error.
This commit is contained in:
@@ -21,6 +21,7 @@ from dbgpt.core import (
|
|||||||
)
|
)
|
||||||
from dbgpt.rag.retriever.rerank import RerankEmbeddingsRanker
|
from dbgpt.rag.retriever.rerank import RerankEmbeddingsRanker
|
||||||
from dbgpt.rag.retriever.rewrite import QueryRewrite
|
from dbgpt.rag.retriever.rewrite import QueryRewrite
|
||||||
|
from dbgpt.serve.rag.retriever.knowledge_space import KnowledgeSpaceRetriever
|
||||||
from dbgpt.util.tracer import root_tracer, trace
|
from dbgpt.util.tracer import root_tracer, trace
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
@@ -77,7 +78,6 @@ class ChatKnowledge(BaseChat):
|
|||||||
)
|
)
|
||||||
from dbgpt.serve.rag.models.models import (
|
from dbgpt.serve.rag.models.models import (
|
||||||
KnowledgeSpaceDao,
|
KnowledgeSpaceDao,
|
||||||
KnowledgeSpaceEntity,
|
|
||||||
)
|
)
|
||||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||||
|
|
||||||
@@ -113,12 +113,19 @@ class ChatKnowledge(BaseChat):
|
|||||||
# We use reranker, so if the top_k is less than 20,
|
# We use reranker, so if the top_k is less than 20,
|
||||||
# we need to set it to 20
|
# we need to set it to 20
|
||||||
retriever_top_k = max(CFG.RERANK_TOP_K, 20)
|
retriever_top_k = max(CFG.RERANK_TOP_K, 20)
|
||||||
self.embedding_retriever = EmbeddingRetriever(
|
# self.embedding_retriever = EmbeddingRetriever(
|
||||||
|
# top_k=retriever_top_k,
|
||||||
|
# index_store=vector_store_connector.index_client,
|
||||||
|
# query_rewrite=query_rewrite,
|
||||||
|
# rerank=reranker,
|
||||||
|
# )
|
||||||
|
self._space_retriever = KnowledgeSpaceRetriever(
|
||||||
|
space_id=self.knowledge_space,
|
||||||
top_k=retriever_top_k,
|
top_k=retriever_top_k,
|
||||||
index_store=vector_store_connector.index_client,
|
|
||||||
query_rewrite=query_rewrite,
|
query_rewrite=query_rewrite,
|
||||||
rerank=reranker,
|
rerank=reranker,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.prompt_template.template_is_strict = False
|
self.prompt_template.template_is_strict = False
|
||||||
self.relations = None
|
self.relations = None
|
||||||
self.chunk_dao = DocumentChunkDao()
|
self.chunk_dao = DocumentChunkDao()
|
||||||
@@ -275,6 +282,6 @@ class ChatKnowledge(BaseChat):
|
|||||||
with root_tracer.start_span(
|
with root_tracer.start_span(
|
||||||
"execute_similar_search", metadata={"query": query}
|
"execute_similar_search", metadata={"query": query}
|
||||||
):
|
):
|
||||||
return await self.embedding_retriever.aretrieve_with_scores(
|
return await self._space_retriever.aretrieve_with_scores(
|
||||||
query, self.recall_score
|
query, self.recall_score
|
||||||
)
|
)
|
||||||
|
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from dbgpt._private.pydantic import BaseModel, Field, model_to_dict
|
from dbgpt._private.pydantic import BaseModel, Field, model_to_dict
|
||||||
|
|
||||||
@@ -61,6 +61,7 @@ class Chunk(Document):
|
|||||||
default="\n",
|
default="\n",
|
||||||
description="Separator between metadata fields when converting to string.",
|
description="Separator between metadata fields when converting to string.",
|
||||||
)
|
)
|
||||||
|
retriever: Optional[str] = Field(default=None, description="retriever name")
|
||||||
|
|
||||||
def to_dict(self, **kwargs: Any) -> Dict[str, Any]:
|
def to_dict(self, **kwargs: Any) -> Dict[str, Any]:
|
||||||
"""Convert Chunk to dict."""
|
"""Convert Chunk to dict."""
|
||||||
|
@@ -65,7 +65,7 @@ class KnowledgeSpaceRetrieverResource(RetrieverResource):
|
|||||||
|
|
||||||
def __init__(self, name: str, space_name: str, context: Optional[dict] = None):
|
def __init__(self, name: str, space_name: str, context: Optional[dict] = None):
|
||||||
retriever = KnowledgeSpaceRetriever(
|
retriever = KnowledgeSpaceRetriever(
|
||||||
space_name=space_name,
|
space_id=space_name,
|
||||||
top_k=context.get("top_k", None) if context else 4,
|
top_k=context.get("top_k", None) if context else 4,
|
||||||
)
|
)
|
||||||
super().__init__(name, retriever=retriever)
|
super().__init__(name, retriever=retriever)
|
||||||
|
@@ -5,8 +5,12 @@ from dbgpt.component import ComponentType
|
|||||||
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
|
from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG
|
||||||
from dbgpt.core import Chunk
|
from dbgpt.core import Chunk
|
||||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||||
|
from dbgpt.rag.retriever import EmbeddingRetriever, Ranker, QueryRewrite
|
||||||
from dbgpt.rag.retriever.base import BaseRetriever
|
from dbgpt.rag.retriever.base import BaseRetriever
|
||||||
from dbgpt.serve.rag.connector import VectorStoreConnector
|
from dbgpt.serve.rag.connector import VectorStoreConnector
|
||||||
|
from dbgpt.serve.rag.models.models import KnowledgeSpaceDao
|
||||||
|
from dbgpt.serve.rag.retriever.qa_retriever import QARetriever
|
||||||
|
from dbgpt.serve.rag.retriever.retriever_chain import RetrieverChain
|
||||||
from dbgpt.storage.vector_store.filters import MetadataFilters
|
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||||
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
|
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
|
||||||
|
|
||||||
@@ -18,18 +22,24 @@ class KnowledgeSpaceRetriever(BaseRetriever):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
space_name: str = None,
|
space_id: str = None,
|
||||||
top_k: Optional[int] = 4,
|
top_k: Optional[int] = 4,
|
||||||
|
query_rewrite: Optional[QueryRewrite] = None,
|
||||||
|
rerank: Optional[Ranker] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
space_name (str): knowledge space name
|
space_id (str): knowledge space name
|
||||||
top_k (Optional[int]): top k
|
top_k (Optional[int]): top k
|
||||||
|
query_rewrite: (Optional[QueryRewrite]) query rewrite
|
||||||
|
rerank: (Optional[Ranker]) rerank
|
||||||
"""
|
"""
|
||||||
if space_name is None:
|
if space_id is None:
|
||||||
raise ValueError("space_name is required")
|
raise ValueError("space_id is required")
|
||||||
self._space_name = space_name
|
self._space_id = space_id
|
||||||
self._top_k = top_k
|
self._top_k = top_k
|
||||||
|
self._query_rewrite = query_rewrite
|
||||||
|
self._rerank = rerank
|
||||||
embedding_factory = CFG.SYSTEM_APP.get_component(
|
embedding_factory = CFG.SYSTEM_APP.get_component(
|
||||||
"embedding_factory", EmbeddingFactory
|
"embedding_factory", EmbeddingFactory
|
||||||
)
|
)
|
||||||
@@ -37,8 +47,9 @@ class KnowledgeSpaceRetriever(BaseRetriever):
|
|||||||
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
|
model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
|
||||||
)
|
)
|
||||||
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
from dbgpt.storage.vector_store.base import VectorStoreConfig
|
||||||
|
space_dao = KnowledgeSpaceDao()
|
||||||
config = VectorStoreConfig(name=self._space_name, embedding_fn=embedding_fn)
|
space = space_dao.get_one({"id": space_id})
|
||||||
|
config = VectorStoreConfig(name=space.name, embedding_fn=embedding_fn)
|
||||||
self._vector_store_connector = VectorStoreConnector(
|
self._vector_store_connector = VectorStoreConnector(
|
||||||
vector_store_type=CFG.VECTOR_STORE_TYPE,
|
vector_store_type=CFG.VECTOR_STORE_TYPE,
|
||||||
vector_store_config=config,
|
vector_store_config=config,
|
||||||
@@ -47,6 +58,20 @@ class KnowledgeSpaceRetriever(BaseRetriever):
|
|||||||
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
||||||
).create()
|
).create()
|
||||||
|
|
||||||
|
self._retriever_chain = RetrieverChain(retrievers=[
|
||||||
|
QARetriever(space_id=space_id,
|
||||||
|
top_k=top_k,
|
||||||
|
embedding_fn=embedding_fn
|
||||||
|
),
|
||||||
|
EmbeddingRetriever(
|
||||||
|
index_store=self._vector_store_connector.index_client,
|
||||||
|
top_k=top_k,
|
||||||
|
query_rewrite=self._query_rewrite,
|
||||||
|
rerank=self._rerank
|
||||||
|
)
|
||||||
|
], executor=self._executor
|
||||||
|
)
|
||||||
|
|
||||||
def _retrieve(
|
def _retrieve(
|
||||||
self, query: str, filters: Optional[MetadataFilters] = None
|
self, query: str, filters: Optional[MetadataFilters] = None
|
||||||
) -> List[Chunk]:
|
) -> List[Chunk]:
|
||||||
@@ -59,8 +84,8 @@ class KnowledgeSpaceRetriever(BaseRetriever):
|
|||||||
Return:
|
Return:
|
||||||
List[Chunk]: list of chunks
|
List[Chunk]: list of chunks
|
||||||
"""
|
"""
|
||||||
candidates = self._vector_store_connector.similar_search(
|
candidates = self._retriever_chain.retrieve(
|
||||||
doc=query, topk=self._top_k, filters=filters
|
query=query, filters=filters
|
||||||
)
|
)
|
||||||
return candidates
|
return candidates
|
||||||
|
|
||||||
@@ -80,13 +105,10 @@ class KnowledgeSpaceRetriever(BaseRetriever):
|
|||||||
Return:
|
Return:
|
||||||
List[Chunk]: list of chunks with score
|
List[Chunk]: list of chunks with score
|
||||||
"""
|
"""
|
||||||
candidates_with_score = self._vector_store_connector.similar_search_with_scores(
|
candidates_with_scores = self._retriever_chain.retrieve_with_scores(
|
||||||
doc=query,
|
query, score_threshold, filters
|
||||||
topk=self._top_k,
|
|
||||||
score_threshold=score_threshold,
|
|
||||||
filters=filters,
|
|
||||||
)
|
)
|
||||||
return candidates_with_score
|
return candidates_with_scores
|
||||||
|
|
||||||
async def _aretrieve(
|
async def _aretrieve(
|
||||||
self, query: str, filters: Optional[MetadataFilters] = None
|
self, query: str, filters: Optional[MetadataFilters] = None
|
||||||
|
218
dbgpt/serve/rag/retriever/qa_retriever.py
Normal file
218
dbgpt/serve/rag/retriever/qa_retriever.py
Normal file
@@ -0,0 +1,218 @@
|
|||||||
|
import ast
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import List, Optional, Any
|
||||||
|
|
||||||
|
from dbgpt._private.config import Config
|
||||||
|
from dbgpt.app.knowledge.chunk_db import DocumentChunkDao, DocumentChunkEntity
|
||||||
|
from dbgpt.app.knowledge.document_db import KnowledgeDocumentDao
|
||||||
|
|
||||||
|
from dbgpt.component import ComponentType
|
||||||
|
from dbgpt.core import Chunk
|
||||||
|
from dbgpt.rag.retriever.base import BaseRetriever
|
||||||
|
from dbgpt.serve.rag.models.models import KnowledgeSpaceDao, KnowledgeSpaceEntity
|
||||||
|
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||||
|
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
|
||||||
|
from dbgpt.util.similarity_util import calculate_cosine_similarity
|
||||||
|
from dbgpt.util.string_utils import remove_trailing_punctuation
|
||||||
|
|
||||||
|
CFG = Config()
|
||||||
|
CHUNK_PAGE_SIZE = 1000
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class QARetriever(BaseRetriever):
|
||||||
|
"""Document QA retriever."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
space_id: str = None,
|
||||||
|
top_k: Optional[int] = 4,
|
||||||
|
embedding_fn: Optional[Any] = 4,
|
||||||
|
lambda_value: Optional[float] = 1e-5,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
space_id (str): knowledge space name
|
||||||
|
top_k (Optional[int]): top k
|
||||||
|
"""
|
||||||
|
if space_id is None:
|
||||||
|
raise ValueError("space_id is required")
|
||||||
|
self._top_k = top_k
|
||||||
|
self._lambda_value = lambda_value
|
||||||
|
self._space_dao = KnowledgeSpaceDao()
|
||||||
|
self._document_dao = KnowledgeDocumentDao()
|
||||||
|
self._chunk_dao = DocumentChunkDao()
|
||||||
|
self._embedding_fn = embedding_fn
|
||||||
|
|
||||||
|
space = self._space_dao.get_one(
|
||||||
|
{"id": space_id}
|
||||||
|
)
|
||||||
|
if not space:
|
||||||
|
raise ValueError("space not found")
|
||||||
|
self.documents = self._document_dao.get_list({"space": space.name})
|
||||||
|
self._executor = CFG.SYSTEM_APP.get_component(
|
||||||
|
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
||||||
|
).create()
|
||||||
|
|
||||||
|
def _retrieve(
|
||||||
|
self, query: str, filters: Optional[MetadataFilters] = None
|
||||||
|
) -> List[Chunk]:
|
||||||
|
"""Retrieve knowledge chunks.
|
||||||
|
Args:
|
||||||
|
query (str): query text
|
||||||
|
filters: (Optional[MetadataFilters]) metadata filters.
|
||||||
|
Return:
|
||||||
|
List[Chunk]: list of chunks
|
||||||
|
"""
|
||||||
|
query = remove_trailing_punctuation(query)
|
||||||
|
candidate_results = []
|
||||||
|
for doc in self.documents:
|
||||||
|
if doc.questions:
|
||||||
|
questions = json.loads(doc.questions)
|
||||||
|
if query in questions:
|
||||||
|
chunks = self._chunk_dao.get_document_chunks(
|
||||||
|
DocumentChunkEntity(
|
||||||
|
document_id=doc.id
|
||||||
|
),
|
||||||
|
page_size=CHUNK_PAGE_SIZE
|
||||||
|
)
|
||||||
|
candidates = [
|
||||||
|
Chunk(content=chunk.content,
|
||||||
|
metadata=ast.literal_eval(chunk.meta_info),
|
||||||
|
retriever=self.name(),
|
||||||
|
score=0.0)
|
||||||
|
for chunk in chunks
|
||||||
|
]
|
||||||
|
candidate_results.extend(
|
||||||
|
self._cosine_similarity_rerank(candidates, query)
|
||||||
|
)
|
||||||
|
return candidate_results
|
||||||
|
|
||||||
|
def _retrieve_with_score(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
score_threshold: float,
|
||||||
|
filters: Optional[MetadataFilters] = None,
|
||||||
|
lambda_value: Optional[float] = None,
|
||||||
|
) -> List[Chunk]:
|
||||||
|
"""Retrieve knowledge chunks with score.
|
||||||
|
Args:
|
||||||
|
query (str): query text
|
||||||
|
score_threshold (float): score threshold
|
||||||
|
filters: (Optional[MetadataFilters]) metadata filters.
|
||||||
|
Return:
|
||||||
|
List[Chunk]: list of chunks with score
|
||||||
|
"""
|
||||||
|
query = remove_trailing_punctuation(query)
|
||||||
|
candidate_results = []
|
||||||
|
doc_ids = [doc.id for doc in self.documents]
|
||||||
|
query_param = DocumentChunkEntity()
|
||||||
|
chunks = self._chunk_dao.get_chunks_with_questions(
|
||||||
|
query=query_param,
|
||||||
|
document_ids=doc_ids
|
||||||
|
)
|
||||||
|
for chunk in chunks:
|
||||||
|
if chunk.questions:
|
||||||
|
questions = json.loads(chunk.questions)
|
||||||
|
if query in questions:
|
||||||
|
logger.info(f"qa chunk hit:{chunk}, question:{query}")
|
||||||
|
candidate_results.append(
|
||||||
|
Chunk(content=chunk.content,
|
||||||
|
chunk_id=str(chunk.id),
|
||||||
|
metadata={
|
||||||
|
"prop_field": ast.literal_eval(chunk.meta_info)
|
||||||
|
},
|
||||||
|
retriever=self.name(),
|
||||||
|
score=1.0
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if len(candidate_results) > 0:
|
||||||
|
return self._cosine_similarity_rerank(candidate_results, query)
|
||||||
|
|
||||||
|
for doc in self.documents:
|
||||||
|
if doc.questions:
|
||||||
|
questions = json.loads(doc.questions)
|
||||||
|
if query in questions:
|
||||||
|
logger.info(f"qa document hit:{doc}, question:{query}")
|
||||||
|
chunks = self._chunk_dao.get_document_chunks(
|
||||||
|
DocumentChunkEntity(document_id=doc.id),
|
||||||
|
page_size=CHUNK_PAGE_SIZE
|
||||||
|
)
|
||||||
|
candidates_with_scores = [
|
||||||
|
Chunk(content=chunk.content,
|
||||||
|
chunk_id=str(chunk.id),
|
||||||
|
metadata={
|
||||||
|
"prop_field": ast.literal_eval(chunk.meta_info)
|
||||||
|
},
|
||||||
|
retriever=self.name(),
|
||||||
|
score=1.0)
|
||||||
|
for chunk in chunks
|
||||||
|
]
|
||||||
|
candidate_results.extend(
|
||||||
|
self._cosine_similarity_rerank(candidates_with_scores, query)
|
||||||
|
)
|
||||||
|
return candidate_results
|
||||||
|
|
||||||
|
async def _aretrieve(
|
||||||
|
self, query: str, filters: Optional[MetadataFilters] = None
|
||||||
|
) -> List[Chunk]:
|
||||||
|
"""Retrieve knowledge chunks.
|
||||||
|
Args:
|
||||||
|
query (str): query text
|
||||||
|
filters: (Optional[MetadataFilters]) metadata filters.
|
||||||
|
Return:
|
||||||
|
List[Chunk]: list of chunks
|
||||||
|
"""
|
||||||
|
candidates = await blocking_func_to_async(
|
||||||
|
self._executor, self._retrieve, query, filters
|
||||||
|
)
|
||||||
|
return candidates
|
||||||
|
|
||||||
|
async def _aretrieve_with_score(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
score_threshold: float,
|
||||||
|
filters: Optional[MetadataFilters] = None,
|
||||||
|
) -> List[Chunk]:
|
||||||
|
"""Retrieve knowledge chunks with score.
|
||||||
|
Args:
|
||||||
|
query (str): query text
|
||||||
|
score_threshold (float): score threshold
|
||||||
|
filters: (Optional[MetadataFilters]) metadata filters.
|
||||||
|
Return:
|
||||||
|
List[Chunk]: list of chunks with score
|
||||||
|
"""
|
||||||
|
candidates_with_score = await blocking_func_to_async(
|
||||||
|
self._executor, self._retrieve_with_score, query, score_threshold, filters
|
||||||
|
)
|
||||||
|
return candidates_with_score
|
||||||
|
|
||||||
|
def _cosine_similarity_rerank(self, candidates_with_scores: List[Chunk]
|
||||||
|
, query: str) -> List[Chunk]:
|
||||||
|
"""Rerank candidates using cosine similarity."""
|
||||||
|
if len(candidates_with_scores) > self._top_k:
|
||||||
|
for candidate in candidates_with_scores:
|
||||||
|
similarity = calculate_cosine_similarity(
|
||||||
|
embeddings=self._embedding_fn,
|
||||||
|
prediction=query,
|
||||||
|
contexts=[candidate.content]
|
||||||
|
)
|
||||||
|
score = float(similarity.mean())
|
||||||
|
candidate.score = score
|
||||||
|
candidates_with_scores.sort(key=lambda x: x.score, reverse=True)
|
||||||
|
candidates_with_scores = candidates_with_scores[: self._top_k]
|
||||||
|
candidates_with_scores = [
|
||||||
|
Chunk(content=candidate.content,
|
||||||
|
chunk_id=candidate.chunk_id,
|
||||||
|
metadata=candidate.metadata,
|
||||||
|
retriever=self.name(),
|
||||||
|
score=1.0)
|
||||||
|
for candidate in candidates_with_scores
|
||||||
|
]
|
||||||
|
return candidates_with_scores
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def name(self):
|
||||||
|
"""Return retriever name."""
|
||||||
|
return "qa_retriever"
|
78
dbgpt/serve/rag/retriever/retriever_chain.py
Normal file
78
dbgpt/serve/rag/retriever/retriever_chain.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
from concurrent.futures import ThreadPoolExecutor, Executor
|
||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
from dbgpt.core import Chunk
|
||||||
|
from dbgpt.rag.retriever.base import BaseRetriever
|
||||||
|
from dbgpt.storage.vector_store.filters import MetadataFilters
|
||||||
|
from dbgpt.util.executor_utils import blocking_func_to_async
|
||||||
|
|
||||||
|
|
||||||
|
class RetrieverChain(BaseRetriever):
|
||||||
|
"""Retriever chain class."""
|
||||||
|
|
||||||
|
def __init__(self, retrievers: Optional[List[BaseRetriever]] = None,
|
||||||
|
executor: Optional[Executor] = None):
|
||||||
|
"""Create retriever chain instance."""
|
||||||
|
self._retrievers = retrievers or []
|
||||||
|
self._executor = executor or ThreadPoolExecutor()
|
||||||
|
|
||||||
|
def _retrieve(self, query: str, filters: Optional[MetadataFilters] = None) -> List[
|
||||||
|
Chunk]:
|
||||||
|
"""Retrieve knowledge chunks.
|
||||||
|
Args:
|
||||||
|
query (str): query text
|
||||||
|
filters: (Optional[MetadataFilters]) metadata filters.
|
||||||
|
Return:
|
||||||
|
List[Chunk]: list of chunks
|
||||||
|
"""
|
||||||
|
for retriever in self._retrievers:
|
||||||
|
candidates = retriever.retrieve(
|
||||||
|
query, filters
|
||||||
|
)
|
||||||
|
if candidates:
|
||||||
|
return candidates
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def _aretrieve(self, query: str, filters: Optional[MetadataFilters] = None) -> \
|
||||||
|
List[Chunk]:
|
||||||
|
"""Retrieve knowledge chunks.
|
||||||
|
Args:
|
||||||
|
query (str): query text
|
||||||
|
filters: (Optional[MetadataFilters]) metadata filters.
|
||||||
|
Return:
|
||||||
|
List[Chunk]: list of chunks
|
||||||
|
"""
|
||||||
|
candidates = await blocking_func_to_async(
|
||||||
|
self._executor, self._retrieve, query, filters
|
||||||
|
)
|
||||||
|
return candidates
|
||||||
|
|
||||||
|
def _retrieve_with_score(self, query: str, score_threshold: float, filters: Optional[MetadataFilters] = None) -> List[Chunk]:
|
||||||
|
"""Retrieve knowledge chunks.
|
||||||
|
Args:
|
||||||
|
query (str): query text
|
||||||
|
filters: (Optional[MetadataFilters]) metadata filters.
|
||||||
|
Return:
|
||||||
|
List[Chunk]: list of chunks
|
||||||
|
"""
|
||||||
|
for retriever in self._retrievers:
|
||||||
|
candidates_with_scores = retriever.retrieve_with_scores(
|
||||||
|
query=query, score_threshold=score_threshold, filters=filters
|
||||||
|
)
|
||||||
|
if candidates_with_scores:
|
||||||
|
return candidates_with_scores
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def _aretrieve_with_score(self, query: str, score_threshold: float, filters: Optional[MetadataFilters] = None) -> List[Chunk]:
|
||||||
|
"""Retrieve knowledge chunks with score.
|
||||||
|
Args:
|
||||||
|
query (str): query text
|
||||||
|
score_threshold (float): score threshold
|
||||||
|
filters: (Optional[MetadataFilters]) metadata filters.
|
||||||
|
Return:
|
||||||
|
List[Chunk]: list of chunks with score
|
||||||
|
"""
|
||||||
|
candidates_with_score = await blocking_func_to_async(
|
||||||
|
self._executor, self._retrieve_with_score, query, score_threshold, filters
|
||||||
|
)
|
||||||
|
return candidates_with_score
|
Reference in New Issue
Block a user