DB-GPT/dbgpt/rag/retriever/base.py
明天 b124ecc10b
feat: (0.6)New UI (#1855)
Co-authored-by: 夏姜 <wenfengjiang.jwf@digital-engine.com>
Co-authored-by: aries_ckt <916701291@qq.com>
Co-authored-by: wb-lh513319 <wb-lh513319@alibaba-inc.com>
Co-authored-by: csunny <cfqsunny@163.com>
2024-08-21 17:37:45 +08:00

170 lines
4.5 KiB
Python

"""Base retriever module."""
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Dict, List, Optional
from dbgpt.core import Chunk
from dbgpt.storage.vector_store.filters import MetadataFilters
class RetrieverStrategy(str, Enum):
"""Retriever strategy.
Args:
- EMBEDDING: embedding retriever
- KEYWORD: keyword retriever
- HYBRID: hybrid retriever
"""
EMBEDDING = "embedding"
GRAPH = "graph"
KEYWORD = "keyword"
HYBRID = "hybrid"
class BaseRetriever(ABC):
"""Base retriever."""
def load_document(self, chunks: List[Chunk], **kwargs: Dict[str, Any]) -> List[str]:
"""Load document in vector database.
Args:
- chunks: document chunks.
Return chunk ids.
"""
raise NotImplementedError
def retrieve(
self, query: str, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Retrieve knowledge chunks.
Args:
query (str): query text.
filters: (Optional[MetadataFilters]) metadata filters.
Returns:
List[Chunk]: list of chunks
"""
return self._retrieve(query, filters)
async def aretrieve(
self, query: str, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Retrieve knowledge chunks.
Args:
query (str): async query text.
filters: (Optional[MetadataFilters]) metadata filters.
Returns:
List[Chunk]: list of chunks
"""
return await self._aretrieve(query, filters)
def retrieve_with_scores(
self,
query: str,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Retrieve knowledge chunks with score.
Args:
query (str): query text.
score_threshold (float): score threshold.
filters: (Optional[MetadataFilters]) metadata filters.
Returns:
List[Chunk]: list of chunks
"""
return self._retrieve_with_score(query, score_threshold, filters)
async def aretrieve_with_scores(
self,
query: str,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Retrieve knowledge chunks with score.
Args:
query (str): query text
score_threshold (float): score threshold
filters: (Optional[MetadataFilters]) metadata filters.
Returns:
List[Chunk]: list of chunks
"""
return await self._aretrieve_with_score(query, score_threshold, filters)
@abstractmethod
def _retrieve(
self, query: str, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Retrieve knowledge chunks.
Args:
query (str): query text
filters: (Optional[MetadataFilters]) metadata filters.
Returns:
List[Chunk]: list of chunks
"""
@abstractmethod
async def _aretrieve(
self, query: str, filters: Optional[MetadataFilters] = None
) -> List[Chunk]:
"""Async Retrieve knowledge chunks.
Args:
query (str): query text
filters: (Optional[MetadataFilters]) metadata filters.
Returns:
List[Chunk]: list of chunks
"""
@abstractmethod
def _retrieve_with_score(
self,
query: str,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Retrieve knowledge chunks with score.
Args:
query (str): query text
score_threshold (float): score threshold
filters: (Optional[MetadataFilters]) metadata filters.
Returns:
List[Chunk]: list of chunks
"""
@abstractmethod
async def _aretrieve_with_score(
self,
query: str,
score_threshold: float,
filters: Optional[MetadataFilters] = None,
) -> List[Chunk]:
"""Async Retrieve knowledge chunks with score.
Args:
query (str): query text
score_threshold (float): score threshold
filters: (Optional[MetadataFilters]) metadata filters.
Returns:
List[Chunk]: list of chunks
"""
@classmethod
def name(cls):
"""Return the name of the retriever."""
raise NotImplementedError("Current retriever does not support name")