chore: Add pylint for DB-GPT rag lib (#1267)

This commit is contained in:
Fangyin Cheng
2024-03-07 23:27:43 +08:00
committed by GitHub
parent aaaf34db17
commit 7446817340
70 changed files with 1135 additions and 587 deletions

View File

@@ -1,7 +1,8 @@
"""Embedding retriever operator."""
from functools import reduce
from typing import Any, Optional
from dbgpt.core.awel.task.base import IN
from dbgpt.core.interface.operators.retriever import RetrieverOperator
from dbgpt.rag.retriever.embedding import EmbeddingRetriever
from dbgpt.rag.retriever.rerank import Ranker
@@ -10,25 +11,29 @@ from dbgpt.storage.vector_store.connector import VectorStoreConnector
class EmbeddingRetrieverOperator(RetrieverOperator[Any, Any]):
"""The Embedding Retriever Operator."""
def __init__(
self,
vector_store_connector: VectorStoreConnector,
top_k: int,
score_threshold: Optional[float] = 0.3,
score_threshold: float = 0.3,
query_rewrite: Optional[QueryRewrite] = None,
rerank: Ranker = None,
vector_store_connector: VectorStoreConnector = None,
rerank: Optional[Ranker] = None,
**kwargs
):
"""Create a new EmbeddingRetrieverOperator."""
super().__init__(**kwargs)
self._score_threshold = score_threshold
self._retriever = EmbeddingRetriever(
vector_store_connector=vector_store_connector,
top_k=top_k,
query_rewrite=query_rewrite,
rerank=rerank,
vector_store_connector=vector_store_connector,
)
def retrieve(self, query: IN) -> Any:
def retrieve(self, query: Any) -> Any:
"""Retrieve the candidates."""
if isinstance(query, str):
return self._retriever.retrieve_with_scores(query, self._score_threshold)
elif isinstance(query, list):