mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-01 17:16:51 +00:00
feat(core): Upgrade pydantic to 2.x (#1428)
This commit is contained in:
@@ -2,9 +2,10 @@
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Generic, List, Literal, Optional, TypeVar
|
||||
from enum import IntEnum
|
||||
from typing import Any, Dict, Generic, List, Literal, Optional, TypeVar, Union
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt._private.pydantic import BaseModel, Field, model_to_dict
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -41,6 +42,28 @@ class Result(BaseModel, Generic[T]):
|
||||
"""
|
||||
return Result(success=False, err_code=err_code, err_msg=msg, data=None)
|
||||
|
||||
def to_dict(self, **kwargs) -> Dict[str, Any]:
|
||||
"""Convert to dict."""
|
||||
return model_to_dict(self, **kwargs)
|
||||
|
||||
|
||||
class APIChatCompletionRequest(BaseModel):
|
||||
"""Chat completion request entity."""
|
||||
|
||||
model: str = Field(..., description="Model name")
|
||||
messages: Union[str, List[Dict[str, str]]] = Field(..., description="Messages")
|
||||
temperature: Optional[float] = Field(0.7, description="Temperature")
|
||||
top_p: Optional[float] = Field(1.0, description="Top p")
|
||||
top_k: Optional[int] = Field(-1, description="Top k")
|
||||
n: Optional[int] = Field(1, description="Number of completions")
|
||||
max_tokens: Optional[int] = Field(None, description="Max tokens")
|
||||
stop: Optional[Union[str, List[str]]] = Field(None, description="Stop")
|
||||
stream: Optional[bool] = Field(False, description="Stream")
|
||||
user: Optional[str] = Field(None, description="User")
|
||||
repetition_penalty: Optional[float] = Field(1.0, description="Repetition penalty")
|
||||
frequency_penalty: Optional[float] = Field(0.0, description="Frequency penalty")
|
||||
presence_penalty: Optional[float] = Field(0.0, description="Presence penalty")
|
||||
|
||||
|
||||
class DeltaMessage(BaseModel):
|
||||
"""Delta message entity for chat completion response."""
|
||||
@@ -122,3 +145,97 @@ class ErrorResponse(BaseModel):
|
||||
object: str = Field("error", description="Object type")
|
||||
message: str = Field(..., description="Error message")
|
||||
code: int = Field(..., description="Error code")
|
||||
|
||||
|
||||
class EmbeddingsRequest(BaseModel):
|
||||
"""Embeddings request entity."""
|
||||
|
||||
model: Optional[str] = Field(None, description="Model name")
|
||||
engine: Optional[str] = Field(None, description="Engine name")
|
||||
input: Union[str, List[Any]] = Field(..., description="Input data")
|
||||
user: Optional[str] = Field(None, description="User name")
|
||||
encoding_format: Optional[str] = Field(None, description="Encoding format")
|
||||
|
||||
|
||||
class EmbeddingsResponse(BaseModel):
|
||||
"""Embeddings response entity."""
|
||||
|
||||
object: str = Field("list", description="Object type")
|
||||
data: List[Dict[str, Any]] = Field(..., description="Data list")
|
||||
model: str = Field(..., description="Model name")
|
||||
usage: UsageInfo = Field(..., description="Usage info")
|
||||
|
||||
|
||||
class ModelPermission(BaseModel):
|
||||
"""Model permission entity."""
|
||||
|
||||
id: str = Field(
|
||||
default_factory=lambda: f"modelperm-{str(uuid.uuid1())}",
|
||||
description="Permission ID",
|
||||
)
|
||||
object: str = Field("model_permission", description="Object type")
|
||||
created: int = Field(
|
||||
default_factory=lambda: int(time.time()), description="Created time"
|
||||
)
|
||||
allow_create_engine: bool = Field(False, description="Allow create engine")
|
||||
allow_sampling: bool = Field(True, description="Allow sampling")
|
||||
allow_logprobs: bool = Field(True, description="Allow logprobs")
|
||||
allow_search_indices: bool = Field(True, description="Allow search indices")
|
||||
allow_view: bool = Field(True, description="Allow view")
|
||||
allow_fine_tuning: bool = Field(False, description="Allow fine tuning")
|
||||
organization: str = Field("*", description="Organization")
|
||||
group: Optional[str] = Field(None, description="Group")
|
||||
is_blocking: bool = Field(False, description="Is blocking")
|
||||
|
||||
|
||||
class ModelCard(BaseModel):
|
||||
"""Model card entity."""
|
||||
|
||||
id: str = Field(..., description="Model ID")
|
||||
object: str = Field("model", description="Object type")
|
||||
created: int = Field(
|
||||
default_factory=lambda: int(time.time()), description="Created time"
|
||||
)
|
||||
owned_by: str = Field("DB-GPT", description="Owned by")
|
||||
root: Optional[str] = Field(None, description="Root")
|
||||
parent: Optional[str] = Field(None, description="Parent")
|
||||
permission: List[ModelPermission] = Field(
|
||||
default_factory=list, description="Permission"
|
||||
)
|
||||
|
||||
|
||||
class ModelList(BaseModel):
|
||||
"""Model list entity."""
|
||||
|
||||
object: str = Field("list", description="Object type")
|
||||
data: List[ModelCard] = Field(default_factory=list, description="Model list data")
|
||||
|
||||
|
||||
class ErrorCode(IntEnum):
|
||||
"""Error code enumeration.
|
||||
|
||||
https://platform.openai.com/docs/guides/error-codes/api-errors.
|
||||
|
||||
Adapted from fastchat.constants.
|
||||
"""
|
||||
|
||||
VALIDATION_TYPE_ERROR = 40001
|
||||
|
||||
INVALID_AUTH_KEY = 40101
|
||||
INCORRECT_AUTH_KEY = 40102
|
||||
NO_PERMISSION = 40103
|
||||
|
||||
INVALID_MODEL = 40301
|
||||
PARAM_OUT_OF_RANGE = 40302
|
||||
CONTEXT_OVERFLOW = 40303
|
||||
|
||||
RATE_LIMIT = 42901
|
||||
QUOTA_EXCEEDED = 42902
|
||||
ENGINE_OVERLOADED = 42903
|
||||
|
||||
INTERNAL_ERROR = 50001
|
||||
CUDA_OUT_OF_MEMORY = 50002
|
||||
GRADIO_REQUEST_ERROR = 50003
|
||||
GRADIO_STREAM_UNKNOWN_ERROR = 50004
|
||||
CONTROLLER_NO_WORKER = 50005
|
||||
CONTROLLER_WORKER_TIMEOUT = 50006
|
||||
|
Reference in New Issue
Block a user