Compare commits

...

7 Commits

Author SHA1 Message Date
Eugene Yurtsev
721b6e14ed x 2024-07-02 10:26:51 -04:00
Eugene Yurtsev
e108810d5c x 2024-07-02 10:13:27 -04:00
Eugene Yurtsev
b71aefc1e4 lint 2024-07-02 09:28:32 -04:00
Eugene Yurtsev
76a6b43616 Update @root_validator 2024-07-02 09:15:42 -04:00
Eugene Yurtsev
88ca0316f9 Update @root_validator 2024-07-02 09:14:50 -04:00
Eugene Yurtsev
fee6cc33a3 Update @root_validator 2024-07-02 09:13:01 -04:00
Eugene Yurtsev
59fd10c9f1 Update @root_validators 2024-07-01 16:12:19 -04:00
9 changed files with 40 additions and 27 deletions

View File

@@ -217,7 +217,7 @@ class JinaChat(BaseChatModel):
values["model_kwargs"] = extra
return values
@root_validator()
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["jinachat_api_key"] = convert_to_secret_str(

View File

@@ -341,7 +341,7 @@ class ChatKinetica(BaseChatModel):
kdbc: Any = Field(exclude=True)
""" Kinetica DB connection. """
@root_validator()
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Pydantic object validator."""

View File

@@ -84,8 +84,8 @@ class ChatKonko(ChatOpenAI):
max_tokens: int = 20
"""Maximum number of tokens to generate."""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
@root_validator(pre=True)
def pre_init(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["konko_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "konko_api_key", "KONKO_API_KEY")
@@ -115,7 +115,10 @@ class ChatKonko(ChatOpenAI):
"You are using an older version of the 'konko' package. "
"Please consider upgrading to access new features."
)
return values
@root_validator(pre=False, skip_on_failure=True)
def validate_n(cls, values: Dict) -> Dict:
if values["n"] < 1:
raise ValueError("n must be at least 1.")
if values["n"] > 1 and values["streaming"]:

View File

@@ -239,7 +239,7 @@ class ChatLiteLLM(BaseChatModel):
return _completion_with_retry(**kwargs)
@root_validator()
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate api key, python package exists, temperature, top_p, and top_k."""
try:
@@ -275,6 +275,11 @@ class ChatLiteLLM(BaseChatModel):
values, "together_ai_api_key", "TOGETHERAI_API_KEY", default=""
)
values["client"] = litellm
return values
@root_validator(pre=False, skip_on_failure=True)
def post_init(cls, values: Dict) -> Dict:
"""Validate api key, python package exists, temperature, top_p, and top_k."""
if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
raise ValueError("temperature must be in the range [0.0, 1.0]")

View File

@@ -28,7 +28,7 @@ class MoonshotChat(MoonshotCommon, ChatOpenAI): # type: ignore[misc]
moonshot = MoonshotChat(model="moonshot-v1-8k")
"""
@root_validator()
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the environment is set up correctly."""
values["moonshot_api_key"] = convert_to_secret_str(

View File

@@ -47,7 +47,7 @@ class ChatOctoAI(ChatOpenAI):
def is_lc_serializable(cls) -> bool:
return False
@root_validator()
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["octoai_api_base"] = get_from_dict_or_env(

View File

@@ -145,6 +145,9 @@ def _convert_delta_to_message_chunk(
return default_class(content=content) # type: ignore[call-arg]
DEFAULT_MAX_RETRIES = 2
@deprecated(
since="0.0.10", removal="0.3.0", alternative_import="langchain_openai.ChatOpenAI"
)
@@ -218,7 +221,7 @@ class ChatOpenAI(BaseChatModel):
)
"""Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or
None."""
max_retries: int = Field(default=2)
max_retries: int = Field(default=DEFAULT_MAX_RETRIES)
"""Maximum number of retries to make when generating."""
streaming: bool = False
"""Whether to stream the results or not."""
@@ -274,14 +277,9 @@ class ChatOpenAI(BaseChatModel):
values["model_kwargs"] = extra
return values
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
@root_validator(pre=True)
def pre_init(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
if values["n"] < 1:
raise ValueError("n must be at least 1.")
if values["n"] > 1 and values["streaming"]:
raise ValueError("n must be 1 when streaming.")
values["openai_api_key"] = get_from_dict_or_env(
values, "openai_api_key", "OPENAI_API_KEY"
)
@@ -291,7 +289,7 @@ class ChatOpenAI(BaseChatModel):
or os.getenv("OPENAI_ORG_ID")
or os.getenv("OPENAI_ORGANIZATION")
)
values["openai_api_base"] = values["openai_api_base"] or os.getenv(
values["openai_api_base"] = values.get("openai_api_base") or os.getenv(
"OPENAI_API_BASE"
)
values["openai_proxy"] = get_from_dict_or_env(
@@ -311,14 +309,14 @@ class ChatOpenAI(BaseChatModel):
if is_openai_v1():
client_params = {
"api_key": values["openai_api_key"],
"organization": values["openai_organization"],
"base_url": values["openai_api_base"],
"timeout": values["request_timeout"],
"max_retries": values["max_retries"],
"default_headers": values["default_headers"],
"default_query": values["default_query"],
"http_client": values["http_client"],
"api_key": values.get("openai_api_key"),
"organization": values.get("openai_organization"),
"base_url": values.get("openai_api_base"),
"timeout": values.get("request_timeout"),
"max_retries": values.get("max_retries", DEFAULT_MAX_RETRIES),
"default_headers": values.get("default_headers"),
"default_query": values.get("default_query"),
"http_client": values.get("http_client"),
}
if not values.get("client"):
@@ -333,6 +331,14 @@ class ChatOpenAI(BaseChatModel):
pass
return values
@root_validator(pre=True, skip_on_failure=True)
def validate_environment(self, values: Dict) -> Dict:
if values["n"] < 1:
raise ValueError("n must be at least 1.")
if values["n"] > 1 and values["streaming"]:
raise ValueError("n must be 1 when streaming.")
return values
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OpenAI API."""

View File

@@ -169,7 +169,7 @@ class ChatYuan2(BaseChatModel):
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["yuan2_api_key"] = get_from_dict_or_env(
values, "yuan2_api_key", "YUAN2_API_KEY"
values, "yuan2_api_key", "YUAN2_API_KEY", default="EMPTY"
)
try:

View File

@@ -96,8 +96,7 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
# See: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#embeddings
values["chunk_size"] = min(values["chunk_size"], 16)
try:
import openai
import openai # noqa: F401
except ImportError:
raise ImportError(
"Could not import openai python package. "