feat(core): Upgrade pydantic to 2.x (#1428)

This commit is contained in:
Fangyin Cheng
2024-04-20 09:41:16 +08:00
committed by GitHub
parent baa1e3f9f6
commit 57be1ece18
103 changed files with 1146 additions and 534 deletions

View File

@@ -159,9 +159,9 @@ def setup_dev_environment(
start_http = _check_has_http_trigger(dags)
if start_http:
from fastapi import FastAPI
from dbgpt.util.fastapi import create_app
app = FastAPI()
app = create_app()
else:
app = None
system_app = SystemApp(app)

View File

@@ -4,6 +4,7 @@ DAGManager will load DAGs from dag_dirs, and register the trigger nodes
to TriggerManager.
"""
import logging
import threading
from typing import Dict, List, Optional
from dbgpt.component import BaseComponent, ComponentType, SystemApp
@@ -29,6 +30,7 @@ class DAGManager(BaseComponent):
from ..trigger.trigger_manager import DefaultTriggerManager
super().__init__(system_app)
self.lock = threading.Lock()
self.dag_loader = LocalFileDAGLoader(dag_dirs)
self.system_app = system_app
self.dag_map: Dict[str, DAG] = {}
@@ -61,39 +63,54 @@ class DAGManager(BaseComponent):
def register_dag(self, dag: DAG, alias_name: Optional[str] = None):
"""Register a DAG."""
dag_id = dag.dag_id
if dag_id in self.dag_map:
raise ValueError(f"Register DAG error, DAG ID {dag_id} has already exist")
self.dag_map[dag_id] = dag
if alias_name:
self.dag_alias_map[alias_name] = dag_id
with self.lock:
dag_id = dag.dag_id
if dag_id in self.dag_map:
raise ValueError(
f"Register DAG error, DAG ID {dag_id} has already exist"
)
self.dag_map[dag_id] = dag
if alias_name:
self.dag_alias_map[alias_name] = dag_id
if self._trigger_manager:
for trigger in dag.trigger_nodes:
self._trigger_manager.register_trigger(trigger, self.system_app)
self._trigger_manager.after_register()
else:
logger.warning("No trigger manager, not register dag trigger")
if self._trigger_manager:
for trigger in dag.trigger_nodes:
self._trigger_manager.register_trigger(trigger, self.system_app)
self._trigger_manager.after_register()
else:
logger.warning("No trigger manager, not register dag trigger")
def unregister_dag(self, dag_id: str):
"""Unregister a DAG."""
if dag_id not in self.dag_map:
raise ValueError(f"Unregister DAG error, DAG ID {dag_id} does not exist")
dag = self.dag_map[dag_id]
# Clear the alias map
for alias_name, _dag_id in self.dag_alias_map.items():
if _dag_id == dag_id:
with self.lock:
if dag_id not in self.dag_map:
raise ValueError(
f"Unregister DAG error, DAG ID {dag_id} does not exist"
)
dag = self.dag_map[dag_id]
# Collect aliases to remove
# TODO(fangyinc): It can be faster if we maintain a reverse map
aliases_to_remove = [
alias_name
for alias_name, _dag_id in self.dag_alias_map.items()
if _dag_id == dag_id
]
# Remove collected aliases
for alias_name in aliases_to_remove:
del self.dag_alias_map[alias_name]
if self._trigger_manager:
for trigger in dag.trigger_nodes:
self._trigger_manager.unregister_trigger(trigger, self.system_app)
del self.dag_map[dag_id]
if self._trigger_manager:
for trigger in dag.trigger_nodes:
self._trigger_manager.unregister_trigger(trigger, self.system_app)
# Finally remove the DAG from the map
del self.dag_map[dag_id]
def get_dag(
self, dag_id: Optional[str] = None, alias_name: Optional[str] = None
) -> Optional[DAG]:
"""Get a DAG by dag_id or alias_name."""
# Not lock, because it is read only and need to be fast
if dag_id and dag_id in self.dag_map:
return self.dag_map[dag_id]
if alias_name in self.dag_alias_map:

View File

@@ -7,7 +7,13 @@ from datetime import date, datetime
from enum import Enum
from typing import Any, Dict, List, Optional, Type, TypeVar, Union, cast
from dbgpt._private.pydantic import BaseModel, Field, ValidationError, root_validator
from dbgpt._private.pydantic import (
BaseModel,
Field,
ValidationError,
model_to_dict,
model_validator,
)
from dbgpt.core.awel.util.parameter_util import BaseDynamicOptions, OptionValue
from dbgpt.core.interface.serialization import Serializable
@@ -281,7 +287,7 @@ class TypeMetadata(BaseModel):
def new(self: TM) -> TM:
"""Copy the metadata."""
return self.__class__(**self.dict())
return self.__class__(**self.model_dump(exclude_defaults=True))
class Parameter(TypeMetadata, Serializable):
@@ -332,12 +338,15 @@ class Parameter(TypeMetadata, Serializable):
None, description="The value of the parameter(Saved in the dag file)"
)
@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Pre fill the metadata.
Transform the value to the real type.
"""
if not isinstance(values, dict):
return values
type_cls = values.get("type_cls")
to_handle_values = {
"value": values.get("value"),
@@ -443,7 +452,7 @@ class Parameter(TypeMetadata, Serializable):
def to_dict(self) -> Dict:
"""Convert current metadata to json dict."""
dict_value = self.dict(exclude={"options"})
dict_value = model_to_dict(self, exclude={"options"})
if not self.options:
dict_value["options"] = None
elif isinstance(self.options, BaseDynamicOptions):
@@ -535,7 +544,7 @@ class BaseResource(Serializable, BaseModel):
def to_dict(self) -> Dict:
"""Convert current metadata to json dict."""
return self.dict()
return model_to_dict(self)
class Resource(BaseResource, TypeMetadata):
@@ -693,9 +702,12 @@ class BaseMetadata(BaseResource):
)
return runnable_parameters
@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def base_pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Pre fill the metadata."""
if not isinstance(values, dict):
return values
if "category_label" not in values:
category = values["category"]
if isinstance(category, str):
@@ -713,7 +725,7 @@ class BaseMetadata(BaseResource):
def to_dict(self) -> Dict:
"""Convert current metadata to json dict."""
dict_value = self.dict(exclude={"parameters"})
dict_value = model_to_dict(self, exclude={"parameters"})
dict_value["parameters"] = [
parameter.to_dict() for parameter in self.parameters
]
@@ -738,9 +750,12 @@ class ResourceMetadata(BaseMetadata, TypeMetadata):
],
)
@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Pre fill the metadata."""
if not isinstance(values, dict):
return values
if "flow_type" not in values:
values["flow_type"] = "resource"
if "id" not in values:
@@ -846,9 +861,12 @@ class ViewMetadata(BaseMetadata):
examples=["dbgpt.model.operators.LLMOperator"],
)
@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Pre fill the metadata."""
if not isinstance(values, dict):
return values
if "flow_type" not in values:
values["flow_type"] = "operator"
if "id" not in values:

View File

@@ -6,7 +6,13 @@ from contextlib import suppress
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast
from dbgpt._private.pydantic import BaseModel, Field, root_validator, validator
from dbgpt._private.pydantic import (
BaseModel,
Field,
field_validator,
model_to_dict,
model_validator,
)
from dbgpt.core.awel.dag.base import DAG, DAGNode
from .base import (
@@ -73,7 +79,8 @@ class FlowNodeData(BaseModel):
..., description="Absolute position of the node"
)
@validator("data", pre=True)
@field_validator("data", mode="before")
@classmethod
def parse_data(cls, value: Any):
"""Parse the data."""
if isinstance(value, dict):
@@ -123,9 +130,12 @@ class FlowEdgeData(BaseModel):
examples=["buttonedge"],
)
@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Pre fill the metadata."""
if not isinstance(values, dict):
return values
if (
"source_order" not in values
and "source_handle" in values
@@ -315,9 +325,12 @@ class FlowPanel(BaseModel):
examples=["2021-08-01 12:00:00", "2021-08-01 12:00:01", "2021-08-01 12:00:02"],
)
@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Pre fill the metadata."""
if not isinstance(values, dict):
return values
label = values.get("label")
name = values.get("name")
flow_category = str(values.get("flow_category", ""))
@@ -329,6 +342,10 @@ class FlowPanel(BaseModel):
values["name"] = name
return values
def to_dict(self) -> Dict[str, Any]:
"""Convert to dict."""
return model_to_dict(self)
class FlowFactory:
"""Flow factory."""

View File

@@ -15,7 +15,14 @@ from typing import (
get_origin,
)
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt._private.pydantic import (
BaseModel,
Field,
field_is_required,
field_outer_type,
model_fields,
model_to_dict,
)
from dbgpt.util.i18n_utils import _
from ..dag.base import DAG
@@ -61,7 +68,7 @@ class AWELHttpError(RuntimeError):
def _default_streaming_predict_func(body: "CommonRequestType") -> bool:
if isinstance(body, BaseModel):
body = body.dict()
body = model_to_dict(body)
elif isinstance(body, str):
try:
body = json.loads(body)
@@ -254,7 +261,7 @@ class CommonLLMHttpRequestBody(BaseHttpBody):
"or in full each time. "
"If this parameter is not provided, the default is full return.",
)
enable_vis: str = Field(
enable_vis: bool = Field(
default=True, description="response content whether to output vis label"
)
extra: Optional[Dict[str, Any]] = Field(
@@ -574,18 +581,20 @@ class HttpTrigger(Trigger):
if isinstance(req_body_cls, type) and issubclass(
req_body_cls, BaseModel
):
fields = req_body_cls.__fields__ # type: ignore
fields = model_fields(req_body_cls) # type: ignore
parameters = []
for field_name, field in fields.items():
default_value = (
Parameter.empty if field.required else field.default
Parameter.empty
if field_is_required(field)
else field.default
)
parameters.append(
Parameter(
name=field_name,
kind=Parameter.KEYWORD_ONLY,
default=default_value,
annotation=field.outer_type_,
annotation=field_outer_type(field),
)
)
elif req_body_cls == Dict[str, Any] or req_body_cls == dict:
@@ -1029,7 +1038,7 @@ class RequestBodyToDictOperator(MapOperator[CommonLLMHttpRequestBody, Dict[str,
async def map(self, request_body: CommonLLMHttpRequestBody) -> Dict[str, Any]:
"""Map the request body to response body."""
dict_value = request_body.dict()
dict_value = model_to_dict(request_body)
if not self._key:
return dict_value
else:
@@ -1138,7 +1147,7 @@ class RequestedParsedOperator(MapOperator[CommonLLMHttpRequestBody, str]):
async def map(self, request_body: CommonLLMHttpRequestBody) -> str:
"""Map the request body to response body."""
dict_value = request_body.dict()
dict_value = model_to_dict(request_body)
if not self._key or self._key not in dict_value:
raise ValueError(
f"Prefix key {self._key} is not a valid key of the request body"

View File

@@ -4,7 +4,7 @@ import inspect
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List
from dbgpt._private.pydantic import BaseModel, Field, root_validator
from dbgpt._private.pydantic import BaseModel, Field, model_validator
from dbgpt.core.interface.serialization import Serializable
_DEFAULT_DYNAMIC_REGISTRY = {}
@@ -44,9 +44,12 @@ class FunctionDynamicOptions(BaseDynamicOptions):
"""Return the option values of the parameter."""
return self.func()
@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Pre fill the function id."""
if not isinstance(values, dict):
return values
func = values.get("func")
if func is None:
raise ValueError(

View File

@@ -4,7 +4,7 @@ import json
import uuid
from typing import Any, Dict
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt._private.pydantic import BaseModel, Field, model_to_dict
class Document(BaseModel):
@@ -64,7 +64,7 @@ class Chunk(Document):
def to_dict(self, **kwargs: Any) -> Dict[str, Any]:
"""Convert Chunk to dict."""
data = self.dict(**kwargs)
data = model_to_dict(self, **kwargs)
data["class_name"] = self.class_name()
return data

View File

@@ -10,7 +10,7 @@ from typing import Any, AsyncIterator, Dict, List, Optional, Union
from cachetools import TTLCache
from dbgpt._private.pydantic import BaseModel
from dbgpt._private.pydantic import BaseModel, model_to_dict
from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType
from dbgpt.util import BaseParameters
from dbgpt.util.annotations import PublicAPI
@@ -312,7 +312,7 @@ class ModelRequest:
if isinstance(context, dict):
context_dict = context
elif isinstance(context, BaseModel):
context_dict = context.dict()
context_dict = model_to_dict(context)
if context_dict and "stream" not in context_dict:
context_dict["stream"] = stream
if context_dict:

View File

@@ -6,7 +6,7 @@ from abc import ABC, abstractmethod
from datetime import datetime
from typing import Callable, Dict, List, Optional, Tuple, Union, cast
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt._private.pydantic import BaseModel, Field, model_to_dict
from dbgpt.core.interface.storage import (
InMemoryStorage,
ResourceIdentifier,
@@ -42,7 +42,7 @@ class BaseMessage(BaseModel, ABC):
"""
return {
"type": self.type,
"data": self.dict(),
"data": model_to_dict(self),
"index": self.index,
"round_index": self.round_index,
}
@@ -264,7 +264,7 @@ class ModelMessage(BaseModel):
Returns:
List[Dict[str, str]]: The dict list
"""
return list(map(lambda m: m.dict(), messages))
return list(map(lambda m: model_to_dict(m), messages))
@staticmethod
def build_human_message(content: str) -> "ModelMessage":

View File

@@ -2,7 +2,7 @@
from abc import ABC
from typing import Any, Dict, List, Optional, Union
from dbgpt._private.pydantic import root_validator
from dbgpt._private.pydantic import model_validator
from dbgpt.core import (
ModelMessage,
ModelMessageRoleType,
@@ -71,9 +71,12 @@ from dbgpt.util.i18n_utils import _
class CommonChatPromptTemplate(ChatPromptTemplate):
"""The common chat prompt template."""
@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Pre fill the messages."""
if not isinstance(values, dict):
return values
if "system_message" not in values:
values["system_message"] = "You are a helpful AI Assistant."
if "human_message" not in values:

View File

@@ -8,7 +8,7 @@ from abc import ABC, abstractmethod
from string import Formatter
from typing import Any, Callable, Dict, List, Optional, Set, Union
from dbgpt._private.pydantic import BaseModel, root_validator
from dbgpt._private.pydantic import BaseModel, ConfigDict, model_validator
from dbgpt.core.interface.message import BaseMessage, HumanMessage, SystemMessage
from dbgpt.core.interface.storage import (
InMemoryStorage,
@@ -51,6 +51,8 @@ class BasePromptTemplate(BaseModel):
class PromptTemplate(BasePromptTemplate):
"""Prompt template."""
model_config = ConfigDict(arbitrary_types_allowed=True)
template: str
"""The prompt template."""
@@ -69,11 +71,6 @@ class PromptTemplate(BasePromptTemplate):
template_define: Optional[str] = None
"""this template define"""
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@property
def _prompt_type(self) -> str:
"""Return the prompt type key."""
@@ -239,9 +236,12 @@ class ChatPromptTemplate(BasePromptTemplate):
raise ValueError(f"Unsupported message type: {type(message)}")
return result_messages
@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def base_pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Pre-fill the messages."""
if not isinstance(values, dict):
return values
input_variables = values.get("input_variables", {})
messages = values.get("messages", [])
if not input_variables:

View File

@@ -9,7 +9,7 @@ from typing import Dict, Optional, Type
class Serializable(ABC):
"""The serializable abstract class."""
serializer: Optional["Serializer"] = None
_serializer: Optional["Serializer"] = None
@abstractmethod
def to_dict(self) -> Dict:
@@ -21,11 +21,12 @@ class Serializable(ABC):
Returns:
bytes: The byte array after serialization
"""
if self.serializer is None:
if self._serializer is None:
raise ValueError(
"Serializer is not set. Please set the serializer before serialization."
"Serializer is not set. Please set the serializer before "
"serialization."
)
return self.serializer.serialize(self)
return self._serializer.serialize(self)
def set_serializer(self, serializer: "Serializer") -> None:
"""Set the serializer for current serializable object.
@@ -33,7 +34,7 @@ class Serializable(ABC):
Args:
serializer (Serializer): The serializer to set
"""
self.serializer = serializer
self._serializer = serializer
class Serializer(ABC):

View File

@@ -426,7 +426,7 @@ class InMemoryStorage(StorageInterface[T, T]):
"""
if not data:
raise StorageError("Data cannot be None")
if not data.serializer:
if not data._serializer:
data.set_serializer(self.serializer)
if data.identifier.str_identifier in self._data:
@@ -439,7 +439,7 @@ class InMemoryStorage(StorageInterface[T, T]):
"""Update the data to the storage."""
if not data:
raise StorageError("Data cannot be None")
if not data.serializer:
if not data._serializer:
data.set_serializer(self.serializer)
self._data[data.identifier.str_identifier] = data.serialize()

View File

@@ -2,9 +2,10 @@
import time
import uuid
from typing import Any, Generic, List, Literal, Optional, TypeVar
from enum import IntEnum
from typing import Any, Dict, Generic, List, Literal, Optional, TypeVar, Union
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt._private.pydantic import BaseModel, Field, model_to_dict
T = TypeVar("T")
@@ -41,6 +42,28 @@ class Result(BaseModel, Generic[T]):
"""
return Result(success=False, err_code=err_code, err_msg=msg, data=None)
def to_dict(self, **kwargs) -> Dict[str, Any]:
"""Convert to dict."""
return model_to_dict(self, **kwargs)
class APIChatCompletionRequest(BaseModel):
"""Chat completion request entity."""
model: str = Field(..., description="Model name")
messages: Union[str, List[Dict[str, str]]] = Field(..., description="Messages")
temperature: Optional[float] = Field(0.7, description="Temperature")
top_p: Optional[float] = Field(1.0, description="Top p")
top_k: Optional[int] = Field(-1, description="Top k")
n: Optional[int] = Field(1, description="Number of completions")
max_tokens: Optional[int] = Field(None, description="Max tokens")
stop: Optional[Union[str, List[str]]] = Field(None, description="Stop")
stream: Optional[bool] = Field(False, description="Stream")
user: Optional[str] = Field(None, description="User")
repetition_penalty: Optional[float] = Field(1.0, description="Repetition penalty")
frequency_penalty: Optional[float] = Field(0.0, description="Frequency penalty")
presence_penalty: Optional[float] = Field(0.0, description="Presence penalty")
class DeltaMessage(BaseModel):
"""Delta message entity for chat completion response."""
@@ -122,3 +145,97 @@ class ErrorResponse(BaseModel):
object: str = Field("error", description="Object type")
message: str = Field(..., description="Error message")
code: int = Field(..., description="Error code")
class EmbeddingsRequest(BaseModel):
"""Embeddings request entity."""
model: Optional[str] = Field(None, description="Model name")
engine: Optional[str] = Field(None, description="Engine name")
input: Union[str, List[Any]] = Field(..., description="Input data")
user: Optional[str] = Field(None, description="User name")
encoding_format: Optional[str] = Field(None, description="Encoding format")
class EmbeddingsResponse(BaseModel):
"""Embeddings response entity."""
object: str = Field("list", description="Object type")
data: List[Dict[str, Any]] = Field(..., description="Data list")
model: str = Field(..., description="Model name")
usage: UsageInfo = Field(..., description="Usage info")
class ModelPermission(BaseModel):
"""Model permission entity."""
id: str = Field(
default_factory=lambda: f"modelperm-{str(uuid.uuid1())}",
description="Permission ID",
)
object: str = Field("model_permission", description="Object type")
created: int = Field(
default_factory=lambda: int(time.time()), description="Created time"
)
allow_create_engine: bool = Field(False, description="Allow create engine")
allow_sampling: bool = Field(True, description="Allow sampling")
allow_logprobs: bool = Field(True, description="Allow logprobs")
allow_search_indices: bool = Field(True, description="Allow search indices")
allow_view: bool = Field(True, description="Allow view")
allow_fine_tuning: bool = Field(False, description="Allow fine tuning")
organization: str = Field("*", description="Organization")
group: Optional[str] = Field(None, description="Group")
is_blocking: bool = Field(False, description="Is blocking")
class ModelCard(BaseModel):
"""Model card entity."""
id: str = Field(..., description="Model ID")
object: str = Field("model", description="Object type")
created: int = Field(
default_factory=lambda: int(time.time()), description="Created time"
)
owned_by: str = Field("DB-GPT", description="Owned by")
root: Optional[str] = Field(None, description="Root")
parent: Optional[str] = Field(None, description="Parent")
permission: List[ModelPermission] = Field(
default_factory=list, description="Permission"
)
class ModelList(BaseModel):
"""Model list entity."""
object: str = Field("list", description="Object type")
data: List[ModelCard] = Field(default_factory=list, description="Model list data")
class ErrorCode(IntEnum):
"""Error code enumeration.
https://platform.openai.com/docs/guides/error-codes/api-errors.
Adapted from fastchat.constants.
"""
VALIDATION_TYPE_ERROR = 40001
INVALID_AUTH_KEY = 40101
INCORRECT_AUTH_KEY = 40102
NO_PERMISSION = 40103
INVALID_MODEL = 40301
PARAM_OUT_OF_RANGE = 40302
CONTEXT_OVERFLOW = 40303
RATE_LIMIT = 42901
QUOTA_EXCEEDED = 42902
ENGINE_OVERLOADED = 42903
INTERNAL_ERROR = 50001
CUDA_OUT_OF_MEMORY = 50002
GRADIO_REQUEST_ERROR = 50003
GRADIO_STREAM_UNKNOWN_ERROR = 50004
CONTROLLER_NO_WORKER = 50005
CONTROLLER_WORKER_TIMEOUT = 50006