mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-20 09:14:44 +00:00
feat(core): Support export messages
This commit is contained in:
parent
2ed145c3aa
commit
000cf5a97b
@ -1,9 +1,12 @@
|
|||||||
|
import io
|
||||||
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
from functools import cache
|
from functools import cache
|
||||||
from typing import List, Optional
|
from typing import List, Literal, Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||||
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
||||||
|
from starlette.responses import JSONResponse, StreamingResponse
|
||||||
|
|
||||||
from dbgpt.component import SystemApp
|
from dbgpt.component import SystemApp
|
||||||
from dbgpt.util import PaginationResult
|
from dbgpt.util import PaginationResult
|
||||||
@ -231,6 +234,89 @@ async def get_history_messages(con_uid: str, service: Service = Depends(get_serv
|
|||||||
return Result.succ(service.get_history_messages(ServeRequest(conv_uid=con_uid)))
|
return Result.succ(service.get_history_messages(ServeRequest(conv_uid=con_uid)))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/export_messages",
|
||||||
|
dependencies=[Depends(check_api_key)],
|
||||||
|
)
|
||||||
|
async def export_all_messages(
|
||||||
|
user_name: Optional[str] = None,
|
||||||
|
user_id: Optional[str] = None,
|
||||||
|
sys_code: Optional[str] = None,
|
||||||
|
format: Literal["file", "json"] = Query(
|
||||||
|
"file", description="response format(file or json)"
|
||||||
|
),
|
||||||
|
service: Service = Depends(get_service),
|
||||||
|
):
|
||||||
|
"""Export all conversations and messages for a user
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_name (str): The user name
|
||||||
|
user_id (str): The user id (alternative to user_name)
|
||||||
|
sys_code (str): The system code
|
||||||
|
format (str): The format of the response, either 'file' or 'json', defaults to
|
||||||
|
'file'
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary containing all conversations and their messages
|
||||||
|
"""
|
||||||
|
# 1. Get all conversations for the user
|
||||||
|
request = ServeRequest(
|
||||||
|
user_name=user_name or user_id,
|
||||||
|
sys_code=sys_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize pagination variables
|
||||||
|
page = 1
|
||||||
|
page_size = 100 # Adjust based on your needs
|
||||||
|
all_conversations = []
|
||||||
|
|
||||||
|
# Paginate through all conversations
|
||||||
|
while True:
|
||||||
|
pagination_result = service.get_list_by_page(request, page, page_size)
|
||||||
|
all_conversations.extend(pagination_result.items)
|
||||||
|
|
||||||
|
if page >= pagination_result.total_pages:
|
||||||
|
break
|
||||||
|
page += 1
|
||||||
|
|
||||||
|
# 2. For each conversation, get all messages
|
||||||
|
result = {
|
||||||
|
"user_name": user_name or user_id,
|
||||||
|
"sys_code": sys_code,
|
||||||
|
"total_conversations": len(all_conversations),
|
||||||
|
"conversations": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
for conv in all_conversations:
|
||||||
|
messages = service.get_history_messages(ServeRequest(conv_uid=conv.conv_uid))
|
||||||
|
conversation_data = {
|
||||||
|
"conv_uid": conv.conv_uid,
|
||||||
|
"chat_mode": conv.chat_mode,
|
||||||
|
"app_code": conv.app_code,
|
||||||
|
"create_time": conv.gmt_created,
|
||||||
|
"update_time": conv.gmt_modified,
|
||||||
|
"total_messages": len(messages),
|
||||||
|
"messages": [msg.dict() for msg in messages],
|
||||||
|
}
|
||||||
|
result["conversations"].append(conversation_data)
|
||||||
|
|
||||||
|
if format == "json":
|
||||||
|
return JSONResponse(content=result)
|
||||||
|
else:
|
||||||
|
file_name = (
|
||||||
|
f"conversation_export_{user_name or user_id or 'dbgpt'}_"
|
||||||
|
f"{sys_code or 'dbgpt'}"
|
||||||
|
)
|
||||||
|
# Return the json file
|
||||||
|
return StreamingResponse(
|
||||||
|
io.BytesIO(
|
||||||
|
json.dumps(result, ensure_ascii=False, indent=4).encode("utf-8")
|
||||||
|
),
|
||||||
|
media_type="application/file",
|
||||||
|
headers={"Content-Disposition": f"attachment;filename={file_name}.json"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def init_endpoints(system_app: SystemApp, config: ServeConfig) -> None:
|
def init_endpoints(system_app: SystemApp, config: ServeConfig) -> None:
|
||||||
"""Initialize the endpoints"""
|
"""Initialize the endpoints"""
|
||||||
global global_system_app
|
global global_system_app
|
||||||
|
@ -109,6 +109,20 @@ class ServerResponse(BaseModel):
|
|||||||
"dbgpt",
|
"dbgpt",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
gmt_created: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The record creation time.",
|
||||||
|
examples=[
|
||||||
|
"2023-01-07 09:00:00",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
gmt_modified: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The record update time.",
|
||||||
|
examples=[
|
||||||
|
"2023-01-07 09:00:00",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
def to_dict(self, **kwargs) -> Dict[str, Any]:
|
def to_dict(self, **kwargs) -> Dict[str, Any]:
|
||||||
"""Convert the model to a dictionary"""
|
"""Convert the model to a dictionary"""
|
||||||
|
@ -60,6 +60,17 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
|||||||
RES: The response
|
RES: The response
|
||||||
"""
|
"""
|
||||||
# TODO implement your own logic here, transfer the entity to a response
|
# TODO implement your own logic here, transfer the entity to a response
|
||||||
|
gmt_created = (
|
||||||
|
entity.gmt_created.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
if entity.gmt_created
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
gmt_modified = (
|
||||||
|
entity.gmt_modified.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
if entity.gmt_modified
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
return ServerResponse(
|
return ServerResponse(
|
||||||
app_code=entity.app_code,
|
app_code=entity.app_code,
|
||||||
conv_uid=entity.conv_uid,
|
conv_uid=entity.conv_uid,
|
||||||
@ -67,6 +78,8 @@ class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]):
|
|||||||
chat_mode=entity.chat_mode,
|
chat_mode=entity.chat_mode,
|
||||||
user_name="",
|
user_name="",
|
||||||
sys_code=entity.sys_code,
|
sys_code=entity.sys_code,
|
||||||
|
gmt_created=gmt_created,
|
||||||
|
gmt_modified=gmt_modified,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_latest_message(self, conv_uid: str) -> Optional[MessageStorageItem]:
|
def get_latest_message(self, conv_uid: str) -> Optional[MessageStorageItem]:
|
||||||
|
Loading…
Reference in New Issue
Block a user