mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-27 05:47:47 +00:00
259 lines
8.6 KiB
Python
259 lines
8.6 KiB
Python
"""API schema module."""
|
|
|
|
import time
|
|
import uuid
|
|
from enum import IntEnum
|
|
from typing import Any, Dict, Generic, List, Literal, Optional, TypeVar, Union
|
|
|
|
from dbgpt._private.pydantic import BaseModel, Field, model_to_dict
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
class Result(BaseModel, Generic[T]):
|
|
"""Common result entity for API response."""
|
|
|
|
success: bool = Field(
|
|
..., description="Whether it is successful, True: success, False: failure"
|
|
)
|
|
err_code: str | None = Field(None, description="Error code")
|
|
err_msg: str | None = Field(None, description="Error message")
|
|
data: T | None = Field(None, description="Return data")
|
|
|
|
@staticmethod
|
|
def succ(data: T) -> "Result[T]":
|
|
"""Build a successful result entity.
|
|
|
|
Args:
|
|
data (T): Return data
|
|
|
|
Returns:
|
|
Result[T]: Result entity
|
|
"""
|
|
return Result(success=True, err_code=None, err_msg=None, data=data)
|
|
|
|
@staticmethod
|
|
def failed(msg: str, err_code: Optional[str] = "E000X") -> "Result[Any]":
|
|
"""Build a failed result entity.
|
|
|
|
Args:
|
|
msg (str): Error message
|
|
err_code (Optional[str], optional): Error code. Defaults to "E000X".
|
|
"""
|
|
return Result(success=False, err_code=err_code, err_msg=msg, data=None)
|
|
|
|
def to_dict(self, **kwargs) -> Dict[str, Any]:
|
|
"""Convert to dict."""
|
|
return model_to_dict(self, **kwargs)
|
|
|
|
|
|
class APIChatCompletionRequest(BaseModel):
|
|
"""Chat completion request entity."""
|
|
|
|
model: str = Field(..., description="Model name")
|
|
messages: Union[str, List[Dict[str, str]]] = Field(..., description="Messages")
|
|
temperature: Optional[float] = Field(0.7, description="Temperature")
|
|
top_p: Optional[float] = Field(1.0, description="Top p")
|
|
top_k: Optional[int] = Field(-1, description="Top k")
|
|
n: Optional[int] = Field(1, description="Number of completions")
|
|
max_tokens: Optional[int] = Field(None, description="Max tokens")
|
|
stop: Optional[Union[str, List[str]]] = Field(None, description="Stop")
|
|
stream: Optional[bool] = Field(False, description="Stream")
|
|
user: Optional[str] = Field(None, description="User")
|
|
repetition_penalty: Optional[float] = Field(1.0, description="Repetition penalty")
|
|
frequency_penalty: Optional[float] = Field(0.0, description="Frequency penalty")
|
|
presence_penalty: Optional[float] = Field(0.0, description="Presence penalty")
|
|
|
|
|
|
class DeltaMessage(BaseModel):
|
|
"""Delta message entity for chat completion response."""
|
|
|
|
role: Optional[str] = None
|
|
content: Optional[str] = None
|
|
|
|
|
|
class ChatCompletionResponseStreamChoice(BaseModel):
|
|
"""Chat completion response choice entity."""
|
|
|
|
index: int = Field(..., description="Choice index")
|
|
delta: DeltaMessage = Field(..., description="Delta message")
|
|
finish_reason: Optional[Literal["stop", "length"]] = Field(
|
|
None, description="Finish reason"
|
|
)
|
|
|
|
|
|
class ChatCompletionStreamResponse(BaseModel):
|
|
"""Chat completion response stream entity."""
|
|
|
|
id: str = Field(
|
|
default_factory=lambda: f"chatcmpl-{str(uuid.uuid1())}", description="Stream ID"
|
|
)
|
|
created: int = Field(
|
|
default_factory=lambda: int(time.time()), description="Created time"
|
|
)
|
|
model: str = Field(..., description="Model name")
|
|
choices: List[ChatCompletionResponseStreamChoice] = Field(
|
|
..., description="Chat completion response choices"
|
|
)
|
|
|
|
|
|
class ChatMessage(BaseModel):
|
|
"""Chat message entity."""
|
|
|
|
role: str = Field(..., description="Role of the message")
|
|
content: str = Field(..., description="Content of the message")
|
|
|
|
|
|
class UsageInfo(BaseModel):
|
|
"""Usage info entity."""
|
|
|
|
prompt_tokens: int = Field(0, description="Prompt tokens")
|
|
total_tokens: int = Field(0, description="Total tokens")
|
|
completion_tokens: Optional[int] = Field(0, description="Completion tokens")
|
|
|
|
|
|
class ChatCompletionResponseChoice(BaseModel):
|
|
"""Chat completion response choice entity."""
|
|
|
|
index: int = Field(..., description="Choice index")
|
|
message: ChatMessage = Field(..., description="Chat message")
|
|
finish_reason: Optional[Literal["stop", "length"]] = Field(
|
|
None, description="Finish reason"
|
|
)
|
|
|
|
|
|
class ChatCompletionResponse(BaseModel):
|
|
"""Chat completion response entity."""
|
|
|
|
id: str = Field(
|
|
default_factory=lambda: f"chatcmpl-{str(uuid.uuid1())}", description="Stream ID"
|
|
)
|
|
object: str = "chat.completion"
|
|
created: int = Field(
|
|
default_factory=lambda: int(time.time()), description="Created time"
|
|
)
|
|
model: str = Field(..., description="Model name")
|
|
choices: List[ChatCompletionResponseChoice] = Field(
|
|
..., description="Chat completion response choices"
|
|
)
|
|
usage: UsageInfo = Field(..., description="Usage info")
|
|
|
|
|
|
class ErrorResponse(BaseModel):
|
|
"""Error response entity."""
|
|
|
|
object: str = Field("error", description="Object type")
|
|
message: str = Field(..., description="Error message")
|
|
code: int = Field(..., description="Error code")
|
|
|
|
|
|
class EmbeddingsRequest(BaseModel):
|
|
"""Embeddings request entity."""
|
|
|
|
model: Optional[str] = Field(None, description="Model name")
|
|
engine: Optional[str] = Field(None, description="Engine name")
|
|
input: Union[str, List[Any]] = Field(..., description="Input data")
|
|
user: Optional[str] = Field(None, description="User name")
|
|
encoding_format: Optional[str] = Field(None, description="Encoding format")
|
|
|
|
|
|
class EmbeddingsResponse(BaseModel):
|
|
"""Embeddings response entity."""
|
|
|
|
object: str = Field("list", description="Object type")
|
|
data: List[Dict[str, Any]] = Field(..., description="Data list")
|
|
model: str = Field(..., description="Model name")
|
|
usage: UsageInfo = Field(..., description="Usage info")
|
|
|
|
|
|
class RelevanceRequest(BaseModel):
|
|
"""Relevance request entity."""
|
|
|
|
model: str = Field(..., description="Rerank model name")
|
|
query: str = Field(..., description="Query text")
|
|
documents: List[str] = Field(..., description="Document texts")
|
|
|
|
|
|
class RelevanceResponse(BaseModel):
|
|
"""Relevance response entity."""
|
|
|
|
object: str = Field("list", description="Object type")
|
|
model: str = Field(..., description="Rerank model name")
|
|
data: List[float] = Field(..., description="Data list, relevance scores")
|
|
usage: UsageInfo = Field(..., description="Usage info")
|
|
|
|
|
|
class ModelPermission(BaseModel):
|
|
"""Model permission entity."""
|
|
|
|
id: str = Field(
|
|
default_factory=lambda: f"modelperm-{str(uuid.uuid1())}",
|
|
description="Permission ID",
|
|
)
|
|
object: str = Field("model_permission", description="Object type")
|
|
created: int = Field(
|
|
default_factory=lambda: int(time.time()), description="Created time"
|
|
)
|
|
allow_create_engine: bool = Field(False, description="Allow create engine")
|
|
allow_sampling: bool = Field(True, description="Allow sampling")
|
|
allow_logprobs: bool = Field(True, description="Allow logprobs")
|
|
allow_search_indices: bool = Field(True, description="Allow search indices")
|
|
allow_view: bool = Field(True, description="Allow view")
|
|
allow_fine_tuning: bool = Field(False, description="Allow fine tuning")
|
|
organization: str = Field("*", description="Organization")
|
|
group: Optional[str] = Field(None, description="Group")
|
|
is_blocking: bool = Field(False, description="Is blocking")
|
|
|
|
|
|
class ModelCard(BaseModel):
|
|
"""Model card entity."""
|
|
|
|
id: str = Field(..., description="Model ID")
|
|
object: str = Field("model", description="Object type")
|
|
created: int = Field(
|
|
default_factory=lambda: int(time.time()), description="Created time"
|
|
)
|
|
owned_by: str = Field("DB-GPT", description="Owned by")
|
|
root: Optional[str] = Field(None, description="Root")
|
|
parent: Optional[str] = Field(None, description="Parent")
|
|
permission: List[ModelPermission] = Field(
|
|
default_factory=list, description="Permission"
|
|
)
|
|
|
|
|
|
class ModelList(BaseModel):
|
|
"""Model list entity."""
|
|
|
|
object: str = Field("list", description="Object type")
|
|
data: List[ModelCard] = Field(default_factory=list, description="Model list data")
|
|
|
|
|
|
class ErrorCode(IntEnum):
|
|
"""Error code enumeration.
|
|
|
|
https://platform.openai.com/docs/guides/error-codes/api-errors.
|
|
|
|
Adapted from fastchat.constants.
|
|
"""
|
|
|
|
VALIDATION_TYPE_ERROR = 40001
|
|
|
|
INVALID_AUTH_KEY = 40101
|
|
INCORRECT_AUTH_KEY = 40102
|
|
NO_PERMISSION = 40103
|
|
|
|
INVALID_MODEL = 40301
|
|
PARAM_OUT_OF_RANGE = 40302
|
|
CONTEXT_OVERFLOW = 40303
|
|
|
|
RATE_LIMIT = 42901
|
|
QUOTA_EXCEEDED = 42902
|
|
ENGINE_OVERLOADED = 42903
|
|
|
|
INTERNAL_ERROR = 50001
|
|
CUDA_OUT_OF_MEMORY = 50002
|
|
GRADIO_REQUEST_ERROR = 50003
|
|
GRADIO_STREAM_UNKNOWN_ERROR = 50004
|
|
CONTROLLER_NO_WORKER = 50005
|
|
CONTROLLER_WORKER_TIMEOUT = 50006
|