diff --git a/dbgpt/_private/llm_metadata.py b/dbgpt/_private/llm_metadata.py index a661c04ef..e77908a2b 100644 --- a/dbgpt/_private/llm_metadata.py +++ b/dbgpt/_private/llm_metadata.py @@ -1,10 +1,12 @@ -from dbgpt._private.pydantic import BaseModel, Field +from dbgpt._private.pydantic import BaseModel, ConfigDict, Field DEFAULT_CONTEXT_WINDOW = 3900 DEFAULT_NUM_OUTPUTS = 256 class LLMMetadata(BaseModel): + model_config = ConfigDict(protected_namespaces=()) + context_window: int = Field( default=DEFAULT_CONTEXT_WINDOW, description=( diff --git a/dbgpt/_private/pydantic.py b/dbgpt/_private/pydantic.py index 6cc3f9a76..44e939e22 100644 --- a/dbgpt/_private/pydantic.py +++ b/dbgpt/_private/pydantic.py @@ -1,7 +1,14 @@ +from typing import get_origin + import pydantic if pydantic.VERSION.startswith("1."): PYDANTIC_VERSION = 1 + raise NotImplementedError("pydantic 1.x is not supported, please upgrade to 2.x.") +else: + PYDANTIC_VERSION = 2 + # pydantic 2.x + # Now we upgrade to pydantic 2.x from pydantic import ( BaseModel, ConfigDict, @@ -13,33 +20,72 @@ if pydantic.VERSION.startswith("1."): PositiveInt, PrivateAttr, ValidationError, - root_validator, - validator, - ) -else: - PYDANTIC_VERSION = 2 - # pydantic 2.x - from pydantic.v1 import ( - BaseModel, - ConfigDict, - Extra, - Field, - NonNegativeFloat, - NonNegativeInt, - PositiveFloat, - PositiveInt, - PrivateAttr, - ValidationError, + field_validator, + model_validator, root_validator, validator, ) + EXTRA_FORBID = "forbid" -def model_to_json(model, **kwargs): - """Convert a pydantic model to json""" + +def model_to_json(model, **kwargs) -> str: + """Convert a pydantic model to json.""" if PYDANTIC_VERSION == 1: return model.json(**kwargs) else: if "ensure_ascii" in kwargs: del kwargs["ensure_ascii"] return model.model_dump_json(**kwargs) + + +def model_to_dict(model, **kwargs) -> dict: + """Convert a pydantic model to dict.""" + if PYDANTIC_VERSION == 1: + return model.dict(**kwargs) + else: + return model.model_dump(**kwargs) + + +def model_fields(model): + """Return the fields of a pydantic model.""" + if PYDANTIC_VERSION == 1: + return model.__fields__ + else: + return model.model_fields + + +def field_is_required(field) -> bool: + """Return if a field is required.""" + if PYDANTIC_VERSION == 1: + return field.required + else: + return field.is_required() + + +def field_outer_type(field): + """Return the outer type of a field.""" + if PYDANTIC_VERSION == 1: + return field.outer_type_ + else: + # https://github.com/pydantic/pydantic/discussions/7217 + origin = get_origin(field.annotation) + if origin is None: + return field.annotation + return origin + + +def field_description(field): + """Return the description of a field.""" + if PYDANTIC_VERSION == 1: + return field.field_info.description + else: + return field.description + + +def field_default(field): + """Return the default value of a field.""" + if PYDANTIC_VERSION == 1: + return field.field_info.default + else: + return field.default diff --git a/dbgpt/agent/actions/action.py b/dbgpt/agent/actions/action.py index 4e598b317..f2fd950b5 100644 --- a/dbgpt/agent/actions/action.py +++ b/dbgpt/agent/actions/action.py @@ -15,7 +15,13 @@ from typing import ( get_origin, ) -from dbgpt._private.pydantic import BaseModel +from dbgpt._private.pydantic import ( + BaseModel, + field_default, + field_description, + model_fields, + model_to_dict, +) from dbgpt.util.json_utils import find_json_objects from ...vis.base import Vis @@ -45,6 +51,10 @@ class ActionOutput(BaseModel): return None return cls.parse_obj(param) + def to_dict(self) -> Dict[str, Any]: + """Convert the object to a dictionary.""" + return model_to_dict(self) + class Action(ABC, Generic[T]): """Base Action class for defining agent actions.""" @@ -85,12 +95,13 @@ class Action(ABC, Generic[T]): if origin is None: example = {} single_model_type = cast(Type[BaseModel], model_type) - for field_name, field in single_model_type.__fields__.items(): - field_info = field.field_info - if field_info.description: - example[field_name] = field_info.description - elif field_info.default: - example[field_name] = field_info.default + for field_name, field in model_fields(single_model_type).items(): + description = field_description(field) + default_value = field_default(field) + if description: + example[field_name] = description + elif default_value: + example[field_name] = default_value else: example[field_name] = "" return example diff --git a/dbgpt/agent/actions/chart_action.py b/dbgpt/agent/actions/chart_action.py index 5f68fb745..76df3a77a 100644 --- a/dbgpt/agent/actions/chart_action.py +++ b/dbgpt/agent/actions/chart_action.py @@ -3,7 +3,7 @@ import json import logging from typing import Optional -from dbgpt._private.pydantic import BaseModel, Field +from dbgpt._private.pydantic import BaseModel, Field, model_to_json from dbgpt.vis.tags.vis_chart import Vis, VisChart from ..resource.resource_api import AgentResource, ResourceType @@ -86,13 +86,13 @@ class ChartAction(Action[SqlInput]): if not self.render_protocol: raise ValueError("The rendering protocol is not initialized!") view = await self.render_protocol.display( - chart=json.loads(param.json()), data_df=data_df + chart=json.loads(model_to_json(param)), data_df=data_df ) if not self.resource_need: raise ValueError("The resource type is not found!") return ActionOutput( is_exe_success=True, - content=param.json(), + content=model_to_json(param), view=view, resource_type=self.resource_need.value, resource_value=resource.value, diff --git a/dbgpt/agent/actions/dashboard_action.py b/dbgpt/agent/actions/dashboard_action.py index a28dfcb8d..098f81835 100644 --- a/dbgpt/agent/actions/dashboard_action.py +++ b/dbgpt/agent/actions/dashboard_action.py @@ -3,7 +3,7 @@ import json import logging from typing import List, Optional -from dbgpt._private.pydantic import BaseModel, Field +from dbgpt._private.pydantic import BaseModel, Field, model_to_dict from dbgpt.vis.tags.vis_dashboard import Vis, VisDashboard from ..resource.resource_api import AgentResource, ResourceType @@ -30,6 +30,10 @@ class ChartItem(BaseModel): ) thought: str = Field(..., description="Summary of thoughts to the user") + def to_dict(self): + """Convert to dict.""" + return model_to_dict(self) + class DashboardAction(Action[List[ChartItem]]): """Dashboard action class.""" @@ -101,7 +105,7 @@ class DashboardAction(Action[List[ChartItem]]): sql_df = await resource_db_client.query_to_df( resource.value, chart_item.sql ) - chart_dict = chart_item.dict() + chart_dict = chart_item.to_dict() chart_dict["data"] = sql_df except Exception as e: @@ -113,7 +117,9 @@ class DashboardAction(Action[List[ChartItem]]): view = await self.render_protocol.display(charts=chart_params) return ActionOutput( is_exe_success=True, - content=json.dumps([chart_item.dict() for chart_item in chart_items]), + content=json.dumps( + [chart_item.to_dict() for chart_item in chart_items] + ), view=view, ) except Exception as e: diff --git a/dbgpt/agent/core/agent_manage.py b/dbgpt/agent/core/agent_manage.py index 8997b6c3a..3cc85fa62 100644 --- a/dbgpt/agent/core/agent_manage.py +++ b/dbgpt/agent/core/agent_manage.py @@ -112,7 +112,7 @@ class AgentManager(BaseComponent): """Return the description of an agent by name.""" return self._agents[name][1].desc - def all_agents(self): + def all_agents(self) -> Dict[str, str]: """Return a dictionary of all registered agents and their descriptions.""" result = {} for name, value in self._agents.items(): diff --git a/dbgpt/agent/core/base_agent.py b/dbgpt/agent/core/base_agent.py index b5b5ad9e2..6dcaa1ee4 100644 --- a/dbgpt/agent/core/base_agent.py +++ b/dbgpt/agent/core/base_agent.py @@ -5,7 +5,7 @@ import json import logging from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast -from dbgpt._private.pydantic import Field +from dbgpt._private.pydantic import ConfigDict, Field from dbgpt.core import LLMClient, ModelMessageRoleType from dbgpt.util.error_types import LLMChatError from dbgpt.util.tracer import SpanType, root_tracer @@ -27,6 +27,8 @@ logger = logging.getLogger(__name__) class ConversableAgent(Role, Agent): """ConversableAgent is an agent that can communicate with other agents.""" + model_config = ConfigDict(arbitrary_types_allowed=True) + agent_context: Optional[AgentContext] = Field(None, description="Agent context") actions: List[Action] = Field(default_factory=list) resources: List[AgentResource] = Field(default_factory=list) @@ -38,11 +40,6 @@ class ConversableAgent(Role, Agent): llm_client: Optional[AIWrapper] = None oai_system_message: List[Dict] = Field(default_factory=list) - class Config: - """Pydantic configuration.""" - - arbitrary_types_allowed = True - def __init__(self, **kwargs): """Create a new agent.""" Role.__init__(self, **kwargs) @@ -377,8 +374,10 @@ class ConversableAgent(Role, Agent): **act_extent_param, ) if act_out: - reply_message.action_report = act_out.dict() - span.metadata["action_report"] = act_out.dict() if act_out else None + reply_message.action_report = act_out.to_dict() + span.metadata["action_report"] = ( + act_out.to_dict() if act_out else None + ) with root_tracer.start_span( "agent.generate_reply.verify", @@ -496,7 +495,7 @@ class ConversableAgent(Role, Agent): "recipient": self.get_name(), "reviewer": reviewer.get_name() if reviewer else None, "need_resource": need_resource.to_dict() if need_resource else None, - "rely_action_out": last_out.dict() if last_out else None, + "rely_action_out": last_out.to_dict() if last_out else None, "conv_uid": self.not_null_agent_context.conv_id, "action_index": i, "total_action": len(self.actions), @@ -508,7 +507,7 @@ class ConversableAgent(Role, Agent): rely_action_out=last_out, **kwargs, ) - span.metadata["action_out"] = last_out.dict() if last_out else None + span.metadata["action_out"] = last_out.to_dict() if last_out else None return last_out async def correctness_check( diff --git a/dbgpt/agent/core/base_team.py b/dbgpt/agent/core/base_team.py index d04e0ca64..2460086e3 100644 --- a/dbgpt/agent/core/base_team.py +++ b/dbgpt/agent/core/base_team.py @@ -3,7 +3,7 @@ import logging from typing import Dict, List, Optional, Tuple, Union -from dbgpt._private.pydantic import BaseModel, Field +from dbgpt._private.pydantic import BaseModel, ConfigDict, Field from ..actions.action import ActionOutput from .agent import Agent, AgentMessage @@ -68,16 +68,13 @@ def _content_str(content: Union[str, List, None]) -> str: class Team(BaseModel): """Team class for managing a group of agents in a team chat.""" + model_config = ConfigDict(arbitrary_types_allowed=True) + agents: List[Agent] = Field(default_factory=list) messages: List[Dict] = Field(default_factory=list) max_round: int = 100 is_team: bool = True - class Config: - """Pydantic model configuration.""" - - arbitrary_types_allowed = True - def __init__(self, **kwargs): """Create a new Team instance.""" super().__init__(**kwargs) @@ -122,6 +119,8 @@ class Team(BaseModel): class ManagerAgent(ConversableAgent, Team): """Manager Agent class.""" + model_config = ConfigDict(arbitrary_types_allowed=True) + profile: str = "TeamManager" goal: str = "manage all hired intelligent agents to complete mission objectives" constraints: List[str] = [] @@ -132,11 +131,6 @@ class ManagerAgent(ConversableAgent, Team): # of the agent has already been retried. max_retry_count: int = 1 - class Config: - """Pydantic model configuration.""" - - arbitrary_types_allowed = True - def __init__(self, **kwargs): """Create a new ManagerAgent instance.""" ConversableAgent.__init__(self, **kwargs) diff --git a/dbgpt/agent/core/llm/llm.py b/dbgpt/agent/core/llm/llm.py index b0e655dc5..9dc57f4dc 100644 --- a/dbgpt/agent/core/llm/llm.py +++ b/dbgpt/agent/core/llm/llm.py @@ -4,7 +4,7 @@ from collections import defaultdict from enum import Enum from typing import Any, Dict, List, Optional, Type -from dbgpt._private.pydantic import BaseModel, Field +from dbgpt._private.pydantic import BaseModel, ConfigDict, Field from dbgpt.core import LLMClient, ModelMetadata, ModelRequest logger = logging.getLogger(__name__) @@ -106,11 +106,8 @@ def register_llm_strategy( class LLMConfig(BaseModel): """LLM configuration.""" + model_config = ConfigDict(arbitrary_types_allowed=True) + llm_client: Optional[LLMClient] = Field(default_factory=LLMClient) llm_strategy: LLMStrategyType = Field(default=LLMStrategyType.Default) strategy_context: Optional[Any] = None - - class Config: - """Pydantic model config.""" - - arbitrary_types_allowed = True diff --git a/dbgpt/agent/core/role.py b/dbgpt/agent/core/role.py index 1a97af64f..04d50c204 100644 --- a/dbgpt/agent/core/role.py +++ b/dbgpt/agent/core/role.py @@ -2,33 +2,30 @@ from abc import ABC from typing import List, Optional -from dbgpt._private.pydantic import BaseModel +from dbgpt._private.pydantic import BaseModel, ConfigDict, Field class Role(ABC, BaseModel): """Role class for role-based conversation.""" + model_config = ConfigDict(arbitrary_types_allowed=True) + profile: str = "" name: str = "" - resource_introduction = "" + resource_introduction: str = "" goal: str = "" expand_prompt: str = "" - fixed_subgoal: Optional[str] = None + fixed_subgoal: Optional[str] = Field(None, description="Fixed subgoal") - constraints: List[str] = [] + constraints: List[str] = Field(default_factory=list, description="Constraints") examples: str = "" desc: str = "" language: str = "en" is_human: bool = False is_team: bool = False - class Config: - """Pydantic config.""" - - arbitrary_types_allowed = True - def prompt_template( self, specified_prompt: Optional[str] = None, diff --git a/dbgpt/agent/core/user_proxy_agent.py b/dbgpt/agent/core/user_proxy_agent.py index 62cb7a112..61087ae12 100644 --- a/dbgpt/agent/core/user_proxy_agent.py +++ b/dbgpt/agent/core/user_proxy_agent.py @@ -8,7 +8,7 @@ class UserProxyAgent(ConversableAgent): That can execute code and provide feedback to the other agents. """ - name = "User" + name: str = "User" profile: str = "Human" desc: str = ( @@ -16,4 +16,4 @@ class UserProxyAgent(ConversableAgent): "Plan execution needs to be approved by this admin." ) - is_human = True + is_human: bool = True diff --git a/dbgpt/agent/expand/dashboard_assistant_agent.py b/dbgpt/agent/expand/dashboard_assistant_agent.py index bdfd2656e..337ead0e4 100644 --- a/dbgpt/agent/expand/dashboard_assistant_agent.py +++ b/dbgpt/agent/expand/dashboard_assistant_agent.py @@ -34,8 +34,6 @@ class DashboardAssistantAgent(ConversableAgent): "professional reports" ) - max_retry_count: int = 3 - def __init__(self, **kwargs): """Create a new instance of DashboardAssistantAgent.""" super().__init__(**kwargs) diff --git a/dbgpt/agent/expand/data_scientist_agent.py b/dbgpt/agent/expand/data_scientist_agent.py index 55fd73ccb..cbe8150c8 100644 --- a/dbgpt/agent/expand/data_scientist_agent.py +++ b/dbgpt/agent/expand/data_scientist_agent.py @@ -17,7 +17,7 @@ logger = logging.getLogger(__name__) class DataScientistAgent(ConversableAgent): """Data Scientist Agent.""" - name = "Edgar" + name: str = "Edgar" profile: str = "DataScientist" goal: str = ( "Use correct {dialect} SQL to analyze and solve tasks based on the data" diff --git a/dbgpt/agent/expand/plugin_assistant_agent.py b/dbgpt/agent/expand/plugin_assistant_agent.py index 0829d4c6a..9c7137d66 100644 --- a/dbgpt/agent/expand/plugin_assistant_agent.py +++ b/dbgpt/agent/expand/plugin_assistant_agent.py @@ -16,7 +16,7 @@ class PluginAssistantAgent(ConversableAgent): plugin_generator: Optional[PluginPromptGenerator] = None - name = "LuBan" + name: str = "LuBan" profile: str = "ToolExpert" goal: str = ( "Read and understand the tool information given in the resources below to " diff --git a/dbgpt/agent/expand/retrieve_summary_assistant_agent.py b/dbgpt/agent/expand/retrieve_summary_assistant_agent.py index fc839f9e0..9e1df8954 100644 --- a/dbgpt/agent/expand/retrieve_summary_assistant_agent.py +++ b/dbgpt/agent/expand/retrieve_summary_assistant_agent.py @@ -244,7 +244,7 @@ class RetrieveSummaryAssistantAgent(ConversableAgent): **act_extent_param, ) if act_out: - reply_message.action_report = act_out.dict() + reply_message.action_report = act_out.to_dict() # 4.Reply information verification check_pass, reason = await self.verify(reply_message, sender, reviewer) is_success = check_pass diff --git a/dbgpt/agent/expand/summary_assistant_agent.py b/dbgpt/agent/expand/summary_assistant_agent.py index f7dc65ace..8f6521d81 100644 --- a/dbgpt/agent/expand/summary_assistant_agent.py +++ b/dbgpt/agent/expand/summary_assistant_agent.py @@ -11,7 +11,7 @@ logger = logging.getLogger(__name__) class SummaryAssistantAgent(ConversableAgent): """Summary Assistant Agent.""" - name = "Aristotle" + name: str = "Aristotle" profile: str = "Summarizer" goal: str = ( "Summarize answer summaries based on user questions from provided " diff --git a/dbgpt/agent/plan/awel/agent_operator_resource.py b/dbgpt/agent/plan/awel/agent_operator_resource.py index 02325e0e1..94577b249 100644 --- a/dbgpt/agent/plan/awel/agent_operator_resource.py +++ b/dbgpt/agent/plan/awel/agent_operator_resource.py @@ -1,7 +1,7 @@ """The AWEL Agent Operator Resource.""" from typing import Any, Dict, List, Optional -from dbgpt._private.pydantic import BaseModel, Field, root_validator +from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_validator from dbgpt.core import LLMClient from dbgpt.core.awel.flow import ( FunctionDynamicOptions, @@ -55,9 +55,12 @@ from ...resource.resource_api import AgentResource, ResourceType class AWELAgentResource(AgentResource): """AWEL Agent Resource.""" - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Pre fill the agent ResourceType.""" + if not isinstance(values, dict): + return values name = values.pop("agent_resource_name") type = values.pop("agent_resource_type") value = values.pop("agent_resource_value") @@ -109,10 +112,7 @@ class AWELAgentResource(AgentResource): class AWELAgentConfig(LLMConfig): """AWEL Agent Config.""" - @root_validator(pre=True) - def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]: - """Pre fill the agent ResourceType.""" - return values + pass def _agent_resource_option_values() -> List[OptionValue]: @@ -173,20 +173,20 @@ def _agent_resource_option_values() -> List[OptionValue]: class AWELAgent(BaseModel): """AWEL Agent.""" + model_config = ConfigDict(arbitrary_types_allowed=True) + agent_profile: str role_name: Optional[str] = None llm_config: Optional[LLMConfig] = None resources: List[AgentResource] = Field(default_factory=list) fixed_subgoal: Optional[str] = None - class Config: - """Config for the BaseModel.""" - - arbitrary_types_allowed = True - - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Pre fill the agent ResourceType.""" + if not isinstance(values, dict): + return values resource = values.pop("agent_resource") llm_config = values.pop("agent_llm_Config") diff --git a/dbgpt/agent/plan/awel/team_awel_layout.py b/dbgpt/agent/plan/awel/team_awel_layout.py index 61e6a064e..681d554ef 100644 --- a/dbgpt/agent/plan/awel/team_awel_layout.py +++ b/dbgpt/agent/plan/awel/team_awel_layout.py @@ -5,7 +5,13 @@ from abc import ABC, abstractmethod from typing import List, Optional, cast from dbgpt._private.config import Config -from dbgpt._private.pydantic import BaseModel, Field, validator +from dbgpt._private.pydantic import ( + BaseModel, + ConfigDict, + Field, + model_to_dict, + validator, +) from dbgpt.core.awel import DAG from dbgpt.core.awel.dag.dag_manager import DAGManager @@ -70,23 +76,20 @@ class AWELTeamContext(BaseModel): def to_dict(self): """Convert the object to a dictionary.""" - return self.dict() + return model_to_dict(self) class AWELBaseManager(ManagerAgent, ABC): """AWEL base manager.""" + model_config = ConfigDict(arbitrary_types_allowed=True) + goal: str = ( "Promote and solve user problems according to the process arranged by AWEL." ) constraints: List[str] = [] desc: str = goal - class Config: - """Config for the BaseModel.""" - - arbitrary_types_allowed = True - async def _a_process_received_message(self, message: AgentMessage, sender: Agent): """Process the received message.""" pass @@ -157,15 +160,12 @@ class WrappedAWELLayoutManager(AWELBaseManager): Receives a DAG or builds a DAG from the agents. """ + model_config = ConfigDict(arbitrary_types_allowed=True) + profile: str = "WrappedAWELLayoutManager" dag: Optional[DAG] = Field(None, description="The DAG of the manager") - class Config: - """Config for the BaseModel.""" - - arbitrary_types_allowed = True - def get_dag(self) -> DAG: """Get the DAG of the manager.""" if self.dag: @@ -236,15 +236,12 @@ class WrappedAWELLayoutManager(AWELBaseManager): class DefaultAWELLayoutManager(AWELBaseManager): """The manager of the team for the AWEL layout.""" + model_config = ConfigDict(arbitrary_types_allowed=True) + profile: str = "DefaultAWELLayoutManager" dag: AWELTeamContext = Field(...) - class Config: - """Config for the BaseModel.""" - - arbitrary_types_allowed = True - @validator("dag") def check_dag(cls, value): """Check the DAG of the manager.""" diff --git a/dbgpt/agent/plan/planner_agent.py b/dbgpt/agent/plan/planner_agent.py index e8ef08252..e9c7e6254 100644 --- a/dbgpt/agent/plan/planner_agent.py +++ b/dbgpt/agent/plan/planner_agent.py @@ -84,7 +84,7 @@ class PlannerAgent(ConversableAgent): " and allocate resources to achieve complex task goals." ) - examples = """ + examples: str = """ user:help me build a sales report summarizing our key metrics and trends assistants:[ {{ diff --git a/dbgpt/agent/plugin/commands/command_manage.py b/dbgpt/agent/plugin/commands/command_manage.py index 959ead090..32cc18276 100644 --- a/dbgpt/agent/plugin/commands/command_manage.py +++ b/dbgpt/agent/plugin/commands/command_manage.py @@ -192,7 +192,7 @@ class PluginStatus(BaseModel): logo_url: Optional[str] = None api_result: Optional[str] = None err_msg: Optional[str] = None - start_time = datetime.now().timestamp() * 1000 + start_time: float = datetime.now().timestamp() * 1000 end_time: Optional[str] = None df: Any = None diff --git a/dbgpt/agent/resource/resource_api.py b/dbgpt/agent/resource/resource_api.py index 85293cbb8..4531d5a0e 100644 --- a/dbgpt/agent/resource/resource_api.py +++ b/dbgpt/agent/resource/resource_api.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from enum import Enum from typing import Any, Dict, List, Optional, Union -from dbgpt._private.pydantic import BaseModel +from dbgpt._private.pydantic import BaseModel, model_to_dict class ResourceType(Enum): @@ -67,7 +67,7 @@ class AgentResource(BaseModel): def to_dict(self) -> Dict[str, Any]: """Convert the AgentResource object to a dictionary.""" - temp = self.dict() + temp = model_to_dict(self) for field, value in temp.items(): if isinstance(value, Enum): temp[field] = value.value diff --git a/dbgpt/app/dbgpt_server.py b/dbgpt/app/dbgpt_server.py index da2cf8efc..76d6a0952 100644 --- a/dbgpt/app/dbgpt_server.py +++ b/dbgpt/app/dbgpt_server.py @@ -5,7 +5,6 @@ from typing import List from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware -from fastapi.openapi.docs import get_swagger_ui_html # fastapi import time cost about 0.05s from fastapi.staticfiles import StaticFiles @@ -24,7 +23,7 @@ from dbgpt.app.component_configs import initialize_components from dbgpt.component import SystemApp from dbgpt.configs.model_config import EMBEDDING_MODEL_CONFIG, LLM_MODEL_CONFIG, LOGDIR from dbgpt.serve.core import add_exception_handler -from dbgpt.util.fastapi import PriorityAPIRouter +from dbgpt.util.fastapi import create_app, replace_router from dbgpt.util.i18n_utils import _, set_default_language from dbgpt.util.parameter_utils import _get_dict_from_obj from dbgpt.util.system_utils import get_system_info @@ -45,16 +44,14 @@ static_file_path = os.path.join(ROOT_PATH, "dbgpt", "app/static") CFG = Config() set_default_language(CFG.LANGUAGE) - -app = FastAPI( +app = create_app( title=_("DB-GPT Open API"), description=_("DB-GPT Open API"), version=version, openapi_tags=[], ) # Use custom router to support priority -app.router = PriorityAPIRouter() -app.setup() +replace_router(app) app.mount( "/swagger_static", diff --git a/dbgpt/app/knowledge/chunk_db.py b/dbgpt/app/knowledge/chunk_db.py index e6c3018e2..fc40573f8 100644 --- a/dbgpt/app/knowledge/chunk_db.py +++ b/dbgpt/app/knowledge/chunk_db.py @@ -4,6 +4,7 @@ from typing import List from sqlalchemy import Column, DateTime, Integer, String, Text, func from dbgpt._private.config import Config +from dbgpt.serve.rag.api.schemas import DocumentChunkVO from dbgpt.storage.metadata import BaseDao, Model CFG = Config() @@ -23,6 +24,22 @@ class DocumentChunkEntity(Model): def __repr__(self): return f"DocumentChunkEntity(id={self.id}, doc_name='{self.doc_name}', doc_type='{self.doc_type}', document_id='{self.document_id}', content='{self.content}', meta_info='{self.meta_info}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')" + @classmethod + def to_to_document_chunk_vo(cls, entity_list: List["DocumentChunkEntity"]): + return [ + DocumentChunkVO( + id=entity.id, + document_id=entity.document_id, + doc_name=entity.doc_name, + doc_type=entity.doc_type, + content=entity.content, + meta_info=entity.meta_info, + gmt_created=entity.gmt_created.strftime("%Y-%m-%d %H:%M:%S"), + gmt_modified=entity.gmt_modified.strftime("%Y-%m-%d %H:%M:%S"), + ) + for entity in entity_list + ] + class DocumentChunkDao(BaseDao): def create_documents_chunks(self, documents: List): @@ -45,7 +62,7 @@ class DocumentChunkDao(BaseDao): def get_document_chunks( self, query: DocumentChunkEntity, page=1, page_size=20, document_ids=None - ): + ) -> List[DocumentChunkVO]: session = self.get_raw_session() document_chunks = session.query(DocumentChunkEntity) if query.id is not None: @@ -81,7 +98,7 @@ class DocumentChunkDao(BaseDao): ) result = document_chunks.all() session.close() - return result + return DocumentChunkEntity.to_to_document_chunk_vo(result) def get_document_chunks_count(self, query: DocumentChunkEntity): session = self.get_raw_session() diff --git a/dbgpt/app/knowledge/document_db.py b/dbgpt/app/knowledge/document_db.py index 1165b7296..c8ce94da6 100644 --- a/dbgpt/app/knowledge/document_db.py +++ b/dbgpt/app/knowledge/document_db.py @@ -4,8 +4,13 @@ from typing import Any, Dict, List, Union from sqlalchemy import Column, DateTime, Integer, String, Text, func from dbgpt._private.config import Config +from dbgpt._private.pydantic import model_to_dict from dbgpt.serve.conversation.api.schemas import ServeRequest -from dbgpt.serve.rag.api.schemas import DocumentServeRequest, DocumentServeResponse +from dbgpt.serve.rag.api.schemas import ( + DocumentServeRequest, + DocumentServeResponse, + DocumentVO, +) from dbgpt.storage.metadata import BaseDao, Model CFG = Config() @@ -30,6 +35,55 @@ class KnowledgeDocumentEntity(Model): def __repr__(self): return f"KnowledgeDocumentEntity(id={self.id}, doc_name='{self.doc_name}', doc_type='{self.doc_type}', chunk_size='{self.chunk_size}', status='{self.status}', last_sync='{self.last_sync}', content='{self.content}', result='{self.result}', summary='{self.summary}', gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')" + @classmethod + def to_document_vo( + cls, entity_list: List["KnowledgeDocumentEntity"] + ) -> List[DocumentVO]: + vo_results = [] + for item in entity_list: + vo_results.append( + DocumentVO( + id=item.id, + doc_name=item.doc_name, + doc_type=item.doc_type, + space=item.space, + chunk_size=item.chunk_size, + status=item.status, + last_sync=item.last_sync.strftime("%Y-%m-%d %H:%M:%S"), + content=item.content, + result=item.result, + vector_ids=item.vector_ids, + summary=item.summary, + gmt_created=item.gmt_created.strftime("%Y-%m-%d %H:%M:%S"), + gmt_modified=item.gmt_modified.strftime("%Y-%m-%d %H:%M:%S"), + ) + ) + return vo_results + + @classmethod + def from_document_vo(cls, vo: DocumentVO) -> "KnowledgeDocumentEntity": + entity = KnowledgeDocumentEntity( + id=vo.id, + doc_name=vo.doc_name, + doc_type=vo.doc_type, + space=vo.space, + chunk_size=vo.chunk_size, + status=vo.status, + content=vo.content, + result=vo.result, + vector_ids=vo.vector_ids, + summary=vo.summary, + ) + if vo.last_sync: + entity.last_sync = datetime.strptime(vo.last_sync, "%Y-%m-%d %H:%M:%S") + if vo.gmt_created: + entity.gmt_created = datetime.strptime(vo.gmt_created, "%Y-%m-%d %H:%M:%S") + if vo.gmt_modified: + entity.gmt_modified = datetime.strptime( + vo.gmt_modified, "%Y-%m-%d %H:%M:%S" + ) + return entity + class KnowledgeDocumentDao(BaseDao): def create_knowledge_document(self, document: KnowledgeDocumentEntity): @@ -53,7 +107,7 @@ class KnowledgeDocumentDao(BaseDao): session.close() return doc_id - def get_knowledge_documents(self, query, page=1, page_size=20): + def get_knowledge_documents(self, query, page=1, page_size=20) -> List[DocumentVO]: """Get a list of documents that match the given query. Args: query: A KnowledgeDocumentEntity object containing the query parameters. @@ -92,9 +146,9 @@ class KnowledgeDocumentDao(BaseDao): ) result = knowledge_documents.all() session.close() - return result + return KnowledgeDocumentEntity.to_document_vo(result) - def documents_by_ids(self, ids) -> List[KnowledgeDocumentEntity]: + def documents_by_ids(self, ids) -> List[DocumentVO]: """Get a list of documents by their IDs. Args: ids: A list of document IDs. @@ -109,7 +163,7 @@ class KnowledgeDocumentDao(BaseDao): ) result = knowledge_documents.all() session.close() - return result + return KnowledgeDocumentEntity.to_document_vo(result) def get_documents(self, query): session = self.get_raw_session() @@ -233,7 +287,9 @@ class KnowledgeDocumentDao(BaseDao): T: The entity """ request_dict = ( - request.dict() if isinstance(request, DocumentServeRequest) else request + model_to_dict(request) + if isinstance(request, DocumentServeRequest) + else request ) entity = KnowledgeDocumentEntity(**request_dict) return entity diff --git a/dbgpt/app/knowledge/request/request.py b/dbgpt/app/knowledge/request/request.py index 14e12ce90..7c4897e03 100644 --- a/dbgpt/app/knowledge/request/request.py +++ b/dbgpt/app/knowledge/request/request.py @@ -1,9 +1,6 @@ from typing import List, Optional -from fastapi import UploadFile - -from dbgpt._private.pydantic import BaseModel -from dbgpt.rag.chunk_manager import ChunkParameters +from dbgpt._private.pydantic import BaseModel, ConfigDict class KnowledgeQueryRequest(BaseModel): @@ -59,6 +56,8 @@ class DocumentQueryRequest(BaseModel): class DocumentSyncRequest(BaseModel): """Sync request""" + model_config = ConfigDict(protected_namespaces=()) + """doc_ids: doc ids""" doc_ids: List @@ -104,6 +103,8 @@ class SpaceArgumentRequest(BaseModel): class DocumentSummaryRequest(BaseModel): """Sync request""" + model_config = ConfigDict(protected_namespaces=()) + """doc_ids: doc ids""" doc_id: int model_name: str @@ -113,5 +114,7 @@ class DocumentSummaryRequest(BaseModel): class EntityExtractRequest(BaseModel): """argument: argument""" + model_config = ConfigDict(protected_namespaces=()) + text: str model_name: str diff --git a/dbgpt/app/knowledge/request/response.py b/dbgpt/app/knowledge/request/response.py index ccdc6519f..d1530b98a 100644 --- a/dbgpt/app/knowledge/request/response.py +++ b/dbgpt/app/knowledge/request/response.py @@ -1,28 +1,29 @@ -from typing import List +from typing import List, Optional -from dbgpt._private.pydantic import BaseModel +from dbgpt._private.pydantic import BaseModel, Field +from dbgpt.serve.rag.api.schemas import DocumentChunkVO, DocumentVO class ChunkQueryResponse(BaseModel): """data: data""" - data: List = None + data: List[DocumentChunkVO] = Field(..., description="document chunk list") """summary: document summary""" - summary: str = None + summary: Optional[str] = Field(None, description="document summary") """total: total size""" - total: int = None + total: Optional[int] = Field(None, description="total size") """page: current page""" - page: int = None + page: Optional[int] = Field(None, description="current page") class DocumentQueryResponse(BaseModel): """data: data""" - data: List = None + data: List[DocumentVO] = Field(..., description="document list") """total: total size""" - total: int = None + total: Optional[int] = Field(None, description="total size") """page: current page""" - page: int = None + page: Optional[int] = Field(None, description="current page") class SpaceQueryResponse(BaseModel): diff --git a/dbgpt/app/knowledge/service.py b/dbgpt/app/knowledge/service.py index 58aff202c..39d6d91b7 100644 --- a/dbgpt/app/knowledge/service.py +++ b/dbgpt/app/knowledge/service.py @@ -174,9 +174,11 @@ class KnowledgeService: Returns: - res DocumentQueryResponse """ - res = DocumentQueryResponse() + + total = None + page = request.page if request.doc_ids and len(request.doc_ids) > 0: - res.data = knowledge_document_dao.documents_by_ids(request.doc_ids) + data = knowledge_document_dao.documents_by_ids(request.doc_ids) else: query = KnowledgeDocumentEntity( doc_name=request.doc_name, @@ -184,12 +186,11 @@ class KnowledgeService: space=space, status=request.status, ) - res.data = knowledge_document_dao.get_knowledge_documents( + data = knowledge_document_dao.get_knowledge_documents( query, page=request.page, page_size=request.page_size ) - res.total = knowledge_document_dao.get_knowledge_documents_count(query) - res.page = request.page - return res + total = knowledge_document_dao.get_knowledge_documents_count(query) + return DocumentQueryResponse(data=data, total=total, page=page) def batch_document_sync( self, @@ -505,13 +506,15 @@ class KnowledgeService: document_query = KnowledgeDocumentEntity(id=request.document_id) documents = knowledge_document_dao.get_documents(document_query) - res = ChunkQueryResponse() - res.data = document_chunk_dao.get_document_chunks( + data = document_chunk_dao.get_document_chunks( query, page=request.page, page_size=request.page_size ) - res.summary = documents[0].summary - res.total = document_chunk_dao.get_document_chunks_count(query) - res.page = request.page + res = ChunkQueryResponse( + data=data, + summary=documents[0].summary, + total=document_chunk_dao.get_document_chunks_count(query), + page=request.page, + ) return res @trace("async_doc_embedding") diff --git a/dbgpt/app/openapi/api_v1/api_v1.py b/dbgpt/app/openapi/api_v1/api_v1.py index f782e5655..55f58bb0e 100644 --- a/dbgpt/app/openapi/api_v1/api_v1.py +++ b/dbgpt/app/openapi/api_v1/api_v1.py @@ -10,6 +10,7 @@ from fastapi import APIRouter, Body, Depends, File, UploadFile from fastapi.responses import StreamingResponse from dbgpt._private.config import Config +from dbgpt._private.pydantic import model_to_dict, model_to_json from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest from dbgpt.app.knowledge.service import KnowledgeService from dbgpt.app.openapi.api_view_model import ( @@ -147,7 +148,7 @@ def get_executor() -> Executor: ).create() -@router.get("/v1/chat/db/list", response_model=Result[DBConfig]) +@router.get("/v1/chat/db/list", response_model=Result[List[DBConfig]]) async def db_connect_list(): return Result.succ(CFG.local_db_manager.get_db_list()) @@ -189,7 +190,7 @@ async def db_summary(db_name: str, db_type: str): return Result.succ(True) -@router.get("/v1/chat/db/support/type", response_model=Result[DbTypeInfo]) +@router.get("/v1/chat/db/support/type", response_model=Result[List[DbTypeInfo]]) async def db_support_types(): support_types = CFG.local_db_manager.get_all_completed_types() db_type_infos = [] @@ -223,7 +224,7 @@ async def dialogue_scenes(): return Result.succ(scene_vos) -@router.post("/v1/chat/mode/params/list", response_model=Result[dict]) +@router.post("/v1/chat/mode/params/list", response_model=Result[dict | list]) async def params_list(chat_mode: str = ChatScene.ChatNormal.value()): if ChatScene.ChatWithDbQA.value() == chat_mode: return Result.succ(get_db_list()) @@ -378,7 +379,9 @@ async def chat_completions( ) else: with root_tracer.start_span( - "get_chat_instance", span_type=SpanType.CHAT, metadata=dialogue.dict() + "get_chat_instance", + span_type=SpanType.CHAT, + metadata=model_to_dict(dialogue), ): chat: BaseChat = await get_chat_instance(dialogue) @@ -458,7 +461,10 @@ async def stream_generator(chat, incremental: bool, model_name: str): chunk = ChatCompletionStreamResponse( id=chat.chat_session_id, choices=[choice_data], model=model_name ) - yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + json_chunk = model_to_json( + chunk, exclude_unset=True, ensure_ascii=False + ) + yield f"data: {json_chunk}\n\n" else: # TODO generate an openai-compatible streaming responses msg = msg.replace("\n", "\\n") diff --git a/dbgpt/app/openapi/api_v1/editor/api_editor_v1.py b/dbgpt/app/openapi/api_v1/editor/api_editor_v1.py index 73c0a95de..8931e4dfe 100644 --- a/dbgpt/app/openapi/api_v1/editor/api_editor_v1.py +++ b/dbgpt/app/openapi/api_v1/editor/api_editor_v1.py @@ -43,7 +43,7 @@ def get_edit_service() -> EditorService: return EditorService.get_instance(CFG.SYSTEM_APP) -@router.get("/v1/editor/db/tables", response_model=Result[DbTable]) +@router.get("/v1/editor/db/tables", response_model=Result[DataNode]) async def get_editor_tables( db_name: str, page_index: int, page_size: int, search_str: str = "" ): @@ -70,15 +70,15 @@ async def get_editor_tables( return Result.succ(db_node) -@router.get("/v1/editor/sql/rounds", response_model=Result[ChatDbRounds]) +@router.get("/v1/editor/sql/rounds", response_model=Result[List[ChatDbRounds]]) async def get_editor_sql_rounds( con_uid: str, editor_service: EditorService = Depends(get_edit_service) ): - logger.info("get_editor_sql_rounds:{con_uid}") + logger.info(f"get_editor_sql_rounds:{ con_uid}") return Result.succ(editor_service.get_editor_sql_rounds(con_uid)) -@router.get("/v1/editor/sql", response_model=Result[dict]) +@router.get("/v1/editor/sql", response_model=Result[dict | list]) async def get_editor_sql( con_uid: str, round: int, editor_service: EditorService = Depends(get_edit_service) ): @@ -107,7 +107,7 @@ async def editor_sql_run(run_param: dict = Body()): end_time = time.time() * 1000 sql_run_data: SqlRunData = SqlRunData( result_info="", - run_cost=(end_time - start_time) / 1000, + run_cost=int((end_time - start_time) / 1000), colunms=colunms, values=sql_result, ) diff --git a/dbgpt/app/openapi/api_v1/editor/service.py b/dbgpt/app/openapi/api_v1/editor/service.py index bb27e0a92..b06cf02ed 100644 --- a/dbgpt/app/openapi/api_v1/editor/service.py +++ b/dbgpt/app/openapi/api_v1/editor/service.py @@ -2,7 +2,7 @@ from __future__ import annotations import json import logging -from typing import TYPE_CHECKING, Dict, List, Optional +from typing import TYPE_CHECKING, Dict, List, Optional, Union from dbgpt._private.config import Config from dbgpt.app.openapi.api_view_model import Result @@ -70,7 +70,7 @@ class EditorService(BaseComponent): def get_editor_sql_by_round( self, conv_uid: str, round_index: int - ) -> Optional[Dict]: + ) -> Optional[Union[List, Dict]]: storage_conv: StorageConversation = self.get_storage_conv(conv_uid) messages_by_round = _split_messages_by_round(storage_conv.messages) for one_round_message in messages_by_round: @@ -184,7 +184,7 @@ class EditorService(BaseComponent): return Result.failed(msg="Can't Find Chart Detail Info!") -def _parse_pure_dict(res_str: str) -> Dict: +def _parse_pure_dict(res_str: str) -> Union[Dict, List]: output_parser = BaseOutputParser() context = output_parser.parse_prompt_response(res_str) return json.loads(context) diff --git a/dbgpt/app/openapi/api_v1/editor/sql_editor.py b/dbgpt/app/openapi/api_v1/editor/sql_editor.py index 5a05199e0..f03c43efd 100644 --- a/dbgpt/app/openapi/api_v1/editor/sql_editor.py +++ b/dbgpt/app/openapi/api_v1/editor/sql_editor.py @@ -1,4 +1,4 @@ -from typing import List +from typing import Any, List, Optional from dbgpt._private.pydantic import BaseModel from dbgpt.app.scene.chat_dashboard.data_preparation.report_schma import ValueItem @@ -9,15 +9,15 @@ class DataNode(BaseModel): key: str type: str = "" - default_value: str = None + default_value: Optional[Any] = None can_null: str = "YES" - comment: str = None + comment: Optional[str] = None children: List = [] class SqlRunData(BaseModel): result_info: str - run_cost: str + run_cost: int colunms: List[str] values: List diff --git a/dbgpt/app/openapi/api_v2.py b/dbgpt/app/openapi/api_v2.py index a43a895ba..684d8c5c1 100644 --- a/dbgpt/app/openapi/api_v2.py +++ b/dbgpt/app/openapi/api_v2.py @@ -8,6 +8,7 @@ from fastapi import APIRouter, Body, Depends, HTTPException from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from starlette.responses import JSONResponse, StreamingResponse +from dbgpt._private.pydantic import model_to_dict, model_to_json from dbgpt.app.openapi.api_v1.api_v1 import ( CHAT_FACTORY, __new_conversation, @@ -130,7 +131,9 @@ async def chat_completions( or request.chat_mode == ChatMode.CHAT_DATA.value ): with root_tracer.start_span( - "get_chat_instance", span_type=SpanType.CHAT, metadata=request.dict() + "get_chat_instance", + span_type=SpanType.CHAT, + metadata=model_to_dict(request), ): chat: BaseChat = await get_chat_instance(request) @@ -243,21 +246,22 @@ async def chat_app_stream_wrapper(request: ChatCompletionRequestBody = None): model=request.model, created=int(time.time()), ) - content = ( - f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + json_content = model_to_json( + chunk, exclude_unset=True, ensure_ascii=False ) + content = f"data: {json_content}\n\n" yield content yield "data: [DONE]\n\n" async def chat_flow_wrapper(request: ChatCompletionRequestBody): flow_service = get_chat_flow() - flow_req = CommonLLMHttpRequestBody(**request.dict()) + flow_req = CommonLLMHttpRequestBody(**model_to_dict(request)) flow_uid = request.chat_param output = await flow_service.safe_chat_flow(flow_uid, flow_req) if not output.success: return JSONResponse( - ErrorResponse(message=output.text, code=output.error_code).dict(), + model_to_dict(ErrorResponse(message=output.text, code=output.error_code)), status_code=400, ) else: @@ -282,7 +286,7 @@ async def chat_flow_stream_wrapper( request (OpenAPIChatCompletionRequest): request """ flow_service = get_chat_flow() - flow_req = CommonLLMHttpRequestBody(**request.dict()) + flow_req = CommonLLMHttpRequestBody(**model_to_dict(request)) flow_uid = request.chat_param async for output in flow_service.chat_stream_openai(flow_uid, flow_req): diff --git a/dbgpt/app/openapi/api_view_model.py b/dbgpt/app/openapi/api_view_model.py index f3655e30a..7d00c4a69 100644 --- a/dbgpt/app/openapi/api_view_model.py +++ b/dbgpt/app/openapi/api_view_model.py @@ -1,17 +1,15 @@ -import time -import uuid -from typing import Any, Generic, List, Literal, Optional, TypeVar +from typing import Any, Dict, Generic, Optional, TypeVar -from dbgpt._private.pydantic import BaseModel, Field +from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_dict T = TypeVar("T") -class Result(Generic[T], BaseModel): +class Result(BaseModel, Generic[T]): success: bool - err_code: str = None - err_msg: str = None - data: T = None + err_code: Optional[str] = None + err_msg: Optional[str] = None + data: Optional[T] = None @classmethod def succ(cls, data: T): @@ -21,6 +19,9 @@ class Result(Generic[T], BaseModel): def failed(cls, code: str = "E000X", msg=None): return Result(success=False, err_code=code, err_msg=msg, data=None) + def to_dict(self) -> Dict[str, Any]: + return model_to_dict(self) + class ChatSceneVo(BaseModel): chat_scene: str = Field(..., description="chat_scene") @@ -31,6 +32,8 @@ class ChatSceneVo(BaseModel): class ConversationVo(BaseModel): + model_config = ConfigDict(protected_namespaces=()) + """ dialogue_uid """ @@ -43,7 +46,7 @@ class ConversationVo(BaseModel): """ user """ - user_name: str = None + user_name: Optional[str] = Field(None, description="user name") """ the scene of chat """ @@ -52,21 +55,23 @@ class ConversationVo(BaseModel): """ chat scene select param """ - select_param: str = None + select_param: Optional[str] = Field(None, description="chat scene select param") """ llm model name """ - model_name: str = None + model_name: Optional[str] = Field(None, description="llm model name") """Used to control whether the content is returned incrementally or in full each time. If this parameter is not provided, the default is full return. """ incremental: bool = False - sys_code: Optional[str] = None + sys_code: Optional[str] = Field(None, description="System code") class MessageVo(BaseModel): + model_config = ConfigDict(protected_namespaces=()) + """ role that sends out the current message """ @@ -83,7 +88,9 @@ class MessageVo(BaseModel): """ time the current message was sent """ - time_stamp: Any = None + time_stamp: Optional[Any] = Field( + None, description="time the current message was sent" + ) """ model_name diff --git a/dbgpt/app/openapi/base.py b/dbgpt/app/openapi/base.py index 04de7b920..b3f537e15 100644 --- a/dbgpt/app/openapi/base.py +++ b/dbgpt/app/openapi/base.py @@ -11,4 +11,4 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE loc = ".".join(list(map(str, error.get("loc")))) message += loc + ":" + error.get("msg") + ";" res = Result.failed(code="E0001", msg=message) - return JSONResponse(status_code=400, content=res.dict()) + return JSONResponse(status_code=400, content=res.to_dict()) diff --git a/dbgpt/app/scene/base.py b/dbgpt/app/scene/base.py index c36501662..08c3cec27 100644 --- a/dbgpt/app/scene/base.py +++ b/dbgpt/app/scene/base.py @@ -1,7 +1,7 @@ from enum import Enum from typing import List, Optional -from dbgpt._private.pydantic import BaseModel, Field +from dbgpt._private.pydantic import BaseModel, ConfigDict, Field from dbgpt.core import BaseOutputParser, ChatPromptTemplate from dbgpt.core._private.example_base import ExampleSelector @@ -154,10 +154,7 @@ class AppScenePromptTemplateAdapter(BaseModel): Include some fields that in :class:`dbgpt.core.PromptTemplate` """ - class Config: - """Configuration for this pydantic object.""" - - arbitrary_types_allowed = True + model_config = ConfigDict(arbitrary_types_allowed=True) prompt: ChatPromptTemplate = Field(..., description="The prompt of this scene") template_scene: Optional[str] = Field( diff --git a/dbgpt/app/scene/base_chat.py b/dbgpt/app/scene/base_chat.py index d36d1d05e..899d35b40 100644 --- a/dbgpt/app/scene/base_chat.py +++ b/dbgpt/app/scene/base_chat.py @@ -6,7 +6,7 @@ from abc import ABC, abstractmethod from typing import Any, AsyncIterator, Dict from dbgpt._private.config import Config -from dbgpt._private.pydantic import Extra +from dbgpt._private.pydantic import EXTRA_FORBID from dbgpt.app.scene.base import AppScenePromptTemplateAdapter, ChatScene from dbgpt.app.scene.operators.app_operator import ( AppChatComposerOperator, @@ -72,11 +72,6 @@ class BaseChat(ABC): # convert system message to human message auto_convert_message: bool = True - class Config: - """Configuration for this pydantic object.""" - - arbitrary_types_allowed = True - @trace("BaseChat.__init__") def __init__(self, chat_param: Dict): """Chat Module Initialization @@ -142,12 +137,6 @@ class BaseChat(ABC): self._message_version = chat_param.get("message_version", "v2") self._chat_param = chat_param - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid - arbitrary_types_allowed = True - @property def chat_type(self) -> str: raise NotImplementedError("Not supported for this chat type.") diff --git a/dbgpt/app/scene/chat_dashboard/data_preparation/report_schma.py b/dbgpt/app/scene/chat_dashboard/data_preparation/report_schma.py index 8a8c12749..4d0b5088c 100644 --- a/dbgpt/app/scene/chat_dashboard/data_preparation/report_schma.py +++ b/dbgpt/app/scene/chat_dashboard/data_preparation/report_schma.py @@ -1,4 +1,4 @@ -from typing import Any, List +from typing import Any, List, Optional from dbgpt._private.pydantic import BaseModel @@ -38,7 +38,7 @@ class ChartData(BaseModel): class ReportData(BaseModel): conv_uid: str template_name: str - template_introduce: str = None + template_introduce: Optional[str] = None charts: List[ChartData] def prepare_dict(self): diff --git a/dbgpt/client/client.py b/dbgpt/client/client.py index a30aba8b9..da5f235d8 100644 --- a/dbgpt/client/client.py +++ b/dbgpt/client/client.py @@ -7,6 +7,7 @@ from urllib.parse import urlparse import httpx +from dbgpt._private.pydantic import model_to_dict from dbgpt.core.schema.api import ChatCompletionResponse, ChatCompletionStreamResponse from .schema import ChatCompletionRequestBody @@ -167,7 +168,7 @@ class Client: enable_vis=enable_vis, ) response = await self._http_client.post( - self._api_url + "/chat/completions", json=request.dict() + self._api_url + "/chat/completions", json=model_to_dict(request) ) if response.status_code == 200: json_data = json.loads(response.text) @@ -242,7 +243,7 @@ class Client: incremental=incremental, enable_vis=enable_vis, ) - async for chat_completion_response in self._chat_stream(request.dict()): + async for chat_completion_response in self._chat_stream(model_to_dict(request)): yield chat_completion_response async def _chat_stream( @@ -262,6 +263,7 @@ class Client: headers={}, ) as response: if response.status_code == 200: + sse_data = "" async for line in response.aiter_lines(): try: if line.strip() == "data: [DONE]": @@ -277,7 +279,9 @@ class Client: ) yield chat_completion_response except Exception as e: - raise e + raise Exception( + f"Failed to parse SSE data: {e}, sse_data: {sse_data}" + ) else: try: diff --git a/dbgpt/client/datasource.py b/dbgpt/client/datasource.py index 47244b47e..09cd938c9 100644 --- a/dbgpt/client/datasource.py +++ b/dbgpt/client/datasource.py @@ -1,6 +1,7 @@ """this module contains the datasource client functions.""" from typing import List +from dbgpt._private.pydantic import model_to_dict from dbgpt.core.schema.api import Result from .client import Client, ClientException @@ -17,7 +18,7 @@ async def create_datasource( datasource (DatasourceModel): The datasource model. """ try: - res = await client.get("/datasources", datasource.dict()) + res = await client.get("/datasources", model_to_dict(datasource)) result: Result = res.json() if result["success"]: return DatasourceModel(**result["data"]) @@ -41,7 +42,7 @@ async def update_datasource( ClientException: If the request failed. """ try: - res = await client.put("/datasources", datasource.dict()) + res = await client.put("/datasources", model_to_dict(datasource)) result: Result = res.json() if result["success"]: return DatasourceModel(**result["data"]) diff --git a/dbgpt/client/flow.py b/dbgpt/client/flow.py index 15a453593..ec1b7edb4 100644 --- a/dbgpt/client/flow.py +++ b/dbgpt/client/flow.py @@ -15,7 +15,7 @@ async def create_flow(client: Client, flow: FlowPanel) -> FlowPanel: flow (FlowPanel): The flow panel. """ try: - res = await client.get("/awel/flows", flow.dict()) + res = await client.get("/awel/flows", flow.to_dict()) result: Result = res.json() if result["success"]: return FlowPanel(**result["data"]) @@ -37,7 +37,7 @@ async def update_flow(client: Client, flow: FlowPanel) -> FlowPanel: ClientException: If the request failed. """ try: - res = await client.put("/awel/flows", flow.dict()) + res = await client.put("/awel/flows", flow.to_dict()) result: Result = res.json() if result["success"]: return FlowPanel(**result["data"]) diff --git a/dbgpt/client/knowledge.py b/dbgpt/client/knowledge.py index f8b3404ec..bfd762770 100644 --- a/dbgpt/client/knowledge.py +++ b/dbgpt/client/knowledge.py @@ -2,6 +2,7 @@ import json from typing import List +from dbgpt._private.pydantic import model_to_dict, model_to_json from dbgpt.core.schema.api import Result from .client import Client, ClientException @@ -20,7 +21,7 @@ async def create_space(client: Client, space_model: SpaceModel) -> SpaceModel: ClientException: If the request failed. """ try: - res = await client.post("/knowledge/spaces", space_model.dict()) + res = await client.post("/knowledge/spaces", model_to_dict(space_model)) result: Result = res.json() if result["success"]: return SpaceModel(**result["data"]) @@ -42,7 +43,7 @@ async def update_space(client: Client, space_model: SpaceModel) -> SpaceModel: ClientException: If the request failed. """ try: - res = await client.put("/knowledge/spaces", space_model.dict()) + res = await client.put("/knowledge/spaces", model_to_dict(space_model)) result: Result = res.json() if result["success"]: return SpaceModel(**result["data"]) @@ -126,7 +127,7 @@ async def create_document(client: Client, doc_model: DocumentModel) -> DocumentM """ try: - res = await client.post_param("/knowledge/documents", doc_model.dict()) + res = await client.post_param("/knowledge/documents", model_to_dict(doc_model)) result: Result = res.json() if result["success"]: return DocumentModel(**result["data"]) @@ -210,7 +211,7 @@ async def sync_document(client: Client, sync_model: SyncModel) -> List: """ try: res = await client.post( - "/knowledge/documents/sync", [json.loads(sync_model.json())] + "/knowledge/documents/sync", [json.loads(model_to_json(sync_model))] ) result: Result = res.json() if result["success"]: diff --git a/dbgpt/client/schema.py b/dbgpt/client/schema.py index d89033203..d534e2839 100644 --- a/dbgpt/client/schema.py +++ b/dbgpt/client/schema.py @@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional, Union from fastapi import File, UploadFile -from dbgpt._private.pydantic import BaseModel, Field +from dbgpt._private.pydantic import BaseModel, ConfigDict, Field from dbgpt.rag.chunk_manager import ChunkParameters @@ -60,7 +60,7 @@ class ChatCompletionRequestBody(BaseModel): "or in full each time. " "If this parameter is not provided, the default is full return.", ) - enable_vis: str = Field( + enable_vis: bool = Field( default=True, description="response content whether to output vis label" ) @@ -267,6 +267,8 @@ class DocumentModel(BaseModel): class SyncModel(BaseModel): """Sync model.""" + model_config = ConfigDict(protected_namespaces=()) + """doc_id: doc id""" doc_id: str = Field(None, description="The doc id") diff --git a/dbgpt/component.py b/dbgpt/component.py index 100f2dd9d..5310c8036 100644 --- a/dbgpt/component.py +++ b/dbgpt/component.py @@ -297,8 +297,8 @@ class SystemApp(LifeCycle): if not self.app: self._register_exit_handler() return + from dbgpt.util.fastapi import register_event_handler - @self.app.on_event("startup") async def startup_event(): """ASGI app startup event handler.""" @@ -312,12 +312,14 @@ class SystemApp(LifeCycle): asyncio.create_task(_startup_func()) self.after_start() - @self.app.on_event("shutdown") async def shutdown_event(): """ASGI app shutdown event handler.""" await self.async_before_stop() self.before_stop() + register_event_handler(self.app, "startup", startup_event) + register_event_handler(self.app, "shutdown", shutdown_event) + def _register_exit_handler(self): """Register an exit handler to stop the system app.""" atexit.register(self.before_stop) diff --git a/dbgpt/core/awel/__init__.py b/dbgpt/core/awel/__init__.py index 25dae7f0b..a0f850f86 100644 --- a/dbgpt/core/awel/__init__.py +++ b/dbgpt/core/awel/__init__.py @@ -159,9 +159,9 @@ def setup_dev_environment( start_http = _check_has_http_trigger(dags) if start_http: - from fastapi import FastAPI + from dbgpt.util.fastapi import create_app - app = FastAPI() + app = create_app() else: app = None system_app = SystemApp(app) diff --git a/dbgpt/core/awel/dag/dag_manager.py b/dbgpt/core/awel/dag/dag_manager.py index 83b9e8946..611988145 100644 --- a/dbgpt/core/awel/dag/dag_manager.py +++ b/dbgpt/core/awel/dag/dag_manager.py @@ -4,6 +4,7 @@ DAGManager will load DAGs from dag_dirs, and register the trigger nodes to TriggerManager. """ import logging +import threading from typing import Dict, List, Optional from dbgpt.component import BaseComponent, ComponentType, SystemApp @@ -29,6 +30,7 @@ class DAGManager(BaseComponent): from ..trigger.trigger_manager import DefaultTriggerManager super().__init__(system_app) + self.lock = threading.Lock() self.dag_loader = LocalFileDAGLoader(dag_dirs) self.system_app = system_app self.dag_map: Dict[str, DAG] = {} @@ -61,39 +63,54 @@ class DAGManager(BaseComponent): def register_dag(self, dag: DAG, alias_name: Optional[str] = None): """Register a DAG.""" - dag_id = dag.dag_id - if dag_id in self.dag_map: - raise ValueError(f"Register DAG error, DAG ID {dag_id} has already exist") - self.dag_map[dag_id] = dag - if alias_name: - self.dag_alias_map[alias_name] = dag_id + with self.lock: + dag_id = dag.dag_id + if dag_id in self.dag_map: + raise ValueError( + f"Register DAG error, DAG ID {dag_id} has already exist" + ) + self.dag_map[dag_id] = dag + if alias_name: + self.dag_alias_map[alias_name] = dag_id - if self._trigger_manager: - for trigger in dag.trigger_nodes: - self._trigger_manager.register_trigger(trigger, self.system_app) - self._trigger_manager.after_register() - else: - logger.warning("No trigger manager, not register dag trigger") + if self._trigger_manager: + for trigger in dag.trigger_nodes: + self._trigger_manager.register_trigger(trigger, self.system_app) + self._trigger_manager.after_register() + else: + logger.warning("No trigger manager, not register dag trigger") def unregister_dag(self, dag_id: str): """Unregister a DAG.""" - if dag_id not in self.dag_map: - raise ValueError(f"Unregister DAG error, DAG ID {dag_id} does not exist") - dag = self.dag_map[dag_id] - # Clear the alias map - for alias_name, _dag_id in self.dag_alias_map.items(): - if _dag_id == dag_id: + with self.lock: + if dag_id not in self.dag_map: + raise ValueError( + f"Unregister DAG error, DAG ID {dag_id} does not exist" + ) + dag = self.dag_map[dag_id] + + # Collect aliases to remove + # TODO(fangyinc): It can be faster if we maintain a reverse map + aliases_to_remove = [ + alias_name + for alias_name, _dag_id in self.dag_alias_map.items() + if _dag_id == dag_id + ] + # Remove collected aliases + for alias_name in aliases_to_remove: del self.dag_alias_map[alias_name] - if self._trigger_manager: - for trigger in dag.trigger_nodes: - self._trigger_manager.unregister_trigger(trigger, self.system_app) - del self.dag_map[dag_id] + if self._trigger_manager: + for trigger in dag.trigger_nodes: + self._trigger_manager.unregister_trigger(trigger, self.system_app) + # Finally remove the DAG from the map + del self.dag_map[dag_id] def get_dag( self, dag_id: Optional[str] = None, alias_name: Optional[str] = None ) -> Optional[DAG]: """Get a DAG by dag_id or alias_name.""" + # Not lock, because it is read only and need to be fast if dag_id and dag_id in self.dag_map: return self.dag_map[dag_id] if alias_name in self.dag_alias_map: diff --git a/dbgpt/core/awel/flow/base.py b/dbgpt/core/awel/flow/base.py index 215679e05..8e9441818 100644 --- a/dbgpt/core/awel/flow/base.py +++ b/dbgpt/core/awel/flow/base.py @@ -7,7 +7,13 @@ from datetime import date, datetime from enum import Enum from typing import Any, Dict, List, Optional, Type, TypeVar, Union, cast -from dbgpt._private.pydantic import BaseModel, Field, ValidationError, root_validator +from dbgpt._private.pydantic import ( + BaseModel, + Field, + ValidationError, + model_to_dict, + model_validator, +) from dbgpt.core.awel.util.parameter_util import BaseDynamicOptions, OptionValue from dbgpt.core.interface.serialization import Serializable @@ -281,7 +287,7 @@ class TypeMetadata(BaseModel): def new(self: TM) -> TM: """Copy the metadata.""" - return self.__class__(**self.dict()) + return self.__class__(**self.model_dump(exclude_defaults=True)) class Parameter(TypeMetadata, Serializable): @@ -332,12 +338,15 @@ class Parameter(TypeMetadata, Serializable): None, description="The value of the parameter(Saved in the dag file)" ) - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Pre fill the metadata. Transform the value to the real type. """ + if not isinstance(values, dict): + return values type_cls = values.get("type_cls") to_handle_values = { "value": values.get("value"), @@ -443,7 +452,7 @@ class Parameter(TypeMetadata, Serializable): def to_dict(self) -> Dict: """Convert current metadata to json dict.""" - dict_value = self.dict(exclude={"options"}) + dict_value = model_to_dict(self, exclude={"options"}) if not self.options: dict_value["options"] = None elif isinstance(self.options, BaseDynamicOptions): @@ -535,7 +544,7 @@ class BaseResource(Serializable, BaseModel): def to_dict(self) -> Dict: """Convert current metadata to json dict.""" - return self.dict() + return model_to_dict(self) class Resource(BaseResource, TypeMetadata): @@ -693,9 +702,12 @@ class BaseMetadata(BaseResource): ) return runnable_parameters - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def base_pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Pre fill the metadata.""" + if not isinstance(values, dict): + return values if "category_label" not in values: category = values["category"] if isinstance(category, str): @@ -713,7 +725,7 @@ class BaseMetadata(BaseResource): def to_dict(self) -> Dict: """Convert current metadata to json dict.""" - dict_value = self.dict(exclude={"parameters"}) + dict_value = model_to_dict(self, exclude={"parameters"}) dict_value["parameters"] = [ parameter.to_dict() for parameter in self.parameters ] @@ -738,9 +750,12 @@ class ResourceMetadata(BaseMetadata, TypeMetadata): ], ) - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Pre fill the metadata.""" + if not isinstance(values, dict): + return values if "flow_type" not in values: values["flow_type"] = "resource" if "id" not in values: @@ -846,9 +861,12 @@ class ViewMetadata(BaseMetadata): examples=["dbgpt.model.operators.LLMOperator"], ) - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Pre fill the metadata.""" + if not isinstance(values, dict): + return values if "flow_type" not in values: values["flow_type"] = "operator" if "id" not in values: diff --git a/dbgpt/core/awel/flow/flow_factory.py b/dbgpt/core/awel/flow/flow_factory.py index 4eebda2c6..e57cc1937 100644 --- a/dbgpt/core/awel/flow/flow_factory.py +++ b/dbgpt/core/awel/flow/flow_factory.py @@ -6,7 +6,13 @@ from contextlib import suppress from enum import Enum from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast -from dbgpt._private.pydantic import BaseModel, Field, root_validator, validator +from dbgpt._private.pydantic import ( + BaseModel, + Field, + field_validator, + model_to_dict, + model_validator, +) from dbgpt.core.awel.dag.base import DAG, DAGNode from .base import ( @@ -73,7 +79,8 @@ class FlowNodeData(BaseModel): ..., description="Absolute position of the node" ) - @validator("data", pre=True) + @field_validator("data", mode="before") + @classmethod def parse_data(cls, value: Any): """Parse the data.""" if isinstance(value, dict): @@ -123,9 +130,12 @@ class FlowEdgeData(BaseModel): examples=["buttonedge"], ) - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Pre fill the metadata.""" + if not isinstance(values, dict): + return values if ( "source_order" not in values and "source_handle" in values @@ -315,9 +325,12 @@ class FlowPanel(BaseModel): examples=["2021-08-01 12:00:00", "2021-08-01 12:00:01", "2021-08-01 12:00:02"], ) - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Pre fill the metadata.""" + if not isinstance(values, dict): + return values label = values.get("label") name = values.get("name") flow_category = str(values.get("flow_category", "")) @@ -329,6 +342,10 @@ class FlowPanel(BaseModel): values["name"] = name return values + def to_dict(self) -> Dict[str, Any]: + """Convert to dict.""" + return model_to_dict(self) + class FlowFactory: """Flow factory.""" diff --git a/dbgpt/core/awel/trigger/http_trigger.py b/dbgpt/core/awel/trigger/http_trigger.py index f3ee7a31e..18f7d5b9b 100644 --- a/dbgpt/core/awel/trigger/http_trigger.py +++ b/dbgpt/core/awel/trigger/http_trigger.py @@ -15,7 +15,14 @@ from typing import ( get_origin, ) -from dbgpt._private.pydantic import BaseModel, Field +from dbgpt._private.pydantic import ( + BaseModel, + Field, + field_is_required, + field_outer_type, + model_fields, + model_to_dict, +) from dbgpt.util.i18n_utils import _ from ..dag.base import DAG @@ -61,7 +68,7 @@ class AWELHttpError(RuntimeError): def _default_streaming_predict_func(body: "CommonRequestType") -> bool: if isinstance(body, BaseModel): - body = body.dict() + body = model_to_dict(body) elif isinstance(body, str): try: body = json.loads(body) @@ -254,7 +261,7 @@ class CommonLLMHttpRequestBody(BaseHttpBody): "or in full each time. " "If this parameter is not provided, the default is full return.", ) - enable_vis: str = Field( + enable_vis: bool = Field( default=True, description="response content whether to output vis label" ) extra: Optional[Dict[str, Any]] = Field( @@ -574,18 +581,20 @@ class HttpTrigger(Trigger): if isinstance(req_body_cls, type) and issubclass( req_body_cls, BaseModel ): - fields = req_body_cls.__fields__ # type: ignore + fields = model_fields(req_body_cls) # type: ignore parameters = [] for field_name, field in fields.items(): default_value = ( - Parameter.empty if field.required else field.default + Parameter.empty + if field_is_required(field) + else field.default ) parameters.append( Parameter( name=field_name, kind=Parameter.KEYWORD_ONLY, default=default_value, - annotation=field.outer_type_, + annotation=field_outer_type(field), ) ) elif req_body_cls == Dict[str, Any] or req_body_cls == dict: @@ -1029,7 +1038,7 @@ class RequestBodyToDictOperator(MapOperator[CommonLLMHttpRequestBody, Dict[str, async def map(self, request_body: CommonLLMHttpRequestBody) -> Dict[str, Any]: """Map the request body to response body.""" - dict_value = request_body.dict() + dict_value = model_to_dict(request_body) if not self._key: return dict_value else: @@ -1138,7 +1147,7 @@ class RequestedParsedOperator(MapOperator[CommonLLMHttpRequestBody, str]): async def map(self, request_body: CommonLLMHttpRequestBody) -> str: """Map the request body to response body.""" - dict_value = request_body.dict() + dict_value = model_to_dict(request_body) if not self._key or self._key not in dict_value: raise ValueError( f"Prefix key {self._key} is not a valid key of the request body" diff --git a/dbgpt/core/awel/util/parameter_util.py b/dbgpt/core/awel/util/parameter_util.py index d68e03f97..defd99a3b 100644 --- a/dbgpt/core/awel/util/parameter_util.py +++ b/dbgpt/core/awel/util/parameter_util.py @@ -4,7 +4,7 @@ import inspect from abc import ABC, abstractmethod from typing import Any, Callable, Dict, List -from dbgpt._private.pydantic import BaseModel, Field, root_validator +from dbgpt._private.pydantic import BaseModel, Field, model_validator from dbgpt.core.interface.serialization import Serializable _DEFAULT_DYNAMIC_REGISTRY = {} @@ -44,9 +44,12 @@ class FunctionDynamicOptions(BaseDynamicOptions): """Return the option values of the parameter.""" return self.func() - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Pre fill the function id.""" + if not isinstance(values, dict): + return values func = values.get("func") if func is None: raise ValueError( diff --git a/dbgpt/core/interface/knowledge.py b/dbgpt/core/interface/knowledge.py index 7fe75b0a5..369108883 100644 --- a/dbgpt/core/interface/knowledge.py +++ b/dbgpt/core/interface/knowledge.py @@ -4,7 +4,7 @@ import json import uuid from typing import Any, Dict -from dbgpt._private.pydantic import BaseModel, Field +from dbgpt._private.pydantic import BaseModel, Field, model_to_dict class Document(BaseModel): @@ -64,7 +64,7 @@ class Chunk(Document): def to_dict(self, **kwargs: Any) -> Dict[str, Any]: """Convert Chunk to dict.""" - data = self.dict(**kwargs) + data = model_to_dict(self, **kwargs) data["class_name"] = self.class_name() return data diff --git a/dbgpt/core/interface/llm.py b/dbgpt/core/interface/llm.py index 4d26ddb84..e6a5d24d4 100644 --- a/dbgpt/core/interface/llm.py +++ b/dbgpt/core/interface/llm.py @@ -10,7 +10,7 @@ from typing import Any, AsyncIterator, Dict, List, Optional, Union from cachetools import TTLCache -from dbgpt._private.pydantic import BaseModel +from dbgpt._private.pydantic import BaseModel, model_to_dict from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType from dbgpt.util import BaseParameters from dbgpt.util.annotations import PublicAPI @@ -312,7 +312,7 @@ class ModelRequest: if isinstance(context, dict): context_dict = context elif isinstance(context, BaseModel): - context_dict = context.dict() + context_dict = model_to_dict(context) if context_dict and "stream" not in context_dict: context_dict["stream"] = stream if context_dict: diff --git a/dbgpt/core/interface/message.py b/dbgpt/core/interface/message.py index 238ce012a..4d832e0af 100755 --- a/dbgpt/core/interface/message.py +++ b/dbgpt/core/interface/message.py @@ -6,7 +6,7 @@ from abc import ABC, abstractmethod from datetime import datetime from typing import Callable, Dict, List, Optional, Tuple, Union, cast -from dbgpt._private.pydantic import BaseModel, Field +from dbgpt._private.pydantic import BaseModel, Field, model_to_dict from dbgpt.core.interface.storage import ( InMemoryStorage, ResourceIdentifier, @@ -42,7 +42,7 @@ class BaseMessage(BaseModel, ABC): """ return { "type": self.type, - "data": self.dict(), + "data": model_to_dict(self), "index": self.index, "round_index": self.round_index, } @@ -264,7 +264,7 @@ class ModelMessage(BaseModel): Returns: List[Dict[str, str]]: The dict list """ - return list(map(lambda m: m.dict(), messages)) + return list(map(lambda m: model_to_dict(m), messages)) @staticmethod def build_human_message(content: str) -> "ModelMessage": diff --git a/dbgpt/core/interface/operators/prompt_operator.py b/dbgpt/core/interface/operators/prompt_operator.py index 6b6d31089..c3765aa67 100644 --- a/dbgpt/core/interface/operators/prompt_operator.py +++ b/dbgpt/core/interface/operators/prompt_operator.py @@ -2,7 +2,7 @@ from abc import ABC from typing import Any, Dict, List, Optional, Union -from dbgpt._private.pydantic import root_validator +from dbgpt._private.pydantic import model_validator from dbgpt.core import ( ModelMessage, ModelMessageRoleType, @@ -71,9 +71,12 @@ from dbgpt.util.i18n_utils import _ class CommonChatPromptTemplate(ChatPromptTemplate): """The common chat prompt template.""" - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Pre fill the messages.""" + if not isinstance(values, dict): + return values if "system_message" not in values: values["system_message"] = "You are a helpful AI Assistant." if "human_message" not in values: diff --git a/dbgpt/core/interface/prompt.py b/dbgpt/core/interface/prompt.py index 1351eb09a..54d98d861 100644 --- a/dbgpt/core/interface/prompt.py +++ b/dbgpt/core/interface/prompt.py @@ -8,7 +8,7 @@ from abc import ABC, abstractmethod from string import Formatter from typing import Any, Callable, Dict, List, Optional, Set, Union -from dbgpt._private.pydantic import BaseModel, root_validator +from dbgpt._private.pydantic import BaseModel, ConfigDict, model_validator from dbgpt.core.interface.message import BaseMessage, HumanMessage, SystemMessage from dbgpt.core.interface.storage import ( InMemoryStorage, @@ -51,6 +51,8 @@ class BasePromptTemplate(BaseModel): class PromptTemplate(BasePromptTemplate): """Prompt template.""" + model_config = ConfigDict(arbitrary_types_allowed=True) + template: str """The prompt template.""" @@ -69,11 +71,6 @@ class PromptTemplate(BasePromptTemplate): template_define: Optional[str] = None """this template define""" - class Config: - """Configuration for this pydantic object.""" - - arbitrary_types_allowed = True - @property def _prompt_type(self) -> str: """Return the prompt type key.""" @@ -239,9 +236,12 @@ class ChatPromptTemplate(BasePromptTemplate): raise ValueError(f"Unsupported message type: {type(message)}") return result_messages - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def base_pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Pre-fill the messages.""" + if not isinstance(values, dict): + return values input_variables = values.get("input_variables", {}) messages = values.get("messages", []) if not input_variables: diff --git a/dbgpt/core/interface/serialization.py b/dbgpt/core/interface/serialization.py index 72f1fe3b0..51ba29dcf 100644 --- a/dbgpt/core/interface/serialization.py +++ b/dbgpt/core/interface/serialization.py @@ -9,7 +9,7 @@ from typing import Dict, Optional, Type class Serializable(ABC): """The serializable abstract class.""" - serializer: Optional["Serializer"] = None + _serializer: Optional["Serializer"] = None @abstractmethod def to_dict(self) -> Dict: @@ -21,11 +21,12 @@ class Serializable(ABC): Returns: bytes: The byte array after serialization """ - if self.serializer is None: + if self._serializer is None: raise ValueError( - "Serializer is not set. Please set the serializer before serialization." + "Serializer is not set. Please set the serializer before " + "serialization." ) - return self.serializer.serialize(self) + return self._serializer.serialize(self) def set_serializer(self, serializer: "Serializer") -> None: """Set the serializer for current serializable object. @@ -33,7 +34,7 @@ class Serializable(ABC): Args: serializer (Serializer): The serializer to set """ - self.serializer = serializer + self._serializer = serializer class Serializer(ABC): diff --git a/dbgpt/core/interface/storage.py b/dbgpt/core/interface/storage.py index 07eefcbab..2a61746ec 100644 --- a/dbgpt/core/interface/storage.py +++ b/dbgpt/core/interface/storage.py @@ -426,7 +426,7 @@ class InMemoryStorage(StorageInterface[T, T]): """ if not data: raise StorageError("Data cannot be None") - if not data.serializer: + if not data._serializer: data.set_serializer(self.serializer) if data.identifier.str_identifier in self._data: @@ -439,7 +439,7 @@ class InMemoryStorage(StorageInterface[T, T]): """Update the data to the storage.""" if not data: raise StorageError("Data cannot be None") - if not data.serializer: + if not data._serializer: data.set_serializer(self.serializer) self._data[data.identifier.str_identifier] = data.serialize() diff --git a/dbgpt/core/schema/api.py b/dbgpt/core/schema/api.py index 1904d841b..ca8b4c4c4 100644 --- a/dbgpt/core/schema/api.py +++ b/dbgpt/core/schema/api.py @@ -2,9 +2,10 @@ import time import uuid -from typing import Any, Generic, List, Literal, Optional, TypeVar +from enum import IntEnum +from typing import Any, Dict, Generic, List, Literal, Optional, TypeVar, Union -from dbgpt._private.pydantic import BaseModel, Field +from dbgpt._private.pydantic import BaseModel, Field, model_to_dict T = TypeVar("T") @@ -41,6 +42,28 @@ class Result(BaseModel, Generic[T]): """ return Result(success=False, err_code=err_code, err_msg=msg, data=None) + def to_dict(self, **kwargs) -> Dict[str, Any]: + """Convert to dict.""" + return model_to_dict(self, **kwargs) + + +class APIChatCompletionRequest(BaseModel): + """Chat completion request entity.""" + + model: str = Field(..., description="Model name") + messages: Union[str, List[Dict[str, str]]] = Field(..., description="Messages") + temperature: Optional[float] = Field(0.7, description="Temperature") + top_p: Optional[float] = Field(1.0, description="Top p") + top_k: Optional[int] = Field(-1, description="Top k") + n: Optional[int] = Field(1, description="Number of completions") + max_tokens: Optional[int] = Field(None, description="Max tokens") + stop: Optional[Union[str, List[str]]] = Field(None, description="Stop") + stream: Optional[bool] = Field(False, description="Stream") + user: Optional[str] = Field(None, description="User") + repetition_penalty: Optional[float] = Field(1.0, description="Repetition penalty") + frequency_penalty: Optional[float] = Field(0.0, description="Frequency penalty") + presence_penalty: Optional[float] = Field(0.0, description="Presence penalty") + class DeltaMessage(BaseModel): """Delta message entity for chat completion response.""" @@ -122,3 +145,97 @@ class ErrorResponse(BaseModel): object: str = Field("error", description="Object type") message: str = Field(..., description="Error message") code: int = Field(..., description="Error code") + + +class EmbeddingsRequest(BaseModel): + """Embeddings request entity.""" + + model: Optional[str] = Field(None, description="Model name") + engine: Optional[str] = Field(None, description="Engine name") + input: Union[str, List[Any]] = Field(..., description="Input data") + user: Optional[str] = Field(None, description="User name") + encoding_format: Optional[str] = Field(None, description="Encoding format") + + +class EmbeddingsResponse(BaseModel): + """Embeddings response entity.""" + + object: str = Field("list", description="Object type") + data: List[Dict[str, Any]] = Field(..., description="Data list") + model: str = Field(..., description="Model name") + usage: UsageInfo = Field(..., description="Usage info") + + +class ModelPermission(BaseModel): + """Model permission entity.""" + + id: str = Field( + default_factory=lambda: f"modelperm-{str(uuid.uuid1())}", + description="Permission ID", + ) + object: str = Field("model_permission", description="Object type") + created: int = Field( + default_factory=lambda: int(time.time()), description="Created time" + ) + allow_create_engine: bool = Field(False, description="Allow create engine") + allow_sampling: bool = Field(True, description="Allow sampling") + allow_logprobs: bool = Field(True, description="Allow logprobs") + allow_search_indices: bool = Field(True, description="Allow search indices") + allow_view: bool = Field(True, description="Allow view") + allow_fine_tuning: bool = Field(False, description="Allow fine tuning") + organization: str = Field("*", description="Organization") + group: Optional[str] = Field(None, description="Group") + is_blocking: bool = Field(False, description="Is blocking") + + +class ModelCard(BaseModel): + """Model card entity.""" + + id: str = Field(..., description="Model ID") + object: str = Field("model", description="Object type") + created: int = Field( + default_factory=lambda: int(time.time()), description="Created time" + ) + owned_by: str = Field("DB-GPT", description="Owned by") + root: Optional[str] = Field(None, description="Root") + parent: Optional[str] = Field(None, description="Parent") + permission: List[ModelPermission] = Field( + default_factory=list, description="Permission" + ) + + +class ModelList(BaseModel): + """Model list entity.""" + + object: str = Field("list", description="Object type") + data: List[ModelCard] = Field(default_factory=list, description="Model list data") + + +class ErrorCode(IntEnum): + """Error code enumeration. + + https://platform.openai.com/docs/guides/error-codes/api-errors. + + Adapted from fastchat.constants. + """ + + VALIDATION_TYPE_ERROR = 40001 + + INVALID_AUTH_KEY = 40101 + INCORRECT_AUTH_KEY = 40102 + NO_PERMISSION = 40103 + + INVALID_MODEL = 40301 + PARAM_OUT_OF_RANGE = 40302 + CONTEXT_OVERFLOW = 40303 + + RATE_LIMIT = 42901 + QUOTA_EXCEEDED = 42902 + ENGINE_OVERLOADED = 42903 + + INTERNAL_ERROR = 50001 + CUDA_OUT_OF_MEMORY = 50002 + GRADIO_REQUEST_ERROR = 50003 + GRADIO_STREAM_UNKNOWN_ERROR = 50004 + CONTROLLER_NO_WORKER = 50005 + CONTROLLER_WORKER_TIMEOUT = 50006 diff --git a/dbgpt/model/cluster/apiserver/api.py b/dbgpt/model/cluster/apiserver/api.py index c3cdbc5b2..ab58d75fb 100644 --- a/dbgpt/model/cluster/apiserver/api.py +++ b/dbgpt/model/cluster/apiserver/api.py @@ -9,14 +9,18 @@ import logging from typing import Any, Dict, Generator, List, Optional import shortuuid -from fastapi import APIRouter, Depends, FastAPI, HTTPException +from fastapi import APIRouter, Depends, HTTPException from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer -from fastchat.constants import ErrorCode -from fastchat.protocol.api_protocol import APIChatCompletionRequest, ErrorResponse -from fastchat.protocol.openai_api_protocol import ( + +from dbgpt._private.pydantic import BaseModel, model_to_dict, model_to_json +from dbgpt.component import BaseComponent, ComponentType, SystemApp +from dbgpt.core import ModelOutput +from dbgpt.core.interface.message import ModelMessage +from dbgpt.core.schema.api import ( + APIChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, @@ -25,20 +29,18 @@ from fastchat.protocol.openai_api_protocol import ( DeltaMessage, EmbeddingsRequest, EmbeddingsResponse, + ErrorCode, + ErrorResponse, ModelCard, ModelList, ModelPermission, UsageInfo, ) - -from dbgpt._private.pydantic import BaseModel -from dbgpt.component import BaseComponent, ComponentType, SystemApp -from dbgpt.core import ModelOutput -from dbgpt.core.interface.message import ModelMessage from dbgpt.model.base import ModelInstance from dbgpt.model.cluster.manager_base import WorkerManager, WorkerManagerFactory from dbgpt.model.cluster.registry import ModelRegistry from dbgpt.model.parameter import ModelAPIServerParameters, WorkerType +from dbgpt.util.fastapi import create_app from dbgpt.util.parameter_utils import EnvArgumentParser from dbgpt.util.utils import setup_logging @@ -88,7 +90,7 @@ def create_error_response(code: int, message: str) -> JSONResponse: We can't use fastchat.serve.openai_api_server because it has too many dependencies. """ return JSONResponse( - ErrorResponse(message=message, code=code).dict(), status_code=400 + model_to_dict(ErrorResponse(message=message, code=code)), status_code=400 ) @@ -266,7 +268,8 @@ class APIServer(BaseComponent): chunk = ChatCompletionStreamResponse( id=id, choices=[choice_data], model=model_name ) - yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + json_data = model_to_json(chunk, exclude_unset=True, ensure_ascii=False) + yield f"data: {json_data}\n\n" previous_text = "" async for model_output in worker_manager.generate_stream(params): @@ -297,10 +300,15 @@ class APIServer(BaseComponent): if model_output.finish_reason is not None: finish_stream_events.append(chunk) continue - yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + json_data = model_to_json(chunk, exclude_unset=True, ensure_ascii=False) + yield f"data: {json_data}\n\n" + # There is not "content" field in the last delta message, so exclude_none to exclude field "content". for finish_chunk in finish_stream_events: - yield f"data: {finish_chunk.json(exclude_none=True, ensure_ascii=False)}\n\n" + json_data = model_to_json( + finish_chunk, exclude_unset=True, ensure_ascii=False + ) + yield f"data: {json_data}\n\n" yield "data: [DONE]\n\n" async def chat_completion_generate( @@ -335,8 +343,8 @@ class APIServer(BaseComponent): ) ) if model_output.usage: - task_usage = UsageInfo.parse_obj(model_output.usage) - for usage_key, usage_value in task_usage.dict().items(): + task_usage = UsageInfo.model_validate(model_output.usage) + for usage_key, usage_value in model_to_dict(task_usage).items(): setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) return ChatCompletionResponse(model=model_name, choices=choices, usage=usage) @@ -442,8 +450,9 @@ async def create_embeddings( } for i, emb in enumerate(embeddings) ] - return EmbeddingsResponse(data=data, model=request.model, usage=UsageInfo()).dict( - exclude_none=True + return model_to_dict( + EmbeddingsResponse(data=data, model=request.model, usage=UsageInfo()), + exclude_none=True, ) @@ -492,7 +501,7 @@ def initialize_apiserver( embedded_mod = True if not app: embedded_mod = False - app = FastAPI() + app = create_app() if not system_app: system_app = SystemApp(app) diff --git a/dbgpt/model/cluster/apiserver/tests/test_api.py b/dbgpt/model/cluster/apiserver/tests/test_api.py index 00a374780..d9001ed44 100644 --- a/dbgpt/model/cluster/apiserver/tests/test_api.py +++ b/dbgpt/model/cluster/apiserver/tests/test_api.py @@ -22,9 +22,10 @@ from dbgpt.model.cluster.apiserver.api import ( ) from dbgpt.model.cluster.tests.conftest import _new_cluster from dbgpt.model.cluster.worker.manager import _DefaultWorkerManagerFactory +from dbgpt.util.fastapi import create_app from dbgpt.util.openai_utils import chat_completion, chat_completion_stream -app = FastAPI() +app = create_app() app.add_middleware( CORSMiddleware, allow_origins=["*"], @@ -69,7 +70,7 @@ async def client(request, system_app: SystemApp): async def test_get_all_models(client: AsyncClient): res = await client.get("/api/v1/models") res.status_code == 200 - model_lists = ModelList.parse_obj(res.json()) + model_lists = ModelList.model_validate(res.json()) print(f"model list json: {res.json()}") assert model_lists.object == "list" assert len(model_lists.data) == 2 diff --git a/dbgpt/model/cluster/controller/controller.py b/dbgpt/model/cluster/controller/controller.py index 1591c120b..64311cd96 100644 --- a/dbgpt/model/cluster/controller/controller.py +++ b/dbgpt/model/cluster/controller/controller.py @@ -2,7 +2,7 @@ import logging from abc import ABC, abstractmethod from typing import List -from fastapi import APIRouter, FastAPI +from fastapi import APIRouter from dbgpt.component import BaseComponent, ComponentType, SystemApp from dbgpt.model.base import ModelInstance @@ -10,6 +10,7 @@ from dbgpt.model.cluster.registry import EmbeddedModelRegistry, ModelRegistry from dbgpt.model.parameter import ModelControllerParameters from dbgpt.util.api_utils import _api_remote as api_remote from dbgpt.util.api_utils import _sync_api_remote as sync_api_remote +from dbgpt.util.fastapi import create_app from dbgpt.util.parameter_utils import EnvArgumentParser from dbgpt.util.utils import setup_http_service_logging, setup_logging @@ -152,7 +153,7 @@ def initialize_controller( import uvicorn setup_http_service_logging() - app = FastAPI() + app = create_app() app.include_router(router, prefix="/api", tags=["Model"]) uvicorn.run(app, host=host, port=port, log_level="info") diff --git a/dbgpt/model/cluster/worker/manager.py b/dbgpt/model/cluster/worker/manager.py index 020388771..b57698af9 100644 --- a/dbgpt/model/cluster/worker/manager.py +++ b/dbgpt/model/cluster/worker/manager.py @@ -11,7 +11,7 @@ from concurrent.futures import ThreadPoolExecutor from dataclasses import asdict from typing import Awaitable, Callable, Iterator -from fastapi import APIRouter, FastAPI +from fastapi import APIRouter from fastapi.responses import StreamingResponse from dbgpt.component import SystemApp @@ -28,6 +28,7 @@ from dbgpt.model.cluster.registry import ModelRegistry from dbgpt.model.cluster.worker_base import ModelWorker from dbgpt.model.parameter import ModelWorkerParameters, WorkerType from dbgpt.model.utils.llm_utils import list_supported_models +from dbgpt.util.fastapi import create_app, register_event_handler from dbgpt.util.parameter_utils import ( EnvArgumentParser, ParameterDescription, @@ -829,7 +830,7 @@ def _setup_fastapi( worker_params: ModelWorkerParameters, app=None, ignore_exception: bool = False ): if not app: - app = FastAPI() + app = create_app() setup_http_service_logging() if worker_params.standalone: @@ -850,7 +851,6 @@ def _setup_fastapi( initialize_controller(app=app) app.include_router(controller_router, prefix="/api") - @app.on_event("startup") async def startup_event(): async def start_worker_manager(): try: @@ -865,10 +865,11 @@ def _setup_fastapi( # the fastapi app (registered to the controller) asyncio.create_task(start_worker_manager()) - @app.on_event("shutdown") - async def startup_event(): + async def shutdown_event(): await worker_manager.stop(ignore_exception=ignore_exception) + register_event_handler(app, "startup", startup_event) + register_event_handler(app, "shutdown", shutdown_event) return app diff --git a/dbgpt/rag/embedding/embeddings.py b/dbgpt/rag/embedding/embeddings.py index 0e957cf49..59d08b9fd 100644 --- a/dbgpt/rag/embedding/embeddings.py +++ b/dbgpt/rag/embedding/embeddings.py @@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional import aiohttp import requests -from dbgpt._private.pydantic import BaseModel, Extra, Field +from dbgpt._private.pydantic import EXTRA_FORBID, BaseModel, ConfigDict, Field from dbgpt.core import Embeddings from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource from dbgpt.util.i18n_utils import _ @@ -64,10 +64,12 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings): ) """ + model_config = ConfigDict(extra=EXTRA_FORBID, protected_namespaces=()) + client: Any #: :meta private: model_name: str = DEFAULT_MODEL_NAME """Model name to use.""" - cache_folder: Optional[str] = None + cache_folder: Optional[str] = Field(None, description="Path of the cache folder.") """Path to store models. Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable.""" model_kwargs: Dict[str, Any] = Field(default_factory=dict) @@ -79,7 +81,6 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings): def __init__(self, **kwargs: Any): """Initialize the sentence_transformer.""" - super().__init__(**kwargs) try: import sentence_transformers @@ -89,14 +90,12 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings): "Please install it with `pip install sentence-transformers`." ) from exc - self.client = sentence_transformers.SentenceTransformer( - self.model_name, cache_folder=self.cache_folder, **self.model_kwargs + kwargs["client"] = sentence_transformers.SentenceTransformer( + kwargs.get("model_name"), + cache_folder=kwargs.get("cache_folder"), + **kwargs.get("model_kwargs"), ) - - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid + super().__init__(**kwargs) def embed_documents(self, texts: List[str]) -> List[List[float]]: """Compute doc embeddings using a HuggingFace transformer model. @@ -184,6 +183,8 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings): ) """ + model_config = ConfigDict(extra=EXTRA_FORBID, protected_namespaces=()) + client: Any #: :meta private: model_name: str = DEFAULT_INSTRUCT_MODEL """Model name to use.""" @@ -201,20 +202,18 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings): def __init__(self, **kwargs: Any): """Initialize the sentence_transformer.""" - super().__init__(**kwargs) try: from InstructorEmbedding import INSTRUCTOR - self.client = INSTRUCTOR( - self.model_name, cache_folder=self.cache_folder, **self.model_kwargs + kwargs["client"] = INSTRUCTOR( + kwargs.get("model_name"), + cache_folder=kwargs.get("cache_folder"), + **kwargs.get("model_kwargs"), ) except ImportError as e: raise ImportError("Dependencies for InstructorEmbedding not found.") from e - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid + super().__init__(**kwargs) def embed_documents(self, texts: List[str]) -> List[List[float]]: """Compute doc embeddings using a HuggingFace instruct model. @@ -267,6 +266,8 @@ class HuggingFaceBgeEmbeddings(BaseModel, Embeddings): ) """ + model_config = ConfigDict(extra=EXTRA_FORBID, protected_namespaces=()) + client: Any #: :meta private: model_name: str = DEFAULT_BGE_MODEL """Model name to use.""" @@ -282,7 +283,6 @@ class HuggingFaceBgeEmbeddings(BaseModel, Embeddings): def __init__(self, **kwargs: Any): """Initialize the sentence_transformer.""" - super().__init__(**kwargs) try: import sentence_transformers @@ -292,17 +292,16 @@ class HuggingFaceBgeEmbeddings(BaseModel, Embeddings): "Please install it with `pip install sentence_transformers`." ) from exc - self.client = sentence_transformers.SentenceTransformer( - self.model_name, cache_folder=self.cache_folder, **self.model_kwargs + kwargs["client"] = sentence_transformers.SentenceTransformer( + kwargs.get("model_name"), + cache_folder=kwargs.get("cache_folder"), + **kwargs.get("model_kwargs"), ) + + super().__init__(**kwargs) if "-zh" in self.model_name: self.query_instruction = DEFAULT_QUERY_BGE_INSTRUCTION_ZH - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid - def embed_documents(self, texts: List[str]) -> List[List[float]]: """Compute doc embeddings using a HuggingFace transformer model. @@ -360,6 +359,8 @@ class HuggingFaceInferenceAPIEmbeddings(BaseModel, Embeddings): Requires a HuggingFace Inference API key and a model name. """ + model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=()) + api_key: str """Your API key for the HuggingFace Inference API.""" model_name: str = "sentence-transformers/all-MiniLM-L6-v2" @@ -475,6 +476,8 @@ class JinaEmbeddings(BaseModel, Embeddings): "jina-embeddings-v2-base-en". """ + model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=()) + api_url: Any #: :meta private: session: Any #: :meta private: api_key: str @@ -485,7 +488,6 @@ class JinaEmbeddings(BaseModel, Embeddings): def __init__(self, **kwargs): """Create a new JinaEmbeddings instance.""" - super().__init__(**kwargs) try: import requests except ImportError: @@ -493,11 +495,23 @@ class JinaEmbeddings(BaseModel, Embeddings): "The requests python package is not installed. Please install it with " "`pip install requests`" ) - self.api_url = "https://api.jina.ai/v1/embeddings" - self.session = requests.Session() - self.session.headers.update( - {"Authorization": f"Bearer {self.api_key}", "Accept-Encoding": "identity"} - ) + if "api_url" not in kwargs: + kwargs["api_url"] = "https://api.jina.ai/v1/embeddings" + if "session" not in kwargs: # noqa: SIM401 + session = requests.Session() + else: + session = kwargs["session"] + api_key = kwargs.get("api_key") + if api_key: + session.headers.update( + { + "Authorization": f"Bearer {api_key}", + "Accept-Encoding": "identity", + } + ) + kwargs["session"] = session + + super().__init__(**kwargs) def embed_documents(self, texts: List[str]) -> List[List[float]]: """Get the embeddings for a list of texts. @@ -627,6 +641,8 @@ class OpenAPIEmbeddings(BaseModel, Embeddings): openai_embeddings.embed_documents(texts) """ + model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=()) + api_url: str = Field( default="http://localhost:8100/api/v1/embeddings", description="The URL of the embeddings API.", @@ -643,14 +659,8 @@ class OpenAPIEmbeddings(BaseModel, Embeddings): session: Optional[requests.Session] = None - class Config: - """Configuration for this pydantic object.""" - - arbitrary_types_allowed = True - def __init__(self, **kwargs): """Initialize the OpenAPIEmbeddings.""" - super().__init__(**kwargs) try: import requests except ImportError: @@ -658,8 +668,15 @@ class OpenAPIEmbeddings(BaseModel, Embeddings): "The requests python package is not installed. " "Please install it with `pip install requests`" ) - self.session = requests.Session() - self.session.headers.update({"Authorization": f"Bearer {self.api_key}"}) + if "session" not in kwargs: # noqa: SIM401 + session = requests.Session() + else: + session = kwargs["session"] + api_key = kwargs.get("api_key") + if api_key: + session.headers.update({"Authorization": f"Bearer {api_key}"}) + kwargs["session"] = session + super().__init__(**kwargs) def embed_documents(self, texts: List[str]) -> List[List[float]]: """Get the embeddings for a list of texts. diff --git a/dbgpt/rag/text_splitter/token_splitter.py b/dbgpt/rag/text_splitter/token_splitter.py index 5ae06967c..1236ce196 100644 --- a/dbgpt/rag/text_splitter/token_splitter.py +++ b/dbgpt/rag/text_splitter/token_splitter.py @@ -58,7 +58,6 @@ class TokenTextSplitter(BaseModel): tokenizer = tokenizer or globals_helper.tokenizer all_seps = [separator] + (backup_separators or []) - self._split_fns = [split_by_sep(sep) for sep in all_seps] + [split_by_char()] super().__init__( chunk_size=chunk_size, @@ -68,6 +67,7 @@ class TokenTextSplitter(BaseModel): # callback_manager=callback_manager, tokenizer=tokenizer, ) + self._split_fns = [split_by_sep(sep) for sep in all_seps] + [split_by_char()] @classmethod def class_name(cls) -> str: diff --git a/dbgpt/serve/agent/agents/controller.py b/dbgpt/serve/agent/agents/controller.py index 1776f54bd..0dbad248e 100644 --- a/dbgpt/serve/agent/agents/controller.py +++ b/dbgpt/serve/agent/agents/controller.py @@ -86,7 +86,9 @@ class MultiAgents(BaseComponent, ABC): def gpts_create(self, entity: GptsInstanceEntity): self.gpts_intance.add(entity) - def get_dbgpts(self, user_code: str = None, sys_code: str = None): + def get_dbgpts( + self, user_code: str = None, sys_code: str = None + ) -> Optional[List[GptsApp]]: apps = self.gpts_app.app_list( GptsAppQuery(user_code=user_code, sys_code=sys_code) ).app_list @@ -338,7 +340,7 @@ class MultiAgents(BaseComponent, ABC): multi_agents = MultiAgents() -@router.post("/v1/dbgpts/agents/list", response_model=Result[str]) +@router.post("/v1/dbgpts/agents/list", response_model=Result[Dict[str, str]]) async def agents_list(): logger.info("agents_list!") try: @@ -348,7 +350,7 @@ async def agents_list(): return Result.failed(code="E30001", msg=str(e)) -@router.get("/v1/dbgpts/list", response_model=Result[str]) +@router.get("/v1/dbgpts/list", response_model=Result[List[GptsApp]]) async def get_dbgpts(user_code: str = None, sys_code: str = None): logger.info(f"get_dbgpts:{user_code},{sys_code}") try: @@ -359,14 +361,14 @@ async def get_dbgpts(user_code: str = None, sys_code: str = None): @router.post("/v1/dbgpts/chat/completions", response_model=Result[str]) -async def dgpts_completions( +async def dbgpts_completions( gpts_name: str, user_query: str, conv_id: str = None, user_code: str = None, sys_code: str = None, ): - logger.info(f"dgpts_completions:{gpts_name},{user_query},{conv_id}") + logger.info(f"dbgpts_completions:{gpts_name},{user_query},{conv_id}") if conv_id is None: conv_id = str(uuid.uuid1()) @@ -390,12 +392,12 @@ async def dgpts_completions( @router.post("/v1/dbgpts/chat/cancel", response_model=Result[str]) -async def dgpts_chat_cancel( +async def dbgpts_chat_cancel( conv_id: str = None, user_code: str = None, sys_code: str = None ): pass @router.post("/v1/dbgpts/chat/feedback", response_model=Result[str]) -async def dgpts_chat_feedback(filter: PagenationFilter[PluginHubFilter] = Body()): +async def dbgpts_chat_feedback(filter: PagenationFilter[PluginHubFilter] = Body()): pass diff --git a/dbgpt/serve/agent/db/gpts_app.py b/dbgpt/serve/agent/db/gpts_app.py index 7396f4e20..b3f7b32be 100644 --- a/dbgpt/serve/agent/db/gpts_app.py +++ b/dbgpt/serve/agent/db/gpts_app.py @@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional, Union from sqlalchemy import Column, DateTime, Integer, String, Text, UniqueConstraint -from dbgpt._private.pydantic import BaseModel +from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_json from dbgpt.agent.plan.awel.team_awel_layout import AWELTeamContext from dbgpt.agent.resource.resource_api import AgentResource from dbgpt.serve.agent.team.base import TeamMode @@ -17,6 +17,8 @@ logger = logging.getLogger(__name__) class GptsAppDetail(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + app_code: Optional[str] = None app_name: Optional[str] = None agent_name: Optional[str] = None @@ -28,11 +30,6 @@ class GptsAppDetail(BaseModel): created_at: datetime = datetime.now() updated_at: datetime = datetime.now() - class Config: - """Configuration for this pydantic object.""" - - arbitrary_types_allowed = True - def to_dict(self): return {k: self._serialize(v) for k, v in self.__dict__.items()} @@ -86,6 +83,8 @@ class GptsAppDetail(BaseModel): class GptsApp(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + app_code: Optional[str] = None app_name: Optional[str] = None app_describe: Optional[str] = None @@ -100,11 +99,6 @@ class GptsApp(BaseModel): updated_at: datetime = datetime.now() details: List[GptsAppDetail] = [] - class Config: - """Configuration for this pydantic object.""" - - arbitrary_types_allowed = True - def to_dict(self): return {k: self._serialize(v) for k, v in self.__dict__.items()} @@ -146,7 +140,9 @@ class GptsAppResponse(BaseModel): total_count: Optional[int] = 0 total_page: Optional[int] = 0 current_page: Optional[int] = 0 - app_list: Optional[List[GptsApp]] = [] + app_list: Optional[List[GptsApp]] = Field( + default_factory=list, description="app list" + ) class GptsAppCollection(BaseModel): @@ -207,7 +203,8 @@ class GptsAppEntity(Model): team_context = Column( Text, nullable=True, - comment="The execution logic and team member content that teams with different working modes rely on", + comment="The execution logic and team member content that teams with different " + "working modes rely on", ) user_code = Column(String(255), nullable=True, comment="user code") @@ -565,7 +562,7 @@ def _parse_team_context(team_context: Optional[Union[str, AWELTeamContext]] = No parse team_context to str """ if isinstance(team_context, AWELTeamContext): - return team_context.json() + return model_to_json(team_context) return team_context diff --git a/dbgpt/serve/agent/db/my_plugin_db.py b/dbgpt/serve/agent/db/my_plugin_db.py index bd202dab0..eb995722b 100644 --- a/dbgpt/serve/agent/db/my_plugin_db.py +++ b/dbgpt/serve/agent/db/my_plugin_db.py @@ -1,9 +1,12 @@ from datetime import datetime +from typing import List from sqlalchemy import Column, DateTime, Integer, String, UniqueConstraint, func from dbgpt.storage.metadata import BaseDao, Model +from ..model import MyPluginVO + class MyPluginEntity(Model): __tablename__ = "my_plugin" @@ -27,6 +30,28 @@ class MyPluginEntity(Model): ) UniqueConstraint("user_code", "name", name="uk_name") + @classmethod + def to_vo(cls, entities: List["MyPluginEntity"]) -> List[MyPluginVO]: + results = [] + for entity in entities: + results.append( + MyPluginVO( + id=entity.id, + tenant=entity.tenant, + user_code=entity.user_code, + user_name=entity.user_name, + sys_code=entity.sys_code, + name=entity.name, + file_name=entity.file_name, + type=entity.type, + version=entity.version, + use_count=entity.use_count, + succ_count=entity.succ_count, + gmt_created=entity.gmt_created.strftime("%Y-%m-%d %H:%M:%S"), + ) + ) + return results + class MyPluginDao(BaseDao): def add(self, engity: MyPluginEntity): diff --git a/dbgpt/serve/agent/db/plugin_hub_db.py b/dbgpt/serve/agent/db/plugin_hub_db.py index 8230682ab..0e619d273 100644 --- a/dbgpt/serve/agent/db/plugin_hub_db.py +++ b/dbgpt/serve/agent/db/plugin_hub_db.py @@ -1,4 +1,5 @@ from datetime import datetime +from typing import List import pytz from sqlalchemy import ( @@ -14,6 +15,8 @@ from sqlalchemy import ( from dbgpt.storage.metadata import BaseDao, Model +from ..model import PluginHubVO + # TODO We should consider that the production environment does not have permission to execute the DDL char_set_sql = DDL("ALTER TABLE plugin_hub CONVERT TO CHARACTER SET utf8mb4") @@ -40,6 +43,27 @@ class PluginHubEntity(Model): UniqueConstraint("name", name="uk_name") Index("idx_q_type", "type") + @classmethod + def to_vo(cls, entities: List["PluginHubEntity"]) -> List[PluginHubVO]: + results = [] + for entity in entities: + vo = PluginHubVO( + id=entity.id, + name=entity.name, + description=entity.description, + author=entity.author, + email=entity.email, + type=entity.type, + version=entity.version, + storage_channel=entity.storage_channel, + storage_url=entity.storage_url, + download_param=entity.download_param, + installed=entity.installed, + gmt_created=entity.gmt_created.strftime("%Y-%m-%d %H:%M:%S"), + ) + results.append(vo) + return results + class PluginHubDao(BaseDao): def add(self, engity: PluginHubEntity): diff --git a/dbgpt/serve/agent/hub/controller.py b/dbgpt/serve/agent/hub/controller.py index cffc42cb5..31ab92025 100644 --- a/dbgpt/serve/agent/hub/controller.py +++ b/dbgpt/serve/agent/hub/controller.py @@ -18,6 +18,9 @@ from dbgpt.serve.agent.model import ( PluginHubParam, ) +from ..db import MyPluginEntity +from ..model import MyPluginVO, PluginHubVO + router = APIRouter() logger = logging.getLogger(__name__) @@ -73,7 +76,7 @@ async def plugin_hub_update(update_param: PluginHubParam = Body()): return Result.failed(code="E0020", msg=f"Agent Hub Update Error! {e}") -@router.post("/v1/agent/query", response_model=Result[str]) +@router.post("/v1/agent/query", response_model=Result[dict]) async def get_agent_list(filter: PagenationFilter[PluginHubFilter] = Body()): logger.info(f"get_agent_list:{filter.__dict__}") filter_enetity: PluginHubEntity = PluginHubEntity() @@ -85,24 +88,21 @@ async def get_agent_list(filter: PagenationFilter[PluginHubFilter] = Body()): datas, total_pages, total_count = plugin_hub.hub_dao.list( filter_enetity, filter.page_index, filter.page_size ) - result: PagenationResult[PluginHubEntity] = PagenationResult[PluginHubEntity]() + result: PagenationResult[PluginHubVO] = PagenationResult[PluginHubVO]() result.page_index = filter.page_index result.page_size = filter.page_size result.total_page = total_pages result.total_row_count = total_count - result.datas = datas + result.datas = PluginHubEntity.to_vo(datas) # print(json.dumps(result.to_dic())) return Result.succ(result.to_dic()) -@router.post("/v1/agent/my", response_model=Result[str]) +@router.post("/v1/agent/my", response_model=Result[List[MyPluginVO]]) async def my_agents(user: str = None): logger.info(f"my_agents:{user}") agents = plugin_hub.get_my_plugin(user) - agent_dicts = [] - for agent in agents: - agent_dicts.append(agent.__dict__) - + agent_dicts = MyPluginEntity.to_vo(agents) return Result.succ(agent_dicts) diff --git a/dbgpt/serve/agent/model.py b/dbgpt/serve/agent/model.py index 6984f835f..f5d71d9b5 100644 --- a/dbgpt/serve/agent/model.py +++ b/dbgpt/serve/agent/model.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from typing import Generic, List, Optional, TypeVar -from dbgpt._private.pydantic import BaseModel, Field +from dbgpt._private.pydantic import BaseModel, ConfigDict, Field T = TypeVar("T") @@ -13,6 +13,7 @@ class PagenationFilter(BaseModel, Generic[T]): class PagenationResult(BaseModel, Generic[T]): + model_config = ConfigDict(arbitrary_types_allowed=True) page_index: int = 1 page_size: int = 20 total_page: int = 0 @@ -34,14 +35,14 @@ class PagenationResult(BaseModel, Generic[T]): @dataclass class PluginHubFilter(BaseModel): - name: str - description: str - author: str - email: str - type: str - version: str - storage_channel: str - storage_url: str + name: Optional[str] = None + description: Optional[str] = None + author: Optional[str] = None + email: Optional[str] = None + type: Optional[str] = None + version: Optional[str] = None + storage_channel: Optional[str] = None + storage_url: Optional[str] = None @dataclass @@ -67,3 +68,33 @@ class PluginHubParam(BaseModel): authorization: Optional[str] = Field( None, description="github download authorization", nullable=True ) + + +class PluginHubVO(BaseModel): + id: int = Field(..., description="Plugin id") + name: str = Field(..., description="Plugin name") + description: str = Field(..., description="Plugin description") + author: Optional[str] = Field(None, description="Plugin author") + email: Optional[str] = Field(None, description="Plugin email") + type: Optional[str] = Field(None, description="Plugin type") + version: Optional[str] = Field(None, description="Plugin version") + storage_channel: Optional[str] = Field(None, description="Plugin storage channel") + storage_url: Optional[str] = Field(None, description="Plugin storage url") + download_param: Optional[str] = Field(None, description="Plugin download param") + installed: Optional[int] = Field(None, description="Plugin installed") + gmt_created: Optional[str] = Field(None, description="Plugin upload time") + + +class MyPluginVO(BaseModel): + id: int = Field(..., description="My Plugin") + tenant: Optional[str] = Field(None, description="My Plugin tenant") + user_code: Optional[str] = Field(None, description="My Plugin user code") + user_name: Optional[str] = Field(None, description="My Plugin user name") + sys_code: Optional[str] = Field(None, description="My Plugin sys code") + name: str = Field(..., description="My Plugin name") + file_name: str = Field(..., description="My Plugin file name") + type: Optional[str] = Field(None, description="My Plugin type") + version: Optional[str] = Field(None, description="My Plugin version") + use_count: Optional[int] = Field(None, description="My Plugin use count") + succ_count: Optional[int] = Field(None, description="My Plugin succ count") + gmt_created: Optional[str] = Field(None, description="My Plugin install time") diff --git a/dbgpt/serve/conversation/api/schemas.py b/dbgpt/serve/conversation/api/schemas.py index 2558d8ad9..02bdc0f81 100644 --- a/dbgpt/serve/conversation/api/schemas.py +++ b/dbgpt/serve/conversation/api/schemas.py @@ -1,7 +1,7 @@ # Define your Pydantic schemas here -from typing import Any, Optional +from typing import Any, Dict, Optional -from dbgpt._private.pydantic import BaseModel, Field +from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_dict from ..config import SERVE_APP_NAME_HUMP @@ -9,8 +9,7 @@ from ..config import SERVE_APP_NAME_HUMP class ServeRequest(BaseModel): """Conversation request model""" - class Config: - title = f"ServeRequest for {SERVE_APP_NAME_HUMP}" + model_config = ConfigDict(title=f"ServeRequest for {SERVE_APP_NAME_HUMP}") # Just for query chat_mode: str = Field( @@ -42,12 +41,17 @@ class ServeRequest(BaseModel): ], ) + def to_dict(self, **kwargs) -> Dict[str, Any]: + """Convert the model to a dictionary""" + return model_to_dict(self, **kwargs) + class ServerResponse(BaseModel): """Conversation response model""" - class Config: - title = f"ServerResponse for {SERVE_APP_NAME_HUMP}" + model_config = ConfigDict( + title=f"ServerResponse for {SERVE_APP_NAME_HUMP}", protected_namespaces=() + ) conv_uid: str = Field( ..., @@ -99,8 +103,13 @@ class ServerResponse(BaseModel): ], ) + def to_dict(self, **kwargs) -> Dict[str, Any]: + """Convert the model to a dictionary""" + return model_to_dict(self, **kwargs) + class MessageVo(BaseModel): + model_config = ConfigDict(protected_namespaces=()) role: str = Field( ..., description="The role that sends out the current message.", @@ -139,3 +148,7 @@ class MessageVo(BaseModel): "vicuna-13b-v1.5", ], ) + + def to_dict(self, **kwargs) -> Dict[str, Any]: + """Convert the model to a dictionary""" + return model_to_dict(self, **kwargs) diff --git a/dbgpt/serve/conversation/models/models.py b/dbgpt/serve/conversation/models/models.py index 860f456f4..aa1574987 100644 --- a/dbgpt/serve/conversation/models/models.py +++ b/dbgpt/serve/conversation/models/models.py @@ -2,7 +2,6 @@ You can define your own models and DAOs here """ import json -from datetime import datetime from typing import Any, Dict, List, Optional, Union from dbgpt.core import MessageStorageItem @@ -31,7 +30,9 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]): Returns: T: The entity """ - request_dict = request.dict() if isinstance(request, ServeRequest) else request + request_dict = ( + request.to_dict() if isinstance(request, ServeRequest) else request + ) entity = ServeEntity(**request_dict) # TODO implement your own logic here, transfer the request_dict to an entity return entity diff --git a/dbgpt/serve/core/schemas.py b/dbgpt/serve/core/schemas.py index 698ac6805..e67fde198 100644 --- a/dbgpt/serve/core/schemas.py +++ b/dbgpt/serve/core/schemas.py @@ -30,7 +30,7 @@ async def validation_exception_handler( message += loc + ":" + error.get("msg") + ";" res = Result.failed(msg=message, err_code="E0001") logger.error(f"validation_exception_handler catch RequestValidationError: {res}") - return JSONResponse(status_code=400, content=res.dict()) + return JSONResponse(status_code=400, content=res.to_dict()) async def http_exception_handler(request: Request, exc: HTTPException): @@ -39,7 +39,7 @@ async def http_exception_handler(request: Request, exc: HTTPException): err_code=str(exc.status_code), ) logger.error(f"http_exception_handler catch HTTPException: {res}") - return JSONResponse(status_code=exc.status_code, content=res.dict()) + return JSONResponse(status_code=exc.status_code, content=res.to_dict()) async def common_exception_handler(request: Request, exc: Exception) -> JSONResponse: @@ -57,7 +57,7 @@ async def common_exception_handler(request: Request, exc: Exception) -> JSONResp err_code="E0003", ) logger.error(f"common_exception_handler catch Exception: {res}") - return JSONResponse(status_code=400, content=res.dict()) + return JSONResponse(status_code=400, content=res.to_dict()) def add_exception_handler(app: "FastAPI"): diff --git a/dbgpt/serve/core/tests/conftest.py b/dbgpt/serve/core/tests/conftest.py index 5091d2ed2..efc02cf47 100644 --- a/dbgpt/serve/core/tests/conftest.py +++ b/dbgpt/serve/core/tests/conftest.py @@ -2,12 +2,12 @@ from typing import Dict import pytest import pytest_asyncio -from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from httpx import AsyncClient from dbgpt.component import SystemApp from dbgpt.util import AppConfig +from dbgpt.util.fastapi import create_app def create_system_app(param: Dict) -> SystemApp: @@ -17,7 +17,7 @@ def create_system_app(param: Dict) -> SystemApp: elif not isinstance(app_config, AppConfig): raise RuntimeError("app_config must be AppConfig or dict") - test_app = FastAPI() + test_app = create_app() test_app.add_middleware( CORSMiddleware, allow_origins=["*"], diff --git a/dbgpt/serve/datasource/api/schemas.py b/dbgpt/serve/datasource/api/schemas.py index f3d28554a..c4aa4ec99 100644 --- a/dbgpt/serve/datasource/api/schemas.py +++ b/dbgpt/serve/datasource/api/schemas.py @@ -1,6 +1,6 @@ from typing import Optional -from dbgpt._private.pydantic import BaseModel, Field +from dbgpt._private.pydantic import BaseModel, ConfigDict, Field from ..config import SERVE_APP_NAME_HUMP @@ -23,6 +23,8 @@ class DatasourceServeRequest(BaseModel): class DatasourceServeResponse(BaseModel): """Flow response model""" + model_config = ConfigDict(title=f"ServeResponse for {SERVE_APP_NAME_HUMP}") + """name: knowledge space name""" """vector_type: vector type""" @@ -35,7 +37,3 @@ class DatasourceServeResponse(BaseModel): db_user: str = Field("", description="Database user.") db_pwd: str = Field("", description="Database password.") comment: str = Field("", description="Comment for the database.") - - # TODO define your own fields here - class Config: - title = f"ServerResponse for {SERVE_APP_NAME_HUMP}" diff --git a/dbgpt/serve/datasource/service/service.py b/dbgpt/serve/datasource/service/service.py index c409c0bb3..44c9c8274 100644 --- a/dbgpt/serve/datasource/service/service.py +++ b/dbgpt/serve/datasource/service/service.py @@ -4,6 +4,7 @@ from typing import List, Optional from fastapi import HTTPException from dbgpt._private.config import Config +from dbgpt._private.pydantic import model_to_dict from dbgpt.component import ComponentType, SystemApp from dbgpt.core.awel.dag.dag_manager import DAGManager from dbgpt.datasource.db_conn_info import DBConfig @@ -129,9 +130,9 @@ class Service( status_code=400, detail=f"there is no datasource name:{request.db_name} exists", ) - db_config = DBConfig(**request.dict()) + db_config = DBConfig(**model_to_dict(request)) if CFG.local_db_manager.edit_db(db_config): - return DatasourceServeResponse(**db_config.dict()) + return DatasourceServeResponse(**model_to_dict(db_config)) else: raise HTTPException( status_code=400, diff --git a/dbgpt/serve/flow/api/endpoints.py b/dbgpt/serve/flow/api/endpoints.py index 746254f59..6cb5ef879 100644 --- a/dbgpt/serve/flow/api/endpoints.py +++ b/dbgpt/serve/flow/api/endpoints.py @@ -138,7 +138,9 @@ async def update( @router.delete("/flows/{uid}") -async def delete(uid: str, service: Service = Depends(get_service)) -> Result[None]: +async def delete( + uid: str, service: Service = Depends(get_service) +) -> Result[ServerResponse]: """Delete a Flow entity Args: diff --git a/dbgpt/serve/flow/api/schemas.py b/dbgpt/serve/flow/api/schemas.py index a25d4fe3d..6fb8c1924 100644 --- a/dbgpt/serve/flow/api/schemas.py +++ b/dbgpt/serve/flow/api/schemas.py @@ -1,3 +1,5 @@ +from dbgpt._private.pydantic import ConfigDict + # Define your Pydantic schemas here from dbgpt.core.awel.flow.flow_factory import FlowPanel @@ -10,5 +12,5 @@ class ServerResponse(FlowPanel): """Flow response model""" # TODO define your own fields here - class Config: - title = f"ServerResponse for {SERVE_APP_NAME_HUMP}" + + model_config = ConfigDict(title=f"ServerResponse for {SERVE_APP_NAME_HUMP}") diff --git a/dbgpt/serve/flow/models/models.py b/dbgpt/serve/flow/models/models.py index 285015846..025bb95e0 100644 --- a/dbgpt/serve/flow/models/models.py +++ b/dbgpt/serve/flow/models/models.py @@ -7,6 +7,7 @@ from typing import Any, Dict, Union from sqlalchemy import Column, DateTime, Integer, String, Text, UniqueConstraint +from dbgpt._private.pydantic import model_to_dict from dbgpt.core.awel.flow.flow_factory import State from dbgpt.storage.metadata import BaseDao, Model from dbgpt.storage.metadata._base_dao import QUERY_SPEC @@ -82,7 +83,9 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]): Returns: T: The entity """ - request_dict = request.dict() if isinstance(request, ServeRequest) else request + request_dict = ( + model_to_dict(request) if isinstance(request, ServeRequest) else request + ) flow_data = json.dumps(request_dict.get("flow_data"), ensure_ascii=False) state = request_dict.get("state", State.INITIALIZING.value) error_message = request_dict.get("error_message") @@ -184,7 +187,7 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]): entry.flow_category = update_request.flow_category if update_request.flow_data: entry.flow_data = json.dumps( - update_request.flow_data.dict(), ensure_ascii=False + model_to_dict(update_request.flow_data), ensure_ascii=False ) if update_request.description: entry.description = update_request.description diff --git a/dbgpt/serve/flow/service/service.py b/dbgpt/serve/flow/service/service.py index ed2547f2f..8f7a3dad7 100644 --- a/dbgpt/serve/flow/service/service.py +++ b/dbgpt/serve/flow/service/service.py @@ -1,6 +1,5 @@ import json import logging -import time import traceback from typing import Any, AsyncIterator, List, Optional, cast @@ -236,6 +235,8 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]): flow.uid = exist_inst.uid self.update_flow(flow, check_editable=False, save_failed_flow=True) except Exception as e: + import traceback + message = traceback.format_exc() logger.warning( f"Load DAG {flow.name} from dbgpts error: {str(e)}, detail: {message}" @@ -296,7 +297,10 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]): ) return self.create_and_save_dag(update_obj) except Exception as e: - if old_data: + if old_data and old_data.state == State.RUNNING: + # Old flow is running, try to recover it + # first set the state to DEPLOYED + old_data.state = State.DEPLOYED self.create_and_save_dag(old_data) raise e @@ -387,8 +391,8 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]): request.incremental = False async for output in self.safe_chat_stream_flow(flow_uid, request): text = output.text - # if text: - # text = text.replace("\n", "\\n") + if text: + text = text.replace("\n", "\\n") if output.error_code != 0: yield f"data:[SERVER_ERROR]{text}\n\n" break @@ -407,7 +411,9 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]): chunk = ChatCompletionStreamResponse( id=conv_uid, choices=[choice_data], model=request.model ) - yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + json_data = model_to_json(chunk, exclude_unset=True, ensure_ascii=False) + + yield f"data: {json_data}\n\n" request.incremental = True async for output in self.safe_chat_stream_flow(flow_uid, request): diff --git a/dbgpt/serve/prompt/api/schemas.py b/dbgpt/serve/prompt/api/schemas.py index 6d8d67924..370ce8364 100644 --- a/dbgpt/serve/prompt/api/schemas.py +++ b/dbgpt/serve/prompt/api/schemas.py @@ -1,16 +1,15 @@ # Define your Pydantic schemas here from typing import Optional -from dbgpt._private.pydantic import BaseModel, Field +from dbgpt._private.pydantic import BaseModel, ConfigDict, Field from ..config import SERVE_APP_NAME_HUMP class ServeRequest(BaseModel): - """Prompt request model""" + """Prompt request model.""" - class Config: - title = f"ServeRequest for {SERVE_APP_NAME_HUMP}" + model_config = ConfigDict(title=f"ServeRequest for {SERVE_APP_NAME_HUMP}") chat_scene: Optional[str] = Field( None, @@ -69,8 +68,7 @@ class ServeRequest(BaseModel): class ServerResponse(ServeRequest): """Prompt response model""" - class Config: - title = f"ServerResponse for {SERVE_APP_NAME_HUMP}" + model_config = ConfigDict(title=f"ServerResponse for {SERVE_APP_NAME_HUMP}") id: Optional[int] = Field( None, diff --git a/dbgpt/serve/prompt/models/models.py b/dbgpt/serve/prompt/models/models.py index 61346db47..3baecb722 100644 --- a/dbgpt/serve/prompt/models/models.py +++ b/dbgpt/serve/prompt/models/models.py @@ -6,6 +6,7 @@ from typing import Any, Dict, Union from sqlalchemy import Column, DateTime, Index, Integer, String, Text, UniqueConstraint +from dbgpt._private.pydantic import model_to_dict from dbgpt.storage.metadata import BaseDao, Model, db from ..api.schemas import ServeRequest, ServerResponse @@ -78,7 +79,9 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]): Returns: T: The entity """ - request_dict = request.dict() if isinstance(request, ServeRequest) else request + request_dict = ( + model_to_dict(request) if isinstance(request, ServeRequest) else request + ) entity = ServeEntity(**request_dict) return entity diff --git a/dbgpt/serve/prompt/tests/test_endpoints.py b/dbgpt/serve/prompt/tests/test_endpoints.py index 701b43523..9bde556b8 100644 --- a/dbgpt/serve/prompt/tests/test_endpoints.py +++ b/dbgpt/serve/prompt/tests/test_endpoints.py @@ -8,7 +8,7 @@ from dbgpt.storage.metadata import db from dbgpt.util import PaginationResult from ..api.endpoints import init_endpoints, router -from ..api.schemas import ServeRequest, ServerResponse +from ..api.schemas import ServerResponse from ..config import SERVE_CONFIG_KEY_PREFIX diff --git a/dbgpt/serve/rag/api/schemas.py b/dbgpt/serve/rag/api/schemas.py index c648080a3..e94819a5f 100644 --- a/dbgpt/serve/rag/api/schemas.py +++ b/dbgpt/serve/rag/api/schemas.py @@ -2,7 +2,7 @@ from typing import Optional from fastapi import File, UploadFile -from dbgpt._private.pydantic import BaseModel, Field +from dbgpt._private.pydantic import BaseModel, ConfigDict, Field from dbgpt.rag.chunk_manager import ChunkParameters from ..config import SERVE_APP_NAME_HUMP @@ -29,71 +29,102 @@ class SpaceServeRequest(BaseModel): class DocumentServeRequest(BaseModel): - id: int = Field(None, description="The doc id") - doc_name: str = Field(None, description="doc name") + id: Optional[int] = Field(None, description="The doc id") + doc_name: Optional[str] = Field(None, description="doc name") """doc_type: document type""" - doc_type: str = Field(None, description="The doc type") + doc_type: Optional[str] = Field(None, description="The doc type") """content: description""" - content: str = Field(None, description="content") + content: Optional[str] = Field(None, description="content") """doc file""" doc_file: UploadFile = File(...) """doc_source: doc source""" - doc_source: str = None + doc_source: Optional[str] = Field(None, description="doc source") """doc_source: doc source""" - space_id: str = None + space_id: Optional[str] = Field(None, description="space id") class DocumentServeResponse(BaseModel): - id: int = Field(None, description="The doc id") - doc_name: str = Field(None, description="doc type") + id: Optional[int] = Field(None, description="The doc id") + doc_name: Optional[str] = Field(None, description="doc type") """vector_type: vector type""" - doc_type: str = Field(None, description="The doc content") + doc_type: Optional[str] = Field(None, description="The doc content") """desc: description""" - content: str = Field(None, description="content") + content: Optional[str] = Field(None, description="content") """vector ids""" - vector_ids: str = Field(None, description="vector ids") + vector_ids: Optional[str] = Field(None, description="vector ids") """doc_source: doc source""" - doc_source: str = None + doc_source: Optional[str] = Field(None, description="doc source") """doc_source: doc source""" - space: str = None + space: Optional[str] = Field(None, description="space name") class KnowledgeSyncRequest(BaseModel): """Sync request""" """doc_ids: doc ids""" - doc_id: int = Field(None, description="The doc id") + doc_id: Optional[int] = Field(None, description="The doc id") """space id""" - space_id: str = Field(None, description="space id") + space_id: Optional[str] = Field(None, description="space id") """model_name: model name""" model_name: Optional[str] = Field(None, description="model name") """chunk_parameters: chunk parameters """ - chunk_parameters: ChunkParameters = Field(None, description="chunk parameters") + chunk_parameters: Optional[ChunkParameters] = Field( + None, description="chunk parameters" + ) class SpaceServeResponse(BaseModel): """Flow response model""" + model_config = ConfigDict(title=f"ServeResponse for {SERVE_APP_NAME_HUMP}") + """name: knowledge space name""" """vector_type: vector type""" - id: int = Field(None, description="The space id") - name: str = Field(None, description="The space name") + id: Optional[int] = Field(None, description="The space id") + name: Optional[str] = Field(None, description="The space name") """vector_type: vector type""" - vector_type: str = Field(None, description="The vector type") + vector_type: Optional[str] = Field(None, description="The vector type") """desc: description""" - desc: str = Field(None, description="The description") + desc: Optional[str] = Field(None, description="The description") """context: argument context""" - context: str = Field(None, description="The context") + context: Optional[str] = Field(None, description="The context") """owner: owner""" - owner: str = Field(None, description="The owner") + owner: Optional[str] = Field(None, description="The owner") """sys code""" - sys_code: str = Field(None, description="The sys code") + sys_code: Optional[str] = Field(None, description="The sys code") # TODO define your own fields here - class Config: - title = f"ServerResponse for {SERVE_APP_NAME_HUMP}" + + +class DocumentChunkVO(BaseModel): + id: int = Field(..., description="document chunk id") + document_id: int = Field(..., description="document id") + doc_name: str = Field(..., description="document name") + doc_type: str = Field(..., description="document type") + content: str = Field(..., description="document content") + meta_info: str = Field(..., description="document meta info") + gmt_created: str = Field(..., description="document create time") + gmt_modified: str = Field(..., description="document modify time") + + +class DocumentVO(BaseModel): + """Document Entity.""" + + id: int = Field(..., description="document id") + doc_name: str = Field(..., description="document name") + doc_type: str = Field(..., description="document type") + space: str = Field(..., description="document space name") + chunk_size: int = Field(..., description="document chunk size") + status: str = Field(..., description="document status") + last_sync: str = Field(..., description="document last sync time") + content: str = Field(..., description="document content") + result: Optional[str] = Field(None, description="document result") + vector_ids: Optional[str] = Field(None, description="document vector ids") + summary: Optional[str] = Field(None, description="document summary") + gmt_created: str = Field(..., description="document create time") + gmt_modified: str = Field(..., description="document modify time") diff --git a/dbgpt/serve/rag/models/models.py b/dbgpt/serve/rag/models/models.py index f8fd2986c..2b47df8ba 100644 --- a/dbgpt/serve/rag/models/models.py +++ b/dbgpt/serve/rag/models/models.py @@ -3,6 +3,7 @@ from typing import Any, Dict, List, Union from sqlalchemy import Column, DateTime, Integer, String, Text +from dbgpt._private.pydantic import model_to_dict from dbgpt.serve.rag.api.schemas import SpaceServeRequest, SpaceServeResponse from dbgpt.storage.metadata import BaseDao, Model @@ -89,7 +90,7 @@ class KnowledgeSpaceDao(BaseDao): entry = query.first() if entry is None: raise Exception("Invalid request") - for key, value in update_request.dict().items(): # type: ignore + for key, value in model_to_dict(update_request).items(): # type: ignore if value is not None: setattr(entry, key, value) session.merge(entry) @@ -117,7 +118,9 @@ class KnowledgeSpaceDao(BaseDao): T: The entity """ request_dict = ( - request.dict() if isinstance(request, SpaceServeRequest) else request + model_to_dict(request) + if isinstance(request, SpaceServeRequest) + else request ) entity = KnowledgeSpaceEntity(**request_dict) return entity diff --git a/dbgpt/serve/rag/service/service.py b/dbgpt/serve/rag/service/service.py index 6a167c87d..ff2a17133 100644 --- a/dbgpt/serve/rag/service/service.py +++ b/dbgpt/serve/rag/service/service.py @@ -40,6 +40,7 @@ from dbgpt.util.tracer import root_tracer, trace from ..api.schemas import ( DocumentServeRequest, DocumentServeResponse, + DocumentVO, KnowledgeSyncRequest, SpaceServeRequest, SpaceServeResponse, @@ -419,7 +420,7 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes def _sync_knowledge_document( self, space_id, - doc: KnowledgeDocumentEntity, + doc_vo: DocumentVO, chunk_parameters: ChunkParameters, ) -> List[Chunk]: """sync knowledge document chunk into vector store""" @@ -431,6 +432,8 @@ class Service(BaseService[KnowledgeSpaceEntity, SpaceServeRequest, SpaceServeRes ) from dbgpt.storage.vector_store.base import VectorStoreConfig + doc = KnowledgeDocumentEntity.from_document_vo(doc_vo) + space = self.get({"id": space_id}) config = VectorStoreConfig( name=space.name, diff --git a/dbgpt/serve/utils/_template_files/default_serve_template/api/schemas.py b/dbgpt/serve/utils/_template_files/default_serve_template/api/schemas.py index d84369475..92fd8c68f 100644 --- a/dbgpt/serve/utils/_template_files/default_serve_template/api/schemas.py +++ b/dbgpt/serve/utils/_template_files/default_serve_template/api/schemas.py @@ -1,5 +1,7 @@ # Define your Pydantic schemas here -from dbgpt._private.pydantic import BaseModel, Field +from typing import Any, Dict + +from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_dict from ..config import SERVE_APP_NAME_HUMP @@ -9,13 +11,20 @@ class ServeRequest(BaseModel): # TODO define your own fields here - class Config: - title = f"ServeRequest for {SERVE_APP_NAME_HUMP}" + model_config = ConfigDict(title=f"ServeRequest for {SERVE_APP_NAME_HUMP}") + + def to_dict(self, **kwargs) -> Dict[str, Any]: + """Convert the model to a dictionary""" + return model_to_dict(self, **kwargs) class ServerResponse(BaseModel): """{__template_app_name__hump__} response model""" # TODO define your own fields here - class Config: - title = f"ServerResponse for {SERVE_APP_NAME_HUMP}" + + model_config = ConfigDict(title=f"ServerResponse for {SERVE_APP_NAME_HUMP}") + + def to_dict(self, **kwargs) -> Dict[str, Any]: + """Convert the model to a dictionary""" + return model_to_dict(self, **kwargs) diff --git a/dbgpt/serve/utils/_template_files/default_serve_template/models/models.py b/dbgpt/serve/utils/_template_files/default_serve_template/models/models.py index 039ed289d..3d6732736 100644 --- a/dbgpt/serve/utils/_template_files/default_serve_template/models/models.py +++ b/dbgpt/serve/utils/_template_files/default_serve_template/models/models.py @@ -41,7 +41,9 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]): Returns: T: The entity """ - request_dict = request.dict() if isinstance(request, ServeRequest) else request + request_dict = ( + request.to_dict() if isinstance(request, ServeRequest) else request + ) entity = ServeEntity(**request_dict) # TODO implement your own logic here, transfer the request_dict to an entity return entity diff --git a/dbgpt/storage/metadata/_base_dao.py b/dbgpt/storage/metadata/_base_dao.py index 1b84b34b3..44abdfece 100644 --- a/dbgpt/storage/metadata/_base_dao.py +++ b/dbgpt/storage/metadata/_base_dao.py @@ -3,6 +3,7 @@ from typing import Any, Dict, Generic, Iterator, List, Optional, TypeVar, Union from sqlalchemy.orm.session import Session +from dbgpt._private.pydantic import model_to_dict from dbgpt.util.pagination_utils import PaginationResult from .db_manager import BaseQuery, DatabaseManager, db @@ -165,7 +166,7 @@ class BaseDao(Generic[T, REQ, RES]): entry = query.first() if entry is None: raise Exception("Invalid request") - for key, value in update_request.dict().items(): # type: ignore + for key, value in model_to_dict(update_request).items(): # type: ignore if value is not None: setattr(entry, key, value) session.merge(entry) @@ -272,7 +273,9 @@ class BaseDao(Generic[T, REQ, RES]): model_cls = type(self.from_request(query_request)) query = session.query(model_cls) query_dict = ( - query_request if isinstance(query_request, dict) else query_request.dict() + query_request + if isinstance(query_request, dict) + else model_to_dict(query_request) ) for key, value in query_dict.items(): if value is not None: diff --git a/dbgpt/storage/metadata/tests/test_base_dao.py b/dbgpt/storage/metadata/tests/test_base_dao.py index cfa86ba3c..59cf98356 100644 --- a/dbgpt/storage/metadata/tests/test_base_dao.py +++ b/dbgpt/storage/metadata/tests/test_base_dao.py @@ -4,7 +4,7 @@ import pytest from sqlalchemy import Column, Integer, String from dbgpt._private.pydantic import BaseModel as PydanticBaseModel -from dbgpt._private.pydantic import Field +from dbgpt._private.pydantic import Field, model_to_dict from dbgpt.storage.metadata.db_manager import ( BaseModel, DatabaseManager, @@ -61,7 +61,7 @@ def user_dao(db, User): class UserDao(BaseDao[User, UserRequest, UserResponse]): def from_request(self, request: Union[UserRequest, Dict[str, Any]]) -> User: if isinstance(request, UserRequest): - return User(**request.dict()) + return User(**model_to_dict(request)) else: return User(**request) @@ -71,7 +71,7 @@ def user_dao(db, User): ) def from_response(self, response: UserResponse) -> User: - return User(**response.dict()) + return User(**model_to_dict(response)) def to_response(self, entity: User): return UserResponse(id=entity.id, name=entity.name, age=entity.age) diff --git a/dbgpt/storage/vector_store/base.py b/dbgpt/storage/vector_store/base.py index e14d0956a..4682f086a 100644 --- a/dbgpt/storage/vector_store/base.py +++ b/dbgpt/storage/vector_store/base.py @@ -4,9 +4,9 @@ import math import time from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor -from typing import Any, List, Optional +from typing import Any, Dict, List, Optional -from dbgpt._private.pydantic import BaseModel, Field +from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_dict from dbgpt.core import Chunk, Embeddings from dbgpt.core.awel.flow import Parameter from dbgpt.storage.vector_store.filters import MetadataFilters @@ -87,10 +87,7 @@ _COMMON_PARAMETERS = [ class VectorStoreConfig(BaseModel): """Vector store config.""" - class Config: - """Config for BaseModel.""" - - arbitrary_types_allowed = True + model_config = ConfigDict(arbitrary_types_allowed=True) name: str = Field( default="dbgpt_collection", @@ -122,6 +119,10 @@ class VectorStoreConfig(BaseModel): "bigger than 1, please make sure your vector store is thread-safe.", ) + def to_dict(self, **kwargs) -> Dict[str, Any]: + """Convert to dict.""" + return model_to_dict(self, **kwargs) + class VectorStoreBase(ABC): """Vector store base class.""" diff --git a/dbgpt/storage/vector_store/chroma_store.py b/dbgpt/storage/vector_store/chroma_store.py index 8615363d0..16e282dd8 100644 --- a/dbgpt/storage/vector_store/chroma_store.py +++ b/dbgpt/storage/vector_store/chroma_store.py @@ -6,7 +6,7 @@ from typing import List, Optional from chromadb import PersistentClient from chromadb.config import Settings -from dbgpt._private.pydantic import Field +from dbgpt._private.pydantic import ConfigDict, Field from dbgpt.configs.model_config import PILOT_PATH from dbgpt.core import Chunk from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource @@ -38,16 +38,13 @@ logger = logging.getLogger(__name__) class ChromaVectorConfig(VectorStoreConfig): """Chroma vector store config.""" - class Config: - """Config for BaseModel.""" + model_config = ConfigDict(arbitrary_types_allowed=True) - arbitrary_types_allowed = True - - persist_path: str = Field( + persist_path: Optional[str] = Field( default=os.getenv("CHROMA_PERSIST_PATH", None), description="the persist path of vector store.", ) - collection_metadata: dict = Field( + collection_metadata: Optional[dict] = Field( default=None, description="the index metadata of vector store, if not set, will use the " "default metadata.", @@ -61,7 +58,7 @@ class ChromaStore(VectorStoreBase): """Create a ChromaStore instance.""" from langchain.vectorstores import Chroma - chroma_vector_config = vector_store_config.dict(exclude_none=True) + chroma_vector_config = vector_store_config.to_dict(exclude_none=True) chroma_path = chroma_vector_config.get( "persist_path", os.path.join(PILOT_PATH, "data") ) diff --git a/dbgpt/storage/vector_store/filters.py b/dbgpt/storage/vector_store/filters.py index c67da68ac..10398fccc 100644 --- a/dbgpt/storage/vector_store/filters.py +++ b/dbgpt/storage/vector_store/filters.py @@ -2,7 +2,7 @@ from enum import Enum from typing import List, Union -from pydantic import BaseModel, Field +from dbgpt._private.pydantic import BaseModel, Field class FilterOperator(str, Enum): diff --git a/dbgpt/storage/vector_store/milvus_store.py b/dbgpt/storage/vector_store/milvus_store.py index ebd3737ac..0d5616e03 100644 --- a/dbgpt/storage/vector_store/milvus_store.py +++ b/dbgpt/storage/vector_store/milvus_store.py @@ -6,7 +6,7 @@ import logging import os from typing import Any, Iterable, List, Optional -from dbgpt._private.pydantic import Field +from dbgpt._private.pydantic import ConfigDict, Field from dbgpt.core import Chunk, Embeddings from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource from dbgpt.storage.vector_store.base import ( @@ -96,10 +96,7 @@ logger = logging.getLogger(__name__) class MilvusVectorConfig(VectorStoreConfig): """Milvus vector store config.""" - class Config: - """Config for BaseModel.""" - - arbitrary_types_allowed = True + model_config = ConfigDict(arbitrary_types_allowed=True) uri: str = Field( default="localhost", @@ -155,7 +152,7 @@ class MilvusStore(VectorStoreBase): from pymilvus import connections connect_kwargs = {} - milvus_vector_config = vector_store_config.dict() + milvus_vector_config = vector_store_config.to_dict() self.uri = milvus_vector_config.get("uri") or os.getenv( "MILVUS_URL", "localhost" ) diff --git a/dbgpt/storage/vector_store/pgvector_store.py b/dbgpt/storage/vector_store/pgvector_store.py index b7dbfbd79..02ab4e1ec 100644 --- a/dbgpt/storage/vector_store/pgvector_store.py +++ b/dbgpt/storage/vector_store/pgvector_store.py @@ -2,7 +2,7 @@ import logging from typing import List, Optional -from dbgpt._private.pydantic import Field +from dbgpt._private.pydantic import ConfigDict, Field from dbgpt.core import Chunk from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource from dbgpt.storage.vector_store.base import ( @@ -39,10 +39,7 @@ logger = logging.getLogger(__name__) class PGVectorConfig(VectorStoreConfig): """PG vector store config.""" - class Config: - """Config for BaseModel.""" - - arbitrary_types_allowed = True + model_config = ConfigDict(arbitrary_types_allowed=True) connection_string: str = Field( default=None, diff --git a/dbgpt/storage/vector_store/weaviate_store.py b/dbgpt/storage/vector_store/weaviate_store.py index 8c25064dc..8daf67827 100644 --- a/dbgpt/storage/vector_store/weaviate_store.py +++ b/dbgpt/storage/vector_store/weaviate_store.py @@ -3,7 +3,7 @@ import logging import os from typing import List, Optional -from dbgpt._private.pydantic import Field +from dbgpt._private.pydantic import ConfigDict, Field from dbgpt.core import Chunk from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource from dbgpt.util.i18n_utils import _ @@ -44,10 +44,7 @@ logger = logging.getLogger(__name__) class WeaviateVectorConfig(VectorStoreConfig): """Weaviate vector store config.""" - class Config: - """Config for BaseModel.""" - - arbitrary_types_allowed = True + model_config = ConfigDict(arbitrary_types_allowed=True) weaviate_url: str = Field( default=os.getenv("WEAVIATE_URL", None), diff --git a/dbgpt/util/benchmarks/llm/llm_benchmarks.py b/dbgpt/util/benchmarks/llm/llm_benchmarks.py index 862a79807..231e33d95 100644 --- a/dbgpt/util/benchmarks/llm/llm_benchmarks.py +++ b/dbgpt/util/benchmarks/llm/llm_benchmarks.py @@ -218,9 +218,9 @@ async def run_model(wh: WorkerManager) -> None: def startup_llm_env(): - from fastapi import FastAPI + from dbgpt.util.fastapi import create_app - app = FastAPI() + app = create_app() initialize_worker_manager_in_client( app=app, model_name=model_name, diff --git a/dbgpt/util/dbgpts/loader.py b/dbgpt/util/dbgpts/loader.py index 355b6c930..d24e63fc3 100644 --- a/dbgpt/util/dbgpts/loader.py +++ b/dbgpt/util/dbgpts/loader.py @@ -8,7 +8,7 @@ from typing import Any, Dict, List, Optional, Type, TypeVar, cast import schedule import tomlkit -from dbgpt._private.pydantic import BaseModel, Field, root_validator +from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_validator from dbgpt.component import BaseComponent, SystemApp from dbgpt.core.awel.flow.flow_factory import FlowPanel from dbgpt.util.dbgpts.base import ( @@ -22,8 +22,7 @@ T = TypeVar("T") class BasePackage(BaseModel): - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict(arbitrary_types_allowed=True) name: str = Field(..., description="The name of the package") label: str = Field(..., description="The label of the package") @@ -48,9 +47,12 @@ class BasePackage(BaseModel): def build_from(cls, values: Dict[str, Any], ext_dict: Dict[str, Any]): return cls(**values) - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Pre-fill the definition_file""" + if not isinstance(values, dict): + return values import importlib.resources as pkg_resources name = values.get("name") @@ -97,7 +99,7 @@ class BasePackage(BaseModel): class FlowPackage(BasePackage): - package_type = "flow" + package_type: str = "flow" @classmethod def build_from( @@ -126,7 +128,7 @@ class FlowJsonPackage(FlowPackage): class OperatorPackage(BasePackage): - package_type = "operator" + package_type: str = "operator" operators: List[type] = Field( default_factory=list, description="The operators of the package" @@ -141,7 +143,7 @@ class OperatorPackage(BasePackage): class AgentPackage(BasePackage): - package_type = "agent" + package_type: str = "agent" agents: List[type] = Field( default_factory=list, description="The agents of the package" @@ -240,7 +242,7 @@ def _load_package_from_path(path: str): class DBGPTsLoader(BaseComponent): """The loader of the dbgpts packages""" - name = "dbgpt_dbgpts_loader" + name: str = "dbgpt_dbgpts_loader" def __init__( self, diff --git a/dbgpt/util/fastapi.py b/dbgpt/util/fastapi.py index 8493e32b2..0ed4cc97a 100644 --- a/dbgpt/util/fastapi.py +++ b/dbgpt/util/fastapi.py @@ -1,9 +1,14 @@ """FastAPI utilities.""" -from typing import Any, Callable, Dict +import importlib.metadata as metadata +from contextlib import asynccontextmanager +from typing import Any, Callable, Dict, List, Optional +from fastapi import FastAPI from fastapi.routing import APIRouter +_FASTAPI_VERSION = metadata.version("fastapi") + class PriorityAPIRouter(APIRouter): """A router with priority. @@ -41,3 +46,85 @@ class PriorityAPIRouter(APIRouter): return self.route_priority.get(route.path, 0) self.routes.sort(key=my_func, reverse=True) + + +_HAS_STARTUP = False +_HAS_SHUTDOWN = False +_GLOBAL_STARTUP_HANDLERS: List[Callable] = [] + +_GLOBAL_SHUTDOWN_HANDLERS: List[Callable] = [] + + +def register_event_handler(app: FastAPI, event: str, handler: Callable): + """Register an event handler. + + Args: + app (FastAPI): The FastAPI app. + event (str): The event type. + handler (Callable): The handler function. + + """ + if _FASTAPI_VERSION >= "0.109.1": + # https://fastapi.tiangolo.com/release-notes/#01091 + if event == "startup": + if _HAS_STARTUP: + raise ValueError( + "FastAPI app already started. Cannot add startup handler." + ) + _GLOBAL_STARTUP_HANDLERS.append(handler) + elif event == "shutdown": + if _HAS_SHUTDOWN: + raise ValueError( + "FastAPI app already shutdown. Cannot add shutdown handler." + ) + _GLOBAL_SHUTDOWN_HANDLERS.append(handler) + else: + raise ValueError(f"Invalid event: {event}") + else: + if event == "startup": + app.add_event_handler("startup", handler) + elif event == "shutdown": + app.add_event_handler("shutdown", handler) + else: + raise ValueError(f"Invalid event: {event}") + + +@asynccontextmanager +async def lifespan(app: FastAPI): + # Trigger the startup event. + global _HAS_STARTUP, _HAS_SHUTDOWN + for handler in _GLOBAL_STARTUP_HANDLERS: + await handler() + _HAS_STARTUP = True + yield + # Trigger the shutdown event. + for handler in _GLOBAL_SHUTDOWN_HANDLERS: + await handler() + _HAS_SHUTDOWN = True + + +def create_app(*args, **kwargs) -> FastAPI: + """Create a FastAPI app.""" + _sp = None + if _FASTAPI_VERSION >= "0.109.1": + if "lifespan" not in kwargs: + kwargs["lifespan"] = lifespan + _sp = kwargs["lifespan"] + app = FastAPI(*args, **kwargs) + if _sp: + app.__dbgpt_custom_lifespan = _sp + return app + + +def replace_router(app: FastAPI, router: Optional[APIRouter] = None): + """Replace the router of the FastAPI app.""" + if not router: + router = PriorityAPIRouter() + if _FASTAPI_VERSION >= "0.109.1": + if hasattr(app, "__dbgpt_custom_lifespan"): + _sp = getattr(app, "__dbgpt_custom_lifespan") + router.lifespan_context = _sp + + app.router = router + app.setup() + return app diff --git a/dbgpt/util/network/_cli.py b/dbgpt/util/network/_cli.py index d880bb070..f4a0565c2 100644 --- a/dbgpt/util/network/_cli.py +++ b/dbgpt/util/network/_cli.py @@ -197,10 +197,12 @@ def _start_http_forward( ): import httpx import uvicorn - from fastapi import BackgroundTasks, FastAPI, Request, Response + from fastapi import BackgroundTasks, Request, Response from fastapi.responses import StreamingResponse - app = FastAPI() + from dbgpt.util.fastapi import create_app + + app = create_app() @app.middleware("http") async def forward_http_request(request: Request, call_next): diff --git a/dbgpt/util/pagination_utils.py b/dbgpt/util/pagination_utils.py index 4b9288cb8..f8c20ccd9 100644 --- a/dbgpt/util/pagination_utils.py +++ b/dbgpt/util/pagination_utils.py @@ -1,6 +1,6 @@ from typing import Generic, List, TypeVar -from dbgpt._private.pydantic import BaseModel, Field +from dbgpt._private.pydantic import BaseModel, ConfigDict, Field T = TypeVar("T") @@ -8,6 +8,8 @@ T = TypeVar("T") class PaginationResult(BaseModel, Generic[T]): """Pagination result""" + model_config = ConfigDict(arbitrary_types_allowed=True) + items: List[T] = Field(..., description="The items in the current page") total_count: int = Field(..., description="Total number of items") total_pages: int = Field(..., description="total number of pages") diff --git a/dbgpt/util/prompt_util.py b/dbgpt/util/prompt_util.py index bad972a7a..54aa918e3 100644 --- a/dbgpt/util/prompt_util.py +++ b/dbgpt/util/prompt_util.py @@ -13,7 +13,7 @@ from string import Formatter from typing import Callable, List, Optional, Sequence, Set from dbgpt._private.llm_metadata import LLMMetadata -from dbgpt._private.pydantic import BaseModel, Field, PrivateAttr +from dbgpt._private.pydantic import BaseModel, Field, PrivateAttr, model_validator from dbgpt.core.interface.prompt import get_template_vars from dbgpt.rag.text_splitter.token_splitter import TokenTextSplitter from dbgpt.util.global_helper import globals_helper @@ -62,12 +62,14 @@ class PromptHelper(BaseModel): default=DEFAULT_CHUNK_OVERLAP_RATIO, description="The percentage token amount that each chunk should overlap.", ) - chunk_size_limit: Optional[int] = Field(description="The maximum size of a chunk.") + chunk_size_limit: Optional[int] = Field( + None, description="The maximum size of a chunk." + ) separator: str = Field( default=" ", description="The separator when chunking tokens." ) - _tokenizer: Callable[[str], List] = PrivateAttr() + _tokenizer: Optional[Callable[[str], List]] = PrivateAttr() def __init__( self, @@ -77,21 +79,22 @@ class PromptHelper(BaseModel): chunk_size_limit: Optional[int] = None, tokenizer: Optional[Callable[[str], List]] = None, separator: str = " ", + **kwargs, ) -> None: """Init params.""" if chunk_overlap_ratio > 1.0 or chunk_overlap_ratio < 0.0: raise ValueError("chunk_overlap_ratio must be a float between 0. and 1.") - # TODO: make configurable - self._tokenizer = tokenizer or globals_helper.tokenizer - super().__init__( context_window=context_window, num_output=num_output, chunk_overlap_ratio=chunk_overlap_ratio, chunk_size_limit=chunk_size_limit, separator=separator, + **kwargs, ) + # TODO: make configurable + self._tokenizer = tokenizer or globals_helper.tokenizer def token_count(self, prompt_template: str) -> int: """Get token count of prompt template.""" diff --git a/docs/docs/awel/awel_tutorial/network_program/3.2_http_trigger_get.md b/docs/docs/awel/awel_tutorial/network_program/3.2_http_trigger_get.md index 073092c8a..7f02a1eb1 100644 --- a/docs/docs/awel/awel_tutorial/network_program/3.2_http_trigger_get.md +++ b/docs/docs/awel/awel_tutorial/network_program/3.2_http_trigger_get.md @@ -10,7 +10,7 @@ Before we start writing the code, we need to install the `pydantic` package in y project [awel-tutorial](/docs/awel/awel_tutorial/getting_started/1.1_hello_world#creating-a-projec) ```bash -poetry add "pydantic<2,>=1" +poetry add "pydantic>=2.6.0" ``` Then create a new file named `http_trigger_say_hello.py` in the `awel_tutorial` directory and add the following code: diff --git a/setup.py b/setup.py index 14f2481a4..0cefec7dc 100644 --- a/setup.py +++ b/setup.py @@ -142,6 +142,13 @@ class SetupSpec: self.extras: dict = {} self.install_requires: List[str] = [] + @property + def unique_extras(self) -> dict[str, list[str]]: + unique_extras = {} + for k, v in self.extras.items(): + unique_extras[k] = list(set(v)) + return unique_extras + setup_spec = SetupSpec() @@ -405,14 +412,14 @@ def core_requires(): "importlib-resources==5.12.0", "python-dotenv==1.0.0", "cachetools", - "pydantic<2,>=1", + "pydantic>=2.6.0", # For AWEL type checking "typeguard", ] # For DB-GPT python client SDK setup_spec.extras["client"] = setup_spec.extras["core"] + [ "httpx", - "fastapi==0.98.0", + "fastapi>=0.100.0", ] # Simple command line dependencies setup_spec.extras["cli"] = setup_spec.extras["client"] + [ @@ -490,8 +497,7 @@ def knowledge_requires(): """ setup_spec.extras["rag"] = setup_spec.extras["vstore"] + [ "langchain>=0.0.286", - "spacy==3.5.3", - "chromadb==0.4.10", + "spacy>=3.7", "markdown", "bs4", "python-pptx", @@ -554,9 +560,23 @@ def all_vector_store_requires(): pip install "dbgpt[vstore]" """ setup_spec.extras["vstore"] = [ - "pymilvus", + "chromadb>=0.4.22", + ] + setup_spec.extras["vstore_weaviate"] = setup_spec.extras["vstore"] + [ + # "protobuf", + # "grpcio", + # weaviate depends on grpc which version is very low, we should install it + # manually. "weaviate-client", ] + setup_spec.extras["vstore_milvus"] = setup_spec.extras["vstore"] + [ + "pymilvus", + ] + setup_spec.extras["vstore_all"] = ( + setup_spec.extras["vstore"] + + setup_spec.extras["vstore_weaviate"] + + setup_spec.extras["vstore_milvus"] + ) def all_datasource_requires(): @@ -630,7 +650,6 @@ def default_requires(): # "tokenizers==0.13.3", "tokenizers>=0.14", "accelerate>=0.20.3", - "protobuf==3.20.3", "zhipuai", "dashscope", "chardet", @@ -734,7 +753,7 @@ setuptools.setup( url="https://github.com/eosphoros-ai/DB-GPT", license="https://opensource.org/license/mit/", python_requires=">=3.10", - extras_require=setup_spec.extras, + extras_require=setup_spec.unique_extras, entry_points={ "console_scripts": [ "dbgpt=dbgpt.cli.cli_scripts:main",