chore: Fix pylint error (#1915)

This commit is contained in:
Fangyin Cheng
2024-08-29 16:37:31 +08:00
committed by GitHub
parent 51b4327fdf
commit f72db23bab
6 changed files with 25 additions and 15 deletions

View File

@@ -122,3 +122,6 @@ ignore_missing_imports = True
[mypy-pypdf.*] [mypy-pypdf.*]
ignore_missing_imports = True ignore_missing_imports = True
[mypy-qianfan.*]
ignore_missing_imports = True

View File

@@ -60,6 +60,7 @@ class EmbeddingLoader:
return TongYiEmbeddings(**tongyi_param) return TongYiEmbeddings(**tongyi_param)
elif model_name in ["proxy_qianfan"]: elif model_name in ["proxy_qianfan"]:
from dbgpt.rag.embedding import QianFanEmbeddings from dbgpt.rag.embedding import QianFanEmbeddings
proxy_param = cast(ProxyEmbeddingParameters, param) proxy_param = cast(ProxyEmbeddingParameters, param)
qianfan_param = {"api_key": proxy_param.proxy_api_key} qianfan_param = {"api_key": proxy_param.proxy_api_key}
if proxy_param.proxy_backend: if proxy_param.proxy_backend:

View File

@@ -563,7 +563,7 @@ class ProxyEmbeddingParameters(BaseEmbeddingModelParameters):
metadata={ metadata={
"tags": "privacy", "tags": "privacy",
"help": "The api secret of the current embedding model(OPENAI_API_SECRET)", "help": "The api secret of the current embedding model(OPENAI_API_SECRET)",
} },
) )
proxy_api_version: Optional[str] = field( proxy_api_version: Optional[str] = field(
default=None, default=None,

View File

@@ -14,8 +14,8 @@ from .embeddings import ( # noqa: F401
JinaEmbeddings, JinaEmbeddings,
OllamaEmbeddings, OllamaEmbeddings,
OpenAPIEmbeddings, OpenAPIEmbeddings,
TongYiEmbeddings,
QianFanEmbeddings, QianFanEmbeddings,
TongYiEmbeddings,
) )
from .rerank import CrossEncoderRerankEmbeddings, OpenAPIRerankEmbeddings # noqa: F401 from .rerank import CrossEncoderRerankEmbeddings, OpenAPIRerankEmbeddings # noqa: F401
@@ -34,5 +34,5 @@ __ALL__ = [
"TongYiEmbeddings", "TongYiEmbeddings",
"CrossEncoderRerankEmbeddings", "CrossEncoderRerankEmbeddings",
"OpenAPIRerankEmbeddings", "OpenAPIRerankEmbeddings",
"QianFanEmbeddings" "QianFanEmbeddings",
] ]

View File

@@ -926,6 +926,7 @@ class TongYiEmbeddings(BaseModel, Embeddings):
class QianFanEmbeddings(BaseModel, Embeddings): class QianFanEmbeddings(BaseModel, Embeddings):
"""Baidu Qianfan Embeddings embedding models. """Baidu Qianfan Embeddings embedding models.
Embed: Embed:
.. code-block:: python .. code-block:: python
@@ -936,6 +937,7 @@ class QianFanEmbeddings(BaseModel, Embeddings):
vectors = embeddings.embed_query(text) vectors = embeddings.embed_query(text)
""" # noqa: E501 """ # noqa: E501
client: Any client: Any
chunk_size: int = 16 chunk_size: int = 16
endpoint: str = "" endpoint: str = ""
@@ -983,7 +985,10 @@ class QianFanEmbeddings(BaseModel, Embeddings):
model_name = kwargs.get("model_name") model_name = kwargs.get("model_name")
if not qianfan_ak or not qianfan_sk or not model_name: if not qianfan_ak or not qianfan_sk or not model_name:
raise ValueError("API key, API secret, and model name are required to initialize QianFanEmbeddings.") raise ValueError(
"API key, API secret, and model name are required to initialize "
"QianFanEmbeddings."
)
params = { params = {
"model": model_name, "model": model_name,
@@ -996,6 +1001,7 @@ class QianFanEmbeddings(BaseModel, Embeddings):
super().__init__(**kwargs) super().__init__(**kwargs)
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> List[float]:
"""Compute query embeddings using a QianFan embedding model."""
resp = self.embed_documents([text]) resp = self.embed_documents([text])
return resp[0] return resp[0]

View File

@@ -687,7 +687,7 @@ def default_requires():
"chardet", "chardet",
"sentencepiece", "sentencepiece",
"ollama", "ollama",
"qianfan" "qianfan",
] ]
setup_spec.extras["default"] += setup_spec.extras["framework"] setup_spec.extras["default"] += setup_spec.extras["framework"]
setup_spec.extras["default"] += setup_spec.extras["rag"] setup_spec.extras["default"] += setup_spec.extras["rag"]