feat(core): APP use new SDK component (#1050)

This commit is contained in:
Fangyin Cheng
2024-01-10 10:39:04 +08:00
committed by GitHub
parent e11b72c724
commit fa8b5b190c
242 changed files with 2768 additions and 2163 deletions

View File

@@ -1,3 +1,4 @@
import uuid
from functools import cache
from typing import List, Optional
@@ -10,7 +11,7 @@ from dbgpt.util import PaginationResult
from ..config import APP_NAME, SERVE_APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
from ..service.service import Service
from .schemas import ServeRequest, ServerResponse
from .schemas import MessageVo, ServeRequest, ServerResponse
router = APIRouter()
@@ -94,40 +95,6 @@ async def test_auth():
return {"status": "ok"}
@router.post(
"/", response_model=Result[ServerResponse], dependencies=[Depends(check_api_key)]
)
async def create(
request: ServeRequest, service: Service = Depends(get_service)
) -> Result[ServerResponse]:
"""Create a new Conversation entity
Args:
request (ServeRequest): The request
service (Service): The service
Returns:
ServerResponse: The response
"""
return Result.succ(service.create(request))
@router.put(
"/", response_model=Result[ServerResponse], dependencies=[Depends(check_api_key)]
)
async def update(
request: ServeRequest, service: Service = Depends(get_service)
) -> Result[ServerResponse]:
"""Update a Conversation entity
Args:
request (ServeRequest): The request
service (Service): The service
Returns:
ServerResponse: The response
"""
return Result.succ(service.update(request))
@router.post(
"/query",
response_model=Result[ServerResponse],
@@ -147,6 +114,45 @@ async def query(
return Result.succ(service.get(request))
@router.post(
"/new",
response_model=Result[ServerResponse],
dependencies=[Depends(check_api_key)],
)
async def dialogue_new(
chat_mode: str = "chat_normal",
user_name: str = None,
# TODO remove user id
user_id: str = None,
sys_code: str = None,
):
user_name = user_name or user_id
unique_id = uuid.uuid1()
res = ServerResponse(
user_input="",
conv_uid=str(unique_id),
chat_mode=chat_mode,
user_name=user_name,
sys_code=sys_code,
)
return Result.succ(res)
@router.post(
"/delete",
dependencies=[Depends(check_api_key)],
)
async def delete(con_uid: str, service: Service = Depends(get_service)):
"""Delete a Conversation entity
Args:
con_uid (str): The conversation UID
service (Service): The service
"""
service.delete(ServeRequest(conv_uid=con_uid))
return Result.succ(None)
@router.post(
"/query_page",
response_model=Result[PaginationResult[ServerResponse]],
@@ -155,7 +161,7 @@ async def query(
async def query_page(
request: ServeRequest,
page: Optional[int] = Query(default=1, description="current page"),
page_size: Optional[int] = Query(default=20, description="page size"),
page_size: Optional[int] = Query(default=10, description="page size"),
service: Service = Depends(get_service),
) -> Result[PaginationResult[ServerResponse]]:
"""Query Conversation entities
@@ -171,6 +177,37 @@ async def query_page(
return Result.succ(service.get_list_by_page(request, page, page_size))
@router.get(
"/list",
response_model=Result[List[ServerResponse]],
dependencies=[Depends(check_api_key)],
)
async def list_latest_conv(
user_name: str = None,
user_id: str = None,
sys_code: str = None,
page: Optional[int] = Query(default=1, description="current page"),
page_size: Optional[int] = Query(default=10, description="page size"),
service: Service = Depends(get_service),
) -> Result[List[ServerResponse]]:
"""Return latest conversations"""
request = ServeRequest(
user_name=user_name or user_id,
sys_code=sys_code,
)
return Result.succ(service.get_list_by_page(request, page, page_size).items)
@router.get(
"/messages/history",
response_model=Result[List[MessageVo]],
dependencies=[Depends(check_api_key)],
)
async def get_history_messages(con_uid: str, service: Service = Depends(get_service)):
"""Get the history messages of a conversation"""
return Result.succ(service.get_history_messages(ServeRequest(conv_uid=con_uid)))
def init_endpoints(system_app: SystemApp) -> None:
"""Initialize the endpoints"""
global global_system_app

View File

@@ -1,4 +1,6 @@
# Define your Pydantic schemas here
from typing import Any, Optional
from dbgpt._private.pydantic import BaseModel, Field
from ..config import SERVE_APP_NAME_HUMP
@@ -7,15 +9,133 @@ from ..config import SERVE_APP_NAME_HUMP
class ServeRequest(BaseModel):
"""Conversation request model"""
# TODO define your own fields here
class Config:
title = f"ServeRequest for {SERVE_APP_NAME_HUMP}"
# Just for query
chat_mode: str = Field(
default=None,
description="The chat mode.",
examples=[
"chat_normal",
],
)
conv_uid: Optional[str] = Field(
default=None,
description="The conversation uid.",
examples=[
"5e7100bc-9017-11ee-9876-8fe019728d79",
],
)
user_name: Optional[str] = Field(
default=None,
description="The user name.",
examples=[
"zhangsan",
],
)
sys_code: Optional[str] = Field(
default=None,
description="The system code.",
examples=[
"dbgpt",
],
)
class ServerResponse(BaseModel):
"""Conversation response model"""
# TODO define your own fields here
class Config:
title = f"ServerResponse for {SERVE_APP_NAME_HUMP}"
conv_uid: str = Field(
...,
description="The conversation uid.",
examples=[
"5e7100bc-9017-11ee-9876-8fe019728d79",
],
)
user_input: str = Field(
...,
description="The user input, we return it as the summary the conversation.",
examples=[
"Hello world",
],
)
chat_mode: str = Field(
...,
description="The chat mode.",
examples=[
"chat_normal",
],
)
select_param: Optional[str] = Field(
default=None,
description="The select param.",
examples=[
"my_knowledge_space_name",
],
)
model_name: Optional[str] = Field(
default=None,
description="The model name.",
examples=[
"vicuna-13b-v1.5",
],
)
user_name: Optional[str] = Field(
default=None,
description="The user name.",
examples=[
"zhangsan",
],
)
sys_code: Optional[str] = Field(
default=None,
description="The system code.",
examples=[
"dbgpt",
],
)
class MessageVo(BaseModel):
role: str = Field(
...,
description="The role that sends out the current message.",
examples=["human", "ai", "view"],
)
context: str = Field(
...,
description="The current message content.",
examples=[
"Hello",
"Hi, how are you?",
],
)
order: int = Field(
...,
description="The current message order.",
examples=[
1,
2,
],
)
time_stamp: Optional[Any] = Field(
default=None,
description="The current message time stamp.",
examples=[
"2023-01-07 09:00:00",
],
)
model_name: Optional[str] = Field(
default=None,
description="The model name.",
examples=[
"vicuna-13b-v1.5",
],
)

View File

@@ -20,3 +20,8 @@ class ServeConfig(BaseServeConfig):
api_keys: Optional[str] = field(
default=None, metadata={"help": "API keys for the endpoint, if None, allow all"}
)
default_model: Optional[str] = field(
default=None,
metadata={"help": "Default model name"},
)

View File

@@ -1,30 +1,20 @@
"""This is an auto-generated model file
You can define your own models and DAOs here
"""
import json
from datetime import datetime
from typing import Any, Dict, Union
from sqlalchemy import Column, DateTime, Index, Integer, String, Text
from typing import Any, Dict, List, Optional, Union
from dbgpt.core import MessageStorageItem
from dbgpt.storage.chat_history.chat_history_db import ChatHistoryEntity as ServeEntity
from dbgpt.storage.chat_history.chat_history_db import ChatHistoryMessageEntity
from dbgpt.storage.metadata import BaseDao, Model, db
from dbgpt.util import PaginationResult
from ..api.schemas import ServeRequest, ServerResponse
from ..config import SERVER_APP_TABLE_NAME, ServeConfig
class ServeEntity(Model):
__tablename__ = SERVER_APP_TABLE_NAME
id = Column(Integer, primary_key=True, comment="Auto increment id")
# TODO: define your own fields here
gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time")
gmt_modified = Column(DateTime, default=datetime.now, comment="Record update time")
def __repr__(self):
return f"ServeEntity(id={self.id}, gmt_created='{self.gmt_created}', gmt_modified='{self.gmt_modified}')"
class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
"""The DAO class for Conversation"""
@@ -68,4 +58,95 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
RES: The response
"""
# TODO implement your own logic here, transfer the entity to a response
return ServerResponse()
return ServerResponse(
conv_uid=entity.conv_uid,
user_input=entity.summary,
chat_mode=entity.chat_mode,
user_name=entity.user_name,
sys_code=entity.sys_code,
)
def get_latest_message(self, conv_uid: str) -> Optional[MessageStorageItem]:
"""Get the latest message of a conversation
Args:
conv_uid (str): The conversation UID
Returns:
ChatHistoryMessageEntity: The latest message
"""
with self.session() as session:
entity: ChatHistoryMessageEntity = (
session.query(ChatHistoryMessageEntity)
.filter(ChatHistoryMessageEntity.conv_uid == conv_uid)
.order_by(ChatHistoryMessageEntity.gmt_created.desc())
.first()
)
if not entity:
return None
message_detail = (
json.loads(entity.message_detail) if entity.message_detail else {}
)
return MessageStorageItem(entity.conv_uid, entity.index, message_detail)
def _parse_old_messages(self, entity: ServeEntity) -> List[Dict[str, Any]]:
"""Parse the old messages
Args:
entity (ServeEntity): The entity
Returns:
str: The old messages
"""
messages = json.loads(entity.messages)
return messages
def get_conv_by_page(
self, req: ServeRequest, page: int, page_size: int
) -> PaginationResult[ServerResponse]:
"""Get conversation by page
Args:
req (ServeRequest): The request
page (int): The page number
page_size (int): The page size
Returns:
List[ChatHistoryEntity]: The conversation list
"""
with self.session(commit=False) as session:
query = self._create_query_object(session, req)
query = query.order_by(ServeEntity.gmt_created.desc())
total_count = query.count()
items = query.offset((page - 1) * page_size).limit(page_size)
total_pages = (total_count + page_size - 1) // page_size
result_items = []
for item in items:
select_param, model_name = "", None
if item.messages:
messages = self._parse_old_messages(item)
last_round = max(messages, key=lambda x: x["chat_order"])
if "param_value" in last_round:
select_param = last_round["param_value"]
else:
select_param = ""
else:
latest_message = self.get_latest_message(item.conv_uid)
if latest_message:
message = latest_message.to_message()
select_param = message.additional_kwargs.get("param_value")
model_name = message.additional_kwargs.get("model_name")
res_item = self.to_response(item)
res_item.select_param = select_param
res_item.model_name = model_name
result_items.append(res_item)
result = PaginationResult(
items=result_items,
total_count=total_count,
total_pages=total_pages,
page=page,
page_size=page_size,
)
return result

View File

@@ -8,6 +8,7 @@ from dbgpt.core import StorageInterface
from dbgpt.serve.core import BaseServe
from dbgpt.storage.metadata import DatabaseManager
from .api.endpoints import init_endpoints, router
from .config import (
APP_NAME,
SERVE_APP_NAME,
@@ -15,6 +16,7 @@ from .config import (
SERVE_CONFIG_KEY_PREFIX,
ServeConfig,
)
from .service.service import Service
logger = logging.getLogger(__name__)
@@ -58,6 +60,10 @@ class Serve(BaseServe):
if self._app_has_initiated:
return
self._system_app = system_app
self._system_app.app.include_router(
router, prefix=self._api_prefix, tags=self._api_tags
)
init_endpoints(self._system_app)
self._app_has_initiated = True
def on_init(self):

View File

@@ -1,11 +1,20 @@
from typing import List, Optional
from typing import Any, Dict, List, Optional, Union
from dbgpt.component import BaseComponent, SystemApp
from dbgpt.core import (
InMemoryStorage,
MessageStorageItem,
QuerySpec,
StorageConversation,
StorageInterface,
)
from dbgpt.core.interface.message import _append_view_messages
from dbgpt.serve.core import BaseService
from dbgpt.storage.metadata import BaseDao
from dbgpt.storage.metadata._base_dao import REQ, RES
from dbgpt.util.pagination_utils import PaginationResult
from ..api.schemas import ServeRequest, ServerResponse
from ..api.schemas import MessageVo, ServeRequest, ServerResponse
from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig
from ..models.models import ServeDao, ServeEntity
@@ -15,10 +24,18 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
name = SERVE_SERVICE_COMPONENT_NAME
def __init__(self, system_app: SystemApp, dao: Optional[ServeDao] = None):
def __init__(
self,
system_app: SystemApp,
dao: Optional[ServeDao] = None,
storage: Optional[StorageInterface[StorageConversation, Any]] = None,
message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None,
):
self._system_app = None
self._serve_config: ServeConfig = None
self._dao: ServeDao = dao
self._storage = storage
self._message_storage = message_storage
super().__init__(system_app)
def init_app(self, system_app: SystemApp) -> None:
@@ -34,7 +51,7 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
self._system_app = system_app
@property
def dao(self) -> BaseDao[ServeEntity, ServeRequest, ServerResponse]:
def dao(self) -> ServeDao:
"""Returns the internal DAO."""
return self._dao
@@ -43,6 +60,54 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
"""Returns the internal ServeConfig."""
return self._serve_config
def create(self, request: REQ) -> RES:
raise NotImplementedError()
@property
def conv_storage(self) -> StorageInterface:
"""The conversation storage, store the conversation items."""
if self._storage:
return self._storage
from ..serve import Serve
return Serve.call_on_current_serve(
self._system_app, lambda serve: serve.conv_storage
)
@property
def message_storage(self) -> StorageInterface:
"""The message storage, store the messages of one conversation."""
if self._message_storage:
return self._message_storage
from ..serve import Serve
return Serve.call_on_current_serve(
self._system_app,
lambda serve: serve.message_storage,
)
def create_storage_conv(
self, request: Union[ServeRequest, Dict[str, Any]], load_message: bool = True
) -> StorageConversation:
conv_storage = self.conv_storage
message_storage = self.message_storage
if not conv_storage or not message_storage:
raise RuntimeError(
"Can't get the conversation storage or message storage from current serve component."
)
if isinstance(request, dict):
request = ServeRequest(**request)
storage_conv: StorageConversation = StorageConversation(
conv_uid=request.conv_uid,
chat_mode=request.chat_mode,
user_name=request.user_name,
sys_code=request.sys_code,
conv_storage=conv_storage,
message_storage=message_storage,
load_message=load_message,
)
return storage_conv
def update(self, request: ServeRequest) -> ServerResponse:
"""Update a Conversation entity
@@ -74,18 +139,13 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
return self.dao.get_one(query_request)
def delete(self, request: ServeRequest) -> None:
"""Delete a Conversation entity
"""Delete current conversation and its messages
Args:
request (ServeRequest): The request
"""
# TODO: implement your own logic here
# Build the query request from the request
query_request = {
# "id": request.id
}
self.dao.delete(query_request)
conv: StorageConversation = self.create_storage_conv(request)
conv.delete()
def get_list(self, request: ServeRequest) -> List[ServerResponse]:
"""Get a list of Conversation entities
@@ -114,5 +174,29 @@ class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]):
Returns:
List[ServerResponse]: The response
"""
query_request = request
return self.dao.get_list_page(query_request, page, page_size)
return self.dao.get_conv_by_page(request, page, page_size)
def get_history_messages(
self, request: Union[ServeRequest, Dict[str, Any]]
) -> List[MessageVo]:
"""Get a list of Conversation entities
Args:
request (ServeRequest): The request
Returns:
List[ServerResponse]: The response
"""
conv: StorageConversation = self.create_storage_conv(request)
result = []
messages = _append_view_messages(conv.messages)
for msg in messages:
result.append(
MessageVo(
role=msg.type,
context=msg.content,
order=msg.round_index,
model_name=self.config.default_model,
)
)
return result

View File

@@ -29,7 +29,7 @@ def dao(server_config):
@pytest.fixture
def default_entity_dict():
# TODO: build your default entity dict
return {}
return {"conv_uid": "test_conv_uid", "summary": "hello", "chat_mode": "chat_normal"}
def test_table_exist():
@@ -67,19 +67,6 @@ def test_entity_all():
pass
def test_dao_create(dao, default_entity_dict):
# TODO: implement your test case
req = ServeRequest(**default_entity_dict)
res: ServerResponse = dao.create(req)
assert res is not None
def test_dao_get_one(dao, default_entity_dict):
# TODO: implement your test case
req = ServeRequest(**default_entity_dict)
res: ServerResponse = dao.create(req)
def test_get_dao_get_list(dao):
# TODO: implement your test case
pass

View File

@@ -73,7 +73,7 @@ class BaseServe(BaseComponent, ABC):
Returns:
Optional[BaseServe]: The current serve component.
"""
return system_app.get_component(cls.name, cls, default_component=None)
return cls.get_instance(system_app, default_component=None)
@classmethod
def call_on_current_serve(