mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-05 04:38:26 +00:00
voyageai[patch]: Upgrade root validators for pydantic 2 (#25455)
Update @root_validators to be consistent with pydantic 2 semantics
This commit is contained in:
parent
4cdaca67dc
commit
b297af5482
@ -1,5 +1,4 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
from typing import Iterable, List, Optional
|
from typing import Iterable, List, Optional
|
||||||
|
|
||||||
import voyageai # type: ignore
|
import voyageai # type: ignore
|
||||||
@ -10,7 +9,7 @@ from langchain_core.pydantic_v1 import (
|
|||||||
SecretStr,
|
SecretStr,
|
||||||
root_validator,
|
root_validator,
|
||||||
)
|
)
|
||||||
from langchain_core.utils import convert_to_secret_str
|
from langchain_core.utils import secret_from_env
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -32,34 +31,32 @@ class VoyageAIEmbeddings(BaseModel, Embeddings):
|
|||||||
batch_size: int
|
batch_size: int
|
||||||
show_progress_bar: bool = False
|
show_progress_bar: bool = False
|
||||||
truncation: Optional[bool] = None
|
truncation: Optional[bool] = None
|
||||||
voyage_api_key: Optional[SecretStr] = None
|
voyage_api_key: SecretStr = Field(
|
||||||
|
alias="api_key",
|
||||||
|
default_factory=secret_from_env(
|
||||||
|
"VOYAGE_API_KEY",
|
||||||
|
error_message="Must set `VOYAGE_API_KEY` environment variable or "
|
||||||
|
"pass `api_key` to VoyageAIEmbeddings constructor.",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
extra = "forbid"
|
extra = "forbid"
|
||||||
|
allow_population_by_field_name = True
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def default_values(cls, values: dict) -> dict:
|
def default_values(cls, values: dict) -> dict:
|
||||||
"""Set default batch size based on model"""
|
"""Set default batch size based on model"""
|
||||||
|
|
||||||
model = values.get("model")
|
model = values.get("model")
|
||||||
batch_size = values.get("batch_size")
|
batch_size = values.get("batch_size")
|
||||||
if batch_size is None:
|
if batch_size is None:
|
||||||
values["batch_size"] = 72 if model in ["voyage-2", "voyage-02"] else 7
|
values["batch_size"] = 72 if model in ["voyage-2", "voyage-02"] else 7
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@root_validator()
|
@root_validator(pre=False, skip_on_failure=True)
|
||||||
def validate_environment(cls, values: dict) -> dict:
|
def validate_environment(cls, values: dict) -> dict:
|
||||||
"""Validate that VoyageAI credentials exist in environment."""
|
"""Validate that VoyageAI credentials exist in environment."""
|
||||||
voyage_api_key = values.get("voyage_api_key") or os.getenv(
|
api_key_str = values["voyage_api_key"].get_secret_value()
|
||||||
"VOYAGE_API_KEY", None
|
|
||||||
)
|
|
||||||
if voyage_api_key:
|
|
||||||
api_key_secretstr = convert_to_secret_str(voyage_api_key)
|
|
||||||
values["voyage_api_key"] = api_key_secretstr
|
|
||||||
|
|
||||||
api_key_str = api_key_secretstr.get_secret_value()
|
|
||||||
else:
|
|
||||||
api_key_str = None
|
|
||||||
values["_client"] = voyageai.Client(api_key=api_key_str)
|
values["_client"] = voyageai.Client(api_key=api_key_str)
|
||||||
values["_aclient"] = voyageai.client_async.AsyncClient(api_key=api_key_str)
|
values["_aclient"] = voyageai.client_async.AsyncClient(api_key=api_key_str)
|
||||||
return values
|
return values
|
||||||
|
@ -9,6 +9,17 @@ MODEL = "voyage-2"
|
|||||||
|
|
||||||
def test_initialization_voyage_2() -> None:
|
def test_initialization_voyage_2() -> None:
|
||||||
"""Test embedding model initialization."""
|
"""Test embedding model initialization."""
|
||||||
|
emb = VoyageAIEmbeddings(api_key="NOT_A_VALID_KEY", model=MODEL)
|
||||||
|
assert isinstance(emb, Embeddings)
|
||||||
|
assert emb.batch_size == 72
|
||||||
|
assert emb.model == MODEL
|
||||||
|
assert emb._client is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_initialization_voyage_2_with_full_api_key_name() -> None:
|
||||||
|
"""Test embedding model initialization."""
|
||||||
|
# Testing that we can initialize the model using `voyage_api_key`
|
||||||
|
# instead of `api_key`
|
||||||
emb = VoyageAIEmbeddings(voyage_api_key="NOT_A_VALID_KEY", model=MODEL)
|
emb = VoyageAIEmbeddings(voyage_api_key="NOT_A_VALID_KEY", model=MODEL)
|
||||||
assert isinstance(emb, Embeddings)
|
assert isinstance(emb, Embeddings)
|
||||||
assert emb.batch_size == 72
|
assert emb.batch_size == 72
|
||||||
@ -18,7 +29,7 @@ def test_initialization_voyage_2() -> None:
|
|||||||
|
|
||||||
def test_initialization_voyage_1() -> None:
|
def test_initialization_voyage_1() -> None:
|
||||||
"""Test embedding model initialization."""
|
"""Test embedding model initialization."""
|
||||||
emb = VoyageAIEmbeddings(voyage_api_key="NOT_A_VALID_KEY", model="voyage-01")
|
emb = VoyageAIEmbeddings(api_key="NOT_A_VALID_KEY", model="voyage-01")
|
||||||
assert isinstance(emb, Embeddings)
|
assert isinstance(emb, Embeddings)
|
||||||
assert emb.batch_size == 7
|
assert emb.batch_size == 7
|
||||||
assert emb.model == "voyage-01"
|
assert emb.model == "voyage-01"
|
||||||
@ -28,7 +39,7 @@ def test_initialization_voyage_1() -> None:
|
|||||||
def test_initialization_voyage_1_batch_size() -> None:
|
def test_initialization_voyage_1_batch_size() -> None:
|
||||||
"""Test embedding model initialization."""
|
"""Test embedding model initialization."""
|
||||||
emb = VoyageAIEmbeddings(
|
emb = VoyageAIEmbeddings(
|
||||||
voyage_api_key="NOT_A_VALID_KEY", model="voyage-01", batch_size=15
|
api_key="NOT_A_VALID_KEY", model="voyage-01", batch_size=15
|
||||||
)
|
)
|
||||||
assert isinstance(emb, Embeddings)
|
assert isinstance(emb, Embeddings)
|
||||||
assert emb.batch_size == 15
|
assert emb.batch_size == 15
|
||||||
|
Loading…
Reference in New Issue
Block a user