chore: Fix pylint error (#1915)

This commit is contained in:
Fangyin Cheng
2024-08-29 16:37:31 +08:00
committed by GitHub
parent 51b4327fdf
commit f72db23bab
6 changed files with 25 additions and 15 deletions

View File

@@ -122,3 +122,6 @@ ignore_missing_imports = True
[mypy-pypdf.*]
ignore_missing_imports = True
[mypy-qianfan.*]
ignore_missing_imports = True

View File

@@ -60,6 +60,7 @@ class EmbeddingLoader:
return TongYiEmbeddings(**tongyi_param)
elif model_name in ["proxy_qianfan"]:
from dbgpt.rag.embedding import QianFanEmbeddings
proxy_param = cast(ProxyEmbeddingParameters, param)
qianfan_param = {"api_key": proxy_param.proxy_api_key}
if proxy_param.proxy_backend:

View File

@@ -563,7 +563,7 @@ class ProxyEmbeddingParameters(BaseEmbeddingModelParameters):
metadata={
"tags": "privacy",
"help": "The api secret of the current embedding model(OPENAI_API_SECRET)",
}
},
)
proxy_api_version: Optional[str] = field(
default=None,

View File

@@ -14,8 +14,8 @@ from .embeddings import ( # noqa: F401
JinaEmbeddings,
OllamaEmbeddings,
OpenAPIEmbeddings,
TongYiEmbeddings,
QianFanEmbeddings,
TongYiEmbeddings,
)
from .rerank import CrossEncoderRerankEmbeddings, OpenAPIRerankEmbeddings # noqa: F401
@@ -34,5 +34,5 @@ __ALL__ = [
"TongYiEmbeddings",
"CrossEncoderRerankEmbeddings",
"OpenAPIRerankEmbeddings",
"QianFanEmbeddings"
"QianFanEmbeddings",
]

View File

@@ -926,16 +926,18 @@ class TongYiEmbeddings(BaseModel, Embeddings):
class QianFanEmbeddings(BaseModel, Embeddings):
"""Baidu Qianfan Embeddings embedding models.
Embed:
.. code-block:: python
# embed the documents
vectors = embeddings.embed_documents([text1, text2, ...])
Embed:
.. code-block:: python
# embed the query
vectors = embeddings.embed_query(text)
# embed the documents
vectors = embeddings.embed_documents([text1, text2, ...])
# embed the query
vectors = embeddings.embed_query(text)
""" # noqa: E501
""" # noqa: E501
client: Any
chunk_size: int = 16
endpoint: str = ""
@@ -950,7 +952,7 @@ class QianFanEmbeddings(BaseModel, Embeddings):
"""Model name
you could get from https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Nlks5zkzu
for now, we support Embedding-V1 and
for now, we support Embedding-V1 and
- Embedding-V1 (默认模型)
- bge-large-en
- bge-large-zh
@@ -962,7 +964,7 @@ class QianFanEmbeddings(BaseModel, Embeddings):
default="text-embedding-v1", description="The name of the model to use."
)
init_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""init kwargs for qianfan client init, such as `query_per_second` which is
"""init kwargs for qianfan client init, such as `query_per_second` which is
associated with qianfan resource object to limit QPS"""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
@@ -983,7 +985,10 @@ class QianFanEmbeddings(BaseModel, Embeddings):
model_name = kwargs.get("model_name")
if not qianfan_ak or not qianfan_sk or not model_name:
raise ValueError("API key, API secret, and model name are required to initialize QianFanEmbeddings.")
raise ValueError(
"API key, API secret, and model name are required to initialize "
"QianFanEmbeddings."
)
params = {
"model": model_name,
@@ -996,6 +1001,7 @@ class QianFanEmbeddings(BaseModel, Embeddings):
super().__init__(**kwargs)
def embed_query(self, text: str) -> List[float]:
"""Compute query embeddings using a QianFan embedding model."""
resp = self.embed_documents([text])
return resp[0]
@@ -1011,7 +1017,7 @@ class QianFanEmbeddings(BaseModel, Embeddings):
Each embedding is represented as a list of float values.
"""
text_in_chunks = [
texts[i: i + self.chunk_size]
texts[i : i + self.chunk_size]
for i in range(0, len(texts), self.chunk_size)
]
lst = []

View File

@@ -687,7 +687,7 @@ def default_requires():
"chardet",
"sentencepiece",
"ollama",
"qianfan"
"qianfan",
]
setup_spec.extras["default"] += setup_spec.extras["framework"]
setup_spec.extras["default"] += setup_spec.extras["rag"]