feat(core): Upgrade pydantic to 2.x (#1428)

This commit is contained in:
Fangyin Cheng
2024-04-20 09:41:16 +08:00
committed by GitHub
parent baa1e3f9f6
commit 57be1ece18
103 changed files with 1146 additions and 534 deletions

View File

@@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional
import aiohttp
import requests
from dbgpt._private.pydantic import BaseModel, Extra, Field
from dbgpt._private.pydantic import EXTRA_FORBID, BaseModel, ConfigDict, Field
from dbgpt.core import Embeddings
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
from dbgpt.util.i18n_utils import _
@@ -64,10 +64,12 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
)
"""
model_config = ConfigDict(extra=EXTRA_FORBID, protected_namespaces=())
client: Any #: :meta private:
model_name: str = DEFAULT_MODEL_NAME
"""Model name to use."""
cache_folder: Optional[str] = None
cache_folder: Optional[str] = Field(None, description="Path of the cache folder.")
"""Path to store models. Can be also set by SENTENCE_TRANSFORMERS_HOME
environment variable."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
@@ -79,7 +81,6 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
def __init__(self, **kwargs: Any):
"""Initialize the sentence_transformer."""
super().__init__(**kwargs)
try:
import sentence_transformers
@@ -89,14 +90,12 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
"Please install it with `pip install sentence-transformers`."
) from exc
self.client = sentence_transformers.SentenceTransformer(
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
kwargs["client"] = sentence_transformers.SentenceTransformer(
kwargs.get("model_name"),
cache_folder=kwargs.get("cache_folder"),
**kwargs.get("model_kwargs"),
)
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
super().__init__(**kwargs)
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Compute doc embeddings using a HuggingFace transformer model.
@@ -184,6 +183,8 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
)
"""
model_config = ConfigDict(extra=EXTRA_FORBID, protected_namespaces=())
client: Any #: :meta private:
model_name: str = DEFAULT_INSTRUCT_MODEL
"""Model name to use."""
@@ -201,20 +202,18 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
def __init__(self, **kwargs: Any):
"""Initialize the sentence_transformer."""
super().__init__(**kwargs)
try:
from InstructorEmbedding import INSTRUCTOR
self.client = INSTRUCTOR(
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
kwargs["client"] = INSTRUCTOR(
kwargs.get("model_name"),
cache_folder=kwargs.get("cache_folder"),
**kwargs.get("model_kwargs"),
)
except ImportError as e:
raise ImportError("Dependencies for InstructorEmbedding not found.") from e
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
super().__init__(**kwargs)
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Compute doc embeddings using a HuggingFace instruct model.
@@ -267,6 +266,8 @@ class HuggingFaceBgeEmbeddings(BaseModel, Embeddings):
)
"""
model_config = ConfigDict(extra=EXTRA_FORBID, protected_namespaces=())
client: Any #: :meta private:
model_name: str = DEFAULT_BGE_MODEL
"""Model name to use."""
@@ -282,7 +283,6 @@ class HuggingFaceBgeEmbeddings(BaseModel, Embeddings):
def __init__(self, **kwargs: Any):
"""Initialize the sentence_transformer."""
super().__init__(**kwargs)
try:
import sentence_transformers
@@ -292,17 +292,16 @@ class HuggingFaceBgeEmbeddings(BaseModel, Embeddings):
"Please install it with `pip install sentence_transformers`."
) from exc
self.client = sentence_transformers.SentenceTransformer(
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
kwargs["client"] = sentence_transformers.SentenceTransformer(
kwargs.get("model_name"),
cache_folder=kwargs.get("cache_folder"),
**kwargs.get("model_kwargs"),
)
super().__init__(**kwargs)
if "-zh" in self.model_name:
self.query_instruction = DEFAULT_QUERY_BGE_INSTRUCTION_ZH
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Compute doc embeddings using a HuggingFace transformer model.
@@ -360,6 +359,8 @@ class HuggingFaceInferenceAPIEmbeddings(BaseModel, Embeddings):
Requires a HuggingFace Inference API key and a model name.
"""
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
api_key: str
"""Your API key for the HuggingFace Inference API."""
model_name: str = "sentence-transformers/all-MiniLM-L6-v2"
@@ -475,6 +476,8 @@ class JinaEmbeddings(BaseModel, Embeddings):
"jina-embeddings-v2-base-en".
"""
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
api_url: Any #: :meta private:
session: Any #: :meta private:
api_key: str
@@ -485,7 +488,6 @@ class JinaEmbeddings(BaseModel, Embeddings):
def __init__(self, **kwargs):
"""Create a new JinaEmbeddings instance."""
super().__init__(**kwargs)
try:
import requests
except ImportError:
@@ -493,11 +495,23 @@ class JinaEmbeddings(BaseModel, Embeddings):
"The requests python package is not installed. Please install it with "
"`pip install requests`"
)
self.api_url = "https://api.jina.ai/v1/embeddings"
self.session = requests.Session()
self.session.headers.update(
{"Authorization": f"Bearer {self.api_key}", "Accept-Encoding": "identity"}
)
if "api_url" not in kwargs:
kwargs["api_url"] = "https://api.jina.ai/v1/embeddings"
if "session" not in kwargs: # noqa: SIM401
session = requests.Session()
else:
session = kwargs["session"]
api_key = kwargs.get("api_key")
if api_key:
session.headers.update(
{
"Authorization": f"Bearer {api_key}",
"Accept-Encoding": "identity",
}
)
kwargs["session"] = session
super().__init__(**kwargs)
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Get the embeddings for a list of texts.
@@ -627,6 +641,8 @@ class OpenAPIEmbeddings(BaseModel, Embeddings):
openai_embeddings.embed_documents(texts)
"""
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
api_url: str = Field(
default="http://localhost:8100/api/v1/embeddings",
description="The URL of the embeddings API.",
@@ -643,14 +659,8 @@ class OpenAPIEmbeddings(BaseModel, Embeddings):
session: Optional[requests.Session] = None
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def __init__(self, **kwargs):
"""Initialize the OpenAPIEmbeddings."""
super().__init__(**kwargs)
try:
import requests
except ImportError:
@@ -658,8 +668,15 @@ class OpenAPIEmbeddings(BaseModel, Embeddings):
"The requests python package is not installed. "
"Please install it with `pip install requests`"
)
self.session = requests.Session()
self.session.headers.update({"Authorization": f"Bearer {self.api_key}"})
if "session" not in kwargs: # noqa: SIM401
session = requests.Session()
else:
session = kwargs["session"]
api_key = kwargs.get("api_key")
if api_key:
session.headers.update({"Authorization": f"Bearer {api_key}"})
kwargs["session"] = session
super().__init__(**kwargs)
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Get the embeddings for a list of texts.

View File

@@ -58,7 +58,6 @@ class TokenTextSplitter(BaseModel):
tokenizer = tokenizer or globals_helper.tokenizer
all_seps = [separator] + (backup_separators or [])
self._split_fns = [split_by_sep(sep) for sep in all_seps] + [split_by_char()]
super().__init__(
chunk_size=chunk_size,
@@ -68,6 +67,7 @@ class TokenTextSplitter(BaseModel):
# callback_manager=callback_manager,
tokenizer=tokenizer,
)
self._split_fns = [split_by_sep(sep) for sep in all_seps] + [split_by_char()]
@classmethod
def class_name(cls) -> str: