mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-14 05:31:40 +00:00
feat(core): Upgrade pydantic to 2.x (#1428)
This commit is contained in:
@@ -86,7 +86,9 @@ class MultiAgents(BaseComponent, ABC):
|
||||
def gpts_create(self, entity: GptsInstanceEntity):
|
||||
self.gpts_intance.add(entity)
|
||||
|
||||
def get_dbgpts(self, user_code: str = None, sys_code: str = None):
|
||||
def get_dbgpts(
|
||||
self, user_code: str = None, sys_code: str = None
|
||||
) -> Optional[List[GptsApp]]:
|
||||
apps = self.gpts_app.app_list(
|
||||
GptsAppQuery(user_code=user_code, sys_code=sys_code)
|
||||
).app_list
|
||||
@@ -338,7 +340,7 @@ class MultiAgents(BaseComponent, ABC):
|
||||
multi_agents = MultiAgents()
|
||||
|
||||
|
||||
@router.post("/v1/dbgpts/agents/list", response_model=Result[str])
|
||||
@router.post("/v1/dbgpts/agents/list", response_model=Result[Dict[str, str]])
|
||||
async def agents_list():
|
||||
logger.info("agents_list!")
|
||||
try:
|
||||
@@ -348,7 +350,7 @@ async def agents_list():
|
||||
return Result.failed(code="E30001", msg=str(e))
|
||||
|
||||
|
||||
@router.get("/v1/dbgpts/list", response_model=Result[str])
|
||||
@router.get("/v1/dbgpts/list", response_model=Result[List[GptsApp]])
|
||||
async def get_dbgpts(user_code: str = None, sys_code: str = None):
|
||||
logger.info(f"get_dbgpts:{user_code},{sys_code}")
|
||||
try:
|
||||
@@ -359,14 +361,14 @@ async def get_dbgpts(user_code: str = None, sys_code: str = None):
|
||||
|
||||
|
||||
@router.post("/v1/dbgpts/chat/completions", response_model=Result[str])
|
||||
async def dgpts_completions(
|
||||
async def dbgpts_completions(
|
||||
gpts_name: str,
|
||||
user_query: str,
|
||||
conv_id: str = None,
|
||||
user_code: str = None,
|
||||
sys_code: str = None,
|
||||
):
|
||||
logger.info(f"dgpts_completions:{gpts_name},{user_query},{conv_id}")
|
||||
logger.info(f"dbgpts_completions:{gpts_name},{user_query},{conv_id}")
|
||||
if conv_id is None:
|
||||
conv_id = str(uuid.uuid1())
|
||||
|
||||
@@ -390,12 +392,12 @@ async def dgpts_completions(
|
||||
|
||||
|
||||
@router.post("/v1/dbgpts/chat/cancel", response_model=Result[str])
|
||||
async def dgpts_chat_cancel(
|
||||
async def dbgpts_chat_cancel(
|
||||
conv_id: str = None, user_code: str = None, sys_code: str = None
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
@router.post("/v1/dbgpts/chat/feedback", response_model=Result[str])
|
||||
async def dgpts_chat_feedback(filter: PagenationFilter[PluginHubFilter] = Body()):
|
||||
async def dbgpts_chat_feedback(filter: PagenationFilter[PluginHubFilter] = Body()):
|
||||
pass
|
||||
|
@@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from sqlalchemy import Column, DateTime, Integer, String, Text, UniqueConstraint
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_json
|
||||
from dbgpt.agent.plan.awel.team_awel_layout import AWELTeamContext
|
||||
from dbgpt.agent.resource.resource_api import AgentResource
|
||||
from dbgpt.serve.agent.team.base import TeamMode
|
||||
@@ -17,6 +17,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GptsAppDetail(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
app_code: Optional[str] = None
|
||||
app_name: Optional[str] = None
|
||||
agent_name: Optional[str] = None
|
||||
@@ -28,11 +30,6 @@ class GptsAppDetail(BaseModel):
|
||||
created_at: datetime = datetime.now()
|
||||
updated_at: datetime = datetime.now()
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def to_dict(self):
|
||||
return {k: self._serialize(v) for k, v in self.__dict__.items()}
|
||||
|
||||
@@ -86,6 +83,8 @@ class GptsAppDetail(BaseModel):
|
||||
|
||||
|
||||
class GptsApp(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
app_code: Optional[str] = None
|
||||
app_name: Optional[str] = None
|
||||
app_describe: Optional[str] = None
|
||||
@@ -100,11 +99,6 @@ class GptsApp(BaseModel):
|
||||
updated_at: datetime = datetime.now()
|
||||
details: List[GptsAppDetail] = []
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def to_dict(self):
|
||||
return {k: self._serialize(v) for k, v in self.__dict__.items()}
|
||||
|
||||
@@ -146,7 +140,9 @@ class GptsAppResponse(BaseModel):
|
||||
total_count: Optional[int] = 0
|
||||
total_page: Optional[int] = 0
|
||||
current_page: Optional[int] = 0
|
||||
app_list: Optional[List[GptsApp]] = []
|
||||
app_list: Optional[List[GptsApp]] = Field(
|
||||
default_factory=list, description="app list"
|
||||
)
|
||||
|
||||
|
||||
class GptsAppCollection(BaseModel):
|
||||
@@ -207,7 +203,8 @@ class GptsAppEntity(Model):
|
||||
team_context = Column(
|
||||
Text,
|
||||
nullable=True,
|
||||
comment="The execution logic and team member content that teams with different working modes rely on",
|
||||
comment="The execution logic and team member content that teams with different "
|
||||
"working modes rely on",
|
||||
)
|
||||
|
||||
user_code = Column(String(255), nullable=True, comment="user code")
|
||||
@@ -565,7 +562,7 @@ def _parse_team_context(team_context: Optional[Union[str, AWELTeamContext]] = No
|
||||
parse team_context to str
|
||||
"""
|
||||
if isinstance(team_context, AWELTeamContext):
|
||||
return team_context.json()
|
||||
return model_to_json(team_context)
|
||||
return team_context
|
||||
|
||||
|
||||
|
@@ -1,9 +1,12 @@
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from sqlalchemy import Column, DateTime, Integer, String, UniqueConstraint, func
|
||||
|
||||
from dbgpt.storage.metadata import BaseDao, Model
|
||||
|
||||
from ..model import MyPluginVO
|
||||
|
||||
|
||||
class MyPluginEntity(Model):
|
||||
__tablename__ = "my_plugin"
|
||||
@@ -27,6 +30,28 @@ class MyPluginEntity(Model):
|
||||
)
|
||||
UniqueConstraint("user_code", "name", name="uk_name")
|
||||
|
||||
@classmethod
|
||||
def to_vo(cls, entities: List["MyPluginEntity"]) -> List[MyPluginVO]:
|
||||
results = []
|
||||
for entity in entities:
|
||||
results.append(
|
||||
MyPluginVO(
|
||||
id=entity.id,
|
||||
tenant=entity.tenant,
|
||||
user_code=entity.user_code,
|
||||
user_name=entity.user_name,
|
||||
sys_code=entity.sys_code,
|
||||
name=entity.name,
|
||||
file_name=entity.file_name,
|
||||
type=entity.type,
|
||||
version=entity.version,
|
||||
use_count=entity.use_count,
|
||||
succ_count=entity.succ_count,
|
||||
gmt_created=entity.gmt_created.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
class MyPluginDao(BaseDao):
|
||||
def add(self, engity: MyPluginEntity):
|
||||
|
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
import pytz
|
||||
from sqlalchemy import (
|
||||
@@ -14,6 +15,8 @@ from sqlalchemy import (
|
||||
|
||||
from dbgpt.storage.metadata import BaseDao, Model
|
||||
|
||||
from ..model import PluginHubVO
|
||||
|
||||
# TODO We should consider that the production environment does not have permission to execute the DDL
|
||||
char_set_sql = DDL("ALTER TABLE plugin_hub CONVERT TO CHARACTER SET utf8mb4")
|
||||
|
||||
@@ -40,6 +43,27 @@ class PluginHubEntity(Model):
|
||||
UniqueConstraint("name", name="uk_name")
|
||||
Index("idx_q_type", "type")
|
||||
|
||||
@classmethod
|
||||
def to_vo(cls, entities: List["PluginHubEntity"]) -> List[PluginHubVO]:
|
||||
results = []
|
||||
for entity in entities:
|
||||
vo = PluginHubVO(
|
||||
id=entity.id,
|
||||
name=entity.name,
|
||||
description=entity.description,
|
||||
author=entity.author,
|
||||
email=entity.email,
|
||||
type=entity.type,
|
||||
version=entity.version,
|
||||
storage_channel=entity.storage_channel,
|
||||
storage_url=entity.storage_url,
|
||||
download_param=entity.download_param,
|
||||
installed=entity.installed,
|
||||
gmt_created=entity.gmt_created.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
)
|
||||
results.append(vo)
|
||||
return results
|
||||
|
||||
|
||||
class PluginHubDao(BaseDao):
|
||||
def add(self, engity: PluginHubEntity):
|
||||
|
@@ -18,6 +18,9 @@ from dbgpt.serve.agent.model import (
|
||||
PluginHubParam,
|
||||
)
|
||||
|
||||
from ..db import MyPluginEntity
|
||||
from ..model import MyPluginVO, PluginHubVO
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -73,7 +76,7 @@ async def plugin_hub_update(update_param: PluginHubParam = Body()):
|
||||
return Result.failed(code="E0020", msg=f"Agent Hub Update Error! {e}")
|
||||
|
||||
|
||||
@router.post("/v1/agent/query", response_model=Result[str])
|
||||
@router.post("/v1/agent/query", response_model=Result[dict])
|
||||
async def get_agent_list(filter: PagenationFilter[PluginHubFilter] = Body()):
|
||||
logger.info(f"get_agent_list:{filter.__dict__}")
|
||||
filter_enetity: PluginHubEntity = PluginHubEntity()
|
||||
@@ -85,24 +88,21 @@ async def get_agent_list(filter: PagenationFilter[PluginHubFilter] = Body()):
|
||||
datas, total_pages, total_count = plugin_hub.hub_dao.list(
|
||||
filter_enetity, filter.page_index, filter.page_size
|
||||
)
|
||||
result: PagenationResult[PluginHubEntity] = PagenationResult[PluginHubEntity]()
|
||||
result: PagenationResult[PluginHubVO] = PagenationResult[PluginHubVO]()
|
||||
result.page_index = filter.page_index
|
||||
result.page_size = filter.page_size
|
||||
result.total_page = total_pages
|
||||
result.total_row_count = total_count
|
||||
result.datas = datas
|
||||
result.datas = PluginHubEntity.to_vo(datas)
|
||||
# print(json.dumps(result.to_dic()))
|
||||
return Result.succ(result.to_dic())
|
||||
|
||||
|
||||
@router.post("/v1/agent/my", response_model=Result[str])
|
||||
@router.post("/v1/agent/my", response_model=Result[List[MyPluginVO]])
|
||||
async def my_agents(user: str = None):
|
||||
logger.info(f"my_agents:{user}")
|
||||
agents = plugin_hub.get_my_plugin(user)
|
||||
agent_dicts = []
|
||||
for agent in agents:
|
||||
agent_dicts.append(agent.__dict__)
|
||||
|
||||
agent_dicts = MyPluginEntity.to_vo(agents)
|
||||
return Result.succ(agent_dicts)
|
||||
|
||||
|
||||
|
@@ -1,7 +1,7 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Generic, List, Optional, TypeVar
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -13,6 +13,7 @@ class PagenationFilter(BaseModel, Generic[T]):
|
||||
|
||||
|
||||
class PagenationResult(BaseModel, Generic[T]):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
page_index: int = 1
|
||||
page_size: int = 20
|
||||
total_page: int = 0
|
||||
@@ -34,14 +35,14 @@ class PagenationResult(BaseModel, Generic[T]):
|
||||
|
||||
@dataclass
|
||||
class PluginHubFilter(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
author: str
|
||||
email: str
|
||||
type: str
|
||||
version: str
|
||||
storage_channel: str
|
||||
storage_url: str
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
author: Optional[str] = None
|
||||
email: Optional[str] = None
|
||||
type: Optional[str] = None
|
||||
version: Optional[str] = None
|
||||
storage_channel: Optional[str] = None
|
||||
storage_url: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -67,3 +68,33 @@ class PluginHubParam(BaseModel):
|
||||
authorization: Optional[str] = Field(
|
||||
None, description="github download authorization", nullable=True
|
||||
)
|
||||
|
||||
|
||||
class PluginHubVO(BaseModel):
|
||||
id: int = Field(..., description="Plugin id")
|
||||
name: str = Field(..., description="Plugin name")
|
||||
description: str = Field(..., description="Plugin description")
|
||||
author: Optional[str] = Field(None, description="Plugin author")
|
||||
email: Optional[str] = Field(None, description="Plugin email")
|
||||
type: Optional[str] = Field(None, description="Plugin type")
|
||||
version: Optional[str] = Field(None, description="Plugin version")
|
||||
storage_channel: Optional[str] = Field(None, description="Plugin storage channel")
|
||||
storage_url: Optional[str] = Field(None, description="Plugin storage url")
|
||||
download_param: Optional[str] = Field(None, description="Plugin download param")
|
||||
installed: Optional[int] = Field(None, description="Plugin installed")
|
||||
gmt_created: Optional[str] = Field(None, description="Plugin upload time")
|
||||
|
||||
|
||||
class MyPluginVO(BaseModel):
|
||||
id: int = Field(..., description="My Plugin")
|
||||
tenant: Optional[str] = Field(None, description="My Plugin tenant")
|
||||
user_code: Optional[str] = Field(None, description="My Plugin user code")
|
||||
user_name: Optional[str] = Field(None, description="My Plugin user name")
|
||||
sys_code: Optional[str] = Field(None, description="My Plugin sys code")
|
||||
name: str = Field(..., description="My Plugin name")
|
||||
file_name: str = Field(..., description="My Plugin file name")
|
||||
type: Optional[str] = Field(None, description="My Plugin type")
|
||||
version: Optional[str] = Field(None, description="My Plugin version")
|
||||
use_count: Optional[int] = Field(None, description="My Plugin use count")
|
||||
succ_count: Optional[int] = Field(None, description="My Plugin succ count")
|
||||
gmt_created: Optional[str] = Field(None, description="My Plugin install time")
|
||||
|
Reference in New Issue
Block a user