feat(core): Upgrade pydantic to 2.x (#1428)

This commit is contained in:
Fangyin Cheng
2024-04-20 09:41:16 +08:00
committed by GitHub
parent baa1e3f9f6
commit 57be1ece18
103 changed files with 1146 additions and 534 deletions

View File

@@ -15,7 +15,13 @@ from typing import (
get_origin,
)
from dbgpt._private.pydantic import BaseModel
from dbgpt._private.pydantic import (
BaseModel,
field_default,
field_description,
model_fields,
model_to_dict,
)
from dbgpt.util.json_utils import find_json_objects
from ...vis.base import Vis
@@ -45,6 +51,10 @@ class ActionOutput(BaseModel):
return None
return cls.parse_obj(param)
def to_dict(self) -> Dict[str, Any]:
"""Convert the object to a dictionary."""
return model_to_dict(self)
class Action(ABC, Generic[T]):
"""Base Action class for defining agent actions."""
@@ -85,12 +95,13 @@ class Action(ABC, Generic[T]):
if origin is None:
example = {}
single_model_type = cast(Type[BaseModel], model_type)
for field_name, field in single_model_type.__fields__.items():
field_info = field.field_info
if field_info.description:
example[field_name] = field_info.description
elif field_info.default:
example[field_name] = field_info.default
for field_name, field in model_fields(single_model_type).items():
description = field_description(field)
default_value = field_default(field)
if description:
example[field_name] = description
elif default_value:
example[field_name] = default_value
else:
example[field_name] = ""
return example

View File

@@ -3,7 +3,7 @@ import json
import logging
from typing import Optional
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt._private.pydantic import BaseModel, Field, model_to_json
from dbgpt.vis.tags.vis_chart import Vis, VisChart
from ..resource.resource_api import AgentResource, ResourceType
@@ -86,13 +86,13 @@ class ChartAction(Action[SqlInput]):
if not self.render_protocol:
raise ValueError("The rendering protocol is not initialized")
view = await self.render_protocol.display(
chart=json.loads(param.json()), data_df=data_df
chart=json.loads(model_to_json(param)), data_df=data_df
)
if not self.resource_need:
raise ValueError("The resource type is not found")
return ActionOutput(
is_exe_success=True,
content=param.json(),
content=model_to_json(param),
view=view,
resource_type=self.resource_need.value,
resource_value=resource.value,

View File

@@ -3,7 +3,7 @@ import json
import logging
from typing import List, Optional
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt._private.pydantic import BaseModel, Field, model_to_dict
from dbgpt.vis.tags.vis_dashboard import Vis, VisDashboard
from ..resource.resource_api import AgentResource, ResourceType
@@ -30,6 +30,10 @@ class ChartItem(BaseModel):
)
thought: str = Field(..., description="Summary of thoughts to the user")
def to_dict(self):
"""Convert to dict."""
return model_to_dict(self)
class DashboardAction(Action[List[ChartItem]]):
"""Dashboard action class."""
@@ -101,7 +105,7 @@ class DashboardAction(Action[List[ChartItem]]):
sql_df = await resource_db_client.query_to_df(
resource.value, chart_item.sql
)
chart_dict = chart_item.dict()
chart_dict = chart_item.to_dict()
chart_dict["data"] = sql_df
except Exception as e:
@@ -113,7 +117,9 @@ class DashboardAction(Action[List[ChartItem]]):
view = await self.render_protocol.display(charts=chart_params)
return ActionOutput(
is_exe_success=True,
content=json.dumps([chart_item.dict() for chart_item in chart_items]),
content=json.dumps(
[chart_item.to_dict() for chart_item in chart_items]
),
view=view,
)
except Exception as e:

View File

@@ -112,7 +112,7 @@ class AgentManager(BaseComponent):
"""Return the description of an agent by name."""
return self._agents[name][1].desc
def all_agents(self):
def all_agents(self) -> Dict[str, str]:
"""Return a dictionary of all registered agents and their descriptions."""
result = {}
for name, value in self._agents.items():

View File

@@ -5,7 +5,7 @@ import json
import logging
from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast
from dbgpt._private.pydantic import Field
from dbgpt._private.pydantic import ConfigDict, Field
from dbgpt.core import LLMClient, ModelMessageRoleType
from dbgpt.util.error_types import LLMChatError
from dbgpt.util.tracer import SpanType, root_tracer
@@ -27,6 +27,8 @@ logger = logging.getLogger(__name__)
class ConversableAgent(Role, Agent):
"""ConversableAgent is an agent that can communicate with other agents."""
model_config = ConfigDict(arbitrary_types_allowed=True)
agent_context: Optional[AgentContext] = Field(None, description="Agent context")
actions: List[Action] = Field(default_factory=list)
resources: List[AgentResource] = Field(default_factory=list)
@@ -38,11 +40,6 @@ class ConversableAgent(Role, Agent):
llm_client: Optional[AIWrapper] = None
oai_system_message: List[Dict] = Field(default_factory=list)
class Config:
"""Pydantic configuration."""
arbitrary_types_allowed = True
def __init__(self, **kwargs):
"""Create a new agent."""
Role.__init__(self, **kwargs)
@@ -377,8 +374,10 @@ class ConversableAgent(Role, Agent):
**act_extent_param,
)
if act_out:
reply_message.action_report = act_out.dict()
span.metadata["action_report"] = act_out.dict() if act_out else None
reply_message.action_report = act_out.to_dict()
span.metadata["action_report"] = (
act_out.to_dict() if act_out else None
)
with root_tracer.start_span(
"agent.generate_reply.verify",
@@ -496,7 +495,7 @@ class ConversableAgent(Role, Agent):
"recipient": self.get_name(),
"reviewer": reviewer.get_name() if reviewer else None,
"need_resource": need_resource.to_dict() if need_resource else None,
"rely_action_out": last_out.dict() if last_out else None,
"rely_action_out": last_out.to_dict() if last_out else None,
"conv_uid": self.not_null_agent_context.conv_id,
"action_index": i,
"total_action": len(self.actions),
@@ -508,7 +507,7 @@ class ConversableAgent(Role, Agent):
rely_action_out=last_out,
**kwargs,
)
span.metadata["action_out"] = last_out.dict() if last_out else None
span.metadata["action_out"] = last_out.to_dict() if last_out else None
return last_out
async def correctness_check(

View File

@@ -3,7 +3,7 @@
import logging
from typing import Dict, List, Optional, Tuple, Union
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
from ..actions.action import ActionOutput
from .agent import Agent, AgentMessage
@@ -68,16 +68,13 @@ def _content_str(content: Union[str, List, None]) -> str:
class Team(BaseModel):
"""Team class for managing a group of agents in a team chat."""
model_config = ConfigDict(arbitrary_types_allowed=True)
agents: List[Agent] = Field(default_factory=list)
messages: List[Dict] = Field(default_factory=list)
max_round: int = 100
is_team: bool = True
class Config:
"""Pydantic model configuration."""
arbitrary_types_allowed = True
def __init__(self, **kwargs):
"""Create a new Team instance."""
super().__init__(**kwargs)
@@ -122,6 +119,8 @@ class Team(BaseModel):
class ManagerAgent(ConversableAgent, Team):
"""Manager Agent class."""
model_config = ConfigDict(arbitrary_types_allowed=True)
profile: str = "TeamManager"
goal: str = "manage all hired intelligent agents to complete mission objectives"
constraints: List[str] = []
@@ -132,11 +131,6 @@ class ManagerAgent(ConversableAgent, Team):
# of the agent has already been retried.
max_retry_count: int = 1
class Config:
"""Pydantic model configuration."""
arbitrary_types_allowed = True
def __init__(self, **kwargs):
"""Create a new ManagerAgent instance."""
ConversableAgent.__init__(self, **kwargs)

View File

@@ -4,7 +4,7 @@ from collections import defaultdict
from enum import Enum
from typing import Any, Dict, List, Optional, Type
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
from dbgpt.core import LLMClient, ModelMetadata, ModelRequest
logger = logging.getLogger(__name__)
@@ -106,11 +106,8 @@ def register_llm_strategy(
class LLMConfig(BaseModel):
"""LLM configuration."""
model_config = ConfigDict(arbitrary_types_allowed=True)
llm_client: Optional[LLMClient] = Field(default_factory=LLMClient)
llm_strategy: LLMStrategyType = Field(default=LLMStrategyType.Default)
strategy_context: Optional[Any] = None
class Config:
"""Pydantic model config."""
arbitrary_types_allowed = True

View File

@@ -2,33 +2,30 @@
from abc import ABC
from typing import List, Optional
from dbgpt._private.pydantic import BaseModel
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
class Role(ABC, BaseModel):
"""Role class for role-based conversation."""
model_config = ConfigDict(arbitrary_types_allowed=True)
profile: str = ""
name: str = ""
resource_introduction = ""
resource_introduction: str = ""
goal: str = ""
expand_prompt: str = ""
fixed_subgoal: Optional[str] = None
fixed_subgoal: Optional[str] = Field(None, description="Fixed subgoal")
constraints: List[str] = []
constraints: List[str] = Field(default_factory=list, description="Constraints")
examples: str = ""
desc: str = ""
language: str = "en"
is_human: bool = False
is_team: bool = False
class Config:
"""Pydantic config."""
arbitrary_types_allowed = True
def prompt_template(
self,
specified_prompt: Optional[str] = None,

View File

@@ -8,7 +8,7 @@ class UserProxyAgent(ConversableAgent):
That can execute code and provide feedback to the other agents.
"""
name = "User"
name: str = "User"
profile: str = "Human"
desc: str = (
@@ -16,4 +16,4 @@ class UserProxyAgent(ConversableAgent):
"Plan execution needs to be approved by this admin."
)
is_human = True
is_human: bool = True

View File

@@ -34,8 +34,6 @@ class DashboardAssistantAgent(ConversableAgent):
"professional reports"
)
max_retry_count: int = 3
def __init__(self, **kwargs):
"""Create a new instance of DashboardAssistantAgent."""
super().__init__(**kwargs)

View File

@@ -17,7 +17,7 @@ logger = logging.getLogger(__name__)
class DataScientistAgent(ConversableAgent):
"""Data Scientist Agent."""
name = "Edgar"
name: str = "Edgar"
profile: str = "DataScientist"
goal: str = (
"Use correct {dialect} SQL to analyze and solve tasks based on the data"

View File

@@ -16,7 +16,7 @@ class PluginAssistantAgent(ConversableAgent):
plugin_generator: Optional[PluginPromptGenerator] = None
name = "LuBan"
name: str = "LuBan"
profile: str = "ToolExpert"
goal: str = (
"Read and understand the tool information given in the resources below to "

View File

@@ -244,7 +244,7 @@ class RetrieveSummaryAssistantAgent(ConversableAgent):
**act_extent_param,
)
if act_out:
reply_message.action_report = act_out.dict()
reply_message.action_report = act_out.to_dict()
# 4.Reply information verification
check_pass, reason = await self.verify(reply_message, sender, reviewer)
is_success = check_pass

View File

@@ -11,7 +11,7 @@ logger = logging.getLogger(__name__)
class SummaryAssistantAgent(ConversableAgent):
"""Summary Assistant Agent."""
name = "Aristotle"
name: str = "Aristotle"
profile: str = "Summarizer"
goal: str = (
"Summarize answer summaries based on user questions from provided "

View File

@@ -1,7 +1,7 @@
"""The AWEL Agent Operator Resource."""
from typing import Any, Dict, List, Optional
from dbgpt._private.pydantic import BaseModel, Field, root_validator
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_validator
from dbgpt.core import LLMClient
from dbgpt.core.awel.flow import (
FunctionDynamicOptions,
@@ -55,9 +55,12 @@ from ...resource.resource_api import AgentResource, ResourceType
class AWELAgentResource(AgentResource):
"""AWEL Agent Resource."""
@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Pre fill the agent ResourceType."""
if not isinstance(values, dict):
return values
name = values.pop("agent_resource_name")
type = values.pop("agent_resource_type")
value = values.pop("agent_resource_value")
@@ -109,10 +112,7 @@ class AWELAgentResource(AgentResource):
class AWELAgentConfig(LLMConfig):
"""AWEL Agent Config."""
@root_validator(pre=True)
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Pre fill the agent ResourceType."""
return values
pass
def _agent_resource_option_values() -> List[OptionValue]:
@@ -173,20 +173,20 @@ def _agent_resource_option_values() -> List[OptionValue]:
class AWELAgent(BaseModel):
"""AWEL Agent."""
model_config = ConfigDict(arbitrary_types_allowed=True)
agent_profile: str
role_name: Optional[str] = None
llm_config: Optional[LLMConfig] = None
resources: List[AgentResource] = Field(default_factory=list)
fixed_subgoal: Optional[str] = None
class Config:
"""Config for the BaseModel."""
arbitrary_types_allowed = True
@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Pre fill the agent ResourceType."""
if not isinstance(values, dict):
return values
resource = values.pop("agent_resource")
llm_config = values.pop("agent_llm_Config")

View File

@@ -5,7 +5,13 @@ from abc import ABC, abstractmethod
from typing import List, Optional, cast
from dbgpt._private.config import Config
from dbgpt._private.pydantic import BaseModel, Field, validator
from dbgpt._private.pydantic import (
BaseModel,
ConfigDict,
Field,
model_to_dict,
validator,
)
from dbgpt.core.awel import DAG
from dbgpt.core.awel.dag.dag_manager import DAGManager
@@ -70,23 +76,20 @@ class AWELTeamContext(BaseModel):
def to_dict(self):
"""Convert the object to a dictionary."""
return self.dict()
return model_to_dict(self)
class AWELBaseManager(ManagerAgent, ABC):
"""AWEL base manager."""
model_config = ConfigDict(arbitrary_types_allowed=True)
goal: str = (
"Promote and solve user problems according to the process arranged by AWEL."
)
constraints: List[str] = []
desc: str = goal
class Config:
"""Config for the BaseModel."""
arbitrary_types_allowed = True
async def _a_process_received_message(self, message: AgentMessage, sender: Agent):
"""Process the received message."""
pass
@@ -157,15 +160,12 @@ class WrappedAWELLayoutManager(AWELBaseManager):
Receives a DAG or builds a DAG from the agents.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
profile: str = "WrappedAWELLayoutManager"
dag: Optional[DAG] = Field(None, description="The DAG of the manager")
class Config:
"""Config for the BaseModel."""
arbitrary_types_allowed = True
def get_dag(self) -> DAG:
"""Get the DAG of the manager."""
if self.dag:
@@ -236,15 +236,12 @@ class WrappedAWELLayoutManager(AWELBaseManager):
class DefaultAWELLayoutManager(AWELBaseManager):
"""The manager of the team for the AWEL layout."""
model_config = ConfigDict(arbitrary_types_allowed=True)
profile: str = "DefaultAWELLayoutManager"
dag: AWELTeamContext = Field(...)
class Config:
"""Config for the BaseModel."""
arbitrary_types_allowed = True
@validator("dag")
def check_dag(cls, value):
"""Check the DAG of the manager."""

View File

@@ -84,7 +84,7 @@ class PlannerAgent(ConversableAgent):
" and allocate resources to achieve complex task goals."
)
examples = """
examples: str = """
user:help me build a sales report summarizing our key metrics and trends
assistants:[
{{

View File

@@ -192,7 +192,7 @@ class PluginStatus(BaseModel):
logo_url: Optional[str] = None
api_result: Optional[str] = None
err_msg: Optional[str] = None
start_time = datetime.now().timestamp() * 1000
start_time: float = datetime.now().timestamp() * 1000
end_time: Optional[str] = None
df: Any = None

View File

@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from dbgpt._private.pydantic import BaseModel
from dbgpt._private.pydantic import BaseModel, model_to_dict
class ResourceType(Enum):
@@ -67,7 +67,7 @@ class AgentResource(BaseModel):
def to_dict(self) -> Dict[str, Any]:
"""Convert the AgentResource object to a dictionary."""
temp = self.dict()
temp = model_to_dict(self)
for field, value in temp.items():
if isinstance(value, Enum):
temp[field] = value.value