together[patch]: Update @root_validator for pydantic 2 compatibility (#25423)

This PR updates usage of @root_validator to be compatible with pydantic 2.
This commit is contained in:
Eugene Yurtsev 2024-08-15 11:27:42 -04:00 committed by GitHub
parent a114255b82
commit 831708beb7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 65 additions and 51 deletions

View File

@ -1,6 +1,5 @@
"""Wrapper around Together AI's Chat Completions API.""" """Wrapper around Together AI's Chat Completions API."""
import os
from typing import ( from typing import (
Any, Any,
Dict, Dict,
@ -12,8 +11,8 @@ import openai
from langchain_core.language_models.chat_models import LangSmithParams from langchain_core.language_models.chat_models import LangSmithParams
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import ( from langchain_core.utils import (
convert_to_secret_str, from_env,
get_from_dict_or_env, secret_from_env,
) )
from langchain_openai.chat_models.base import BaseChatOpenAI from langchain_openai.chat_models.base import BaseChatOpenAI
@ -311,13 +310,27 @@ class ChatTogether(BaseChatOpenAI):
model_name: str = Field(default="meta-llama/Llama-3-8b-chat-hf", alias="model") model_name: str = Field(default="meta-llama/Llama-3-8b-chat-hf", alias="model")
"""Model name to use.""" """Model name to use."""
together_api_key: Optional[SecretStr] = Field(default=None, alias="api_key") together_api_key: Optional[SecretStr] = Field(
"""Automatically inferred from env are `TOGETHER_API_KEY` if not provided.""" alias="api_key",
together_api_base: Optional[str] = Field( default_factory=secret_from_env("TOGETHER_API_KEY", default=None),
default="https://api.together.ai/v1/", alias="base_url" )
"""Together AI API key.
Automatically read from env variable `TOGETHER_API_KEY` if not provided.
"""
together_api_base: str = Field(
default_factory=from_env(
"TOGETHER_API_BASE", default="https://api.together.ai/v1/"
),
alias="base_url",
) )
@root_validator() class Config:
"""Pydantic config."""
allow_population_by_field_name = True
@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
if values["n"] < 1: if values["n"] < 1:
@ -325,13 +338,6 @@ class ChatTogether(BaseChatOpenAI):
if values["n"] > 1 and values["streaming"]: if values["n"] > 1 and values["streaming"]:
raise ValueError("n must be 1 when streaming.") raise ValueError("n must be 1 when streaming.")
values["together_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "together_api_key", "TOGETHER_API_KEY")
)
values["together_api_base"] = values["together_api_base"] or os.getenv(
"TOGETHER_API_BASE"
)
client_params = { client_params = {
"api_key": ( "api_key": (
values["together_api_key"].get_secret_value() values["together_api_key"].get_secret_value()

View File

@ -1,7 +1,6 @@
"""Wrapper around Together AI's Embeddings API.""" """Wrapper around Together AI's Embeddings API."""
import logging import logging
import os
import warnings import warnings
from typing import ( from typing import (
Any, Any,
@ -25,9 +24,9 @@ from langchain_core.pydantic_v1 import (
root_validator, root_validator,
) )
from langchain_core.utils import ( from langchain_core.utils import (
convert_to_secret_str, from_env,
get_from_dict_or_env,
get_pydantic_field_names, get_pydantic_field_names,
secret_from_env,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -115,10 +114,19 @@ class TogetherEmbeddings(BaseModel, Embeddings):
Not yet supported. Not yet supported.
""" """
together_api_key: Optional[SecretStr] = Field(default=None, alias="api_key") together_api_key: Optional[SecretStr] = Field(
"""API Key for Solar API.""" alias="api_key",
default_factory=secret_from_env("TOGETHER_API_KEY", default=None),
)
"""Together AI API key.
Automatically read from env variable `TOGETHER_API_KEY` if not provided.
"""
together_api_base: str = Field( together_api_base: str = Field(
default="https://api.together.ai/v1/", alias="base_url" default_factory=from_env(
"TOGETHER_API_BASE", default="https://api.together.ai/v1/"
),
alias="base_url",
) )
"""Endpoint URL to use.""" """Endpoint URL to use."""
embedding_ctx_length: int = 4096 embedding_ctx_length: int = 4096
@ -198,18 +206,9 @@ class TogetherEmbeddings(BaseModel, Embeddings):
values["model_kwargs"] = extra values["model_kwargs"] = extra
return values return values
@root_validator() @root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict: def post_init(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Logic that will post Pydantic initialization."""
together_api_key = get_from_dict_or_env(
values, "together_api_key", "TOGETHER_API_KEY"
)
values["together_api_key"] = (
convert_to_secret_str(together_api_key) if together_api_key else None
)
values["together_api_base"] = values["together_api_base"] or os.getenv(
"TOGETHER_API_BASE"
)
client_params = { client_params = {
"api_key": ( "api_key": (
values["together_api_key"].get_secret_value() values["together_api_key"].get_secret_value()

View File

@ -11,8 +11,10 @@ from langchain_core.callbacks import (
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
) )
from langchain_core.language_models.llms import LLM from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import SecretStr, root_validator from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env from langchain_core.utils import (
secret_from_env,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -36,8 +38,14 @@ class Together(LLM):
base_url: str = "https://api.together.ai/v1/completions" base_url: str = "https://api.together.ai/v1/completions"
"""Base completions API URL.""" """Base completions API URL."""
together_api_key: SecretStr together_api_key: SecretStr = Field(
"""Together AI API key. Get it here: https://api.together.ai/settings/api-keys""" alias="api_key",
default_factory=secret_from_env("TOGETHER_API_KEY"),
)
"""Together AI API key.
Automatically read from env variable `TOGETHER_API_KEY` if not provided.
"""
model: str model: str
"""Model name. Available models listed here: """Model name. Available models listed here:
Base Models: https://docs.together.ai/docs/inference-models#language-models Base Models: https://docs.together.ai/docs/inference-models#language-models
@ -74,21 +82,11 @@ class Together(LLM):
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
extra = "forbid" extra = "forbid"
allow_population_by_field_name = True
@root_validator(pre=True) @root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key exists in environment.""" """Validate that api key exists in environment."""
values["together_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "together_api_key", "TOGETHER_API_KEY")
)
return values
@root_validator()
def validate_max_tokens(cls, values: Dict) -> Dict:
"""The v1 completions endpoint, has max_tokens as required parameter.
Set a default value and warn if the parameter is missing.
"""
if values.get("max_tokens") is None: if values.get("max_tokens") is None:
warnings.warn( warnings.warn(
"The completions endpoint, has 'max_tokens' as required argument. " "The completions endpoint, has 'max_tokens' as required argument. "

View File

@ -9,7 +9,7 @@ from langchain_together import Together
def test_together_api_key_is_secret_string() -> None: def test_together_api_key_is_secret_string() -> None:
"""Test that the API key is stored as a SecretStr.""" """Test that the API key is stored as a SecretStr."""
llm = Together( llm = Together(
together_api_key="secret-api-key", # type: ignore[arg-type] together_api_key="secret-api-key", # type: ignore[call-arg]
model="togethercomputer/RedPajama-INCITE-7B-Base", model="togethercomputer/RedPajama-INCITE-7B-Base",
temperature=0.2, temperature=0.2,
max_tokens=250, max_tokens=250,
@ -38,7 +38,7 @@ def test_together_api_key_masked_when_passed_via_constructor(
) -> None: ) -> None:
"""Test that the API key is masked when passed via the constructor.""" """Test that the API key is masked when passed via the constructor."""
llm = Together( llm = Together(
together_api_key="secret-api-key", # type: ignore[arg-type] together_api_key="secret-api-key", # type: ignore[call-arg]
model="togethercomputer/RedPajama-INCITE-7B-Base", model="togethercomputer/RedPajama-INCITE-7B-Base",
temperature=0.2, temperature=0.2,
max_tokens=250, max_tokens=250,
@ -52,7 +52,18 @@ def test_together_api_key_masked_when_passed_via_constructor(
def test_together_uses_actual_secret_value_from_secretstr() -> None: def test_together_uses_actual_secret_value_from_secretstr() -> None:
"""Test that the actual secret value is correctly retrieved.""" """Test that the actual secret value is correctly retrieved."""
llm = Together( llm = Together(
together_api_key="secret-api-key", # type: ignore[arg-type] together_api_key="secret-api-key", # type: ignore[call-arg]
model="togethercomputer/RedPajama-INCITE-7B-Base",
temperature=0.2,
max_tokens=250,
)
assert cast(SecretStr, llm.together_api_key).get_secret_value() == "secret-api-key"
def test_together_uses_actual_secret_value_from_secretstr_api_key() -> None:
"""Test that the actual secret value is correctly retrieved."""
llm = Together(
api_key="secret-api-key", # type: ignore[arg-type]
model="togethercomputer/RedPajama-INCITE-7B-Base", model="togethercomputer/RedPajama-INCITE-7B-Base",
temperature=0.2, temperature=0.2,
max_tokens=250, max_tokens=250,