feat(core): Upgrade pydantic to 2.x (#1428)

This commit is contained in:
Fangyin Cheng
2024-04-20 09:41:16 +08:00
committed by GitHub
parent baa1e3f9f6
commit 57be1ece18
103 changed files with 1146 additions and 534 deletions

View File

@@ -2,9 +2,10 @@
import time
import uuid
from typing import Any, Generic, List, Literal, Optional, TypeVar
from enum import IntEnum
from typing import Any, Dict, Generic, List, Literal, Optional, TypeVar, Union
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt._private.pydantic import BaseModel, Field, model_to_dict
T = TypeVar("T")
@@ -41,6 +42,28 @@ class Result(BaseModel, Generic[T]):
"""
return Result(success=False, err_code=err_code, err_msg=msg, data=None)
def to_dict(self, **kwargs) -> Dict[str, Any]:
"""Convert to dict."""
return model_to_dict(self, **kwargs)
class APIChatCompletionRequest(BaseModel):
"""Chat completion request entity."""
model: str = Field(..., description="Model name")
messages: Union[str, List[Dict[str, str]]] = Field(..., description="Messages")
temperature: Optional[float] = Field(0.7, description="Temperature")
top_p: Optional[float] = Field(1.0, description="Top p")
top_k: Optional[int] = Field(-1, description="Top k")
n: Optional[int] = Field(1, description="Number of completions")
max_tokens: Optional[int] = Field(None, description="Max tokens")
stop: Optional[Union[str, List[str]]] = Field(None, description="Stop")
stream: Optional[bool] = Field(False, description="Stream")
user: Optional[str] = Field(None, description="User")
repetition_penalty: Optional[float] = Field(1.0, description="Repetition penalty")
frequency_penalty: Optional[float] = Field(0.0, description="Frequency penalty")
presence_penalty: Optional[float] = Field(0.0, description="Presence penalty")
class DeltaMessage(BaseModel):
"""Delta message entity for chat completion response."""
@@ -122,3 +145,97 @@ class ErrorResponse(BaseModel):
object: str = Field("error", description="Object type")
message: str = Field(..., description="Error message")
code: int = Field(..., description="Error code")
class EmbeddingsRequest(BaseModel):
"""Embeddings request entity."""
model: Optional[str] = Field(None, description="Model name")
engine: Optional[str] = Field(None, description="Engine name")
input: Union[str, List[Any]] = Field(..., description="Input data")
user: Optional[str] = Field(None, description="User name")
encoding_format: Optional[str] = Field(None, description="Encoding format")
class EmbeddingsResponse(BaseModel):
"""Embeddings response entity."""
object: str = Field("list", description="Object type")
data: List[Dict[str, Any]] = Field(..., description="Data list")
model: str = Field(..., description="Model name")
usage: UsageInfo = Field(..., description="Usage info")
class ModelPermission(BaseModel):
"""Model permission entity."""
id: str = Field(
default_factory=lambda: f"modelperm-{str(uuid.uuid1())}",
description="Permission ID",
)
object: str = Field("model_permission", description="Object type")
created: int = Field(
default_factory=lambda: int(time.time()), description="Created time"
)
allow_create_engine: bool = Field(False, description="Allow create engine")
allow_sampling: bool = Field(True, description="Allow sampling")
allow_logprobs: bool = Field(True, description="Allow logprobs")
allow_search_indices: bool = Field(True, description="Allow search indices")
allow_view: bool = Field(True, description="Allow view")
allow_fine_tuning: bool = Field(False, description="Allow fine tuning")
organization: str = Field("*", description="Organization")
group: Optional[str] = Field(None, description="Group")
is_blocking: bool = Field(False, description="Is blocking")
class ModelCard(BaseModel):
"""Model card entity."""
id: str = Field(..., description="Model ID")
object: str = Field("model", description="Object type")
created: int = Field(
default_factory=lambda: int(time.time()), description="Created time"
)
owned_by: str = Field("DB-GPT", description="Owned by")
root: Optional[str] = Field(None, description="Root")
parent: Optional[str] = Field(None, description="Parent")
permission: List[ModelPermission] = Field(
default_factory=list, description="Permission"
)
class ModelList(BaseModel):
"""Model list entity."""
object: str = Field("list", description="Object type")
data: List[ModelCard] = Field(default_factory=list, description="Model list data")
class ErrorCode(IntEnum):
"""Error code enumeration.
https://platform.openai.com/docs/guides/error-codes/api-errors.
Adapted from fastchat.constants.
"""
VALIDATION_TYPE_ERROR = 40001
INVALID_AUTH_KEY = 40101
INCORRECT_AUTH_KEY = 40102
NO_PERMISSION = 40103
INVALID_MODEL = 40301
PARAM_OUT_OF_RANGE = 40302
CONTEXT_OVERFLOW = 40303
RATE_LIMIT = 42901
QUOTA_EXCEEDED = 42902
ENGINE_OVERLOADED = 42903
INTERNAL_ERROR = 50001
CUDA_OUT_OF_MEMORY = 50002
GRADIO_REQUEST_ERROR = 50003
GRADIO_STREAM_UNKNOWN_ERROR = 50004
CONTROLLER_NO_WORKER = 50005
CONTROLLER_WORKER_TIMEOUT = 50006