community[patch]: set protected namespaces on embeddings (#26156)

Also fix serdes test for langchain-google-genai.
This commit is contained in:
ccurme
2024-09-10 09:28:41 -04:00
committed by GitHub
parent e24259fee7
commit 6208773c77
20 changed files with 24 additions and 63 deletions

View File

@@ -70,9 +70,7 @@ class BaichuanTextEmbeddings(BaseModel, Embeddings):
chunk_size: int = 16
"""Chunk size when multiple texts are input"""
model_config = ConfigDict(
populate_by_name=True,
)
model_config = ConfigDict(populate_by_name=True, protected_namespaces=())
@model_validator(mode="after")
def validate_environment(self) -> Self:

View File

@@ -75,9 +75,7 @@ class BedrockEmbeddings(BaseModel, Embeddings):
normalize: bool = False
"""Whether the embeddings should be normalized to unit vectors"""
model_config = ConfigDict(
extra="forbid",
)
model_config = ConfigDict(extra="forbid", protected_namespaces=())
@model_validator(mode="after")
def validate_environment(self) -> Self:

View File

@@ -43,9 +43,7 @@ class ClarifaiEmbeddings(BaseModel, Embeddings):
model: Any = Field(default=None, exclude=True) #: :meta private:
api_base: str = "https://api.clarifai.com"
model_config = ConfigDict(
extra="forbid",
)
model_config = ConfigDict(extra="forbid", protected_namespaces=())
@model_validator(mode="before")
@classmethod

View File

@@ -43,9 +43,7 @@ class CloudflareWorkersAIEmbeddings(BaseModel, Embeddings):
self.headers = {"Authorization": f"Bearer {self.api_token}"}
model_config = ConfigDict(
extra="forbid",
)
model_config = ConfigDict(extra="forbid", protected_namespaces=())
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Compute doc embeddings using Cloudflare Workers AI.

View File

@@ -54,9 +54,7 @@ class DeepInfraEmbeddings(BaseModel, Embeddings):
batch_size: int = MAX_BATCH_SIZE
"""Batch size for embedding requests."""
model_config = ConfigDict(
extra="forbid",
)
model_config = ConfigDict(extra="forbid", protected_namespaces=())
@pre_init
def validate_environment(cls, values: Dict) -> Dict:

View File

@@ -67,9 +67,7 @@ class FastEmbedEmbeddings(BaseModel, Embeddings):
_model: Any # : :meta private:
model_config = ConfigDict(
extra="allow",
)
model_config = ConfigDict(extra="allow", protected_namespaces=())
@pre_init
def validate_environment(cls, values: Dict) -> Dict:

View File

@@ -93,9 +93,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
)
model_config = ConfigDict(
extra="forbid",
)
model_config = ConfigDict(extra="forbid", protected_namespaces=())
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Compute doc embeddings using a HuggingFace transformer model.
@@ -209,9 +207,7 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
)
self.show_progress = self.encode_kwargs.pop("show_progress_bar")
model_config = ConfigDict(
extra="forbid",
)
model_config = ConfigDict(extra="forbid", protected_namespaces=())
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Compute doc embeddings using a HuggingFace instruct model.
@@ -350,9 +346,7 @@ class HuggingFaceBgeEmbeddings(BaseModel, Embeddings):
)
self.show_progress = self.encode_kwargs.pop("show_progress_bar")
model_config = ConfigDict(
extra="forbid",
)
model_config = ConfigDict(extra="forbid", protected_namespaces=())
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Compute doc embeddings using a HuggingFace transformer model.

View File

@@ -48,9 +48,7 @@ class HuggingFaceHubEmbeddings(BaseModel, Embeddings):
huggingfacehub_api_token: Optional[str] = None
model_config = ConfigDict(
extra="forbid",
)
model_config = ConfigDict(extra="forbid", protected_namespaces=())
@model_validator(mode="before")
@classmethod

View File

@@ -106,9 +106,7 @@ class IpexLLMBgeEmbeddings(BaseModel, Embeddings):
if "-zh" in self.model_name:
self.query_instruction = DEFAULT_QUERY_BGE_INSTRUCTION_ZH
model_config = ConfigDict(
extra="forbid",
)
model_config = ConfigDict(extra="forbid", protected_namespaces=())
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Compute doc embeddings using a HuggingFace transformer model.

View File

@@ -166,9 +166,7 @@ class LocalAIEmbeddings(BaseModel, Embeddings):
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
model_config = ConfigDict(
extra="forbid",
)
model_config = ConfigDict(extra="forbid", protected_namespaces=())
@model_validator(mode="before")
@classmethod

View File

@@ -39,9 +39,7 @@ class ModelScopeEmbeddings(BaseModel, Embeddings):
model_revision=self.model_revision,
)
model_config = ConfigDict(
extra="forbid",
)
model_config = ConfigDict(extra="forbid", protected_namespaces=())
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Compute doc embeddings using a modelscope embedding model.

View File

@@ -85,9 +85,7 @@ class OCIGenAIEmbeddings(BaseModel, Embeddings):
"""Batch size of OCI GenAI embedding requests. OCI GenAI may handle up to 96 texts
per request"""
model_config = ConfigDict(
extra="forbid",
)
model_config = ConfigDict(extra="forbid", protected_namespaces=())
@pre_init
def validate_environment(cls, values: Dict) -> Dict: # pylint: disable=no-self-argument

View File

@@ -141,9 +141,7 @@ class OllamaEmbeddings(BaseModel, Embeddings):
"""Get the identifying parameters."""
return {**{"model": self.model}, **self._default_params}
model_config = ConfigDict(
extra="forbid",
)
model_config = ConfigDict(extra="forbid", protected_namespaces=())
def _process_emb_response(self, input: str) -> List[float]:
"""Process a response from the API.

View File

@@ -255,8 +255,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
"""Optional httpx.Client."""
model_config = ConfigDict(
populate_by_name=True,
extra="forbid",
populate_by_name=True, extra="forbid", protected_namespaces=()
)
@model_validator(mode="before")

View File

@@ -254,9 +254,7 @@ class OpenVINOEmbeddings(BaseModel, Embeddings):
return all_embeddings
model_config = ConfigDict(
extra="forbid",
)
model_config = ConfigDict(extra="forbid", protected_namespaces=())
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Compute doc embeddings using a HuggingFace transformer model.

View File

@@ -23,9 +23,7 @@ class OVHCloudEmbeddings(BaseModel, Embeddings):
""" OVHcloud AI Endpoints region"""
region: str = "kepler"
model_config = ConfigDict(
extra="forbid",
)
model_config = ConfigDict(extra="forbid", protected_namespaces=())
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)

View File

@@ -111,8 +111,7 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
"""
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra="forbid",
arbitrary_types_allowed=True, extra="forbid", protected_namespaces=()
)
@pre_init

View File

@@ -22,9 +22,7 @@ class SpacyEmbeddings(BaseModel, Embeddings):
model_name: str = "en_core_web_sm"
nlp: Optional[Any] = None
model_config = ConfigDict(
extra="forbid",
)
model_config = ConfigDict(extra="forbid", protected_namespaces=())
@model_validator(mode="before")
@classmethod

View File

@@ -71,9 +71,7 @@ class YandexGPTEmbeddings(BaseModel, Embeddings):
If you provide personal data, confidential information, disable logging."""
grpc_metadata: Sequence
model_config = ConfigDict(
populate_by_name=True,
)
model_config = ConfigDict(populate_by_name=True, protected_namespaces=())
@pre_init
def validate_environment(cls, values: Dict) -> Dict:

View File

@@ -112,8 +112,9 @@ def test_serializable_mapping() -> None:
"chat_models",
"ChatGroq",
),
# TODO(0.3): For now we're skipping this test. Need to fix
# so that it only runs when langchain-aws is installed.
# TODO(0.3): For now we're skipping the below two tests. Need to fix
# so that it only runs when langchain-aws, langchain-google-genai
# are installed.
("langchain", "chat_models", "bedrock", "ChatBedrock"): (
"langchain_aws",
"chat_models",