community[patch]: set protected namespaces on embeddings (#26156)

Also fix serdes test for langchain-google-genai.
This commit is contained in:
ccurme
2024-09-10 09:28:41 -04:00
committed by GitHub
parent e24259fee7
commit 6208773c77
20 changed files with 24 additions and 63 deletions

View File

@@ -70,9 +70,7 @@ class BaichuanTextEmbeddings(BaseModel, Embeddings):
chunk_size: int = 16 chunk_size: int = 16
"""Chunk size when multiple texts are input""" """Chunk size when multiple texts are input"""
model_config = ConfigDict( model_config = ConfigDict(populate_by_name=True, protected_namespaces=())
populate_by_name=True,
)
@model_validator(mode="after") @model_validator(mode="after")
def validate_environment(self) -> Self: def validate_environment(self) -> Self:

View File

@@ -75,9 +75,7 @@ class BedrockEmbeddings(BaseModel, Embeddings):
normalize: bool = False normalize: bool = False
"""Whether the embeddings should be normalized to unit vectors""" """Whether the embeddings should be normalized to unit vectors"""
model_config = ConfigDict( model_config = ConfigDict(extra="forbid", protected_namespaces=())
extra="forbid",
)
@model_validator(mode="after") @model_validator(mode="after")
def validate_environment(self) -> Self: def validate_environment(self) -> Self:

View File

@@ -43,9 +43,7 @@ class ClarifaiEmbeddings(BaseModel, Embeddings):
model: Any = Field(default=None, exclude=True) #: :meta private: model: Any = Field(default=None, exclude=True) #: :meta private:
api_base: str = "https://api.clarifai.com" api_base: str = "https://api.clarifai.com"
model_config = ConfigDict( model_config = ConfigDict(extra="forbid", protected_namespaces=())
extra="forbid",
)
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod

View File

@@ -43,9 +43,7 @@ class CloudflareWorkersAIEmbeddings(BaseModel, Embeddings):
self.headers = {"Authorization": f"Bearer {self.api_token}"} self.headers = {"Authorization": f"Bearer {self.api_token}"}
model_config = ConfigDict( model_config = ConfigDict(extra="forbid", protected_namespaces=())
extra="forbid",
)
def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Compute doc embeddings using Cloudflare Workers AI. """Compute doc embeddings using Cloudflare Workers AI.

View File

@@ -54,9 +54,7 @@ class DeepInfraEmbeddings(BaseModel, Embeddings):
batch_size: int = MAX_BATCH_SIZE batch_size: int = MAX_BATCH_SIZE
"""Batch size for embedding requests.""" """Batch size for embedding requests."""
model_config = ConfigDict( model_config = ConfigDict(extra="forbid", protected_namespaces=())
extra="forbid",
)
@pre_init @pre_init
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:

View File

@@ -67,9 +67,7 @@ class FastEmbedEmbeddings(BaseModel, Embeddings):
_model: Any # : :meta private: _model: Any # : :meta private:
model_config = ConfigDict( model_config = ConfigDict(extra="allow", protected_namespaces=())
extra="allow",
)
@pre_init @pre_init
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:

View File

@@ -93,9 +93,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
) )
model_config = ConfigDict( model_config = ConfigDict(extra="forbid", protected_namespaces=())
extra="forbid",
)
def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Compute doc embeddings using a HuggingFace transformer model. """Compute doc embeddings using a HuggingFace transformer model.
@@ -209,9 +207,7 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
) )
self.show_progress = self.encode_kwargs.pop("show_progress_bar") self.show_progress = self.encode_kwargs.pop("show_progress_bar")
model_config = ConfigDict( model_config = ConfigDict(extra="forbid", protected_namespaces=())
extra="forbid",
)
def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Compute doc embeddings using a HuggingFace instruct model. """Compute doc embeddings using a HuggingFace instruct model.
@@ -350,9 +346,7 @@ class HuggingFaceBgeEmbeddings(BaseModel, Embeddings):
) )
self.show_progress = self.encode_kwargs.pop("show_progress_bar") self.show_progress = self.encode_kwargs.pop("show_progress_bar")
model_config = ConfigDict( model_config = ConfigDict(extra="forbid", protected_namespaces=())
extra="forbid",
)
def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Compute doc embeddings using a HuggingFace transformer model. """Compute doc embeddings using a HuggingFace transformer model.

View File

@@ -48,9 +48,7 @@ class HuggingFaceHubEmbeddings(BaseModel, Embeddings):
huggingfacehub_api_token: Optional[str] = None huggingfacehub_api_token: Optional[str] = None
model_config = ConfigDict( model_config = ConfigDict(extra="forbid", protected_namespaces=())
extra="forbid",
)
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod

View File

@@ -106,9 +106,7 @@ class IpexLLMBgeEmbeddings(BaseModel, Embeddings):
if "-zh" in self.model_name: if "-zh" in self.model_name:
self.query_instruction = DEFAULT_QUERY_BGE_INSTRUCTION_ZH self.query_instruction = DEFAULT_QUERY_BGE_INSTRUCTION_ZH
model_config = ConfigDict( model_config = ConfigDict(extra="forbid", protected_namespaces=())
extra="forbid",
)
def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Compute doc embeddings using a HuggingFace transformer model. """Compute doc embeddings using a HuggingFace transformer model.

View File

@@ -166,9 +166,7 @@ class LocalAIEmbeddings(BaseModel, Embeddings):
model_kwargs: Dict[str, Any] = Field(default_factory=dict) model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified.""" """Holds any model parameters valid for `create` call not explicitly specified."""
model_config = ConfigDict( model_config = ConfigDict(extra="forbid", protected_namespaces=())
extra="forbid",
)
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod

View File

@@ -39,9 +39,7 @@ class ModelScopeEmbeddings(BaseModel, Embeddings):
model_revision=self.model_revision, model_revision=self.model_revision,
) )
model_config = ConfigDict( model_config = ConfigDict(extra="forbid", protected_namespaces=())
extra="forbid",
)
def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Compute doc embeddings using a modelscope embedding model. """Compute doc embeddings using a modelscope embedding model.

View File

@@ -85,9 +85,7 @@ class OCIGenAIEmbeddings(BaseModel, Embeddings):
"""Batch size of OCI GenAI embedding requests. OCI GenAI may handle up to 96 texts """Batch size of OCI GenAI embedding requests. OCI GenAI may handle up to 96 texts
per request""" per request"""
model_config = ConfigDict( model_config = ConfigDict(extra="forbid", protected_namespaces=())
extra="forbid",
)
@pre_init @pre_init
def validate_environment(cls, values: Dict) -> Dict: # pylint: disable=no-self-argument def validate_environment(cls, values: Dict) -> Dict: # pylint: disable=no-self-argument

View File

@@ -141,9 +141,7 @@ class OllamaEmbeddings(BaseModel, Embeddings):
"""Get the identifying parameters.""" """Get the identifying parameters."""
return {**{"model": self.model}, **self._default_params} return {**{"model": self.model}, **self._default_params}
model_config = ConfigDict( model_config = ConfigDict(extra="forbid", protected_namespaces=())
extra="forbid",
)
def _process_emb_response(self, input: str) -> List[float]: def _process_emb_response(self, input: str) -> List[float]:
"""Process a response from the API. """Process a response from the API.

View File

@@ -255,8 +255,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
"""Optional httpx.Client.""" """Optional httpx.Client."""
model_config = ConfigDict( model_config = ConfigDict(
populate_by_name=True, populate_by_name=True, extra="forbid", protected_namespaces=()
extra="forbid",
) )
@model_validator(mode="before") @model_validator(mode="before")

View File

@@ -254,9 +254,7 @@ class OpenVINOEmbeddings(BaseModel, Embeddings):
return all_embeddings return all_embeddings
model_config = ConfigDict( model_config = ConfigDict(extra="forbid", protected_namespaces=())
extra="forbid",
)
def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Compute doc embeddings using a HuggingFace transformer model. """Compute doc embeddings using a HuggingFace transformer model.

View File

@@ -23,9 +23,7 @@ class OVHCloudEmbeddings(BaseModel, Embeddings):
""" OVHcloud AI Endpoints region""" """ OVHcloud AI Endpoints region"""
region: str = "kepler" region: str = "kepler"
model_config = ConfigDict( model_config = ConfigDict(extra="forbid", protected_namespaces=())
extra="forbid",
)
def __init__(self, **kwargs: Any): def __init__(self, **kwargs: Any):
super().__init__(**kwargs) super().__init__(**kwargs)

View File

@@ -111,8 +111,7 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings):
""" """
model_config = ConfigDict( model_config = ConfigDict(
arbitrary_types_allowed=True, arbitrary_types_allowed=True, extra="forbid", protected_namespaces=()
extra="forbid",
) )
@pre_init @pre_init

View File

@@ -22,9 +22,7 @@ class SpacyEmbeddings(BaseModel, Embeddings):
model_name: str = "en_core_web_sm" model_name: str = "en_core_web_sm"
nlp: Optional[Any] = None nlp: Optional[Any] = None
model_config = ConfigDict( model_config = ConfigDict(extra="forbid", protected_namespaces=())
extra="forbid",
)
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod

View File

@@ -71,9 +71,7 @@ class YandexGPTEmbeddings(BaseModel, Embeddings):
If you provide personal data, confidential information, disable logging.""" If you provide personal data, confidential information, disable logging."""
grpc_metadata: Sequence grpc_metadata: Sequence
model_config = ConfigDict( model_config = ConfigDict(populate_by_name=True, protected_namespaces=())
populate_by_name=True,
)
@pre_init @pre_init
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:

View File

@@ -112,8 +112,9 @@ def test_serializable_mapping() -> None:
"chat_models", "chat_models",
"ChatGroq", "ChatGroq",
), ),
# TODO(0.3): For now we're skipping this test. Need to fix # TODO(0.3): For now we're skipping the below two tests. Need to fix
# so that it only runs when langchain-aws is installed. # so that it only runs when langchain-aws, langchain-google-genai
# are installed.
("langchain", "chat_models", "bedrock", "ChatBedrock"): ( ("langchain", "chat_models", "bedrock", "ChatBedrock"): (
"langchain_aws", "langchain_aws",
"chat_models", "chat_models",