mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-05 06:33:20 +00:00
community[patch]: Update root_validators ChatModels: ChatBaichuan, QianfanChatEndpoint, MiniMaxChat, ChatSparkLLM, ChatZhipuAI (#22853)
This PR updates root validators for: - ChatModels: ChatBaichuan, QianfanChatEndpoint, MiniMaxChat, ChatSparkLLM, ChatZhipuAI Issues #22819 --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
parent
cb6cf4b631
commit
bc4cd9c5cc
@ -89,7 +89,7 @@ class ChatBaichuan(BaseChatModel):
|
||||
|
||||
baichuan_api_base: str = Field(default=DEFAULT_API_BASE)
|
||||
"""Baichuan custom endpoints"""
|
||||
baichuan_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
|
||||
baichuan_api_key: SecretStr = Field(alias="api_key")
|
||||
"""Baichuan API Key"""
|
||||
baichuan_secret_key: Optional[SecretStr] = None
|
||||
"""[DEPRECATED, keeping it for for backward compatibility] Baichuan Secret Key"""
|
||||
@ -142,7 +142,7 @@ class ChatBaichuan(BaseChatModel):
|
||||
values["model_kwargs"] = extra
|
||||
return values
|
||||
|
||||
@root_validator()
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
values["baichuan_api_base"] = get_from_dict_or_env(
|
||||
values,
|
||||
@ -153,11 +153,10 @@ class ChatBaichuan(BaseChatModel):
|
||||
values["baichuan_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(
|
||||
values,
|
||||
"baichuan_api_key",
|
||||
["baichuan_api_key", "api_key"],
|
||||
"BAICHUAN_API_KEY",
|
||||
)
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
@property
|
||||
|
@ -135,7 +135,7 @@ class QianfanChatEndpoint(BaseChatModel):
|
||||
|
||||
client: Any #: :meta private:
|
||||
|
||||
qianfan_ak: Optional[SecretStr] = Field(default=None, alias="api_key")
|
||||
qianfan_ak: SecretStr = Field(alias="api_key")
|
||||
"""Qianfan API KEY"""
|
||||
qianfan_sk: Optional[SecretStr] = Field(default=None, alias="secret_key")
|
||||
"""Qianfan SECRET KEY"""
|
||||
@ -171,35 +171,43 @@ class QianfanChatEndpoint(BaseChatModel):
|
||||
|
||||
allow_population_by_field_name = True
|
||||
|
||||
@root_validator()
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
values["qianfan_ak"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(
|
||||
values,
|
||||
"qianfan_ak",
|
||||
["qianfan_ak", "api_key"],
|
||||
"QIANFAN_AK",
|
||||
default="",
|
||||
)
|
||||
)
|
||||
values["qianfan_sk"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(
|
||||
values,
|
||||
"qianfan_sk",
|
||||
["qianfan_sk", "secret_key"],
|
||||
"QIANFAN_SK",
|
||||
default="",
|
||||
)
|
||||
)
|
||||
|
||||
default_values = {
|
||||
name: field.default
|
||||
for name, field in cls.__fields__.items()
|
||||
if field.default is not None
|
||||
}
|
||||
default_values.update(values)
|
||||
params = {
|
||||
**values.get("init_kwargs", {}),
|
||||
"model": values["model"],
|
||||
"stream": values["streaming"],
|
||||
"model": default_values.get("model"),
|
||||
"stream": default_values.get("streaming"),
|
||||
}
|
||||
if values["qianfan_ak"].get_secret_value() != "":
|
||||
params["ak"] = values["qianfan_ak"].get_secret_value()
|
||||
if values["qianfan_sk"].get_secret_value() != "":
|
||||
params["sk"] = values["qianfan_sk"].get_secret_value()
|
||||
if values["endpoint"] is not None and values["endpoint"] != "":
|
||||
params["endpoint"] = values["endpoint"]
|
||||
if (
|
||||
default_values.get("endpoint") is not None
|
||||
and default_values["endpoint"] != ""
|
||||
):
|
||||
params["endpoint"] = default_values["endpoint"]
|
||||
try:
|
||||
import qianfan
|
||||
|
||||
|
@ -166,7 +166,7 @@ class MiniMaxChat(BaseChatModel):
|
||||
)
|
||||
minimax_group_id: Optional[str] = Field(default=None, alias="group_id")
|
||||
"""[DEPRECATED, keeping it for for backward compatibility] Group Id"""
|
||||
minimax_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
|
||||
minimax_api_key: SecretStr = Field(alias="api_key")
|
||||
"""Minimax API Key"""
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results or not."""
|
||||
@ -176,14 +176,18 @@ class MiniMaxChat(BaseChatModel):
|
||||
|
||||
allow_population_by_field_name = True
|
||||
|
||||
@root_validator(allow_reuse=True)
|
||||
@root_validator(pre=True, allow_reuse=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["minimax_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "minimax_api_key", "MINIMAX_API_KEY")
|
||||
get_from_dict_or_env(
|
||||
values,
|
||||
["minimax_api_key", "api_key"],
|
||||
"MINIMAX_API_KEY",
|
||||
)
|
||||
)
|
||||
values["minimax_group_id"] = get_from_dict_or_env(
|
||||
values, "minimax_group_id", "MINIMAX_GROUP_ID"
|
||||
values, ["minimax_group_id", "group_id"], "MINIMAX_GROUP_ID"
|
||||
)
|
||||
# Get custom api url from environment.
|
||||
values["minimax_api_host"] = get_from_dict_or_env(
|
||||
|
@ -195,21 +195,21 @@ class ChatSparkLLM(BaseChatModel):
|
||||
|
||||
return values
|
||||
|
||||
@root_validator()
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
values["spark_app_id"] = get_from_dict_or_env(
|
||||
values,
|
||||
"spark_app_id",
|
||||
["spark_app_id", "app_id"],
|
||||
"IFLYTEK_SPARK_APP_ID",
|
||||
)
|
||||
values["spark_api_key"] = get_from_dict_or_env(
|
||||
values,
|
||||
"spark_api_key",
|
||||
["spark_api_key", "api_key"],
|
||||
"IFLYTEK_SPARK_API_KEY",
|
||||
)
|
||||
values["spark_api_secret"] = get_from_dict_or_env(
|
||||
values,
|
||||
"spark_api_secret",
|
||||
["spark_api_secret", "api_secret"],
|
||||
"IFLYTEK_SPARK_API_SECRET",
|
||||
)
|
||||
values["spark_api_url"] = get_from_dict_or_env(
|
||||
@ -224,9 +224,15 @@ class ChatSparkLLM(BaseChatModel):
|
||||
"IFLYTEK_SPARK_LLM_DOMAIN",
|
||||
SPARK_LLM_DOMAIN,
|
||||
)
|
||||
|
||||
# put extra params into model_kwargs
|
||||
values["model_kwargs"]["temperature"] = values["temperature"] or cls.temperature
|
||||
values["model_kwargs"]["top_k"] = values["top_k"] or cls.top_k
|
||||
default_values = {
|
||||
name: field.default
|
||||
for name, field in cls.__fields__.items()
|
||||
if field.default is not None
|
||||
}
|
||||
values["model_kwargs"]["temperature"] = default_values.get("temperature")
|
||||
values["model_kwargs"]["top_k"] = default_values.get("top_k")
|
||||
|
||||
values["client"] = _SparkLLMClient(
|
||||
app_id=values["spark_app_id"],
|
||||
|
@ -377,10 +377,10 @@ class ChatZhipuAI(BaseChatModel):
|
||||
|
||||
allow_population_by_field_name = True
|
||||
|
||||
@root_validator()
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
values["zhipuai_api_key"] = get_from_dict_or_env(
|
||||
values, "zhipuai_api_key", "ZHIPUAI_API_KEY"
|
||||
values, ["zhipuai_api_key", "api_key"], "ZHIPUAI_API_KEY"
|
||||
)
|
||||
values["zhipuai_api_base"] = get_from_dict_or_env(
|
||||
values, "zhipuai_api_base", "ZHIPUAI_API_BASE", default=ZHIPUAI_API_BASE
|
||||
|
@ -7,7 +7,7 @@ from langchain_community.chat_models.baichuan import ChatBaichuan
|
||||
|
||||
|
||||
def test_chat_baichuan_default() -> None:
|
||||
chat = ChatBaichuan(streaming=True)
|
||||
chat = ChatBaichuan(streaming=True) # type: ignore[call-arg]
|
||||
message = HumanMessage(content="请完整背诵将进酒,背诵5遍")
|
||||
response = chat.invoke([message])
|
||||
assert isinstance(response, AIMessage)
|
||||
@ -15,7 +15,7 @@ def test_chat_baichuan_default() -> None:
|
||||
|
||||
|
||||
def test_chat_baichuan_default_non_streaming() -> None:
|
||||
chat = ChatBaichuan()
|
||||
chat = ChatBaichuan() # type: ignore[call-arg]
|
||||
message = HumanMessage(content="请完整背诵将进酒,背诵5遍")
|
||||
response = chat.invoke([message])
|
||||
assert isinstance(response, AIMessage)
|
||||
@ -39,7 +39,7 @@ def test_chat_baichuan_turbo_non_streaming() -> None:
|
||||
|
||||
|
||||
def test_chat_baichuan_with_temperature() -> None:
|
||||
chat = ChatBaichuan(temperature=1.0)
|
||||
chat = ChatBaichuan(temperature=1.0) # type: ignore[call-arg]
|
||||
message = HumanMessage(content="Hello")
|
||||
response = chat.invoke([message])
|
||||
assert isinstance(response, AIMessage)
|
||||
@ -47,7 +47,7 @@ def test_chat_baichuan_with_temperature() -> None:
|
||||
|
||||
|
||||
def test_chat_baichuan_with_kwargs() -> None:
|
||||
chat = ChatBaichuan()
|
||||
chat = ChatBaichuan() # type: ignore[call-arg]
|
||||
message = HumanMessage(content="百川192K API是什么时候上线的?")
|
||||
response = chat.invoke(
|
||||
[message], temperature=0.88, top_p=0.7, with_search_enhance=True
|
||||
@ -58,7 +58,7 @@ def test_chat_baichuan_with_kwargs() -> None:
|
||||
|
||||
|
||||
def test_extra_kwargs() -> None:
|
||||
chat = ChatBaichuan(temperature=0.88, top_p=0.7, with_search_enhance=True)
|
||||
chat = ChatBaichuan(temperature=0.88, top_p=0.7, with_search_enhance=True) # type: ignore[call-arg]
|
||||
assert chat.temperature == 0.88
|
||||
assert chat.top_p == 0.7
|
||||
assert chat.with_search_enhance is True
|
||||
|
@ -107,7 +107,7 @@ def test_baichuan_key_masked_when_passed_from_env(
|
||||
"""Test initialization with an API key provided via an env variable"""
|
||||
monkeypatch.setenv("BAICHUAN_API_KEY", "test-api-key")
|
||||
|
||||
chat = ChatBaichuan()
|
||||
chat = ChatBaichuan() # type: ignore[call-arg]
|
||||
print(chat.baichuan_api_key, end="") # noqa: T201
|
||||
captured = capsys.readouterr()
|
||||
assert captured.out == "**********"
|
||||
|
Loading…
Reference in New Issue
Block a user