mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-24 03:52:10 +00:00
community[patch]: Remove usage of @root_validator(allow_reuse=True) (#25235)
Remove usage of @root_validator(allow_reuse=True)
This commit is contained in:
@@ -12,8 +12,10 @@ from wsgiref.handlers import format_date_time
|
||||
import numpy as np
|
||||
import requests
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr
|
||||
from langchain_core.utils import (
|
||||
secret_from_env,
|
||||
)
|
||||
from numpy import ndarray
|
||||
|
||||
# SparkLLMTextEmbeddings is an embedding model provided by iFLYTEK Co., Ltd.. (https://iflytek.com/en/).
|
||||
@@ -102,11 +104,18 @@ class SparkLLMTextEmbeddings(BaseModel, Embeddings):
|
||||
]
|
||||
""" # noqa: E501
|
||||
|
||||
spark_app_id: Optional[SecretStr] = Field(default=None, alias="app_id")
|
||||
spark_app_id: SecretStr = Field(
|
||||
alias="app_id", default_factory=secret_from_env("SPARK_APP_ID")
|
||||
)
|
||||
"""Automatically inferred from env var `SPARK_APP_ID` if not provided."""
|
||||
spark_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
|
||||
spark_api_key: Optional[SecretStr] = Field(
|
||||
alias="api_key", default_factory=secret_from_env("SPARK_API_KEY", default=None)
|
||||
)
|
||||
"""Automatically inferred from env var `SPARK_API_KEY` if not provided."""
|
||||
spark_api_secret: Optional[SecretStr] = Field(default=None, alias="api_secret")
|
||||
spark_api_secret: Optional[SecretStr] = Field(
|
||||
alias="api_secret",
|
||||
default_factory=secret_from_env("SPARK_API_SECRET", default=None),
|
||||
)
|
||||
"""Automatically inferred from env var `SPARK_API_SECRET` if not provided."""
|
||||
base_url: str = Field(default="https://emb-cn-huabei-1.xf-yun.com/")
|
||||
"""Base URL path for API requests"""
|
||||
@@ -118,20 +127,6 @@ class SparkLLMTextEmbeddings(BaseModel, Embeddings):
|
||||
class Config:
|
||||
allow_population_by_field_name = True
|
||||
|
||||
@root_validator(allow_reuse=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that auth token exists in environment."""
|
||||
values["spark_app_id"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "spark_app_id", "SPARK_APP_ID")
|
||||
)
|
||||
values["spark_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "spark_api_key", "SPARK_API_KEY")
|
||||
)
|
||||
values["spark_api_secret"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "spark_api_secret", "SPARK_API_SECRET")
|
||||
)
|
||||
return values
|
||||
|
||||
def _embed(self, texts: List[str], host: str) -> Optional[List[List[float]]]:
|
||||
"""Internal method to call Spark Embedding API and return embeddings.
|
||||
|
||||
|
Reference in New Issue
Block a user