mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-26 22:05:29 +00:00
pinecone[major]: Update to pydantic v2
This commit is contained in:
@@ -1,16 +1,19 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Dict, Iterable, List, Optional
|
from typing import Any, Dict, Iterable, List, Optional
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.pydantic_v1 import (
|
|
||||||
BaseModel,
|
|
||||||
Field,
|
|
||||||
SecretStr,
|
|
||||||
root_validator,
|
|
||||||
)
|
|
||||||
from langchain_core.utils import secret_from_env
|
from langchain_core.utils import secret_from_env
|
||||||
from pinecone import Pinecone as PineconeClient # type: ignore
|
from pinecone import Pinecone as PineconeClient # type: ignore[import-untyped]
|
||||||
|
from pydantic import (
|
||||||
|
BaseModel,
|
||||||
|
ConfigDict,
|
||||||
|
Field,
|
||||||
|
PrivateAttr,
|
||||||
|
SecretStr,
|
||||||
|
model_validator,
|
||||||
|
)
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -29,8 +32,8 @@ class PineconeEmbeddings(BaseModel, Embeddings):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Clients
|
# Clients
|
||||||
_client: PineconeClient = Field(default=None, exclude=True)
|
_client: PineconeClient = PrivateAttr(default=None)
|
||||||
_async_client: aiohttp.ClientSession = Field(default=None, exclude=True)
|
_async_client: aiohttp.ClientSession = PrivateAttr(default=None)
|
||||||
model: str
|
model: str
|
||||||
"""Model to use for example 'multilingual-e5-large'."""
|
"""Model to use for example 'multilingual-e5-large'."""
|
||||||
# Config
|
# Config
|
||||||
@@ -44,7 +47,7 @@ class PineconeEmbeddings(BaseModel, Embeddings):
|
|||||||
dimension: Optional[int] = None
|
dimension: Optional[int] = None
|
||||||
#
|
#
|
||||||
show_progress_bar: bool = False
|
show_progress_bar: bool = False
|
||||||
pinecone_api_key: Optional[SecretStr] = Field(
|
pinecone_api_key: SecretStr = Field(
|
||||||
default_factory=secret_from_env(
|
default_factory=secret_from_env(
|
||||||
"PINECONE_API_KEY",
|
"PINECONE_API_KEY",
|
||||||
error_message="Pinecone API key not found. Please set the PINECONE_API_KEY "
|
error_message="Pinecone API key not found. Please set the PINECONE_API_KEY "
|
||||||
@@ -56,12 +59,14 @@ class PineconeEmbeddings(BaseModel, Embeddings):
|
|||||||
|
|
||||||
If not provided, will look for the PINECONE_API_KEY environment variable."""
|
If not provided, will look for the PINECONE_API_KEY environment variable."""
|
||||||
|
|
||||||
class Config:
|
model_config = ConfigDict(
|
||||||
extra = "forbid"
|
extra="forbid",
|
||||||
allow_population_by_field_name = True
|
populate_by_name=True,
|
||||||
|
)
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@model_validator(mode="before")
|
||||||
def set_default_config(cls, values: dict) -> dict:
|
@classmethod
|
||||||
|
def set_default_config(cls, values: dict) -> Any:
|
||||||
"""Set default configuration based on model."""
|
"""Set default configuration based on model."""
|
||||||
default_config_map = {
|
default_config_map = {
|
||||||
"multilingual-e5-large": {
|
"multilingual-e5-large": {
|
||||||
@@ -79,23 +84,23 @@ class PineconeEmbeddings(BaseModel, Embeddings):
|
|||||||
values[key] = value
|
values[key] = value
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@root_validator(pre=False, skip_on_failure=True)
|
@model_validator(mode="after")
|
||||||
def validate_environment(cls, values: dict) -> dict:
|
def validate_environment(self) -> Self:
|
||||||
"""Validate that Pinecone version and credentials exist in environment."""
|
"""Validate that Pinecone version and credentials exist in environment."""
|
||||||
api_key_str = values["pinecone_api_key"].get_secret_value()
|
api_key_str = self.pinecone_api_key.get_secret_value()
|
||||||
client = PineconeClient(api_key=api_key_str, source_tag="langchain")
|
client = PineconeClient(api_key=api_key_str, source_tag="langchain")
|
||||||
values["_client"] = client
|
self._client = client
|
||||||
|
|
||||||
# initialize async client
|
# initialize async client
|
||||||
if not values.get("_async_client"):
|
if not (self._async_client or None):
|
||||||
values["_async_client"] = aiohttp.ClientSession(
|
self._async_client = aiohttp.ClientSession(
|
||||||
headers={
|
headers={
|
||||||
"Api-Key": api_key_str,
|
"Api-Key": api_key_str,
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"X-Pinecone-API-Version": "2024-07",
|
"X-Pinecone-API-Version": "2024-07",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return values
|
return self
|
||||||
|
|
||||||
def _get_batch_iterator(self, texts: List[str]) -> Iterable:
|
def _get_batch_iterator(self, texts: List[str]) -> Iterable:
|
||||||
if self.batch_size is None:
|
if self.batch_size is None:
|
||||||
|
Reference in New Issue
Block a user