mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-03 18:17:45 +00:00
feat(core): Upgrade pydantic to 2.x (#1428)
This commit is contained in:
@@ -218,9 +218,9 @@ async def run_model(wh: WorkerManager) -> None:
|
||||
|
||||
|
||||
def startup_llm_env():
|
||||
from fastapi import FastAPI
|
||||
from dbgpt.util.fastapi import create_app
|
||||
|
||||
app = FastAPI()
|
||||
app = create_app()
|
||||
initialize_worker_manager_in_client(
|
||||
app=app,
|
||||
model_name=model_name,
|
||||
|
@@ -8,7 +8,7 @@ from typing import Any, Dict, List, Optional, Type, TypeVar, cast
|
||||
import schedule
|
||||
import tomlkit
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field, root_validator
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from dbgpt.component import BaseComponent, SystemApp
|
||||
from dbgpt.core.awel.flow.flow_factory import FlowPanel
|
||||
from dbgpt.util.dbgpts.base import (
|
||||
@@ -22,8 +22,7 @@ T = TypeVar("T")
|
||||
|
||||
|
||||
class BasePackage(BaseModel):
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
name: str = Field(..., description="The name of the package")
|
||||
label: str = Field(..., description="The label of the package")
|
||||
@@ -48,9 +47,12 @@ class BasePackage(BaseModel):
|
||||
def build_from(cls, values: Dict[str, Any], ext_dict: Dict[str, Any]):
|
||||
return cls(**values)
|
||||
|
||||
@root_validator(pre=True)
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Pre-fill the definition_file"""
|
||||
if not isinstance(values, dict):
|
||||
return values
|
||||
import importlib.resources as pkg_resources
|
||||
|
||||
name = values.get("name")
|
||||
@@ -97,7 +99,7 @@ class BasePackage(BaseModel):
|
||||
|
||||
|
||||
class FlowPackage(BasePackage):
|
||||
package_type = "flow"
|
||||
package_type: str = "flow"
|
||||
|
||||
@classmethod
|
||||
def build_from(
|
||||
@@ -126,7 +128,7 @@ class FlowJsonPackage(FlowPackage):
|
||||
|
||||
|
||||
class OperatorPackage(BasePackage):
|
||||
package_type = "operator"
|
||||
package_type: str = "operator"
|
||||
|
||||
operators: List[type] = Field(
|
||||
default_factory=list, description="The operators of the package"
|
||||
@@ -141,7 +143,7 @@ class OperatorPackage(BasePackage):
|
||||
|
||||
|
||||
class AgentPackage(BasePackage):
|
||||
package_type = "agent"
|
||||
package_type: str = "agent"
|
||||
|
||||
agents: List[type] = Field(
|
||||
default_factory=list, description="The agents of the package"
|
||||
@@ -240,7 +242,7 @@ def _load_package_from_path(path: str):
|
||||
class DBGPTsLoader(BaseComponent):
|
||||
"""The loader of the dbgpts packages"""
|
||||
|
||||
name = "dbgpt_dbgpts_loader"
|
||||
name: str = "dbgpt_dbgpts_loader"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@@ -1,9 +1,14 @@
|
||||
"""FastAPI utilities."""
|
||||
|
||||
from typing import Any, Callable, Dict
|
||||
import importlib.metadata as metadata
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.routing import APIRouter
|
||||
|
||||
_FASTAPI_VERSION = metadata.version("fastapi")
|
||||
|
||||
|
||||
class PriorityAPIRouter(APIRouter):
|
||||
"""A router with priority.
|
||||
@@ -41,3 +46,85 @@ class PriorityAPIRouter(APIRouter):
|
||||
return self.route_priority.get(route.path, 0)
|
||||
|
||||
self.routes.sort(key=my_func, reverse=True)
|
||||
|
||||
|
||||
_HAS_STARTUP = False
|
||||
_HAS_SHUTDOWN = False
|
||||
_GLOBAL_STARTUP_HANDLERS: List[Callable] = []
|
||||
|
||||
_GLOBAL_SHUTDOWN_HANDLERS: List[Callable] = []
|
||||
|
||||
|
||||
def register_event_handler(app: FastAPI, event: str, handler: Callable):
|
||||
"""Register an event handler.
|
||||
|
||||
Args:
|
||||
app (FastAPI): The FastAPI app.
|
||||
event (str): The event type.
|
||||
handler (Callable): The handler function.
|
||||
|
||||
"""
|
||||
if _FASTAPI_VERSION >= "0.109.1":
|
||||
# https://fastapi.tiangolo.com/release-notes/#01091
|
||||
if event == "startup":
|
||||
if _HAS_STARTUP:
|
||||
raise ValueError(
|
||||
"FastAPI app already started. Cannot add startup handler."
|
||||
)
|
||||
_GLOBAL_STARTUP_HANDLERS.append(handler)
|
||||
elif event == "shutdown":
|
||||
if _HAS_SHUTDOWN:
|
||||
raise ValueError(
|
||||
"FastAPI app already shutdown. Cannot add shutdown handler."
|
||||
)
|
||||
_GLOBAL_SHUTDOWN_HANDLERS.append(handler)
|
||||
else:
|
||||
raise ValueError(f"Invalid event: {event}")
|
||||
else:
|
||||
if event == "startup":
|
||||
app.add_event_handler("startup", handler)
|
||||
elif event == "shutdown":
|
||||
app.add_event_handler("shutdown", handler)
|
||||
else:
|
||||
raise ValueError(f"Invalid event: {event}")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Trigger the startup event.
|
||||
global _HAS_STARTUP, _HAS_SHUTDOWN
|
||||
for handler in _GLOBAL_STARTUP_HANDLERS:
|
||||
await handler()
|
||||
_HAS_STARTUP = True
|
||||
yield
|
||||
# Trigger the shutdown event.
|
||||
for handler in _GLOBAL_SHUTDOWN_HANDLERS:
|
||||
await handler()
|
||||
_HAS_SHUTDOWN = True
|
||||
|
||||
|
||||
def create_app(*args, **kwargs) -> FastAPI:
|
||||
"""Create a FastAPI app."""
|
||||
_sp = None
|
||||
if _FASTAPI_VERSION >= "0.109.1":
|
||||
if "lifespan" not in kwargs:
|
||||
kwargs["lifespan"] = lifespan
|
||||
_sp = kwargs["lifespan"]
|
||||
app = FastAPI(*args, **kwargs)
|
||||
if _sp:
|
||||
app.__dbgpt_custom_lifespan = _sp
|
||||
return app
|
||||
|
||||
|
||||
def replace_router(app: FastAPI, router: Optional[APIRouter] = None):
|
||||
"""Replace the router of the FastAPI app."""
|
||||
if not router:
|
||||
router = PriorityAPIRouter()
|
||||
if _FASTAPI_VERSION >= "0.109.1":
|
||||
if hasattr(app, "__dbgpt_custom_lifespan"):
|
||||
_sp = getattr(app, "__dbgpt_custom_lifespan")
|
||||
router.lifespan_context = _sp
|
||||
|
||||
app.router = router
|
||||
app.setup()
|
||||
return app
|
||||
|
@@ -197,10 +197,12 @@ def _start_http_forward(
|
||||
):
|
||||
import httpx
|
||||
import uvicorn
|
||||
from fastapi import BackgroundTasks, FastAPI, Request, Response
|
||||
from fastapi import BackgroundTasks, Request, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
app = FastAPI()
|
||||
from dbgpt.util.fastapi import create_app
|
||||
|
||||
app = create_app()
|
||||
|
||||
@app.middleware("http")
|
||||
async def forward_http_request(request: Request, call_next):
|
||||
|
@@ -1,6 +1,6 @@
|
||||
from typing import Generic, List, TypeVar
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -8,6 +8,8 @@ T = TypeVar("T")
|
||||
class PaginationResult(BaseModel, Generic[T]):
|
||||
"""Pagination result"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
items: List[T] = Field(..., description="The items in the current page")
|
||||
total_count: int = Field(..., description="Total number of items")
|
||||
total_pages: int = Field(..., description="total number of pages")
|
||||
|
@@ -13,7 +13,7 @@ from string import Formatter
|
||||
from typing import Callable, List, Optional, Sequence, Set
|
||||
|
||||
from dbgpt._private.llm_metadata import LLMMetadata
|
||||
from dbgpt._private.pydantic import BaseModel, Field, PrivateAttr
|
||||
from dbgpt._private.pydantic import BaseModel, Field, PrivateAttr, model_validator
|
||||
from dbgpt.core.interface.prompt import get_template_vars
|
||||
from dbgpt.rag.text_splitter.token_splitter import TokenTextSplitter
|
||||
from dbgpt.util.global_helper import globals_helper
|
||||
@@ -62,12 +62,14 @@ class PromptHelper(BaseModel):
|
||||
default=DEFAULT_CHUNK_OVERLAP_RATIO,
|
||||
description="The percentage token amount that each chunk should overlap.",
|
||||
)
|
||||
chunk_size_limit: Optional[int] = Field(description="The maximum size of a chunk.")
|
||||
chunk_size_limit: Optional[int] = Field(
|
||||
None, description="The maximum size of a chunk."
|
||||
)
|
||||
separator: str = Field(
|
||||
default=" ", description="The separator when chunking tokens."
|
||||
)
|
||||
|
||||
_tokenizer: Callable[[str], List] = PrivateAttr()
|
||||
_tokenizer: Optional[Callable[[str], List]] = PrivateAttr()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -77,21 +79,22 @@ class PromptHelper(BaseModel):
|
||||
chunk_size_limit: Optional[int] = None,
|
||||
tokenizer: Optional[Callable[[str], List]] = None,
|
||||
separator: str = " ",
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Init params."""
|
||||
if chunk_overlap_ratio > 1.0 or chunk_overlap_ratio < 0.0:
|
||||
raise ValueError("chunk_overlap_ratio must be a float between 0. and 1.")
|
||||
|
||||
# TODO: make configurable
|
||||
self._tokenizer = tokenizer or globals_helper.tokenizer
|
||||
|
||||
super().__init__(
|
||||
context_window=context_window,
|
||||
num_output=num_output,
|
||||
chunk_overlap_ratio=chunk_overlap_ratio,
|
||||
chunk_size_limit=chunk_size_limit,
|
||||
separator=separator,
|
||||
**kwargs,
|
||||
)
|
||||
# TODO: make configurable
|
||||
self._tokenizer = tokenizer or globals_helper.tokenizer
|
||||
|
||||
def token_count(self, prompt_template: str) -> int:
|
||||
"""Get token count of prompt template."""
|
||||
|
Reference in New Issue
Block a user