feat(core): Upgrade pydantic to 2.x (#1428)

This commit is contained in:
Fangyin Cheng
2024-04-20 09:41:16 +08:00
committed by GitHub
parent baa1e3f9f6
commit 57be1ece18
103 changed files with 1146 additions and 534 deletions

View File

@@ -86,7 +86,9 @@ class MultiAgents(BaseComponent, ABC):
def gpts_create(self, entity: GptsInstanceEntity):
self.gpts_intance.add(entity)
def get_dbgpts(self, user_code: str = None, sys_code: str = None):
def get_dbgpts(
self, user_code: str = None, sys_code: str = None
) -> Optional[List[GptsApp]]:
apps = self.gpts_app.app_list(
GptsAppQuery(user_code=user_code, sys_code=sys_code)
).app_list
@@ -338,7 +340,7 @@ class MultiAgents(BaseComponent, ABC):
multi_agents = MultiAgents()
@router.post("/v1/dbgpts/agents/list", response_model=Result[str])
@router.post("/v1/dbgpts/agents/list", response_model=Result[Dict[str, str]])
async def agents_list():
logger.info("agents_list!")
try:
@@ -348,7 +350,7 @@ async def agents_list():
return Result.failed(code="E30001", msg=str(e))
@router.get("/v1/dbgpts/list", response_model=Result[str])
@router.get("/v1/dbgpts/list", response_model=Result[List[GptsApp]])
async def get_dbgpts(user_code: str = None, sys_code: str = None):
logger.info(f"get_dbgpts:{user_code},{sys_code}")
try:
@@ -359,14 +361,14 @@ async def get_dbgpts(user_code: str = None, sys_code: str = None):
@router.post("/v1/dbgpts/chat/completions", response_model=Result[str])
async def dgpts_completions(
async def dbgpts_completions(
gpts_name: str,
user_query: str,
conv_id: str = None,
user_code: str = None,
sys_code: str = None,
):
logger.info(f"dgpts_completions:{gpts_name},{user_query},{conv_id}")
logger.info(f"dbgpts_completions:{gpts_name},{user_query},{conv_id}")
if conv_id is None:
conv_id = str(uuid.uuid1())
@@ -390,12 +392,12 @@ async def dgpts_completions(
@router.post("/v1/dbgpts/chat/cancel", response_model=Result[str])
async def dgpts_chat_cancel(
async def dbgpts_chat_cancel(
conv_id: str = None, user_code: str = None, sys_code: str = None
):
pass
@router.post("/v1/dbgpts/chat/feedback", response_model=Result[str])
async def dgpts_chat_feedback(filter: PagenationFilter[PluginHubFilter] = Body()):
async def dbgpts_chat_feedback(filter: PagenationFilter[PluginHubFilter] = Body()):
pass

View File

@@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional, Union
from sqlalchemy import Column, DateTime, Integer, String, Text, UniqueConstraint
from dbgpt._private.pydantic import BaseModel
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_json
from dbgpt.agent.plan.awel.team_awel_layout import AWELTeamContext
from dbgpt.agent.resource.resource_api import AgentResource
from dbgpt.serve.agent.team.base import TeamMode
@@ -17,6 +17,8 @@ logger = logging.getLogger(__name__)
class GptsAppDetail(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
app_code: Optional[str] = None
app_name: Optional[str] = None
agent_name: Optional[str] = None
@@ -28,11 +30,6 @@ class GptsAppDetail(BaseModel):
created_at: datetime = datetime.now()
updated_at: datetime = datetime.now()
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def to_dict(self):
return {k: self._serialize(v) for k, v in self.__dict__.items()}
@@ -86,6 +83,8 @@ class GptsAppDetail(BaseModel):
class GptsApp(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
app_code: Optional[str] = None
app_name: Optional[str] = None
app_describe: Optional[str] = None
@@ -100,11 +99,6 @@ class GptsApp(BaseModel):
updated_at: datetime = datetime.now()
details: List[GptsAppDetail] = []
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def to_dict(self):
return {k: self._serialize(v) for k, v in self.__dict__.items()}
@@ -146,7 +140,9 @@ class GptsAppResponse(BaseModel):
total_count: Optional[int] = 0
total_page: Optional[int] = 0
current_page: Optional[int] = 0
app_list: Optional[List[GptsApp]] = []
app_list: Optional[List[GptsApp]] = Field(
default_factory=list, description="app list"
)
class GptsAppCollection(BaseModel):
@@ -207,7 +203,8 @@ class GptsAppEntity(Model):
team_context = Column(
Text,
nullable=True,
comment="The execution logic and team member content that teams with different working modes rely on",
comment="The execution logic and team member content that teams with different "
"working modes rely on",
)
user_code = Column(String(255), nullable=True, comment="user code")
@@ -565,7 +562,7 @@ def _parse_team_context(team_context: Optional[Union[str, AWELTeamContext]] = No
parse team_context to str
"""
if isinstance(team_context, AWELTeamContext):
return team_context.json()
return model_to_json(team_context)
return team_context

View File

@@ -1,9 +1,12 @@
from datetime import datetime
from typing import List
from sqlalchemy import Column, DateTime, Integer, String, UniqueConstraint, func
from dbgpt.storage.metadata import BaseDao, Model
from ..model import MyPluginVO
class MyPluginEntity(Model):
__tablename__ = "my_plugin"
@@ -27,6 +30,28 @@ class MyPluginEntity(Model):
)
UniqueConstraint("user_code", "name", name="uk_name")
@classmethod
def to_vo(cls, entities: List["MyPluginEntity"]) -> List[MyPluginVO]:
results = []
for entity in entities:
results.append(
MyPluginVO(
id=entity.id,
tenant=entity.tenant,
user_code=entity.user_code,
user_name=entity.user_name,
sys_code=entity.sys_code,
name=entity.name,
file_name=entity.file_name,
type=entity.type,
version=entity.version,
use_count=entity.use_count,
succ_count=entity.succ_count,
gmt_created=entity.gmt_created.strftime("%Y-%m-%d %H:%M:%S"),
)
)
return results
class MyPluginDao(BaseDao):
def add(self, engity: MyPluginEntity):

View File

@@ -1,4 +1,5 @@
from datetime import datetime
from typing import List
import pytz
from sqlalchemy import (
@@ -14,6 +15,8 @@ from sqlalchemy import (
from dbgpt.storage.metadata import BaseDao, Model
from ..model import PluginHubVO
# TODO We should consider that the production environment does not have permission to execute the DDL
char_set_sql = DDL("ALTER TABLE plugin_hub CONVERT TO CHARACTER SET utf8mb4")
@@ -40,6 +43,27 @@ class PluginHubEntity(Model):
UniqueConstraint("name", name="uk_name")
Index("idx_q_type", "type")
@classmethod
def to_vo(cls, entities: List["PluginHubEntity"]) -> List[PluginHubVO]:
results = []
for entity in entities:
vo = PluginHubVO(
id=entity.id,
name=entity.name,
description=entity.description,
author=entity.author,
email=entity.email,
type=entity.type,
version=entity.version,
storage_channel=entity.storage_channel,
storage_url=entity.storage_url,
download_param=entity.download_param,
installed=entity.installed,
gmt_created=entity.gmt_created.strftime("%Y-%m-%d %H:%M:%S"),
)
results.append(vo)
return results
class PluginHubDao(BaseDao):
def add(self, engity: PluginHubEntity):

View File

@@ -18,6 +18,9 @@ from dbgpt.serve.agent.model import (
PluginHubParam,
)
from ..db import MyPluginEntity
from ..model import MyPluginVO, PluginHubVO
router = APIRouter()
logger = logging.getLogger(__name__)
@@ -73,7 +76,7 @@ async def plugin_hub_update(update_param: PluginHubParam = Body()):
return Result.failed(code="E0020", msg=f"Agent Hub Update Error! {e}")
@router.post("/v1/agent/query", response_model=Result[str])
@router.post("/v1/agent/query", response_model=Result[dict])
async def get_agent_list(filter: PagenationFilter[PluginHubFilter] = Body()):
logger.info(f"get_agent_list:{filter.__dict__}")
filter_enetity: PluginHubEntity = PluginHubEntity()
@@ -85,24 +88,21 @@ async def get_agent_list(filter: PagenationFilter[PluginHubFilter] = Body()):
datas, total_pages, total_count = plugin_hub.hub_dao.list(
filter_enetity, filter.page_index, filter.page_size
)
result: PagenationResult[PluginHubEntity] = PagenationResult[PluginHubEntity]()
result: PagenationResult[PluginHubVO] = PagenationResult[PluginHubVO]()
result.page_index = filter.page_index
result.page_size = filter.page_size
result.total_page = total_pages
result.total_row_count = total_count
result.datas = datas
result.datas = PluginHubEntity.to_vo(datas)
# print(json.dumps(result.to_dic()))
return Result.succ(result.to_dic())
@router.post("/v1/agent/my", response_model=Result[str])
@router.post("/v1/agent/my", response_model=Result[List[MyPluginVO]])
async def my_agents(user: str = None):
logger.info(f"my_agents:{user}")
agents = plugin_hub.get_my_plugin(user)
agent_dicts = []
for agent in agents:
agent_dicts.append(agent.__dict__)
agent_dicts = MyPluginEntity.to_vo(agents)
return Result.succ(agent_dicts)

View File

@@ -1,7 +1,7 @@
from dataclasses import dataclass
from typing import Generic, List, Optional, TypeVar
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
T = TypeVar("T")
@@ -13,6 +13,7 @@ class PagenationFilter(BaseModel, Generic[T]):
class PagenationResult(BaseModel, Generic[T]):
model_config = ConfigDict(arbitrary_types_allowed=True)
page_index: int = 1
page_size: int = 20
total_page: int = 0
@@ -34,14 +35,14 @@ class PagenationResult(BaseModel, Generic[T]):
@dataclass
class PluginHubFilter(BaseModel):
name: str
description: str
author: str
email: str
type: str
version: str
storage_channel: str
storage_url: str
name: Optional[str] = None
description: Optional[str] = None
author: Optional[str] = None
email: Optional[str] = None
type: Optional[str] = None
version: Optional[str] = None
storage_channel: Optional[str] = None
storage_url: Optional[str] = None
@dataclass
@@ -67,3 +68,33 @@ class PluginHubParam(BaseModel):
authorization: Optional[str] = Field(
None, description="github download authorization", nullable=True
)
class PluginHubVO(BaseModel):
id: int = Field(..., description="Plugin id")
name: str = Field(..., description="Plugin name")
description: str = Field(..., description="Plugin description")
author: Optional[str] = Field(None, description="Plugin author")
email: Optional[str] = Field(None, description="Plugin email")
type: Optional[str] = Field(None, description="Plugin type")
version: Optional[str] = Field(None, description="Plugin version")
storage_channel: Optional[str] = Field(None, description="Plugin storage channel")
storage_url: Optional[str] = Field(None, description="Plugin storage url")
download_param: Optional[str] = Field(None, description="Plugin download param")
installed: Optional[int] = Field(None, description="Plugin installed")
gmt_created: Optional[str] = Field(None, description="Plugin upload time")
class MyPluginVO(BaseModel):
id: int = Field(..., description="My Plugin")
tenant: Optional[str] = Field(None, description="My Plugin tenant")
user_code: Optional[str] = Field(None, description="My Plugin user code")
user_name: Optional[str] = Field(None, description="My Plugin user name")
sys_code: Optional[str] = Field(None, description="My Plugin sys code")
name: str = Field(..., description="My Plugin name")
file_name: str = Field(..., description="My Plugin file name")
type: Optional[str] = Field(None, description="My Plugin type")
version: Optional[str] = Field(None, description="My Plugin version")
use_count: Optional[int] = Field(None, description="My Plugin use count")
succ_count: Optional[int] = Field(None, description="My Plugin succ count")
gmt_created: Optional[str] = Field(None, description="My Plugin install time")