From d72a08a60d70b271e6b37d5d82bffd174b2b5265 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Thu, 15 Aug 2024 14:46:52 -0400 Subject: [PATCH] groq[patch]: Update root validators for pydantic 2 migration (#25402) --- .../groq/langchain_groq/chat_models.py | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/libs/partners/groq/langchain_groq/chat_models.py b/libs/partners/groq/langchain_groq/chat_models.py index fe506407172..23aeae9f11b 100644 --- a/libs/partners/groq/langchain_groq/chat_models.py +++ b/libs/partners/groq/langchain_groq/chat_models.py @@ -3,7 +3,6 @@ from __future__ import annotations import json -import os import warnings from operator import itemgetter from typing import ( @@ -75,9 +74,9 @@ from langchain_core.pydantic_v1 import ( from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough from langchain_core.tools import BaseTool from langchain_core.utils import ( - convert_to_secret_str, - get_from_dict_or_env, + from_env, get_pydantic_field_names, + secret_from_env, ) from langchain_core.utils.function_calling import ( convert_to_openai_function, @@ -308,13 +307,19 @@ class ChatGroq(BaseChatModel): """Default stop sequences.""" model_kwargs: Dict[str, Any] = Field(default_factory=dict) """Holds any model parameters valid for `create` call not explicitly specified.""" - groq_api_key: Optional[SecretStr] = Field(default=None, alias="api_key") + groq_api_key: Optional[SecretStr] = Field( + alias="api_key", default_factory=secret_from_env("GROQ_API_KEY", default=None) + ) """Automatically inferred from env var `GROQ_API_KEY` if not provided.""" - groq_api_base: Optional[str] = Field(default=None, alias="base_url") + groq_api_base: Optional[str] = Field( + alias="base_url", default_factory=from_env("GROQ_API_BASE", default=None) + ) """Base URL path for API requests, leave blank if not using a proxy or service emulator.""" # to support explicit proxy for Groq - groq_proxy: Optional[str] = None + groq_proxy: Optional[str] = Field( + default_factory=from_env("GROQ_PROXY", default=None) + ) request_timeout: Union[float, Tuple[float, float], Any, None] = Field( default=None, alias="timeout" ) @@ -369,25 +374,20 @@ class ChatGroq(BaseChatModel): values["model_kwargs"] = extra return values - @root_validator() + @root_validator(pre=False, skip_on_failure=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" if values["n"] < 1: raise ValueError("n must be at least 1.") if values["n"] > 1 and values["streaming"]: raise ValueError("n must be 1 when streaming.") - if values["temperature"] == 0: values["temperature"] = 1e-8 - values["groq_api_key"] = convert_to_secret_str( - get_from_dict_or_env(values, "groq_api_key", "GROQ_API_KEY") - ) - values["groq_api_base"] = values["groq_api_base"] or os.getenv("GROQ_API_BASE") - values["groq_proxy"] = values["groq_proxy"] = os.getenv("GROQ_PROXY") - client_params = { - "api_key": values["groq_api_key"].get_secret_value(), + "api_key": values["groq_api_key"].get_secret_value() + if values["groq_api_key"] + else None, "base_url": values["groq_api_base"], "timeout": values["request_timeout"], "max_retries": values["max_retries"],