DB-GPT/dbgpt/serve/conversation/models/models.py
明天 b124ecc10b
feat: (0.6)New UI (#1855)
Co-authored-by: 夏姜 <wenfengjiang.jwf@digital-engine.com>
Co-authored-by: aries_ckt <916701291@qq.com>
Co-authored-by: wb-lh513319 <wb-lh513319@alibaba-inc.com>
Co-authored-by: csunny <cfqsunny@163.com>
2024-08-21 17:37:45 +08:00

155 lines
5.3 KiB
Python

"""This is an auto-generated model file
You can define your own models and DAOs here
"""
import json
from typing import Any, Dict, List, Optional, Union
from dbgpt.core import MessageStorageItem
from dbgpt.storage.chat_history.chat_history_db import ChatHistoryEntity as ServeEntity
from dbgpt.storage.chat_history.chat_history_db import ChatHistoryMessageEntity
from dbgpt.storage.metadata import BaseDao, Model, db
from dbgpt.util import PaginationResult
from ..api.schemas import ServeRequest, ServerResponse
from ..config import SERVER_APP_TABLE_NAME, ServeConfig
class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
"""The DAO class for Conversation"""
def __init__(self, serve_config: ServeConfig):
super().__init__()
self._serve_config = serve_config
def from_request(self, request: Union[ServeRequest, Dict[str, Any]]) -> ServeEntity:
"""Convert the request to an entity
Args:
request (Union[ServeRequest, Dict[str, Any]]): The request
Returns:
T: The entity
"""
request_dict = (
request.to_dict() if isinstance(request, ServeRequest) else request
)
entity = ServeEntity(**request_dict)
# TODO implement your own logic here, transfer the request_dict to an entity
return entity
def to_request(self, entity: ServeEntity) -> ServeRequest:
"""Convert the entity to a request
Args:
entity (T): The entity
Returns:
REQ: The request
"""
# TODO implement your own logic here, transfer the entity to a request
return ServeRequest()
def to_response(self, entity: ServeEntity) -> ServerResponse:
"""Convert the entity to a response
Args:
entity (T): The entity
Returns:
RES: The response
"""
# TODO implement your own logic here, transfer the entity to a response
return ServerResponse(
app_code=entity.app_code,
conv_uid=entity.conv_uid,
user_input=entity.summary,
chat_mode=entity.chat_mode,
user_name="",
sys_code=entity.sys_code,
)
def get_latest_message(self, conv_uid: str) -> Optional[MessageStorageItem]:
"""Get the latest message of a conversation
Args:
conv_uid (str): The conversation UID
Returns:
ChatHistoryMessageEntity: The latest message
"""
with self.session() as session:
entity: ChatHistoryMessageEntity = (
session.query(ChatHistoryMessageEntity)
.filter(ChatHistoryMessageEntity.conv_uid == conv_uid)
.order_by(ChatHistoryMessageEntity.gmt_created.desc())
.first()
)
if not entity:
return None
message_detail = (
json.loads(entity.message_detail) if entity.message_detail else {}
)
return MessageStorageItem(entity.conv_uid, entity.index, message_detail)
def _parse_old_messages(self, entity: ServeEntity) -> List[Dict[str, Any]]:
"""Parse the old messages
Args:
entity (ServeEntity): The entity
Returns:
str: The old messages
"""
messages = json.loads(entity.messages)
return messages
def get_conv_by_page(
self, req: ServeRequest, page: int, page_size: int
) -> PaginationResult[ServerResponse]:
"""Get conversation by page
Args:
req (ServeRequest): The request
page (int): The page number
page_size (int): The page size
Returns:
List[ChatHistoryEntity]: The conversation list
"""
with self.session(commit=False) as session:
query = self._create_query_object(session, req)
query = query.order_by(ServeEntity.gmt_created.desc())
total_count = query.count()
items = query.offset((page - 1) * page_size).limit(page_size)
total_pages = (total_count + page_size - 1) // page_size
result_items = []
for item in items:
select_param, model_name = "", None
if item.messages:
messages = self._parse_old_messages(item)
last_round = max(messages, key=lambda x: x["chat_order"])
if "param_value" in last_round:
select_param = last_round["param_value"]
else:
select_param = ""
else:
latest_message = self.get_latest_message(item.conv_uid)
if latest_message:
message = latest_message.to_message()
select_param = message.additional_kwargs.get("param_value")
model_name = message.additional_kwargs.get("model_name")
res_item = self.to_response(item)
res_item.select_param = select_param
res_item.model_name = model_name
result_items.append(res_item)
result = PaginationResult(
items=result_items,
total_count=total_count,
total_pages=total_pages,
page=page,
page_size=page_size,
)
return result