mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-12 20:53:48 +00:00
feat(core): Upgrade pydantic to 2.x (#1428)
This commit is contained in:
@@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional
|
||||
import aiohttp
|
||||
import requests
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Extra, Field
|
||||
from dbgpt._private.pydantic import EXTRA_FORBID, BaseModel, ConfigDict, Field
|
||||
from dbgpt.core import Embeddings
|
||||
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
|
||||
from dbgpt.util.i18n_utils import _
|
||||
@@ -64,10 +64,12 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
)
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra=EXTRA_FORBID, protected_namespaces=())
|
||||
|
||||
client: Any #: :meta private:
|
||||
model_name: str = DEFAULT_MODEL_NAME
|
||||
"""Model name to use."""
|
||||
cache_folder: Optional[str] = None
|
||||
cache_folder: Optional[str] = Field(None, description="Path of the cache folder.")
|
||||
"""Path to store models. Can be also set by SENTENCE_TRANSFORMERS_HOME
|
||||
environment variable."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
@@ -79,7 +81,6 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
"""Initialize the sentence_transformer."""
|
||||
super().__init__(**kwargs)
|
||||
try:
|
||||
import sentence_transformers
|
||||
|
||||
@@ -89,14 +90,12 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
||||
"Please install it with `pip install sentence-transformers`."
|
||||
) from exc
|
||||
|
||||
self.client = sentence_transformers.SentenceTransformer(
|
||||
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
|
||||
kwargs["client"] = sentence_transformers.SentenceTransformer(
|
||||
kwargs.get("model_name"),
|
||||
cache_folder=kwargs.get("cache_folder"),
|
||||
**kwargs.get("model_kwargs"),
|
||||
)
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Compute doc embeddings using a HuggingFace transformer model.
|
||||
@@ -184,6 +183,8 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
|
||||
)
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra=EXTRA_FORBID, protected_namespaces=())
|
||||
|
||||
client: Any #: :meta private:
|
||||
model_name: str = DEFAULT_INSTRUCT_MODEL
|
||||
"""Model name to use."""
|
||||
@@ -201,20 +202,18 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
"""Initialize the sentence_transformer."""
|
||||
super().__init__(**kwargs)
|
||||
try:
|
||||
from InstructorEmbedding import INSTRUCTOR
|
||||
|
||||
self.client = INSTRUCTOR(
|
||||
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
|
||||
kwargs["client"] = INSTRUCTOR(
|
||||
kwargs.get("model_name"),
|
||||
cache_folder=kwargs.get("cache_folder"),
|
||||
**kwargs.get("model_kwargs"),
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ImportError("Dependencies for InstructorEmbedding not found.") from e
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Compute doc embeddings using a HuggingFace instruct model.
|
||||
@@ -267,6 +266,8 @@ class HuggingFaceBgeEmbeddings(BaseModel, Embeddings):
|
||||
)
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra=EXTRA_FORBID, protected_namespaces=())
|
||||
|
||||
client: Any #: :meta private:
|
||||
model_name: str = DEFAULT_BGE_MODEL
|
||||
"""Model name to use."""
|
||||
@@ -282,7 +283,6 @@ class HuggingFaceBgeEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
"""Initialize the sentence_transformer."""
|
||||
super().__init__(**kwargs)
|
||||
try:
|
||||
import sentence_transformers
|
||||
|
||||
@@ -292,17 +292,16 @@ class HuggingFaceBgeEmbeddings(BaseModel, Embeddings):
|
||||
"Please install it with `pip install sentence_transformers`."
|
||||
) from exc
|
||||
|
||||
self.client = sentence_transformers.SentenceTransformer(
|
||||
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
|
||||
kwargs["client"] = sentence_transformers.SentenceTransformer(
|
||||
kwargs.get("model_name"),
|
||||
cache_folder=kwargs.get("cache_folder"),
|
||||
**kwargs.get("model_kwargs"),
|
||||
)
|
||||
|
||||
super().__init__(**kwargs)
|
||||
if "-zh" in self.model_name:
|
||||
self.query_instruction = DEFAULT_QUERY_BGE_INSTRUCTION_ZH
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Compute doc embeddings using a HuggingFace transformer model.
|
||||
|
||||
@@ -360,6 +359,8 @@ class HuggingFaceInferenceAPIEmbeddings(BaseModel, Embeddings):
|
||||
Requires a HuggingFace Inference API key and a model name.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
|
||||
|
||||
api_key: str
|
||||
"""Your API key for the HuggingFace Inference API."""
|
||||
model_name: str = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
@@ -475,6 +476,8 @@ class JinaEmbeddings(BaseModel, Embeddings):
|
||||
"jina-embeddings-v2-base-en".
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
|
||||
|
||||
api_url: Any #: :meta private:
|
||||
session: Any #: :meta private:
|
||||
api_key: str
|
||||
@@ -485,7 +488,6 @@ class JinaEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Create a new JinaEmbeddings instance."""
|
||||
super().__init__(**kwargs)
|
||||
try:
|
||||
import requests
|
||||
except ImportError:
|
||||
@@ -493,11 +495,23 @@ class JinaEmbeddings(BaseModel, Embeddings):
|
||||
"The requests python package is not installed. Please install it with "
|
||||
"`pip install requests`"
|
||||
)
|
||||
self.api_url = "https://api.jina.ai/v1/embeddings"
|
||||
self.session = requests.Session()
|
||||
self.session.headers.update(
|
||||
{"Authorization": f"Bearer {self.api_key}", "Accept-Encoding": "identity"}
|
||||
)
|
||||
if "api_url" not in kwargs:
|
||||
kwargs["api_url"] = "https://api.jina.ai/v1/embeddings"
|
||||
if "session" not in kwargs: # noqa: SIM401
|
||||
session = requests.Session()
|
||||
else:
|
||||
session = kwargs["session"]
|
||||
api_key = kwargs.get("api_key")
|
||||
if api_key:
|
||||
session.headers.update(
|
||||
{
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Accept-Encoding": "identity",
|
||||
}
|
||||
)
|
||||
kwargs["session"] = session
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Get the embeddings for a list of texts.
|
||||
@@ -627,6 +641,8 @@ class OpenAPIEmbeddings(BaseModel, Embeddings):
|
||||
openai_embeddings.embed_documents(texts)
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
|
||||
|
||||
api_url: str = Field(
|
||||
default="http://localhost:8100/api/v1/embeddings",
|
||||
description="The URL of the embeddings API.",
|
||||
@@ -643,14 +659,8 @@ class OpenAPIEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
session: Optional[requests.Session] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize the OpenAPIEmbeddings."""
|
||||
super().__init__(**kwargs)
|
||||
try:
|
||||
import requests
|
||||
except ImportError:
|
||||
@@ -658,8 +668,15 @@ class OpenAPIEmbeddings(BaseModel, Embeddings):
|
||||
"The requests python package is not installed. "
|
||||
"Please install it with `pip install requests`"
|
||||
)
|
||||
self.session = requests.Session()
|
||||
self.session.headers.update({"Authorization": f"Bearer {self.api_key}"})
|
||||
if "session" not in kwargs: # noqa: SIM401
|
||||
session = requests.Session()
|
||||
else:
|
||||
session = kwargs["session"]
|
||||
api_key = kwargs.get("api_key")
|
||||
if api_key:
|
||||
session.headers.update({"Authorization": f"Bearer {api_key}"})
|
||||
kwargs["session"] = session
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Get the embeddings for a list of texts.
|
||||
|
@@ -58,7 +58,6 @@ class TokenTextSplitter(BaseModel):
|
||||
tokenizer = tokenizer or globals_helper.tokenizer
|
||||
|
||||
all_seps = [separator] + (backup_separators or [])
|
||||
self._split_fns = [split_by_sep(sep) for sep in all_seps] + [split_by_char()]
|
||||
|
||||
super().__init__(
|
||||
chunk_size=chunk_size,
|
||||
@@ -68,6 +67,7 @@ class TokenTextSplitter(BaseModel):
|
||||
# callback_manager=callback_manager,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
self._split_fns = [split_by_sep(sep) for sep in all_seps] + [split_by_char()]
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
|
Reference in New Issue
Block a user