feat(core): Support export messages

This commit is contained in:
Fangyin Cheng 2025-06-25 11:43:08 +08:00
parent 2ed145c3aa
commit 000cf5a97b
3 changed files with 114 additions and 1 deletions

View File

@ -1,9 +1,12 @@
import io
import json
import uuid
from functools import cache
from typing import List, Optional
from typing import List, Literal, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
from starlette.responses import JSONResponse, StreamingResponse
from dbgpt.component import SystemApp
from dbgpt.util import PaginationResult
@ -231,6 +234,89 @@ async def get_history_messages(con_uid: str, service: Service = Depends(get_serv
return Result.succ(service.get_history_messages(ServeRequest(conv_uid=con_uid)))
@router.get(
"/export_messages",
dependencies=[Depends(check_api_key)],
)
async def export_all_messages(
user_name: Optional[str] = None,
user_id: Optional[str] = None,
sys_code: Optional[str] = None,
format: Literal["file", "json"] = Query(
"file", description="response format(file or json)"
),
service: Service = Depends(get_service),
):
"""Export all conversations and messages for a user
Args:
user_name (str): The user name
user_id (str): The user id (alternative to user_name)
sys_code (str): The system code
format (str): The format of the response, either 'file' or 'json', defaults to
'file'
Returns:
A dictionary containing all conversations and their messages
"""
# 1. Get all conversations for the user
request = ServeRequest(
user_name=user_name or user_id,
sys_code=sys_code,
)
# Initialize pagination variables
page = 1
page_size = 100 # Adjust based on your needs
all_conversations = []
# Paginate through all conversations
while True:
pagination_result = service.get_list_by_page(request, page, page_size)
all_conversations.extend(pagination_result.items)
if page >= pagination_result.total_pages:
break
page += 1
# 2. For each conversation, get all messages
result = {
"user_name": user_name or user_id,
"sys_code": sys_code,
"total_conversations": len(all_conversations),
"conversations": [],
}
for conv in all_conversations:
messages = service.get_history_messages(ServeRequest(conv_uid=conv.conv_uid))
conversation_data = {
"conv_uid": conv.conv_uid,
"chat_mode": conv.chat_mode,
"app_code": conv.app_code,
"create_time": conv.gmt_created,
"update_time": conv.gmt_modified,
"total_messages": len(messages),
"messages": [msg.dict() for msg in messages],
}
result["conversations"].append(conversation_data)
if format == "json":
return JSONResponse(content=result)
else:
file_name = (
f"conversation_export_{user_name or user_id or 'dbgpt'}_"
f"{sys_code or 'dbgpt'}"
)
# Return the json file
return StreamingResponse(
io.BytesIO(
json.dumps(result, ensure_ascii=False, indent=4).encode("utf-8")
),
media_type="application/file",
headers={"Content-Disposition": f"attachment;filename={file_name}.json"},
)
def init_endpoints(system_app: SystemApp, config: ServeConfig) -> None:
"""Initialize the endpoints"""
global global_system_app

View File

@ -109,6 +109,20 @@ class ServerResponse(BaseModel):
"dbgpt",
],
)
gmt_created: Optional[str] = Field(
default=None,
description="The record creation time.",
examples=[
"2023-01-07 09:00:00",
],
)
gmt_modified: Optional[str] = Field(
default=None,
description="The record update time.",
examples=[
"2023-01-07 09:00:00",
],
)
def to_dict(self, **kwargs) -> Dict[str, Any]:
"""Convert the model to a dictionary"""

View File

@ -60,6 +60,17 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
RES: The response
"""
# TODO implement your own logic here, transfer the entity to a response
gmt_created = (
entity.gmt_created.strftime("%Y-%m-%d %H:%M:%S")
if entity.gmt_created
else None
)
gmt_modified = (
entity.gmt_modified.strftime("%Y-%m-%d %H:%M:%S")
if entity.gmt_modified
else None
)
return ServerResponse(
app_code=entity.app_code,
conv_uid=entity.conv_uid,
@ -67,6 +78,8 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
chat_mode=entity.chat_mode,
user_name="",
sys_code=entity.sys_code,
gmt_created=gmt_created,
gmt_modified=gmt_modified,
)
def get_latest_message(self, conv_uid: str) -> Optional[MessageStorageItem]: