From b0c48dc9832d0fabe217d67516fd838502c23aa6 Mon Sep 17 00:00:00 2001 From: Liu Jun Date: Wed, 20 Dec 2023 13:49:33 +0800 Subject: [PATCH] community[patch]: make ak and sk optional in qianfan endpoint (#14835) - **Description:** The Qianfan SDK offers multiple authentication methods, but in the `QianfanEndpoint` of Langchain, it currently only supports authentication through AK and SK. In order to accommodate users who wish to use alternative authentication methods, this pull request makes AK and SK optional. This change should not impact existing users, while allowing users to configure other authentication methods as per the Qianfan SDK documentation. - **Issue:** / - **Dependencies:** No - **Tag maintainer:** No - **Twitter handle:** --- .../chat_models/baidu_qianfan_endpoint.py | 8 +++-- .../embeddings/baidu_qianfan_endpoint.py | 30 ++++++++++++------- .../llms/baidu_qianfan_endpoint.py | 30 ++++++++++++------- 3 files changed, 44 insertions(+), 24 deletions(-) diff --git a/libs/community/langchain_community/chat_models/baidu_qianfan_endpoint.py b/libs/community/langchain_community/chat_models/baidu_qianfan_endpoint.py index 81e2a544a47..4b617a10f66 100644 --- a/libs/community/langchain_community/chat_models/baidu_qianfan_endpoint.py +++ b/libs/community/langchain_community/chat_models/baidu_qianfan_endpoint.py @@ -122,6 +122,7 @@ class QianfanChatEndpoint(BaseChatModel): values, "qianfan_ak", "QIANFAN_AK", + default="", ) ) values["qianfan_sk"] = convert_to_secret_str( @@ -129,14 +130,17 @@ class QianfanChatEndpoint(BaseChatModel): values, "qianfan_sk", "QIANFAN_SK", + default="", ) ) params = { - "ak": values["qianfan_ak"].get_secret_value(), - "sk": values["qianfan_sk"].get_secret_value(), "model": values["model"], "stream": values["streaming"], } + if values["qianfan_ak"].get_secret_value() != "": + params["ak"] = values["qianfan_ak"].get_secret_value() + if values["qianfan_sk"].get_secret_value() != "": + params["sk"] = values["qianfan_sk"].get_secret_value() if values["endpoint"] is not None and values["endpoint"] != "": params["endpoint"] = values["endpoint"] try: diff --git a/libs/community/langchain_community/embeddings/baidu_qianfan_endpoint.py b/libs/community/langchain_community/embeddings/baidu_qianfan_endpoint.py index 01c440ab251..2920447c23b 100644 --- a/libs/community/langchain_community/embeddings/baidu_qianfan_endpoint.py +++ b/libs/community/langchain_community/embeddings/baidu_qianfan_endpoint.py @@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional from langchain_core.embeddings import Embeddings from langchain_core.pydantic_v1 import BaseModel, root_validator -from langchain_core.utils import get_from_dict_or_env +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env logger = logging.getLogger(__name__) @@ -67,25 +67,33 @@ class QianfanEmbeddingsEndpoint(BaseModel, Embeddings): ValueError: qianfan package not found, please install it with `pip install qianfan` """ - values["qianfan_ak"] = get_from_dict_or_env( - values, - "qianfan_ak", - "QIANFAN_AK", + values["qianfan_ak"] = convert_to_secret_str( + get_from_dict_or_env( + values, + "qianfan_ak", + "QIANFAN_AK", + default="", + ) ) - values["qianfan_sk"] = get_from_dict_or_env( - values, - "qianfan_sk", - "QIANFAN_SK", + values["qianfan_sk"] = convert_to_secret_str( + get_from_dict_or_env( + values, + "qianfan_sk", + "QIANFAN_SK", + default="", + ) ) try: import qianfan params = { - "ak": values["qianfan_ak"], - "sk": values["qianfan_sk"], "model": values["model"], } + if values["qianfan_ak"].get_secret_value() != "": + params["ak"] = values["qianfan_ak"].get_secret_value() + if values["qianfan_sk"].get_secret_value() != "": + params["sk"] = values["qianfan_sk"].get_secret_value() if values["endpoint"] is not None and values["endpoint"] != "": params["endpoint"] = values["endpoint"] values["client"] = qianfan.Embedding(**params) diff --git a/libs/community/langchain_community/llms/baidu_qianfan_endpoint.py b/libs/community/langchain_community/llms/baidu_qianfan_endpoint.py index 09d765de9d4..36607f1b1e4 100644 --- a/libs/community/langchain_community/llms/baidu_qianfan_endpoint.py +++ b/libs/community/langchain_community/llms/baidu_qianfan_endpoint.py @@ -17,7 +17,7 @@ from langchain_core.callbacks import ( from langchain_core.language_models.llms import LLM from langchain_core.outputs import GenerationChunk from langchain_core.pydantic_v1 import Field, root_validator -from langchain_core.utils import get_from_dict_or_env +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env logger = logging.getLogger(__name__) @@ -73,22 +73,30 @@ class QianfanLLMEndpoint(LLM): @root_validator() def validate_environment(cls, values: Dict) -> Dict: - values["qianfan_ak"] = get_from_dict_or_env( - values, - "qianfan_ak", - "QIANFAN_AK", + values["qianfan_ak"] = convert_to_secret_str( + get_from_dict_or_env( + values, + "qianfan_ak", + "QIANFAN_AK", + default="", + ) ) - values["qianfan_sk"] = get_from_dict_or_env( - values, - "qianfan_sk", - "QIANFAN_SK", + values["qianfan_sk"] = convert_to_secret_str( + get_from_dict_or_env( + values, + "qianfan_sk", + "QIANFAN_SK", + default="", + ) ) params = { - "ak": values["qianfan_ak"], - "sk": values["qianfan_sk"], "model": values["model"], } + if values["qianfan_ak"].get_secret_value() != "": + params["ak"] = values["qianfan_ak"].get_secret_value() + if values["qianfan_sk"].get_secret_value() != "": + params["sk"] = values["qianfan_sk"].get_secret_value() if values["endpoint"] is not None and values["endpoint"] != "": params["endpoint"] = values["endpoint"] try: