mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-25 04:53:36 +00:00
Co-authored-by: 夏姜 <wenfengjiang.jwf@digital-engine.com> Co-authored-by: aries_ckt <916701291@qq.com> Co-authored-by: wb-lh513319 <wb-lh513319@alibaba-inc.com> Co-authored-by: csunny <cfqsunny@163.com>
170 lines
4.5 KiB
Python
170 lines
4.5 KiB
Python
"""Base retriever module."""
|
|
from abc import ABC, abstractmethod
|
|
from enum import Enum
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from dbgpt.core import Chunk
|
|
from dbgpt.storage.vector_store.filters import MetadataFilters
|
|
|
|
|
|
class RetrieverStrategy(str, Enum):
|
|
"""Retriever strategy.
|
|
|
|
Args:
|
|
- EMBEDDING: embedding retriever
|
|
- KEYWORD: keyword retriever
|
|
- HYBRID: hybrid retriever
|
|
"""
|
|
|
|
EMBEDDING = "embedding"
|
|
GRAPH = "graph"
|
|
KEYWORD = "keyword"
|
|
HYBRID = "hybrid"
|
|
|
|
|
|
class BaseRetriever(ABC):
|
|
"""Base retriever."""
|
|
|
|
def load_document(self, chunks: List[Chunk], **kwargs: Dict[str, Any]) -> List[str]:
|
|
"""Load document in vector database.
|
|
|
|
Args:
|
|
- chunks: document chunks.
|
|
Return chunk ids.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def retrieve(
|
|
self, query: str, filters: Optional[MetadataFilters] = None
|
|
) -> List[Chunk]:
|
|
"""Retrieve knowledge chunks.
|
|
|
|
Args:
|
|
query (str): query text.
|
|
filters: (Optional[MetadataFilters]) metadata filters.
|
|
|
|
Returns:
|
|
List[Chunk]: list of chunks
|
|
"""
|
|
return self._retrieve(query, filters)
|
|
|
|
async def aretrieve(
|
|
self, query: str, filters: Optional[MetadataFilters] = None
|
|
) -> List[Chunk]:
|
|
"""Retrieve knowledge chunks.
|
|
|
|
Args:
|
|
query (str): async query text.
|
|
filters: (Optional[MetadataFilters]) metadata filters.
|
|
|
|
Returns:
|
|
List[Chunk]: list of chunks
|
|
"""
|
|
return await self._aretrieve(query, filters)
|
|
|
|
def retrieve_with_scores(
|
|
self,
|
|
query: str,
|
|
score_threshold: float,
|
|
filters: Optional[MetadataFilters] = None,
|
|
) -> List[Chunk]:
|
|
"""Retrieve knowledge chunks with score.
|
|
|
|
Args:
|
|
query (str): query text.
|
|
score_threshold (float): score threshold.
|
|
filters: (Optional[MetadataFilters]) metadata filters.
|
|
|
|
|
|
Returns:
|
|
List[Chunk]: list of chunks
|
|
"""
|
|
return self._retrieve_with_score(query, score_threshold, filters)
|
|
|
|
async def aretrieve_with_scores(
|
|
self,
|
|
query: str,
|
|
score_threshold: float,
|
|
filters: Optional[MetadataFilters] = None,
|
|
) -> List[Chunk]:
|
|
"""Retrieve knowledge chunks with score.
|
|
|
|
Args:
|
|
query (str): query text
|
|
score_threshold (float): score threshold
|
|
filters: (Optional[MetadataFilters]) metadata filters.
|
|
|
|
Returns:
|
|
List[Chunk]: list of chunks
|
|
"""
|
|
return await self._aretrieve_with_score(query, score_threshold, filters)
|
|
|
|
@abstractmethod
|
|
def _retrieve(
|
|
self, query: str, filters: Optional[MetadataFilters] = None
|
|
) -> List[Chunk]:
|
|
"""Retrieve knowledge chunks.
|
|
|
|
Args:
|
|
query (str): query text
|
|
filters: (Optional[MetadataFilters]) metadata filters.
|
|
|
|
Returns:
|
|
List[Chunk]: list of chunks
|
|
"""
|
|
|
|
@abstractmethod
|
|
async def _aretrieve(
|
|
self, query: str, filters: Optional[MetadataFilters] = None
|
|
) -> List[Chunk]:
|
|
"""Async Retrieve knowledge chunks.
|
|
|
|
Args:
|
|
query (str): query text
|
|
filters: (Optional[MetadataFilters]) metadata filters.
|
|
|
|
Returns:
|
|
List[Chunk]: list of chunks
|
|
"""
|
|
|
|
@abstractmethod
|
|
def _retrieve_with_score(
|
|
self,
|
|
query: str,
|
|
score_threshold: float,
|
|
filters: Optional[MetadataFilters] = None,
|
|
) -> List[Chunk]:
|
|
"""Retrieve knowledge chunks with score.
|
|
|
|
Args:
|
|
query (str): query text
|
|
score_threshold (float): score threshold
|
|
filters: (Optional[MetadataFilters]) metadata filters.
|
|
|
|
Returns:
|
|
List[Chunk]: list of chunks
|
|
"""
|
|
|
|
@abstractmethod
|
|
async def _aretrieve_with_score(
|
|
self,
|
|
query: str,
|
|
score_threshold: float,
|
|
filters: Optional[MetadataFilters] = None,
|
|
) -> List[Chunk]:
|
|
"""Async Retrieve knowledge chunks with score.
|
|
|
|
Args:
|
|
query (str): query text
|
|
score_threshold (float): score threshold
|
|
filters: (Optional[MetadataFilters]) metadata filters.
|
|
|
|
Returns:
|
|
List[Chunk]: list of chunks
|
|
"""
|
|
|
|
@classmethod
|
|
def name(cls):
|
|
"""Return the name of the retriever."""
|
|
raise NotImplementedError("Current retriever does not support name")
|