feat(core): Upgrade pydantic to 2.x (#1428)

This commit is contained in:
Fangyin Cheng
2024-04-20 09:41:16 +08:00
committed by GitHub
parent baa1e3f9f6
commit 57be1ece18
103 changed files with 1146 additions and 534 deletions

View File

@@ -159,9 +159,9 @@ def setup_dev_environment(
start_http = _check_has_http_trigger(dags)
if start_http:
from fastapi import FastAPI
from dbgpt.util.fastapi import create_app
app = FastAPI()
app = create_app()
else:
app = None
system_app = SystemApp(app)

View File

@@ -4,6 +4,7 @@ DAGManager will load DAGs from dag_dirs, and register the trigger nodes
to TriggerManager.
"""
import logging
import threading
from typing import Dict, List, Optional
from dbgpt.component import BaseComponent, ComponentType, SystemApp
@@ -29,6 +30,7 @@ class DAGManager(BaseComponent):
from ..trigger.trigger_manager import DefaultTriggerManager
super().__init__(system_app)
self.lock = threading.Lock()
self.dag_loader = LocalFileDAGLoader(dag_dirs)
self.system_app = system_app
self.dag_map: Dict[str, DAG] = {}
@@ -61,39 +63,54 @@ class DAGManager(BaseComponent):
def register_dag(self, dag: DAG, alias_name: Optional[str] = None):
"""Register a DAG."""
dag_id = dag.dag_id
if dag_id in self.dag_map:
raise ValueError(f"Register DAG error, DAG ID {dag_id} has already exist")
self.dag_map[dag_id] = dag
if alias_name:
self.dag_alias_map[alias_name] = dag_id
with self.lock:
dag_id = dag.dag_id
if dag_id in self.dag_map:
raise ValueError(
f"Register DAG error, DAG ID {dag_id} has already exist"
)
self.dag_map[dag_id] = dag
if alias_name:
self.dag_alias_map[alias_name] = dag_id
if self._trigger_manager:
for trigger in dag.trigger_nodes:
self._trigger_manager.register_trigger(trigger, self.system_app)
self._trigger_manager.after_register()
else:
logger.warning("No trigger manager, not register dag trigger")
if self._trigger_manager:
for trigger in dag.trigger_nodes:
self._trigger_manager.register_trigger(trigger, self.system_app)
self._trigger_manager.after_register()
else:
logger.warning("No trigger manager, not register dag trigger")
def unregister_dag(self, dag_id: str):
"""Unregister a DAG."""
if dag_id not in self.dag_map:
raise ValueError(f"Unregister DAG error, DAG ID {dag_id} does not exist")
dag = self.dag_map[dag_id]
# Clear the alias map
for alias_name, _dag_id in self.dag_alias_map.items():
if _dag_id == dag_id:
with self.lock:
if dag_id not in self.dag_map:
raise ValueError(
f"Unregister DAG error, DAG ID {dag_id} does not exist"
)
dag = self.dag_map[dag_id]
# Collect aliases to remove
# TODO(fangyinc): It can be faster if we maintain a reverse map
aliases_to_remove = [
alias_name
for alias_name, _dag_id in self.dag_alias_map.items()
if _dag_id == dag_id
]
# Remove collected aliases
for alias_name in aliases_to_remove:
del self.dag_alias_map[alias_name]
if self._trigger_manager:
for trigger in dag.trigger_nodes:
self._trigger_manager.unregister_trigger(trigger, self.system_app)
del self.dag_map[dag_id]
if self._trigger_manager:
for trigger in dag.trigger_nodes:
self._trigger_manager.unregister_trigger(trigger, self.system_app)
# Finally remove the DAG from the map
del self.dag_map[dag_id]
def get_dag(
self, dag_id: Optional[str] = None, alias_name: Optional[str] = None
) -> Optional[DAG]:
"""Get a DAG by dag_id or alias_name."""
# Not lock, because it is read only and need to be fast
if dag_id and dag_id in self.dag_map:
return self.dag_map[dag_id]
if alias_name in self.dag_alias_map:

View File

@@ -7,7 +7,13 @@ from datetime import date, datetime
from enum import Enum
from typing import Any, Dict, List, Optional, Type, TypeVar, Union, cast
from dbgpt._private.pydantic import BaseModel, Field, ValidationError, root_validator
from dbgpt._private.pydantic import (
BaseModel,
Field,
ValidationError,
model_to_dict,
model_validator,
)
from dbgpt.core.awel.util.parameter_util import BaseDynamicOptions, OptionValue
from dbgpt.core.interface.serialization import Serializable
@@ -281,7 +287,7 @@ class TypeMetadata(BaseModel):
def new(self: TM) -> TM:
"""Copy the metadata."""
return self.__class__(**self.dict())
return self.__class__(**self.model_dump(exclude_defaults=True))
class Parameter(TypeMetadata, Serializable):
@@ -332,12 +338,15 @@ class Parameter(TypeMetadata, Serializable):
None, description="The value of the parameter(Saved in the dag file)"
)
@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Pre fill the metadata.
Transform the value to the real type.
"""
if not isinstance(values, dict):
return values
type_cls = values.get("type_cls")
to_handle_values = {
"value": values.get("value"),
@@ -443,7 +452,7 @@ class Parameter(TypeMetadata, Serializable):
def to_dict(self) -> Dict:
"""Convert current metadata to json dict."""
dict_value = self.dict(exclude={"options"})
dict_value = model_to_dict(self, exclude={"options"})
if not self.options:
dict_value["options"] = None
elif isinstance(self.options, BaseDynamicOptions):
@@ -535,7 +544,7 @@ class BaseResource(Serializable, BaseModel):
def to_dict(self) -> Dict:
"""Convert current metadata to json dict."""
return self.dict()
return model_to_dict(self)
class Resource(BaseResource, TypeMetadata):
@@ -693,9 +702,12 @@ class BaseMetadata(BaseResource):
)
return runnable_parameters
@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def base_pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Pre fill the metadata."""
if not isinstance(values, dict):
return values
if "category_label" not in values:
category = values["category"]
if isinstance(category, str):
@@ -713,7 +725,7 @@ class BaseMetadata(BaseResource):
def to_dict(self) -> Dict:
"""Convert current metadata to json dict."""
dict_value = self.dict(exclude={"parameters"})
dict_value = model_to_dict(self, exclude={"parameters"})
dict_value["parameters"] = [
parameter.to_dict() for parameter in self.parameters
]
@@ -738,9 +750,12 @@ class ResourceMetadata(BaseMetadata, TypeMetadata):
],
)
@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Pre fill the metadata."""
if not isinstance(values, dict):
return values
if "flow_type" not in values:
values["flow_type"] = "resource"
if "id" not in values:
@@ -846,9 +861,12 @@ class ViewMetadata(BaseMetadata):
examples=["dbgpt.model.operators.LLMOperator"],
)
@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Pre fill the metadata."""
if not isinstance(values, dict):
return values
if "flow_type" not in values:
values["flow_type"] = "operator"
if "id" not in values:

View File

@@ -6,7 +6,13 @@ from contextlib import suppress
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast
from dbgpt._private.pydantic import BaseModel, Field, root_validator, validator
from dbgpt._private.pydantic import (
BaseModel,
Field,
field_validator,
model_to_dict,
model_validator,
)
from dbgpt.core.awel.dag.base import DAG, DAGNode
from .base import (
@@ -73,7 +79,8 @@ class FlowNodeData(BaseModel):
..., description="Absolute position of the node"
)
@validator("data", pre=True)
@field_validator("data", mode="before")
@classmethod
def parse_data(cls, value: Any):
"""Parse the data."""
if isinstance(value, dict):
@@ -123,9 +130,12 @@ class FlowEdgeData(BaseModel):
examples=["buttonedge"],
)
@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Pre fill the metadata."""
if not isinstance(values, dict):
return values
if (
"source_order" not in values
and "source_handle" in values
@@ -315,9 +325,12 @@ class FlowPanel(BaseModel):
examples=["2021-08-01 12:00:00", "2021-08-01 12:00:01", "2021-08-01 12:00:02"],
)
@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Pre fill the metadata."""
if not isinstance(values, dict):
return values
label = values.get("label")
name = values.get("name")
flow_category = str(values.get("flow_category", ""))
@@ -329,6 +342,10 @@ class FlowPanel(BaseModel):
values["name"] = name
return values
def to_dict(self) -> Dict[str, Any]:
"""Convert to dict."""
return model_to_dict(self)
class FlowFactory:
"""Flow factory."""

View File

@@ -15,7 +15,14 @@ from typing import (
get_origin,
)
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt._private.pydantic import (
BaseModel,
Field,
field_is_required,
field_outer_type,
model_fields,
model_to_dict,
)
from dbgpt.util.i18n_utils import _
from ..dag.base import DAG
@@ -61,7 +68,7 @@ class AWELHttpError(RuntimeError):
def _default_streaming_predict_func(body: "CommonRequestType") -> bool:
if isinstance(body, BaseModel):
body = body.dict()
body = model_to_dict(body)
elif isinstance(body, str):
try:
body = json.loads(body)
@@ -254,7 +261,7 @@ class CommonLLMHttpRequestBody(BaseHttpBody):
"or in full each time. "
"If this parameter is not provided, the default is full return.",
)
enable_vis: str = Field(
enable_vis: bool = Field(
default=True, description="response content whether to output vis label"
)
extra: Optional[Dict[str, Any]] = Field(
@@ -574,18 +581,20 @@ class HttpTrigger(Trigger):
if isinstance(req_body_cls, type) and issubclass(
req_body_cls, BaseModel
):
fields = req_body_cls.__fields__ # type: ignore
fields = model_fields(req_body_cls) # type: ignore
parameters = []
for field_name, field in fields.items():
default_value = (
Parameter.empty if field.required else field.default
Parameter.empty
if field_is_required(field)
else field.default
)
parameters.append(
Parameter(
name=field_name,
kind=Parameter.KEYWORD_ONLY,
default=default_value,
annotation=field.outer_type_,
annotation=field_outer_type(field),
)
)
elif req_body_cls == Dict[str, Any] or req_body_cls == dict:
@@ -1029,7 +1038,7 @@ class RequestBodyToDictOperator(MapOperator[CommonLLMHttpRequestBody, Dict[str,
async def map(self, request_body: CommonLLMHttpRequestBody) -> Dict[str, Any]:
"""Map the request body to response body."""
dict_value = request_body.dict()
dict_value = model_to_dict(request_body)
if not self._key:
return dict_value
else:
@@ -1138,7 +1147,7 @@ class RequestedParsedOperator(MapOperator[CommonLLMHttpRequestBody, str]):
async def map(self, request_body: CommonLLMHttpRequestBody) -> str:
"""Map the request body to response body."""
dict_value = request_body.dict()
dict_value = model_to_dict(request_body)
if not self._key or self._key not in dict_value:
raise ValueError(
f"Prefix key {self._key} is not a valid key of the request body"

View File

@@ -4,7 +4,7 @@ import inspect
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List
from dbgpt._private.pydantic import BaseModel, Field, root_validator
from dbgpt._private.pydantic import BaseModel, Field, model_validator
from dbgpt.core.interface.serialization import Serializable
_DEFAULT_DYNAMIC_REGISTRY = {}
@@ -44,9 +44,12 @@ class FunctionDynamicOptions(BaseDynamicOptions):
"""Return the option values of the parameter."""
return self.func()
@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Pre fill the function id."""
if not isinstance(values, dict):
return values
func = values.get("func")
if func is None:
raise ValueError(