mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-19 08:47:32 +00:00
feat(core): Support export messages
This commit is contained in:
parent
2ed145c3aa
commit
000cf5a97b
@ -1,9 +1,12 @@
|
||||
import io
|
||||
import json
|
||||
import uuid
|
||||
from functools import cache
|
||||
from typing import List, Optional
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from starlette.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from dbgpt.component import SystemApp
|
||||
from dbgpt.util import PaginationResult
|
||||
@ -231,6 +234,89 @@ async def get_history_messages(con_uid: str, service: Service = Depends(get_serv
|
||||
return Result.succ(service.get_history_messages(ServeRequest(conv_uid=con_uid)))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/export_messages",
|
||||
dependencies=[Depends(check_api_key)],
|
||||
)
|
||||
async def export_all_messages(
|
||||
user_name: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
sys_code: Optional[str] = None,
|
||||
format: Literal["file", "json"] = Query(
|
||||
"file", description="response format(file or json)"
|
||||
),
|
||||
service: Service = Depends(get_service),
|
||||
):
|
||||
"""Export all conversations and messages for a user
|
||||
|
||||
Args:
|
||||
user_name (str): The user name
|
||||
user_id (str): The user id (alternative to user_name)
|
||||
sys_code (str): The system code
|
||||
format (str): The format of the response, either 'file' or 'json', defaults to
|
||||
'file'
|
||||
|
||||
Returns:
|
||||
A dictionary containing all conversations and their messages
|
||||
"""
|
||||
# 1. Get all conversations for the user
|
||||
request = ServeRequest(
|
||||
user_name=user_name or user_id,
|
||||
sys_code=sys_code,
|
||||
)
|
||||
|
||||
# Initialize pagination variables
|
||||
page = 1
|
||||
page_size = 100 # Adjust based on your needs
|
||||
all_conversations = []
|
||||
|
||||
# Paginate through all conversations
|
||||
while True:
|
||||
pagination_result = service.get_list_by_page(request, page, page_size)
|
||||
all_conversations.extend(pagination_result.items)
|
||||
|
||||
if page >= pagination_result.total_pages:
|
||||
break
|
||||
page += 1
|
||||
|
||||
# 2. For each conversation, get all messages
|
||||
result = {
|
||||
"user_name": user_name or user_id,
|
||||
"sys_code": sys_code,
|
||||
"total_conversations": len(all_conversations),
|
||||
"conversations": [],
|
||||
}
|
||||
|
||||
for conv in all_conversations:
|
||||
messages = service.get_history_messages(ServeRequest(conv_uid=conv.conv_uid))
|
||||
conversation_data = {
|
||||
"conv_uid": conv.conv_uid,
|
||||
"chat_mode": conv.chat_mode,
|
||||
"app_code": conv.app_code,
|
||||
"create_time": conv.gmt_created,
|
||||
"update_time": conv.gmt_modified,
|
||||
"total_messages": len(messages),
|
||||
"messages": [msg.dict() for msg in messages],
|
||||
}
|
||||
result["conversations"].append(conversation_data)
|
||||
|
||||
if format == "json":
|
||||
return JSONResponse(content=result)
|
||||
else:
|
||||
file_name = (
|
||||
f"conversation_export_{user_name or user_id or 'dbgpt'}_"
|
||||
f"{sys_code or 'dbgpt'}"
|
||||
)
|
||||
# Return the json file
|
||||
return StreamingResponse(
|
||||
io.BytesIO(
|
||||
json.dumps(result, ensure_ascii=False, indent=4).encode("utf-8")
|
||||
),
|
||||
media_type="application/file",
|
||||
headers={"Content-Disposition": f"attachment;filename={file_name}.json"},
|
||||
)
|
||||
|
||||
|
||||
def init_endpoints(system_app: SystemApp, config: ServeConfig) -> None:
|
||||
"""Initialize the endpoints"""
|
||||
global global_system_app
|
||||
|
@ -109,6 +109,20 @@ class ServerResponse(BaseModel):
|
||||
"dbgpt",
|
||||
],
|
||||
)
|
||||
gmt_created: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The record creation time.",
|
||||
examples=[
|
||||
"2023-01-07 09:00:00",
|
||||
],
|
||||
)
|
||||
gmt_modified: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The record update time.",
|
||||
examples=[
|
||||
"2023-01-07 09:00:00",
|
||||
],
|
||||
)
|
||||
|
||||
def to_dict(self, **kwargs) -> Dict[str, Any]:
|
||||
"""Convert the model to a dictionary"""
|
||||
|
@ -60,6 +60,17 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
RES: The response
|
||||
"""
|
||||
# TODO implement your own logic here, transfer the entity to a response
|
||||
gmt_created = (
|
||||
entity.gmt_created.strftime("%Y-%m-%d %H:%M:%S")
|
||||
if entity.gmt_created
|
||||
else None
|
||||
)
|
||||
gmt_modified = (
|
||||
entity.gmt_modified.strftime("%Y-%m-%d %H:%M:%S")
|
||||
if entity.gmt_modified
|
||||
else None
|
||||
)
|
||||
|
||||
return ServerResponse(
|
||||
app_code=entity.app_code,
|
||||
conv_uid=entity.conv_uid,
|
||||
@ -67,6 +78,8 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
||||
chat_mode=entity.chat_mode,
|
||||
user_name="",
|
||||
sys_code=entity.sys_code,
|
||||
gmt_created=gmt_created,
|
||||
gmt_modified=gmt_modified,
|
||||
)
|
||||
|
||||
def get_latest_message(self, conv_uid: str) -> Optional[MessageStorageItem]:
|
||||
|
Loading…
Reference in New Issue
Block a user