mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 13:54:48 +00:00
together[patch]: Update @root_validator for pydantic 2 compatibility (#25423)
This PR updates usage of @root_validator to be compatible with pydantic 2.
This commit is contained in:
parent
a114255b82
commit
831708beb7
@ -1,6 +1,5 @@
|
|||||||
"""Wrapper around Together AI's Chat Completions API."""
|
"""Wrapper around Together AI's Chat Completions API."""
|
||||||
|
|
||||||
import os
|
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Dict,
|
Dict,
|
||||||
@ -12,8 +11,8 @@ import openai
|
|||||||
from langchain_core.language_models.chat_models import LangSmithParams
|
from langchain_core.language_models.chat_models import LangSmithParams
|
||||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||||
from langchain_core.utils import (
|
from langchain_core.utils import (
|
||||||
convert_to_secret_str,
|
from_env,
|
||||||
get_from_dict_or_env,
|
secret_from_env,
|
||||||
)
|
)
|
||||||
from langchain_openai.chat_models.base import BaseChatOpenAI
|
from langchain_openai.chat_models.base import BaseChatOpenAI
|
||||||
|
|
||||||
@ -311,13 +310,27 @@ class ChatTogether(BaseChatOpenAI):
|
|||||||
|
|
||||||
model_name: str = Field(default="meta-llama/Llama-3-8b-chat-hf", alias="model")
|
model_name: str = Field(default="meta-llama/Llama-3-8b-chat-hf", alias="model")
|
||||||
"""Model name to use."""
|
"""Model name to use."""
|
||||||
together_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
|
together_api_key: Optional[SecretStr] = Field(
|
||||||
"""Automatically inferred from env are `TOGETHER_API_KEY` if not provided."""
|
alias="api_key",
|
||||||
together_api_base: Optional[str] = Field(
|
default_factory=secret_from_env("TOGETHER_API_KEY", default=None),
|
||||||
default="https://api.together.ai/v1/", alias="base_url"
|
)
|
||||||
|
"""Together AI API key.
|
||||||
|
|
||||||
|
Automatically read from env variable `TOGETHER_API_KEY` if not provided.
|
||||||
|
"""
|
||||||
|
together_api_base: str = Field(
|
||||||
|
default_factory=from_env(
|
||||||
|
"TOGETHER_API_BASE", default="https://api.together.ai/v1/"
|
||||||
|
),
|
||||||
|
alias="base_url",
|
||||||
)
|
)
|
||||||
|
|
||||||
@root_validator()
|
class Config:
|
||||||
|
"""Pydantic config."""
|
||||||
|
|
||||||
|
allow_population_by_field_name = True
|
||||||
|
|
||||||
|
@root_validator(pre=False, skip_on_failure=True)
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that api key and python package exists in environment."""
|
"""Validate that api key and python package exists in environment."""
|
||||||
if values["n"] < 1:
|
if values["n"] < 1:
|
||||||
@ -325,13 +338,6 @@ class ChatTogether(BaseChatOpenAI):
|
|||||||
if values["n"] > 1 and values["streaming"]:
|
if values["n"] > 1 and values["streaming"]:
|
||||||
raise ValueError("n must be 1 when streaming.")
|
raise ValueError("n must be 1 when streaming.")
|
||||||
|
|
||||||
values["together_api_key"] = convert_to_secret_str(
|
|
||||||
get_from_dict_or_env(values, "together_api_key", "TOGETHER_API_KEY")
|
|
||||||
)
|
|
||||||
values["together_api_base"] = values["together_api_base"] or os.getenv(
|
|
||||||
"TOGETHER_API_BASE"
|
|
||||||
)
|
|
||||||
|
|
||||||
client_params = {
|
client_params = {
|
||||||
"api_key": (
|
"api_key": (
|
||||||
values["together_api_key"].get_secret_value()
|
values["together_api_key"].get_secret_value()
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
"""Wrapper around Together AI's Embeddings API."""
|
"""Wrapper around Together AI's Embeddings API."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import warnings
|
import warnings
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
@ -25,9 +24,9 @@ from langchain_core.pydantic_v1 import (
|
|||||||
root_validator,
|
root_validator,
|
||||||
)
|
)
|
||||||
from langchain_core.utils import (
|
from langchain_core.utils import (
|
||||||
convert_to_secret_str,
|
from_env,
|
||||||
get_from_dict_or_env,
|
|
||||||
get_pydantic_field_names,
|
get_pydantic_field_names,
|
||||||
|
secret_from_env,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -115,10 +114,19 @@ class TogetherEmbeddings(BaseModel, Embeddings):
|
|||||||
|
|
||||||
Not yet supported.
|
Not yet supported.
|
||||||
"""
|
"""
|
||||||
together_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
|
together_api_key: Optional[SecretStr] = Field(
|
||||||
"""API Key for Solar API."""
|
alias="api_key",
|
||||||
|
default_factory=secret_from_env("TOGETHER_API_KEY", default=None),
|
||||||
|
)
|
||||||
|
"""Together AI API key.
|
||||||
|
|
||||||
|
Automatically read from env variable `TOGETHER_API_KEY` if not provided.
|
||||||
|
"""
|
||||||
together_api_base: str = Field(
|
together_api_base: str = Field(
|
||||||
default="https://api.together.ai/v1/", alias="base_url"
|
default_factory=from_env(
|
||||||
|
"TOGETHER_API_BASE", default="https://api.together.ai/v1/"
|
||||||
|
),
|
||||||
|
alias="base_url",
|
||||||
)
|
)
|
||||||
"""Endpoint URL to use."""
|
"""Endpoint URL to use."""
|
||||||
embedding_ctx_length: int = 4096
|
embedding_ctx_length: int = 4096
|
||||||
@ -198,18 +206,9 @@ class TogetherEmbeddings(BaseModel, Embeddings):
|
|||||||
values["model_kwargs"] = extra
|
values["model_kwargs"] = extra
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@root_validator()
|
@root_validator(pre=False, skip_on_failure=True)
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def post_init(cls, values: Dict) -> Dict:
|
||||||
"""Validate that api key and python package exists in environment."""
|
"""Logic that will post Pydantic initialization."""
|
||||||
together_api_key = get_from_dict_or_env(
|
|
||||||
values, "together_api_key", "TOGETHER_API_KEY"
|
|
||||||
)
|
|
||||||
values["together_api_key"] = (
|
|
||||||
convert_to_secret_str(together_api_key) if together_api_key else None
|
|
||||||
)
|
|
||||||
values["together_api_base"] = values["together_api_base"] or os.getenv(
|
|
||||||
"TOGETHER_API_BASE"
|
|
||||||
)
|
|
||||||
client_params = {
|
client_params = {
|
||||||
"api_key": (
|
"api_key": (
|
||||||
values["together_api_key"].get_secret_value()
|
values["together_api_key"].get_secret_value()
|
||||||
|
@ -11,8 +11,10 @@ from langchain_core.callbacks import (
|
|||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
)
|
)
|
||||||
from langchain_core.language_models.llms import LLM
|
from langchain_core.language_models.llms import LLM
|
||||||
from langchain_core.pydantic_v1 import SecretStr, root_validator
|
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
from langchain_core.utils import (
|
||||||
|
secret_from_env,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -36,8 +38,14 @@ class Together(LLM):
|
|||||||
|
|
||||||
base_url: str = "https://api.together.ai/v1/completions"
|
base_url: str = "https://api.together.ai/v1/completions"
|
||||||
"""Base completions API URL."""
|
"""Base completions API URL."""
|
||||||
together_api_key: SecretStr
|
together_api_key: SecretStr = Field(
|
||||||
"""Together AI API key. Get it here: https://api.together.ai/settings/api-keys"""
|
alias="api_key",
|
||||||
|
default_factory=secret_from_env("TOGETHER_API_KEY"),
|
||||||
|
)
|
||||||
|
"""Together AI API key.
|
||||||
|
|
||||||
|
Automatically read from env variable `TOGETHER_API_KEY` if not provided.
|
||||||
|
"""
|
||||||
model: str
|
model: str
|
||||||
"""Model name. Available models listed here:
|
"""Model name. Available models listed here:
|
||||||
Base Models: https://docs.together.ai/docs/inference-models#language-models
|
Base Models: https://docs.together.ai/docs/inference-models#language-models
|
||||||
@ -74,21 +82,11 @@ class Together(LLM):
|
|||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
extra = "forbid"
|
extra = "forbid"
|
||||||
|
allow_population_by_field_name = True
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that api key exists in environment."""
|
"""Validate that api key exists in environment."""
|
||||||
values["together_api_key"] = convert_to_secret_str(
|
|
||||||
get_from_dict_or_env(values, "together_api_key", "TOGETHER_API_KEY")
|
|
||||||
)
|
|
||||||
return values
|
|
||||||
|
|
||||||
@root_validator()
|
|
||||||
def validate_max_tokens(cls, values: Dict) -> Dict:
|
|
||||||
"""The v1 completions endpoint, has max_tokens as required parameter.
|
|
||||||
|
|
||||||
Set a default value and warn if the parameter is missing.
|
|
||||||
"""
|
|
||||||
if values.get("max_tokens") is None:
|
if values.get("max_tokens") is None:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The completions endpoint, has 'max_tokens' as required argument. "
|
"The completions endpoint, has 'max_tokens' as required argument. "
|
||||||
|
@ -9,7 +9,7 @@ from langchain_together import Together
|
|||||||
def test_together_api_key_is_secret_string() -> None:
|
def test_together_api_key_is_secret_string() -> None:
|
||||||
"""Test that the API key is stored as a SecretStr."""
|
"""Test that the API key is stored as a SecretStr."""
|
||||||
llm = Together(
|
llm = Together(
|
||||||
together_api_key="secret-api-key", # type: ignore[arg-type]
|
together_api_key="secret-api-key", # type: ignore[call-arg]
|
||||||
model="togethercomputer/RedPajama-INCITE-7B-Base",
|
model="togethercomputer/RedPajama-INCITE-7B-Base",
|
||||||
temperature=0.2,
|
temperature=0.2,
|
||||||
max_tokens=250,
|
max_tokens=250,
|
||||||
@ -38,7 +38,7 @@ def test_together_api_key_masked_when_passed_via_constructor(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Test that the API key is masked when passed via the constructor."""
|
"""Test that the API key is masked when passed via the constructor."""
|
||||||
llm = Together(
|
llm = Together(
|
||||||
together_api_key="secret-api-key", # type: ignore[arg-type]
|
together_api_key="secret-api-key", # type: ignore[call-arg]
|
||||||
model="togethercomputer/RedPajama-INCITE-7B-Base",
|
model="togethercomputer/RedPajama-INCITE-7B-Base",
|
||||||
temperature=0.2,
|
temperature=0.2,
|
||||||
max_tokens=250,
|
max_tokens=250,
|
||||||
@ -52,7 +52,18 @@ def test_together_api_key_masked_when_passed_via_constructor(
|
|||||||
def test_together_uses_actual_secret_value_from_secretstr() -> None:
|
def test_together_uses_actual_secret_value_from_secretstr() -> None:
|
||||||
"""Test that the actual secret value is correctly retrieved."""
|
"""Test that the actual secret value is correctly retrieved."""
|
||||||
llm = Together(
|
llm = Together(
|
||||||
together_api_key="secret-api-key", # type: ignore[arg-type]
|
together_api_key="secret-api-key", # type: ignore[call-arg]
|
||||||
|
model="togethercomputer/RedPajama-INCITE-7B-Base",
|
||||||
|
temperature=0.2,
|
||||||
|
max_tokens=250,
|
||||||
|
)
|
||||||
|
assert cast(SecretStr, llm.together_api_key).get_secret_value() == "secret-api-key"
|
||||||
|
|
||||||
|
|
||||||
|
def test_together_uses_actual_secret_value_from_secretstr_api_key() -> None:
|
||||||
|
"""Test that the actual secret value is correctly retrieved."""
|
||||||
|
llm = Together(
|
||||||
|
api_key="secret-api-key", # type: ignore[arg-type]
|
||||||
model="togethercomputer/RedPajama-INCITE-7B-Base",
|
model="togethercomputer/RedPajama-INCITE-7B-Base",
|
||||||
temperature=0.2,
|
temperature=0.2,
|
||||||
max_tokens=250,
|
max_tokens=250,
|
||||||
|
Loading…
Reference in New Issue
Block a user