groq[patch]: Update root validators for pydantic 2 migration (#25402)

This commit is contained in:
Eugene Yurtsev 2024-08-15 14:46:52 -04:00 committed by GitHub
parent 8eb63a609e
commit d72a08a60d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3,7 +3,6 @@
from __future__ import annotations from __future__ import annotations
import json import json
import os
import warnings import warnings
from operator import itemgetter from operator import itemgetter
from typing import ( from typing import (
@ -75,9 +74,9 @@ from langchain_core.pydantic_v1 import (
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langchain_core.utils import ( from langchain_core.utils import (
convert_to_secret_str, from_env,
get_from_dict_or_env,
get_pydantic_field_names, get_pydantic_field_names,
secret_from_env,
) )
from langchain_core.utils.function_calling import ( from langchain_core.utils.function_calling import (
convert_to_openai_function, convert_to_openai_function,
@ -308,13 +307,19 @@ class ChatGroq(BaseChatModel):
"""Default stop sequences.""" """Default stop sequences."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict) model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified.""" """Holds any model parameters valid for `create` call not explicitly specified."""
groq_api_key: Optional[SecretStr] = Field(default=None, alias="api_key") groq_api_key: Optional[SecretStr] = Field(
alias="api_key", default_factory=secret_from_env("GROQ_API_KEY", default=None)
)
"""Automatically inferred from env var `GROQ_API_KEY` if not provided.""" """Automatically inferred from env var `GROQ_API_KEY` if not provided."""
groq_api_base: Optional[str] = Field(default=None, alias="base_url") groq_api_base: Optional[str] = Field(
alias="base_url", default_factory=from_env("GROQ_API_BASE", default=None)
)
"""Base URL path for API requests, leave blank if not using a proxy or service """Base URL path for API requests, leave blank if not using a proxy or service
emulator.""" emulator."""
# to support explicit proxy for Groq # to support explicit proxy for Groq
groq_proxy: Optional[str] = None groq_proxy: Optional[str] = Field(
default_factory=from_env("GROQ_PROXY", default=None)
)
request_timeout: Union[float, Tuple[float, float], Any, None] = Field( request_timeout: Union[float, Tuple[float, float], Any, None] = Field(
default=None, alias="timeout" default=None, alias="timeout"
) )
@ -369,25 +374,20 @@ class ChatGroq(BaseChatModel):
values["model_kwargs"] = extra values["model_kwargs"] = extra
return values return values
@root_validator() @root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
if values["n"] < 1: if values["n"] < 1:
raise ValueError("n must be at least 1.") raise ValueError("n must be at least 1.")
if values["n"] > 1 and values["streaming"]: if values["n"] > 1 and values["streaming"]:
raise ValueError("n must be 1 when streaming.") raise ValueError("n must be 1 when streaming.")
if values["temperature"] == 0: if values["temperature"] == 0:
values["temperature"] = 1e-8 values["temperature"] = 1e-8
values["groq_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "groq_api_key", "GROQ_API_KEY")
)
values["groq_api_base"] = values["groq_api_base"] or os.getenv("GROQ_API_BASE")
values["groq_proxy"] = values["groq_proxy"] = os.getenv("GROQ_PROXY")
client_params = { client_params = {
"api_key": values["groq_api_key"].get_secret_value(), "api_key": values["groq_api_key"].get_secret_value()
if values["groq_api_key"]
else None,
"base_url": values["groq_api_base"], "base_url": values["groq_api_base"],
"timeout": values["request_timeout"], "timeout": values["request_timeout"],
"max_retries": values["max_retries"], "max_retries": values["max_retries"],