DB-GPT/dbgpt/core/schema/api.py
2024-05-16 14:50:16 +08:00

259 lines
8.6 KiB
Python

"""API schema module."""
import time
import uuid
from enum import IntEnum
from typing import Any, Dict, Generic, List, Literal, Optional, TypeVar, Union
from dbgpt._private.pydantic import BaseModel, Field, model_to_dict
T = TypeVar("T")
class Result(BaseModel, Generic[T]):
"""Common result entity for API response."""
success: bool = Field(
..., description="Whether it is successful, True: success, False: failure"
)
err_code: str | None = Field(None, description="Error code")
err_msg: str | None = Field(None, description="Error message")
data: T | None = Field(None, description="Return data")
@staticmethod
def succ(data: T) -> "Result[T]":
"""Build a successful result entity.
Args:
data (T): Return data
Returns:
Result[T]: Result entity
"""
return Result(success=True, err_code=None, err_msg=None, data=data)
@staticmethod
def failed(msg: str, err_code: Optional[str] = "E000X") -> "Result[Any]":
"""Build a failed result entity.
Args:
msg (str): Error message
err_code (Optional[str], optional): Error code. Defaults to "E000X".
"""
return Result(success=False, err_code=err_code, err_msg=msg, data=None)
def to_dict(self, **kwargs) -> Dict[str, Any]:
"""Convert to dict."""
return model_to_dict(self, **kwargs)
class APIChatCompletionRequest(BaseModel):
"""Chat completion request entity."""
model: str = Field(..., description="Model name")
messages: Union[str, List[Dict[str, str]]] = Field(..., description="Messages")
temperature: Optional[float] = Field(0.7, description="Temperature")
top_p: Optional[float] = Field(1.0, description="Top p")
top_k: Optional[int] = Field(-1, description="Top k")
n: Optional[int] = Field(1, description="Number of completions")
max_tokens: Optional[int] = Field(None, description="Max tokens")
stop: Optional[Union[str, List[str]]] = Field(None, description="Stop")
stream: Optional[bool] = Field(False, description="Stream")
user: Optional[str] = Field(None, description="User")
repetition_penalty: Optional[float] = Field(1.0, description="Repetition penalty")
frequency_penalty: Optional[float] = Field(0.0, description="Frequency penalty")
presence_penalty: Optional[float] = Field(0.0, description="Presence penalty")
class DeltaMessage(BaseModel):
"""Delta message entity for chat completion response."""
role: Optional[str] = None
content: Optional[str] = None
class ChatCompletionResponseStreamChoice(BaseModel):
"""Chat completion response choice entity."""
index: int = Field(..., description="Choice index")
delta: DeltaMessage = Field(..., description="Delta message")
finish_reason: Optional[Literal["stop", "length"]] = Field(
None, description="Finish reason"
)
class ChatCompletionStreamResponse(BaseModel):
"""Chat completion response stream entity."""
id: str = Field(
default_factory=lambda: f"chatcmpl-{str(uuid.uuid1())}", description="Stream ID"
)
created: int = Field(
default_factory=lambda: int(time.time()), description="Created time"
)
model: str = Field(..., description="Model name")
choices: List[ChatCompletionResponseStreamChoice] = Field(
..., description="Chat completion response choices"
)
class ChatMessage(BaseModel):
"""Chat message entity."""
role: str = Field(..., description="Role of the message")
content: str = Field(..., description="Content of the message")
class UsageInfo(BaseModel):
"""Usage info entity."""
prompt_tokens: int = Field(0, description="Prompt tokens")
total_tokens: int = Field(0, description="Total tokens")
completion_tokens: Optional[int] = Field(0, description="Completion tokens")
class ChatCompletionResponseChoice(BaseModel):
"""Chat completion response choice entity."""
index: int = Field(..., description="Choice index")
message: ChatMessage = Field(..., description="Chat message")
finish_reason: Optional[Literal["stop", "length"]] = Field(
None, description="Finish reason"
)
class ChatCompletionResponse(BaseModel):
"""Chat completion response entity."""
id: str = Field(
default_factory=lambda: f"chatcmpl-{str(uuid.uuid1())}", description="Stream ID"
)
object: str = "chat.completion"
created: int = Field(
default_factory=lambda: int(time.time()), description="Created time"
)
model: str = Field(..., description="Model name")
choices: List[ChatCompletionResponseChoice] = Field(
..., description="Chat completion response choices"
)
usage: UsageInfo = Field(..., description="Usage info")
class ErrorResponse(BaseModel):
"""Error response entity."""
object: str = Field("error", description="Object type")
message: str = Field(..., description="Error message")
code: int = Field(..., description="Error code")
class EmbeddingsRequest(BaseModel):
"""Embeddings request entity."""
model: Optional[str] = Field(None, description="Model name")
engine: Optional[str] = Field(None, description="Engine name")
input: Union[str, List[Any]] = Field(..., description="Input data")
user: Optional[str] = Field(None, description="User name")
encoding_format: Optional[str] = Field(None, description="Encoding format")
class EmbeddingsResponse(BaseModel):
"""Embeddings response entity."""
object: str = Field("list", description="Object type")
data: List[Dict[str, Any]] = Field(..., description="Data list")
model: str = Field(..., description="Model name")
usage: UsageInfo = Field(..., description="Usage info")
class RelevanceRequest(BaseModel):
"""Relevance request entity."""
model: str = Field(..., description="Rerank model name")
query: str = Field(..., description="Query text")
documents: List[str] = Field(..., description="Document texts")
class RelevanceResponse(BaseModel):
"""Relevance response entity."""
object: str = Field("list", description="Object type")
model: str = Field(..., description="Rerank model name")
data: List[float] = Field(..., description="Data list, relevance scores")
usage: UsageInfo = Field(..., description="Usage info")
class ModelPermission(BaseModel):
"""Model permission entity."""
id: str = Field(
default_factory=lambda: f"modelperm-{str(uuid.uuid1())}",
description="Permission ID",
)
object: str = Field("model_permission", description="Object type")
created: int = Field(
default_factory=lambda: int(time.time()), description="Created time"
)
allow_create_engine: bool = Field(False, description="Allow create engine")
allow_sampling: bool = Field(True, description="Allow sampling")
allow_logprobs: bool = Field(True, description="Allow logprobs")
allow_search_indices: bool = Field(True, description="Allow search indices")
allow_view: bool = Field(True, description="Allow view")
allow_fine_tuning: bool = Field(False, description="Allow fine tuning")
organization: str = Field("*", description="Organization")
group: Optional[str] = Field(None, description="Group")
is_blocking: bool = Field(False, description="Is blocking")
class ModelCard(BaseModel):
"""Model card entity."""
id: str = Field(..., description="Model ID")
object: str = Field("model", description="Object type")
created: int = Field(
default_factory=lambda: int(time.time()), description="Created time"
)
owned_by: str = Field("DB-GPT", description="Owned by")
root: Optional[str] = Field(None, description="Root")
parent: Optional[str] = Field(None, description="Parent")
permission: List[ModelPermission] = Field(
default_factory=list, description="Permission"
)
class ModelList(BaseModel):
"""Model list entity."""
object: str = Field("list", description="Object type")
data: List[ModelCard] = Field(default_factory=list, description="Model list data")
class ErrorCode(IntEnum):
"""Error code enumeration.
https://platform.openai.com/docs/guides/error-codes/api-errors.
Adapted from fastchat.constants.
"""
VALIDATION_TYPE_ERROR = 40001
INVALID_AUTH_KEY = 40101
INCORRECT_AUTH_KEY = 40102
NO_PERMISSION = 40103
INVALID_MODEL = 40301
PARAM_OUT_OF_RANGE = 40302
CONTEXT_OVERFLOW = 40303
RATE_LIMIT = 42901
QUOTA_EXCEEDED = 42902
ENGINE_OVERLOADED = 42903
INTERNAL_ERROR = 50001
CUDA_OUT_OF_MEMORY = 50002
GRADIO_REQUEST_ERROR = 50003
GRADIO_STREAM_UNKNOWN_ERROR = 50004
CONTROLLER_NO_WORKER = 50005
CONTROLLER_WORKER_TIMEOUT = 50006