feat(core): Upgrade pydantic to 2.x (#1428)

This commit is contained in:
Fangyin Cheng
2024-04-20 09:41:16 +08:00
committed by GitHub
parent baa1e3f9f6
commit 57be1ece18
103 changed files with 1146 additions and 534 deletions

View File

@@ -218,9 +218,9 @@ async def run_model(wh: WorkerManager) -> None:
def startup_llm_env():
from fastapi import FastAPI
from dbgpt.util.fastapi import create_app
app = FastAPI()
app = create_app()
initialize_worker_manager_in_client(
app=app,
model_name=model_name,

View File

@@ -8,7 +8,7 @@ from typing import Any, Dict, List, Optional, Type, TypeVar, cast
import schedule
import tomlkit
from dbgpt._private.pydantic import BaseModel, Field, root_validator
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_validator
from dbgpt.component import BaseComponent, SystemApp
from dbgpt.core.awel.flow.flow_factory import FlowPanel
from dbgpt.util.dbgpts.base import (
@@ -22,8 +22,7 @@ T = TypeVar("T")
class BasePackage(BaseModel):
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(arbitrary_types_allowed=True)
name: str = Field(..., description="The name of the package")
label: str = Field(..., description="The label of the package")
@@ -48,9 +47,12 @@ class BasePackage(BaseModel):
def build_from(cls, values: Dict[str, Any], ext_dict: Dict[str, Any]):
return cls(**values)
@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Pre-fill the definition_file"""
if not isinstance(values, dict):
return values
import importlib.resources as pkg_resources
name = values.get("name")
@@ -97,7 +99,7 @@ class BasePackage(BaseModel):
class FlowPackage(BasePackage):
package_type = "flow"
package_type: str = "flow"
@classmethod
def build_from(
@@ -126,7 +128,7 @@ class FlowJsonPackage(FlowPackage):
class OperatorPackage(BasePackage):
package_type = "operator"
package_type: str = "operator"
operators: List[type] = Field(
default_factory=list, description="The operators of the package"
@@ -141,7 +143,7 @@ class OperatorPackage(BasePackage):
class AgentPackage(BasePackage):
package_type = "agent"
package_type: str = "agent"
agents: List[type] = Field(
default_factory=list, description="The agents of the package"
@@ -240,7 +242,7 @@ def _load_package_from_path(path: str):
class DBGPTsLoader(BaseComponent):
"""The loader of the dbgpts packages"""
name = "dbgpt_dbgpts_loader"
name: str = "dbgpt_dbgpts_loader"
def __init__(
self,

View File

@@ -1,9 +1,14 @@
"""FastAPI utilities."""
from typing import Any, Callable, Dict
import importlib.metadata as metadata
from contextlib import asynccontextmanager
from typing import Any, Callable, Dict, List, Optional
from fastapi import FastAPI
from fastapi.routing import APIRouter
_FASTAPI_VERSION = metadata.version("fastapi")
class PriorityAPIRouter(APIRouter):
"""A router with priority.
@@ -41,3 +46,85 @@ class PriorityAPIRouter(APIRouter):
return self.route_priority.get(route.path, 0)
self.routes.sort(key=my_func, reverse=True)
_HAS_STARTUP = False
_HAS_SHUTDOWN = False
_GLOBAL_STARTUP_HANDLERS: List[Callable] = []
_GLOBAL_SHUTDOWN_HANDLERS: List[Callable] = []
def register_event_handler(app: FastAPI, event: str, handler: Callable):
"""Register an event handler.
Args:
app (FastAPI): The FastAPI app.
event (str): The event type.
handler (Callable): The handler function.
"""
if _FASTAPI_VERSION >= "0.109.1":
# https://fastapi.tiangolo.com/release-notes/#01091
if event == "startup":
if _HAS_STARTUP:
raise ValueError(
"FastAPI app already started. Cannot add startup handler."
)
_GLOBAL_STARTUP_HANDLERS.append(handler)
elif event == "shutdown":
if _HAS_SHUTDOWN:
raise ValueError(
"FastAPI app already shutdown. Cannot add shutdown handler."
)
_GLOBAL_SHUTDOWN_HANDLERS.append(handler)
else:
raise ValueError(f"Invalid event: {event}")
else:
if event == "startup":
app.add_event_handler("startup", handler)
elif event == "shutdown":
app.add_event_handler("shutdown", handler)
else:
raise ValueError(f"Invalid event: {event}")
@asynccontextmanager
async def lifespan(app: FastAPI):
# Trigger the startup event.
global _HAS_STARTUP, _HAS_SHUTDOWN
for handler in _GLOBAL_STARTUP_HANDLERS:
await handler()
_HAS_STARTUP = True
yield
# Trigger the shutdown event.
for handler in _GLOBAL_SHUTDOWN_HANDLERS:
await handler()
_HAS_SHUTDOWN = True
def create_app(*args, **kwargs) -> FastAPI:
"""Create a FastAPI app."""
_sp = None
if _FASTAPI_VERSION >= "0.109.1":
if "lifespan" not in kwargs:
kwargs["lifespan"] = lifespan
_sp = kwargs["lifespan"]
app = FastAPI(*args, **kwargs)
if _sp:
app.__dbgpt_custom_lifespan = _sp
return app
def replace_router(app: FastAPI, router: Optional[APIRouter] = None):
"""Replace the router of the FastAPI app."""
if not router:
router = PriorityAPIRouter()
if _FASTAPI_VERSION >= "0.109.1":
if hasattr(app, "__dbgpt_custom_lifespan"):
_sp = getattr(app, "__dbgpt_custom_lifespan")
router.lifespan_context = _sp
app.router = router
app.setup()
return app

View File

@@ -197,10 +197,12 @@ def _start_http_forward(
):
import httpx
import uvicorn
from fastapi import BackgroundTasks, FastAPI, Request, Response
from fastapi import BackgroundTasks, Request, Response
from fastapi.responses import StreamingResponse
app = FastAPI()
from dbgpt.util.fastapi import create_app
app = create_app()
@app.middleware("http")
async def forward_http_request(request: Request, call_next):

View File

@@ -1,6 +1,6 @@
from typing import Generic, List, TypeVar
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
T = TypeVar("T")
@@ -8,6 +8,8 @@ T = TypeVar("T")
class PaginationResult(BaseModel, Generic[T]):
"""Pagination result"""
model_config = ConfigDict(arbitrary_types_allowed=True)
items: List[T] = Field(..., description="The items in the current page")
total_count: int = Field(..., description="Total number of items")
total_pages: int = Field(..., description="total number of pages")

View File

@@ -13,7 +13,7 @@ from string import Formatter
from typing import Callable, List, Optional, Sequence, Set
from dbgpt._private.llm_metadata import LLMMetadata
from dbgpt._private.pydantic import BaseModel, Field, PrivateAttr
from dbgpt._private.pydantic import BaseModel, Field, PrivateAttr, model_validator
from dbgpt.core.interface.prompt import get_template_vars
from dbgpt.rag.text_splitter.token_splitter import TokenTextSplitter
from dbgpt.util.global_helper import globals_helper
@@ -62,12 +62,14 @@ class PromptHelper(BaseModel):
default=DEFAULT_CHUNK_OVERLAP_RATIO,
description="The percentage token amount that each chunk should overlap.",
)
chunk_size_limit: Optional[int] = Field(description="The maximum size of a chunk.")
chunk_size_limit: Optional[int] = Field(
None, description="The maximum size of a chunk."
)
separator: str = Field(
default=" ", description="The separator when chunking tokens."
)
_tokenizer: Callable[[str], List] = PrivateAttr()
_tokenizer: Optional[Callable[[str], List]] = PrivateAttr()
def __init__(
self,
@@ -77,21 +79,22 @@ class PromptHelper(BaseModel):
chunk_size_limit: Optional[int] = None,
tokenizer: Optional[Callable[[str], List]] = None,
separator: str = " ",
**kwargs,
) -> None:
"""Init params."""
if chunk_overlap_ratio > 1.0 or chunk_overlap_ratio < 0.0:
raise ValueError("chunk_overlap_ratio must be a float between 0. and 1.")
# TODO: make configurable
self._tokenizer = tokenizer or globals_helper.tokenizer
super().__init__(
context_window=context_window,
num_output=num_output,
chunk_overlap_ratio=chunk_overlap_ratio,
chunk_size_limit=chunk_size_limit,
separator=separator,
**kwargs,
)
# TODO: make configurable
self._tokenizer = tokenizer or globals_helper.tokenizer
def token_count(self, prompt_template: str) -> int:
"""Get token count of prompt template."""