mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-01 01:04:43 +00:00
feat(core): Upgrade pydantic to 2.x (#1428)
This commit is contained in:
@@ -159,9 +159,9 @@ def setup_dev_environment(
|
||||
|
||||
start_http = _check_has_http_trigger(dags)
|
||||
if start_http:
|
||||
from fastapi import FastAPI
|
||||
from dbgpt.util.fastapi import create_app
|
||||
|
||||
app = FastAPI()
|
||||
app = create_app()
|
||||
else:
|
||||
app = None
|
||||
system_app = SystemApp(app)
|
||||
|
@@ -4,6 +4,7 @@ DAGManager will load DAGs from dag_dirs, and register the trigger nodes
|
||||
to TriggerManager.
|
||||
"""
|
||||
import logging
|
||||
import threading
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||
@@ -29,6 +30,7 @@ class DAGManager(BaseComponent):
|
||||
from ..trigger.trigger_manager import DefaultTriggerManager
|
||||
|
||||
super().__init__(system_app)
|
||||
self.lock = threading.Lock()
|
||||
self.dag_loader = LocalFileDAGLoader(dag_dirs)
|
||||
self.system_app = system_app
|
||||
self.dag_map: Dict[str, DAG] = {}
|
||||
@@ -61,39 +63,54 @@ class DAGManager(BaseComponent):
|
||||
|
||||
def register_dag(self, dag: DAG, alias_name: Optional[str] = None):
|
||||
"""Register a DAG."""
|
||||
dag_id = dag.dag_id
|
||||
if dag_id in self.dag_map:
|
||||
raise ValueError(f"Register DAG error, DAG ID {dag_id} has already exist")
|
||||
self.dag_map[dag_id] = dag
|
||||
if alias_name:
|
||||
self.dag_alias_map[alias_name] = dag_id
|
||||
with self.lock:
|
||||
dag_id = dag.dag_id
|
||||
if dag_id in self.dag_map:
|
||||
raise ValueError(
|
||||
f"Register DAG error, DAG ID {dag_id} has already exist"
|
||||
)
|
||||
self.dag_map[dag_id] = dag
|
||||
if alias_name:
|
||||
self.dag_alias_map[alias_name] = dag_id
|
||||
|
||||
if self._trigger_manager:
|
||||
for trigger in dag.trigger_nodes:
|
||||
self._trigger_manager.register_trigger(trigger, self.system_app)
|
||||
self._trigger_manager.after_register()
|
||||
else:
|
||||
logger.warning("No trigger manager, not register dag trigger")
|
||||
if self._trigger_manager:
|
||||
for trigger in dag.trigger_nodes:
|
||||
self._trigger_manager.register_trigger(trigger, self.system_app)
|
||||
self._trigger_manager.after_register()
|
||||
else:
|
||||
logger.warning("No trigger manager, not register dag trigger")
|
||||
|
||||
def unregister_dag(self, dag_id: str):
|
||||
"""Unregister a DAG."""
|
||||
if dag_id not in self.dag_map:
|
||||
raise ValueError(f"Unregister DAG error, DAG ID {dag_id} does not exist")
|
||||
dag = self.dag_map[dag_id]
|
||||
# Clear the alias map
|
||||
for alias_name, _dag_id in self.dag_alias_map.items():
|
||||
if _dag_id == dag_id:
|
||||
with self.lock:
|
||||
if dag_id not in self.dag_map:
|
||||
raise ValueError(
|
||||
f"Unregister DAG error, DAG ID {dag_id} does not exist"
|
||||
)
|
||||
dag = self.dag_map[dag_id]
|
||||
|
||||
# Collect aliases to remove
|
||||
# TODO(fangyinc): It can be faster if we maintain a reverse map
|
||||
aliases_to_remove = [
|
||||
alias_name
|
||||
for alias_name, _dag_id in self.dag_alias_map.items()
|
||||
if _dag_id == dag_id
|
||||
]
|
||||
# Remove collected aliases
|
||||
for alias_name in aliases_to_remove:
|
||||
del self.dag_alias_map[alias_name]
|
||||
|
||||
if self._trigger_manager:
|
||||
for trigger in dag.trigger_nodes:
|
||||
self._trigger_manager.unregister_trigger(trigger, self.system_app)
|
||||
del self.dag_map[dag_id]
|
||||
if self._trigger_manager:
|
||||
for trigger in dag.trigger_nodes:
|
||||
self._trigger_manager.unregister_trigger(trigger, self.system_app)
|
||||
# Finally remove the DAG from the map
|
||||
del self.dag_map[dag_id]
|
||||
|
||||
def get_dag(
|
||||
self, dag_id: Optional[str] = None, alias_name: Optional[str] = None
|
||||
) -> Optional[DAG]:
|
||||
"""Get a DAG by dag_id or alias_name."""
|
||||
# Not lock, because it is read only and need to be fast
|
||||
if dag_id and dag_id in self.dag_map:
|
||||
return self.dag_map[dag_id]
|
||||
if alias_name in self.dag_alias_map:
|
||||
|
@@ -7,7 +7,13 @@ from datetime import date, datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Type, TypeVar, Union, cast
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field, ValidationError, root_validator
|
||||
from dbgpt._private.pydantic import (
|
||||
BaseModel,
|
||||
Field,
|
||||
ValidationError,
|
||||
model_to_dict,
|
||||
model_validator,
|
||||
)
|
||||
from dbgpt.core.awel.util.parameter_util import BaseDynamicOptions, OptionValue
|
||||
from dbgpt.core.interface.serialization import Serializable
|
||||
|
||||
@@ -281,7 +287,7 @@ class TypeMetadata(BaseModel):
|
||||
|
||||
def new(self: TM) -> TM:
|
||||
"""Copy the metadata."""
|
||||
return self.__class__(**self.dict())
|
||||
return self.__class__(**self.model_dump(exclude_defaults=True))
|
||||
|
||||
|
||||
class Parameter(TypeMetadata, Serializable):
|
||||
@@ -332,12 +338,15 @@ class Parameter(TypeMetadata, Serializable):
|
||||
None, description="The value of the parameter(Saved in the dag file)"
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Pre fill the metadata.
|
||||
|
||||
Transform the value to the real type.
|
||||
"""
|
||||
if not isinstance(values, dict):
|
||||
return values
|
||||
type_cls = values.get("type_cls")
|
||||
to_handle_values = {
|
||||
"value": values.get("value"),
|
||||
@@ -443,7 +452,7 @@ class Parameter(TypeMetadata, Serializable):
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert current metadata to json dict."""
|
||||
dict_value = self.dict(exclude={"options"})
|
||||
dict_value = model_to_dict(self, exclude={"options"})
|
||||
if not self.options:
|
||||
dict_value["options"] = None
|
||||
elif isinstance(self.options, BaseDynamicOptions):
|
||||
@@ -535,7 +544,7 @@ class BaseResource(Serializable, BaseModel):
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert current metadata to json dict."""
|
||||
return self.dict()
|
||||
return model_to_dict(self)
|
||||
|
||||
|
||||
class Resource(BaseResource, TypeMetadata):
|
||||
@@ -693,9 +702,12 @@ class BaseMetadata(BaseResource):
|
||||
)
|
||||
return runnable_parameters
|
||||
|
||||
@root_validator(pre=True)
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def base_pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Pre fill the metadata."""
|
||||
if not isinstance(values, dict):
|
||||
return values
|
||||
if "category_label" not in values:
|
||||
category = values["category"]
|
||||
if isinstance(category, str):
|
||||
@@ -713,7 +725,7 @@ class BaseMetadata(BaseResource):
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert current metadata to json dict."""
|
||||
dict_value = self.dict(exclude={"parameters"})
|
||||
dict_value = model_to_dict(self, exclude={"parameters"})
|
||||
dict_value["parameters"] = [
|
||||
parameter.to_dict() for parameter in self.parameters
|
||||
]
|
||||
@@ -738,9 +750,12 @@ class ResourceMetadata(BaseMetadata, TypeMetadata):
|
||||
],
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Pre fill the metadata."""
|
||||
if not isinstance(values, dict):
|
||||
return values
|
||||
if "flow_type" not in values:
|
||||
values["flow_type"] = "resource"
|
||||
if "id" not in values:
|
||||
@@ -846,9 +861,12 @@ class ViewMetadata(BaseMetadata):
|
||||
examples=["dbgpt.model.operators.LLMOperator"],
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Pre fill the metadata."""
|
||||
if not isinstance(values, dict):
|
||||
return values
|
||||
if "flow_type" not in values:
|
||||
values["flow_type"] = "operator"
|
||||
if "id" not in values:
|
||||
|
@@ -6,7 +6,13 @@ from contextlib import suppress
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field, root_validator, validator
|
||||
from dbgpt._private.pydantic import (
|
||||
BaseModel,
|
||||
Field,
|
||||
field_validator,
|
||||
model_to_dict,
|
||||
model_validator,
|
||||
)
|
||||
from dbgpt.core.awel.dag.base import DAG, DAGNode
|
||||
|
||||
from .base import (
|
||||
@@ -73,7 +79,8 @@ class FlowNodeData(BaseModel):
|
||||
..., description="Absolute position of the node"
|
||||
)
|
||||
|
||||
@validator("data", pre=True)
|
||||
@field_validator("data", mode="before")
|
||||
@classmethod
|
||||
def parse_data(cls, value: Any):
|
||||
"""Parse the data."""
|
||||
if isinstance(value, dict):
|
||||
@@ -123,9 +130,12 @@ class FlowEdgeData(BaseModel):
|
||||
examples=["buttonedge"],
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Pre fill the metadata."""
|
||||
if not isinstance(values, dict):
|
||||
return values
|
||||
if (
|
||||
"source_order" not in values
|
||||
and "source_handle" in values
|
||||
@@ -315,9 +325,12 @@ class FlowPanel(BaseModel):
|
||||
examples=["2021-08-01 12:00:00", "2021-08-01 12:00:01", "2021-08-01 12:00:02"],
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Pre fill the metadata."""
|
||||
if not isinstance(values, dict):
|
||||
return values
|
||||
label = values.get("label")
|
||||
name = values.get("name")
|
||||
flow_category = str(values.get("flow_category", ""))
|
||||
@@ -329,6 +342,10 @@ class FlowPanel(BaseModel):
|
||||
values["name"] = name
|
||||
return values
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dict."""
|
||||
return model_to_dict(self)
|
||||
|
||||
|
||||
class FlowFactory:
|
||||
"""Flow factory."""
|
||||
|
@@ -15,7 +15,14 @@ from typing import (
|
||||
get_origin,
|
||||
)
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt._private.pydantic import (
|
||||
BaseModel,
|
||||
Field,
|
||||
field_is_required,
|
||||
field_outer_type,
|
||||
model_fields,
|
||||
model_to_dict,
|
||||
)
|
||||
from dbgpt.util.i18n_utils import _
|
||||
|
||||
from ..dag.base import DAG
|
||||
@@ -61,7 +68,7 @@ class AWELHttpError(RuntimeError):
|
||||
|
||||
def _default_streaming_predict_func(body: "CommonRequestType") -> bool:
|
||||
if isinstance(body, BaseModel):
|
||||
body = body.dict()
|
||||
body = model_to_dict(body)
|
||||
elif isinstance(body, str):
|
||||
try:
|
||||
body = json.loads(body)
|
||||
@@ -254,7 +261,7 @@ class CommonLLMHttpRequestBody(BaseHttpBody):
|
||||
"or in full each time. "
|
||||
"If this parameter is not provided, the default is full return.",
|
||||
)
|
||||
enable_vis: str = Field(
|
||||
enable_vis: bool = Field(
|
||||
default=True, description="response content whether to output vis label"
|
||||
)
|
||||
extra: Optional[Dict[str, Any]] = Field(
|
||||
@@ -574,18 +581,20 @@ class HttpTrigger(Trigger):
|
||||
if isinstance(req_body_cls, type) and issubclass(
|
||||
req_body_cls, BaseModel
|
||||
):
|
||||
fields = req_body_cls.__fields__ # type: ignore
|
||||
fields = model_fields(req_body_cls) # type: ignore
|
||||
parameters = []
|
||||
for field_name, field in fields.items():
|
||||
default_value = (
|
||||
Parameter.empty if field.required else field.default
|
||||
Parameter.empty
|
||||
if field_is_required(field)
|
||||
else field.default
|
||||
)
|
||||
parameters.append(
|
||||
Parameter(
|
||||
name=field_name,
|
||||
kind=Parameter.KEYWORD_ONLY,
|
||||
default=default_value,
|
||||
annotation=field.outer_type_,
|
||||
annotation=field_outer_type(field),
|
||||
)
|
||||
)
|
||||
elif req_body_cls == Dict[str, Any] or req_body_cls == dict:
|
||||
@@ -1029,7 +1038,7 @@ class RequestBodyToDictOperator(MapOperator[CommonLLMHttpRequestBody, Dict[str,
|
||||
|
||||
async def map(self, request_body: CommonLLMHttpRequestBody) -> Dict[str, Any]:
|
||||
"""Map the request body to response body."""
|
||||
dict_value = request_body.dict()
|
||||
dict_value = model_to_dict(request_body)
|
||||
if not self._key:
|
||||
return dict_value
|
||||
else:
|
||||
@@ -1138,7 +1147,7 @@ class RequestedParsedOperator(MapOperator[CommonLLMHttpRequestBody, str]):
|
||||
|
||||
async def map(self, request_body: CommonLLMHttpRequestBody) -> str:
|
||||
"""Map the request body to response body."""
|
||||
dict_value = request_body.dict()
|
||||
dict_value = model_to_dict(request_body)
|
||||
if not self._key or self._key not in dict_value:
|
||||
raise ValueError(
|
||||
f"Prefix key {self._key} is not a valid key of the request body"
|
||||
|
@@ -4,7 +4,7 @@ import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, List
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field, root_validator
|
||||
from dbgpt._private.pydantic import BaseModel, Field, model_validator
|
||||
from dbgpt.core.interface.serialization import Serializable
|
||||
|
||||
_DEFAULT_DYNAMIC_REGISTRY = {}
|
||||
@@ -44,9 +44,12 @@ class FunctionDynamicOptions(BaseDynamicOptions):
|
||||
"""Return the option values of the parameter."""
|
||||
return self.func()
|
||||
|
||||
@root_validator(pre=True)
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Pre fill the function id."""
|
||||
if not isinstance(values, dict):
|
||||
return values
|
||||
func = values.get("func")
|
||||
if func is None:
|
||||
raise ValueError(
|
||||
|
Reference in New Issue
Block a user