mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-03 01:54:44 +00:00
chore: Fix pylint error (#1915)
This commit is contained in:
@@ -122,3 +122,6 @@ ignore_missing_imports = True
|
||||
|
||||
[mypy-pypdf.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-qianfan.*]
|
||||
ignore_missing_imports = True
|
@@ -60,6 +60,7 @@ class EmbeddingLoader:
|
||||
return TongYiEmbeddings(**tongyi_param)
|
||||
elif model_name in ["proxy_qianfan"]:
|
||||
from dbgpt.rag.embedding import QianFanEmbeddings
|
||||
|
||||
proxy_param = cast(ProxyEmbeddingParameters, param)
|
||||
qianfan_param = {"api_key": proxy_param.proxy_api_key}
|
||||
if proxy_param.proxy_backend:
|
||||
|
@@ -563,7 +563,7 @@ class ProxyEmbeddingParameters(BaseEmbeddingModelParameters):
|
||||
metadata={
|
||||
"tags": "privacy",
|
||||
"help": "The api secret of the current embedding model(OPENAI_API_SECRET)",
|
||||
}
|
||||
},
|
||||
)
|
||||
proxy_api_version: Optional[str] = field(
|
||||
default=None,
|
||||
|
@@ -14,8 +14,8 @@ from .embeddings import ( # noqa: F401
|
||||
JinaEmbeddings,
|
||||
OllamaEmbeddings,
|
||||
OpenAPIEmbeddings,
|
||||
TongYiEmbeddings,
|
||||
QianFanEmbeddings,
|
||||
TongYiEmbeddings,
|
||||
)
|
||||
from .rerank import CrossEncoderRerankEmbeddings, OpenAPIRerankEmbeddings # noqa: F401
|
||||
|
||||
@@ -34,5 +34,5 @@ __ALL__ = [
|
||||
"TongYiEmbeddings",
|
||||
"CrossEncoderRerankEmbeddings",
|
||||
"OpenAPIRerankEmbeddings",
|
||||
"QianFanEmbeddings"
|
||||
"QianFanEmbeddings",
|
||||
]
|
||||
|
@@ -926,16 +926,18 @@ class TongYiEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
class QianFanEmbeddings(BaseModel, Embeddings):
|
||||
"""Baidu Qianfan Embeddings embedding models.
|
||||
Embed:
|
||||
.. code-block:: python
|
||||
|
||||
# embed the documents
|
||||
vectors = embeddings.embed_documents([text1, text2, ...])
|
||||
Embed:
|
||||
.. code-block:: python
|
||||
|
||||
# embed the query
|
||||
vectors = embeddings.embed_query(text)
|
||||
# embed the documents
|
||||
vectors = embeddings.embed_documents([text1, text2, ...])
|
||||
|
||||
# embed the query
|
||||
vectors = embeddings.embed_query(text)
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
""" # noqa: E501
|
||||
client: Any
|
||||
chunk_size: int = 16
|
||||
endpoint: str = ""
|
||||
@@ -950,7 +952,7 @@ class QianFanEmbeddings(BaseModel, Embeddings):
|
||||
"""Model name
|
||||
you could get from https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Nlks5zkzu
|
||||
|
||||
for now, we support Embedding-V1 and
|
||||
for now, we support Embedding-V1 and
|
||||
- Embedding-V1 (默认模型)
|
||||
- bge-large-en
|
||||
- bge-large-zh
|
||||
@@ -962,7 +964,7 @@ class QianFanEmbeddings(BaseModel, Embeddings):
|
||||
default="text-embedding-v1", description="The name of the model to use."
|
||||
)
|
||||
init_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""init kwargs for qianfan client init, such as `query_per_second` which is
|
||||
"""init kwargs for qianfan client init, such as `query_per_second` which is
|
||||
associated with qianfan resource object to limit QPS"""
|
||||
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
@@ -983,7 +985,10 @@ class QianFanEmbeddings(BaseModel, Embeddings):
|
||||
model_name = kwargs.get("model_name")
|
||||
|
||||
if not qianfan_ak or not qianfan_sk or not model_name:
|
||||
raise ValueError("API key, API secret, and model name are required to initialize QianFanEmbeddings.")
|
||||
raise ValueError(
|
||||
"API key, API secret, and model name are required to initialize "
|
||||
"QianFanEmbeddings."
|
||||
)
|
||||
|
||||
params = {
|
||||
"model": model_name,
|
||||
@@ -996,6 +1001,7 @@ class QianFanEmbeddings(BaseModel, Embeddings):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Compute query embeddings using a QianFan embedding model."""
|
||||
resp = self.embed_documents([text])
|
||||
return resp[0]
|
||||
|
||||
@@ -1011,7 +1017,7 @@ class QianFanEmbeddings(BaseModel, Embeddings):
|
||||
Each embedding is represented as a list of float values.
|
||||
"""
|
||||
text_in_chunks = [
|
||||
texts[i: i + self.chunk_size]
|
||||
texts[i : i + self.chunk_size]
|
||||
for i in range(0, len(texts), self.chunk_size)
|
||||
]
|
||||
lst = []
|
||||
|
Reference in New Issue
Block a user