mirror of
https://github.com/hwchase17/langchain.git
synced 2025-04-29 04:16:02 +00:00
pinecone[patch]: Upgrade @root_validators to be consistent with pydantic 2 (#25453)
Upgrade root validators for pydantic 2 migration
This commit is contained in:
parent
b297af5482
commit
34da8be60b
@ -1,5 +1,4 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
from typing import Dict, Iterable, List, Optional
|
from typing import Dict, Iterable, List, Optional
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
@ -10,7 +9,7 @@ from langchain_core.pydantic_v1 import (
|
|||||||
SecretStr,
|
SecretStr,
|
||||||
root_validator,
|
root_validator,
|
||||||
)
|
)
|
||||||
from langchain_core.utils import convert_to_secret_str
|
from langchain_core.utils import secret_from_env
|
||||||
from pinecone import Pinecone as PineconeClient # type: ignore
|
from pinecone import Pinecone as PineconeClient # type: ignore
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -45,10 +44,21 @@ class PineconeEmbeddings(BaseModel, Embeddings):
|
|||||||
dimension: Optional[int] = None
|
dimension: Optional[int] = None
|
||||||
#
|
#
|
||||||
show_progress_bar: bool = False
|
show_progress_bar: bool = False
|
||||||
pinecone_api_key: Optional[SecretStr] = None
|
pinecone_api_key: Optional[SecretStr] = Field(
|
||||||
|
default_factory=secret_from_env(
|
||||||
|
"PINECONE_API_KEY",
|
||||||
|
error_message="Pinecone API key not found. Please set the PINECONE_API_KEY "
|
||||||
|
"environment variable or pass it via `pinecone_api_key`.",
|
||||||
|
),
|
||||||
|
alias="api_key",
|
||||||
|
)
|
||||||
|
"""Pinecone API key.
|
||||||
|
|
||||||
|
If not provided, will look for the PINECONE_API_KEY environment variable."""
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
extra = "forbid"
|
extra = "forbid"
|
||||||
|
allow_population_by_field_name = True
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def set_default_config(cls, values: dict) -> dict:
|
def set_default_config(cls, values: dict) -> dict:
|
||||||
@ -69,25 +79,10 @@ class PineconeEmbeddings(BaseModel, Embeddings):
|
|||||||
values[key] = value
|
values[key] = value
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@root_validator()
|
@root_validator(pre=False, skip_on_failure=True)
|
||||||
def validate_environment(cls, values: dict) -> dict:
|
def validate_environment(cls, values: dict) -> dict:
|
||||||
"""Validate that Pinecone version and credentials exist in environment."""
|
"""Validate that Pinecone version and credentials exist in environment."""
|
||||||
|
api_key_str = values["pinecone_api_key"].get_secret_value()
|
||||||
pinecone_api_key = values.get("pinecone_api_key") or os.getenv(
|
|
||||||
"PINECONE_API_KEY", None
|
|
||||||
)
|
|
||||||
if pinecone_api_key:
|
|
||||||
api_key_secretstr = convert_to_secret_str(pinecone_api_key)
|
|
||||||
values["pinecone_api_key"] = api_key_secretstr
|
|
||||||
|
|
||||||
api_key_str = api_key_secretstr.get_secret_value()
|
|
||||||
else:
|
|
||||||
api_key_str = None
|
|
||||||
if api_key_str is None:
|
|
||||||
raise ValueError(
|
|
||||||
"Pinecone API key not found. Please set the PINECONE_API_KEY "
|
|
||||||
"environment variable or pass it via `pinecone_api_key`."
|
|
||||||
)
|
|
||||||
client = PineconeClient(api_key=api_key_str, source_tag="langchain")
|
client = PineconeClient(api_key=api_key_str, source_tag="langchain")
|
||||||
values["_client"] = client
|
values["_client"] = client
|
||||||
|
|
||||||
|
@ -7,10 +7,22 @@ MODEL_NAME = "multilingual-e5-large"
|
|||||||
|
|
||||||
|
|
||||||
def test_default_config() -> None:
|
def test_default_config() -> None:
|
||||||
e = PineconeEmbeddings(pinecone_api_key=API_KEY, model=MODEL_NAME)
|
e = PineconeEmbeddings(
|
||||||
|
pinecone_api_key=API_KEY, # type: ignore[call-arg]
|
||||||
|
model=MODEL_NAME,
|
||||||
|
)
|
||||||
|
assert e.batch_size == 96
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_config_with_api_key() -> None:
|
||||||
|
e = PineconeEmbeddings(api_key=API_KEY, model=MODEL_NAME)
|
||||||
assert e.batch_size == 96
|
assert e.batch_size == 96
|
||||||
|
|
||||||
|
|
||||||
def test_custom_config() -> None:
|
def test_custom_config() -> None:
|
||||||
e = PineconeEmbeddings(pinecone_api_key=API_KEY, model=MODEL_NAME, batch_size=128)
|
e = PineconeEmbeddings(
|
||||||
|
pinecone_api_key=API_KEY, # type: ignore[call-arg]
|
||||||
|
model=MODEL_NAME,
|
||||||
|
batch_size=128,
|
||||||
|
)
|
||||||
assert e.batch_size == 128
|
assert e.batch_size == 128
|
||||||
|
Loading…
Reference in New Issue
Block a user