mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-12 12:37:14 +00:00
feat(core): Upgrade pydantic to 2.x (#1428)
This commit is contained in:
@@ -15,7 +15,13 @@ from typing import (
|
||||
get_origin,
|
||||
)
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
from dbgpt._private.pydantic import (
|
||||
BaseModel,
|
||||
field_default,
|
||||
field_description,
|
||||
model_fields,
|
||||
model_to_dict,
|
||||
)
|
||||
from dbgpt.util.json_utils import find_json_objects
|
||||
|
||||
from ...vis.base import Vis
|
||||
@@ -45,6 +51,10 @@ class ActionOutput(BaseModel):
|
||||
return None
|
||||
return cls.parse_obj(param)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert the object to a dictionary."""
|
||||
return model_to_dict(self)
|
||||
|
||||
|
||||
class Action(ABC, Generic[T]):
|
||||
"""Base Action class for defining agent actions."""
|
||||
@@ -85,12 +95,13 @@ class Action(ABC, Generic[T]):
|
||||
if origin is None:
|
||||
example = {}
|
||||
single_model_type = cast(Type[BaseModel], model_type)
|
||||
for field_name, field in single_model_type.__fields__.items():
|
||||
field_info = field.field_info
|
||||
if field_info.description:
|
||||
example[field_name] = field_info.description
|
||||
elif field_info.default:
|
||||
example[field_name] = field_info.default
|
||||
for field_name, field in model_fields(single_model_type).items():
|
||||
description = field_description(field)
|
||||
default_value = field_default(field)
|
||||
if description:
|
||||
example[field_name] = description
|
||||
elif default_value:
|
||||
example[field_name] = default_value
|
||||
else:
|
||||
example[field_name] = ""
|
||||
return example
|
||||
|
@@ -3,7 +3,7 @@ import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt._private.pydantic import BaseModel, Field, model_to_json
|
||||
from dbgpt.vis.tags.vis_chart import Vis, VisChart
|
||||
|
||||
from ..resource.resource_api import AgentResource, ResourceType
|
||||
@@ -86,13 +86,13 @@ class ChartAction(Action[SqlInput]):
|
||||
if not self.render_protocol:
|
||||
raise ValueError("The rendering protocol is not initialized!")
|
||||
view = await self.render_protocol.display(
|
||||
chart=json.loads(param.json()), data_df=data_df
|
||||
chart=json.loads(model_to_json(param)), data_df=data_df
|
||||
)
|
||||
if not self.resource_need:
|
||||
raise ValueError("The resource type is not found!")
|
||||
return ActionOutput(
|
||||
is_exe_success=True,
|
||||
content=param.json(),
|
||||
content=model_to_json(param),
|
||||
view=view,
|
||||
resource_type=self.resource_need.value,
|
||||
resource_value=resource.value,
|
||||
|
@@ -3,7 +3,7 @@ import json
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt._private.pydantic import BaseModel, Field, model_to_dict
|
||||
from dbgpt.vis.tags.vis_dashboard import Vis, VisDashboard
|
||||
|
||||
from ..resource.resource_api import AgentResource, ResourceType
|
||||
@@ -30,6 +30,10 @@ class ChartItem(BaseModel):
|
||||
)
|
||||
thought: str = Field(..., description="Summary of thoughts to the user")
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert to dict."""
|
||||
return model_to_dict(self)
|
||||
|
||||
|
||||
class DashboardAction(Action[List[ChartItem]]):
|
||||
"""Dashboard action class."""
|
||||
@@ -101,7 +105,7 @@ class DashboardAction(Action[List[ChartItem]]):
|
||||
sql_df = await resource_db_client.query_to_df(
|
||||
resource.value, chart_item.sql
|
||||
)
|
||||
chart_dict = chart_item.dict()
|
||||
chart_dict = chart_item.to_dict()
|
||||
|
||||
chart_dict["data"] = sql_df
|
||||
except Exception as e:
|
||||
@@ -113,7 +117,9 @@ class DashboardAction(Action[List[ChartItem]]):
|
||||
view = await self.render_protocol.display(charts=chart_params)
|
||||
return ActionOutput(
|
||||
is_exe_success=True,
|
||||
content=json.dumps([chart_item.dict() for chart_item in chart_items]),
|
||||
content=json.dumps(
|
||||
[chart_item.to_dict() for chart_item in chart_items]
|
||||
),
|
||||
view=view,
|
||||
)
|
||||
except Exception as e:
|
||||
|
@@ -112,7 +112,7 @@ class AgentManager(BaseComponent):
|
||||
"""Return the description of an agent by name."""
|
||||
return self._agents[name][1].desc
|
||||
|
||||
def all_agents(self):
|
||||
def all_agents(self) -> Dict[str, str]:
|
||||
"""Return a dictionary of all registered agents and their descriptions."""
|
||||
result = {}
|
||||
for name, value in self._agents.items():
|
||||
|
@@ -5,7 +5,7 @@ import json
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast
|
||||
|
||||
from dbgpt._private.pydantic import Field
|
||||
from dbgpt._private.pydantic import ConfigDict, Field
|
||||
from dbgpt.core import LLMClient, ModelMessageRoleType
|
||||
from dbgpt.util.error_types import LLMChatError
|
||||
from dbgpt.util.tracer import SpanType, root_tracer
|
||||
@@ -27,6 +27,8 @@ logger = logging.getLogger(__name__)
|
||||
class ConversableAgent(Role, Agent):
|
||||
"""ConversableAgent is an agent that can communicate with other agents."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
agent_context: Optional[AgentContext] = Field(None, description="Agent context")
|
||||
actions: List[Action] = Field(default_factory=list)
|
||||
resources: List[AgentResource] = Field(default_factory=list)
|
||||
@@ -38,11 +40,6 @@ class ConversableAgent(Role, Agent):
|
||||
llm_client: Optional[AIWrapper] = None
|
||||
oai_system_message: List[Dict] = Field(default_factory=list)
|
||||
|
||||
class Config:
|
||||
"""Pydantic configuration."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Create a new agent."""
|
||||
Role.__init__(self, **kwargs)
|
||||
@@ -377,8 +374,10 @@ class ConversableAgent(Role, Agent):
|
||||
**act_extent_param,
|
||||
)
|
||||
if act_out:
|
||||
reply_message.action_report = act_out.dict()
|
||||
span.metadata["action_report"] = act_out.dict() if act_out else None
|
||||
reply_message.action_report = act_out.to_dict()
|
||||
span.metadata["action_report"] = (
|
||||
act_out.to_dict() if act_out else None
|
||||
)
|
||||
|
||||
with root_tracer.start_span(
|
||||
"agent.generate_reply.verify",
|
||||
@@ -496,7 +495,7 @@ class ConversableAgent(Role, Agent):
|
||||
"recipient": self.get_name(),
|
||||
"reviewer": reviewer.get_name() if reviewer else None,
|
||||
"need_resource": need_resource.to_dict() if need_resource else None,
|
||||
"rely_action_out": last_out.dict() if last_out else None,
|
||||
"rely_action_out": last_out.to_dict() if last_out else None,
|
||||
"conv_uid": self.not_null_agent_context.conv_id,
|
||||
"action_index": i,
|
||||
"total_action": len(self.actions),
|
||||
@@ -508,7 +507,7 @@ class ConversableAgent(Role, Agent):
|
||||
rely_action_out=last_out,
|
||||
**kwargs,
|
||||
)
|
||||
span.metadata["action_out"] = last_out.dict() if last_out else None
|
||||
span.metadata["action_out"] = last_out.to_dict() if last_out else None
|
||||
return last_out
|
||||
|
||||
async def correctness_check(
|
||||
|
@@ -3,7 +3,7 @@
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from ..actions.action import ActionOutput
|
||||
from .agent import Agent, AgentMessage
|
||||
@@ -68,16 +68,13 @@ def _content_str(content: Union[str, List, None]) -> str:
|
||||
class Team(BaseModel):
|
||||
"""Team class for managing a group of agents in a team chat."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
agents: List[Agent] = Field(default_factory=list)
|
||||
messages: List[Dict] = Field(default_factory=list)
|
||||
max_round: int = 100
|
||||
is_team: bool = True
|
||||
|
||||
class Config:
|
||||
"""Pydantic model configuration."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Create a new Team instance."""
|
||||
super().__init__(**kwargs)
|
||||
@@ -122,6 +119,8 @@ class Team(BaseModel):
|
||||
class ManagerAgent(ConversableAgent, Team):
|
||||
"""Manager Agent class."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
profile: str = "TeamManager"
|
||||
goal: str = "manage all hired intelligent agents to complete mission objectives"
|
||||
constraints: List[str] = []
|
||||
@@ -132,11 +131,6 @@ class ManagerAgent(ConversableAgent, Team):
|
||||
# of the agent has already been retried.
|
||||
max_retry_count: int = 1
|
||||
|
||||
class Config:
|
||||
"""Pydantic model configuration."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Create a new ManagerAgent instance."""
|
||||
ConversableAgent.__init__(self, **kwargs)
|
||||
|
@@ -4,7 +4,7 @@ from collections import defaultdict
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
|
||||
from dbgpt.core import LLMClient, ModelMetadata, ModelRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -106,11 +106,8 @@ def register_llm_strategy(
|
||||
class LLMConfig(BaseModel):
|
||||
"""LLM configuration."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
llm_client: Optional[LLMClient] = Field(default_factory=LLMClient)
|
||||
llm_strategy: LLMStrategyType = Field(default=LLMStrategyType.Default)
|
||||
strategy_context: Optional[Any] = None
|
||||
|
||||
class Config:
|
||||
"""Pydantic model config."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
@@ -2,33 +2,30 @@
|
||||
from abc import ABC
|
||||
from typing import List, Optional
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class Role(ABC, BaseModel):
|
||||
"""Role class for role-based conversation."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
profile: str = ""
|
||||
name: str = ""
|
||||
resource_introduction = ""
|
||||
resource_introduction: str = ""
|
||||
goal: str = ""
|
||||
|
||||
expand_prompt: str = ""
|
||||
|
||||
fixed_subgoal: Optional[str] = None
|
||||
fixed_subgoal: Optional[str] = Field(None, description="Fixed subgoal")
|
||||
|
||||
constraints: List[str] = []
|
||||
constraints: List[str] = Field(default_factory=list, description="Constraints")
|
||||
examples: str = ""
|
||||
desc: str = ""
|
||||
language: str = "en"
|
||||
is_human: bool = False
|
||||
is_team: bool = False
|
||||
|
||||
class Config:
|
||||
"""Pydantic config."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def prompt_template(
|
||||
self,
|
||||
specified_prompt: Optional[str] = None,
|
||||
|
@@ -8,7 +8,7 @@ class UserProxyAgent(ConversableAgent):
|
||||
That can execute code and provide feedback to the other agents.
|
||||
"""
|
||||
|
||||
name = "User"
|
||||
name: str = "User"
|
||||
profile: str = "Human"
|
||||
|
||||
desc: str = (
|
||||
@@ -16,4 +16,4 @@ class UserProxyAgent(ConversableAgent):
|
||||
"Plan execution needs to be approved by this admin."
|
||||
)
|
||||
|
||||
is_human = True
|
||||
is_human: bool = True
|
||||
|
@@ -34,8 +34,6 @@ class DashboardAssistantAgent(ConversableAgent):
|
||||
"professional reports"
|
||||
)
|
||||
|
||||
max_retry_count: int = 3
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Create a new instance of DashboardAssistantAgent."""
|
||||
super().__init__(**kwargs)
|
||||
|
@@ -17,7 +17,7 @@ logger = logging.getLogger(__name__)
|
||||
class DataScientistAgent(ConversableAgent):
|
||||
"""Data Scientist Agent."""
|
||||
|
||||
name = "Edgar"
|
||||
name: str = "Edgar"
|
||||
profile: str = "DataScientist"
|
||||
goal: str = (
|
||||
"Use correct {dialect} SQL to analyze and solve tasks based on the data"
|
||||
|
@@ -16,7 +16,7 @@ class PluginAssistantAgent(ConversableAgent):
|
||||
|
||||
plugin_generator: Optional[PluginPromptGenerator] = None
|
||||
|
||||
name = "LuBan"
|
||||
name: str = "LuBan"
|
||||
profile: str = "ToolExpert"
|
||||
goal: str = (
|
||||
"Read and understand the tool information given in the resources below to "
|
||||
|
@@ -244,7 +244,7 @@ class RetrieveSummaryAssistantAgent(ConversableAgent):
|
||||
**act_extent_param,
|
||||
)
|
||||
if act_out:
|
||||
reply_message.action_report = act_out.dict()
|
||||
reply_message.action_report = act_out.to_dict()
|
||||
# 4.Reply information verification
|
||||
check_pass, reason = await self.verify(reply_message, sender, reviewer)
|
||||
is_success = check_pass
|
||||
|
@@ -11,7 +11,7 @@ logger = logging.getLogger(__name__)
|
||||
class SummaryAssistantAgent(ConversableAgent):
|
||||
"""Summary Assistant Agent."""
|
||||
|
||||
name = "Aristotle"
|
||||
name: str = "Aristotle"
|
||||
profile: str = "Summarizer"
|
||||
goal: str = (
|
||||
"Summarize answer summaries based on user questions from provided "
|
||||
|
@@ -1,7 +1,7 @@
|
||||
"""The AWEL Agent Operator Resource."""
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field, root_validator
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from dbgpt.core import LLMClient
|
||||
from dbgpt.core.awel.flow import (
|
||||
FunctionDynamicOptions,
|
||||
@@ -55,9 +55,12 @@ from ...resource.resource_api import AgentResource, ResourceType
|
||||
class AWELAgentResource(AgentResource):
|
||||
"""AWEL Agent Resource."""
|
||||
|
||||
@root_validator(pre=True)
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Pre fill the agent ResourceType."""
|
||||
if not isinstance(values, dict):
|
||||
return values
|
||||
name = values.pop("agent_resource_name")
|
||||
type = values.pop("agent_resource_type")
|
||||
value = values.pop("agent_resource_value")
|
||||
@@ -109,10 +112,7 @@ class AWELAgentResource(AgentResource):
|
||||
class AWELAgentConfig(LLMConfig):
|
||||
"""AWEL Agent Config."""
|
||||
|
||||
@root_validator(pre=True)
|
||||
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Pre fill the agent ResourceType."""
|
||||
return values
|
||||
pass
|
||||
|
||||
|
||||
def _agent_resource_option_values() -> List[OptionValue]:
|
||||
@@ -173,20 +173,20 @@ def _agent_resource_option_values() -> List[OptionValue]:
|
||||
class AWELAgent(BaseModel):
|
||||
"""AWEL Agent."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
agent_profile: str
|
||||
role_name: Optional[str] = None
|
||||
llm_config: Optional[LLMConfig] = None
|
||||
resources: List[AgentResource] = Field(default_factory=list)
|
||||
fixed_subgoal: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
"""Config for the BaseModel."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Pre fill the agent ResourceType."""
|
||||
if not isinstance(values, dict):
|
||||
return values
|
||||
resource = values.pop("agent_resource")
|
||||
llm_config = values.pop("agent_llm_Config")
|
||||
|
||||
|
@@ -5,7 +5,13 @@ from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, cast
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt._private.pydantic import BaseModel, Field, validator
|
||||
from dbgpt._private.pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
model_to_dict,
|
||||
validator,
|
||||
)
|
||||
from dbgpt.core.awel import DAG
|
||||
from dbgpt.core.awel.dag.dag_manager import DAGManager
|
||||
|
||||
@@ -70,23 +76,20 @@ class AWELTeamContext(BaseModel):
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert the object to a dictionary."""
|
||||
return self.dict()
|
||||
return model_to_dict(self)
|
||||
|
||||
|
||||
class AWELBaseManager(ManagerAgent, ABC):
|
||||
"""AWEL base manager."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
goal: str = (
|
||||
"Promote and solve user problems according to the process arranged by AWEL."
|
||||
)
|
||||
constraints: List[str] = []
|
||||
desc: str = goal
|
||||
|
||||
class Config:
|
||||
"""Config for the BaseModel."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
async def _a_process_received_message(self, message: AgentMessage, sender: Agent):
|
||||
"""Process the received message."""
|
||||
pass
|
||||
@@ -157,15 +160,12 @@ class WrappedAWELLayoutManager(AWELBaseManager):
|
||||
Receives a DAG or builds a DAG from the agents.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
profile: str = "WrappedAWELLayoutManager"
|
||||
|
||||
dag: Optional[DAG] = Field(None, description="The DAG of the manager")
|
||||
|
||||
class Config:
|
||||
"""Config for the BaseModel."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def get_dag(self) -> DAG:
|
||||
"""Get the DAG of the manager."""
|
||||
if self.dag:
|
||||
@@ -236,15 +236,12 @@ class WrappedAWELLayoutManager(AWELBaseManager):
|
||||
class DefaultAWELLayoutManager(AWELBaseManager):
|
||||
"""The manager of the team for the AWEL layout."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
profile: str = "DefaultAWELLayoutManager"
|
||||
|
||||
dag: AWELTeamContext = Field(...)
|
||||
|
||||
class Config:
|
||||
"""Config for the BaseModel."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@validator("dag")
|
||||
def check_dag(cls, value):
|
||||
"""Check the DAG of the manager."""
|
||||
|
@@ -84,7 +84,7 @@ class PlannerAgent(ConversableAgent):
|
||||
" and allocate resources to achieve complex task goals."
|
||||
)
|
||||
|
||||
examples = """
|
||||
examples: str = """
|
||||
user:help me build a sales report summarizing our key metrics and trends
|
||||
assistants:[
|
||||
{{
|
||||
|
@@ -192,7 +192,7 @@ class PluginStatus(BaseModel):
|
||||
logo_url: Optional[str] = None
|
||||
api_result: Optional[str] = None
|
||||
err_msg: Optional[str] = None
|
||||
start_time = datetime.now().timestamp() * 1000
|
||||
start_time: float = datetime.now().timestamp() * 1000
|
||||
end_time: Optional[str] = None
|
||||
|
||||
df: Any = None
|
||||
|
@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
from dbgpt._private.pydantic import BaseModel, model_to_dict
|
||||
|
||||
|
||||
class ResourceType(Enum):
|
||||
@@ -67,7 +67,7 @@ class AgentResource(BaseModel):
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert the AgentResource object to a dictionary."""
|
||||
temp = self.dict()
|
||||
temp = model_to_dict(self)
|
||||
for field, value in temp.items():
|
||||
if isinstance(value, Enum):
|
||||
temp[field] = value.value
|
||||
|
Reference in New Issue
Block a user