feat(core): Add common schemas

This commit is contained in:
Fangyin Cheng 2024-03-21 11:23:24 +08:00
parent ab3e8e54a1
commit b4b810d68f
18 changed files with 188 additions and 123 deletions

View File

@ -13,11 +13,8 @@ from dbgpt._private.config import Config
from dbgpt.app.knowledge.request.request import KnowledgeSpaceRequest
from dbgpt.app.knowledge.service import KnowledgeService
from dbgpt.app.openapi.api_view_model import (
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
ChatSceneVo,
ConversationVo,
DeltaMessage,
MessageVo,
Result,
)
@ -25,6 +22,11 @@ from dbgpt.app.scene import BaseChat, ChatFactory, ChatScene
from dbgpt.component import ComponentType
from dbgpt.configs.model_config import KNOWLEDGE_UPLOAD_ROOT_PATH
from dbgpt.core.awel import CommonLLMHttpRequestBody, CommonLLMHTTPRequestContext
from dbgpt.core.schema.api import (
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
DeltaMessage,
)
from dbgpt.datasource.db_conn_info import DBConfig, DbTypeInfo
from dbgpt.model.base import FlatSupportedModel
from dbgpt.model.cluster import BaseModelController, WorkerManager, WorkerManagerFactory

View File

@ -6,15 +6,6 @@ from typing import Optional
from fastapi import APIRouter, Body, Depends, HTTPException
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastchat.protocol.api_protocol import (
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
ChatMessage,
DeltaMessage,
UsageInfo,
)
from starlette.responses import StreamingResponse
from dbgpt.app.openapi.api_v1.api_v1 import (
@ -26,9 +17,18 @@ from dbgpt.app.openapi.api_v1.api_v1 import (
stream_generator,
)
from dbgpt.app.scene import BaseChat, ChatScene
from dbgpt.client.schemas import ChatCompletionRequestBody, ChatMode
from dbgpt.client.schema import ChatCompletionRequestBody, ChatMode
from dbgpt.component import logger
from dbgpt.core.awel import CommonLLMHttpRequestBody, CommonLLMHTTPRequestContext
from dbgpt.core.schema.api import (
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
ChatMessage,
DeltaMessage,
UsageInfo,
)
from dbgpt.model.cluster.apiserver.api import APISettings
from dbgpt.serve.agent.agents.controller import multi_agents
from dbgpt.serve.flow.api.endpoints import get_service

View File

@ -89,21 +89,3 @@ class MessageVo(BaseModel):
model_name
"""
model_name: str
class DeltaMessage(BaseModel):
role: Optional[str] = None
content: Optional[str] = None
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length"]] = None
class ChatCompletionStreamResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{str(uuid.uuid1())}")
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseStreamChoice]

View File

@ -1,9 +1,10 @@
"""App Client API."""
from typing import List
from dbgpt.client import Client, ClientException
from dbgpt.client.schemas import AppModel
from dbgpt.serve.core import Result
from dbgpt.core.schema.api import Result
from .client import Client, ClientException
from .schema import AppModel
async def get_app(client: Client, app_id: str) -> AppModel:

View File

@ -5,10 +5,10 @@ from typing import Any, AsyncGenerator, List, Optional, Union
from urllib.parse import urlparse
import httpx
from fastchat.protocol.api_protocol import ChatCompletionResponse
from dbgpt.app.openapi.api_view_model import ChatCompletionStreamResponse
from dbgpt.client.schemas import ChatCompletionRequestBody
from dbgpt.core.schema.api import ChatCompletionResponse, ChatCompletionStreamResponse
from .schema import ChatCompletionRequestBody
CLIENT_API_PATH = "/api"
CLIENT_SERVE_PATH = "/serve"
@ -256,14 +256,14 @@ class Client:
)
yield chat_completion_response
except Exception as e:
yield f"data:[SERVER_ERROR]{str(e)}\n\n"
raise e
else:
try:
error = await response.aread()
yield json.loads(error)
except Exception as e:
yield f"data:[SERVER_ERROR]{str(e)}\n\n"
raise e
async def get(self, path: str, *args):
"""Get method.

View File

@ -1,9 +1,10 @@
"""this module contains the flow client functions."""
from typing import List
from dbgpt.client import Client, ClientException
from dbgpt.core.awel.flow.flow_factory import FlowPanel
from dbgpt.serve.core import Result
from dbgpt.core.schema.api import Result
from .client import Client, ClientException
async def create_flow(client: Client, flow: FlowPanel) -> FlowPanel:

View File

@ -2,9 +2,10 @@
import json
from typing import List
from dbgpt.client import Client, ClientException
from dbgpt.client.schemas import DocumentModel, SpaceModel, SyncModel
from dbgpt.serve.core import Result
from dbgpt.core.schema.api import Result
from .client import Client, ClientException
from .schema import DocumentModel, SpaceModel, SyncModel
async def create_space(client: Client, space_model: SpaceModel) -> SpaceModel:

View File

@ -4,8 +4,8 @@ from enum import Enum
from typing import List, Optional, Union
from fastapi import File, UploadFile
from pydantic import BaseModel, Field
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.agent.resource.resource_api import AgentResource
from dbgpt.rag.chunk_manager import ChunkParameters

View File

@ -0,0 +1 @@
"""Module for common schemas."""

116
dbgpt/core/schema/api.py Normal file
View File

@ -0,0 +1,116 @@
"""API schema module."""
import time
import uuid
from typing import Any, Generic, List, Literal, Optional, TypeVar
from dbgpt._private.pydantic import BaseModel, Field
T = TypeVar("T")
class Result(BaseModel, Generic[T]):
"""Common result entity for API response."""
success: bool = Field(
..., description="Whether it is successful, True: success, False: failure"
)
err_code: str | None = Field(None, description="Error code")
err_msg: str | None = Field(None, description="Error message")
data: T | None = Field(None, description="Return data")
@staticmethod
def succ(data: T) -> "Result[T]":
"""Build a successful result entity.
Args:
data (T): Return data
Returns:
Result[T]: Result entity
"""
return Result(success=True, err_code=None, err_msg=None, data=data)
@staticmethod
def failed(msg: str, err_code: Optional[str] = "E000X") -> "Result[Any]":
"""Build a failed result entity.
Args:
msg (str): Error message
err_code (Optional[str], optional): Error code. Defaults to "E000X".
"""
return Result(success=False, err_code=err_code, err_msg=msg, data=None)
class DeltaMessage(BaseModel):
"""Delta message entity for chat completion response."""
role: Optional[str] = None
content: Optional[str] = None
class ChatCompletionResponseStreamChoice(BaseModel):
"""Chat completion response choice entity."""
index: int = Field(..., description="Choice index")
delta: DeltaMessage = Field(..., description="Delta message")
finish_reason: Optional[Literal["stop", "length"]] = Field(
None, description="Finish reason"
)
class ChatCompletionStreamResponse(BaseModel):
"""Chat completion response stream entity."""
id: str = Field(
default_factory=lambda: f"chatcmpl-{str(uuid.uuid1())}", description="Stream ID"
)
created: int = Field(
default_factory=lambda: int(time.time()), description="Created time"
)
model: str = Field(..., description="Model name")
choices: List[ChatCompletionResponseStreamChoice] = Field(
..., description="Chat completion response choices"
)
class ChatMessage(BaseModel):
"""Chat message entity."""
role: str = Field(..., description="Role of the message")
content: str = Field(..., description="Content of the message")
class UsageInfo(BaseModel):
"""Usage info entity."""
prompt_tokens: int = Field(0, description="Prompt tokens")
total_tokens: int = Field(0, description="Total tokens")
completion_tokens: Optional[int] = Field(0, description="Completion tokens")
class ChatCompletionResponseChoice(BaseModel):
"""Chat completion response choice entity."""
index: int = Field(..., description="Choice index")
message: ChatMessage = Field(..., description="Chat message")
finish_reason: Optional[Literal["stop", "length"]] = Field(
None, description="Finish reason"
)
class ChatCompletionResponse(BaseModel):
"""Chat completion response entity."""
id: str = Field(
default_factory=lambda: f"chatcmpl-{str(uuid.uuid1())}", description="Stream ID"
)
object: str = "chat.completion"
created: int = Field(
default_factory=lambda: int(time.time()), description="Created time"
)
model: str = Field(..., description="Model name")
choices: List[ChatCompletionResponseChoice] = Field(
..., description="Chat completion response choices"
)
usage: UsageInfo = Field(..., description="Usage info")

View File

@ -16,14 +16,7 @@ from typing import (
from dbgpt._private.pydantic import model_to_json
from dbgpt.core.awel import TransformStreamAbsOperator
from dbgpt.core.awel.flow import (
IOField,
OperatorCategory,
OperatorType,
Parameter,
ResourceCategory,
ViewMetadata,
)
from dbgpt.core.awel.flow import IOField, OperatorCategory, OperatorType, ViewMetadata
from dbgpt.core.interface.llm import ModelOutput
from dbgpt.core.operators import BaseLLM
@ -217,7 +210,8 @@ async def _to_openai_stream(
import json
import shortuuid
from fastchat.protocol.openai_api_protocol import (
from dbgpt.core.schema.api import (
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
DeltaMessage,

View File

@ -1,12 +1,12 @@
import logging
import sys
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
from typing import TYPE_CHECKING
from fastapi import HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.core.schema.api import Result
if sys.version_info < (3, 11):
try:
@ -18,40 +18,6 @@ if TYPE_CHECKING:
from fastapi import FastAPI
logger = logging.getLogger(__name__)
T = TypeVar("T")
class Result(BaseModel, Generic[T]):
"""Common result entity class"""
success: bool = Field(
..., description="Whether it is successful, True: success, False: failure"
)
err_code: str | None = Field(None, description="Error code")
err_msg: str | None = Field(None, description="Error message")
data: T | None = Field(None, description="Return data")
@staticmethod
def succ(data: T) -> "Result[T]":
"""Build a successful result entity
Args:
data (T): Return data
Returns:
Result[T]: Result entity
"""
return Result(success=True, err_code=None, err_msg=None, data=data)
@staticmethod
def failed(msg: str, err_code: Optional[str] = "E000X") -> "Result[Any]":
"""Build a failed result entity
Args:
msg (str): Error message
err_code (Optional[str], optional): Error code. Defaults to "E000X".
"""
return Result(success=False, err_code=err_code, err_msg=msg, data=None)
async def validation_exception_handler(

View File

@ -485,9 +485,7 @@ async def _chat_with_dag_task(
if OpenAIStreamingOutputOperator and isinstance(
task, OpenAIStreamingOutputOperator
):
from fastchat.protocol.openai_api_protocol import (
ChatCompletionResponseStreamChoice,
)
from dbgpt.core.schema.api import ChatCompletionResponseStreamChoice
previous_text = ""
async for output in await task.call_stream(request):

View File

@ -50,11 +50,12 @@ APP_ID="{YOUR_APP_ID}"
client = Client(api_key=DBGPT_API_KEY)
async for data in client.chat_stream(
messages="Introduce AWEL",
model="chatgpt_proxyllm",
chat_mode="chat_app",
chat_param=APP_ID):
print(data)
messages="Introduce AWEL",
model="chatgpt_proxyllm",
chat_mode="chat_app",
chat_param=APP_ID
):
print(data)
```
</TabItem>

View File

@ -45,6 +45,7 @@ import TabItem from '@theme/TabItem';
from dbgpt.client import Client
DBGPT_API_KEY = "dbgpt"
client = Client(api_key=DBGPT_API_KEY)
async for data in client.chat_stream(
model="chatgpt_proxyllm",

View File

@ -49,11 +49,12 @@ FLOW_ID="{YOUR_FLOW_ID}"
client = Client(api_key=DBGPT_API_KEY)
async for data in client.chat_stream(
messages="Introduce AWEL",
model="chatgpt_proxyllm",
chat_mode="chat_flow",
chat_param=FLOW_ID):
print(data)
messages="Introduce AWEL",
model="chatgpt_proxyllm",
chat_mode="chat_flow",
chat_param=FLOW_ID
):
print(data)
```
</TabItem>
</Tabs>

View File

@ -49,11 +49,12 @@ SPACE_NAME="{YOUR_SPACE_NAME}"
client = Client(api_key=DBGPT_API_KEY)
async for data in client.chat_stream(
messages="Introduce AWEL",
model="chatgpt_proxyllm",
chat_mode="chat_knowledge",
chat_param=SPACE_NAME):
print(data)
messages="Introduce AWEL",
model="chatgpt_proxyllm",
chat_mode="chat_knowledge",
chat_param=SPACE_NAME
):
print(data)
```
</TabItem>
</Tabs>
@ -343,21 +344,20 @@ POST /api/v2/serve/knowledge/spaces
<TabItem value="python_knowledge">
```python
from dbgpt.client import Client
from dbgpt.client.knowledge import create_space
from dbgpt.client.schemas import SpaceModel
from dbgpt.client.schema import SpaceModel
DBGPT_API_KEY = "dbgpt"
client = Client(api_key=DBGPT_API_KEY)
res = await create_space(client,SpaceModel(
name="test_space",
vector_type="Chroma",
desc="for client space",
owner="dbgpt"))
res = await create_space(client, SpaceModel(
name="test_space",
vector_type="Chroma",
desc="for client space",
owner="dbgpt"
))
```
</TabItem>
@ -420,20 +420,20 @@ PUT /api/v2/serve/knowledge/spaces
<TabItem value="python_update_knowledge">
```python
from dbgpt.client import Client
from dbgpt.client.knowledge import update_space
from dbgpt.client.schemas import SpaceModel
from dbgpt.client.schema import SpaceModel
DBGPT_API_KEY = "dbgpt"
client = Client(api_key=DBGPT_API_KEY)
res = await update_space(client, SpaceModel(
name="test_space",
vector_type="Chroma",
desc="for client space update",
owner="dbgpt"))
name="test_space",
vector_type="Chroma",
desc="for client space update",
owner="dbgpt"
))
```

View File

@ -66,7 +66,7 @@ import asyncio
from dbgpt.client import Client
from dbgpt.client.knowledge import create_space
from dbgpt.client.schemas import SpaceModel
from dbgpt.client.schema import SpaceModel
async def main():