mirror of
https://github.com/imartinez/privateGPT.git
synced 2025-08-31 14:52:19 +00:00
Added pagination in chat history
This commit is contained in:
@@ -4,8 +4,7 @@ services:
|
||||
private-gpt:
|
||||
build:
|
||||
dockerfile: Dockerfile.external
|
||||
entrypoint: ./docker-entrypoint.sh
|
||||
|
||||
# entrypoint: ./docker-entrypoint.sh
|
||||
env_file:
|
||||
- .env
|
||||
volumes:
|
||||
|
@@ -1,4 +1,4 @@
|
||||
#!/bin/bash
|
||||
#!/bin/sh
|
||||
|
||||
|
||||
# Initialize alembic ini
|
||||
|
@@ -4,6 +4,7 @@ import uuid
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Security
|
||||
from fastapi_pagination import Page, paginate
|
||||
|
||||
from private_gpt.users.api import deps
|
||||
from private_gpt.users import crud, models, schemas
|
||||
@@ -12,7 +13,7 @@ logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/c", tags=["Chat Histories"])
|
||||
|
||||
|
||||
@router.get("", response_model=list[schemas.ChatHistory])
|
||||
@router.get("", response_model=Page[schemas.ChatHistory])
|
||||
def list_chat_histories(
|
||||
db: Session = Depends(deps.get_db),
|
||||
skip: int = 0,
|
||||
@@ -20,14 +21,14 @@ def list_chat_histories(
|
||||
current_user: models.User = Security(
|
||||
deps.get_current_user,
|
||||
),
|
||||
) -> list[schemas.ChatHistory]:
|
||||
) -> Page[schemas.ChatHistory]:
|
||||
"""
|
||||
Retrieve a list of chat histories with pagination support.
|
||||
"""
|
||||
try:
|
||||
chat_histories = crud.chat.get_chat_history(
|
||||
db, user_id=current_user.id, skip=skip, limit=limit)
|
||||
return chat_histories
|
||||
return paginate(chat_histories)
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
logger.error(f"Error listing chat histories: {str(e)}")
|
||||
@@ -63,14 +64,14 @@ def create_chat_history(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{conversation_id}", response_model=schemas.ChatHistory)
|
||||
@router.get("/{conversation_id}", response_model=Page[schemas.ChatHistory])
|
||||
def read_chat_history(
|
||||
conversation_id: uuid.UUID,
|
||||
db: Session = Depends(deps.get_db),
|
||||
current_user: models.User = Security(
|
||||
deps.get_current_user,
|
||||
),
|
||||
) -> schemas.ChatHistory:
|
||||
) -> Page[schemas.ChatHistory]:
|
||||
"""
|
||||
Read a chat history by ID
|
||||
"""
|
||||
@@ -79,7 +80,7 @@ def read_chat_history(
|
||||
if chat_history is None or chat_history.user_id != current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Chat history not found")
|
||||
return chat_history
|
||||
return paginate(chat_history)
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
logger.error(f"Error reading chat history: {str(e)}")
|
||||
|
@@ -1,50 +0,0 @@
|
||||
from typing import Optional, List, Union, Dict, Any
|
||||
from fastapi import HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm.util import object_mapper
|
||||
|
||||
from private_gpt.users.crud.base import CRUDBase
|
||||
from private_gpt.users.models.chat import ChatHistory
|
||||
from private_gpt.users.schemas.chat import ChatHistoryCreate, ChatHistoryUpdate
|
||||
|
||||
|
||||
class CRUDChat(CRUDBase[ChatHistory, ChatHistoryCreate, ChatHistoryUpdate]):
|
||||
def get_by_id(self, db: Session, *, id: int) -> Optional[ChatHistory]:
|
||||
return db.query(self.model).filter(ChatHistory.conversation_id == id).first()
|
||||
|
||||
def update_messages(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
db_obj: ChatHistory,
|
||||
obj_in: Union[ChatHistoryUpdate, Dict[str, Any]]
|
||||
) -> ChatHistory:
|
||||
try:
|
||||
obj_data = object_mapper(db_obj).data
|
||||
if isinstance(obj_in, dict):
|
||||
update_data = obj_in
|
||||
else:
|
||||
update_data = obj_in.dict(exclude_unset=True)
|
||||
|
||||
# Update the `messages` field by appending new messages
|
||||
existing_messages = obj_data.get("messages", [])
|
||||
new_messages = update_data.get("messages", [])
|
||||
obj_data["messages"] = existing_messages + new_messages
|
||||
|
||||
for field, value in obj_data.items():
|
||||
setattr(db_obj, field, value)
|
||||
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Integrity Error: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
chat = CRUDChat(ChatHistory)
|
@@ -7,7 +7,6 @@ from private_gpt.users.models.department import Department
|
||||
from private_gpt.users.models.document_department import document_department_association
|
||||
from private_gpt.users.crud.base import CRUDBase
|
||||
from typing import Optional, List
|
||||
from fastapi_pagination import Page, paginate
|
||||
|
||||
|
||||
class CRUDDocuments(CRUDBase[Document, DocumentCreate, DocumentUpdate]):
|
||||
|
Reference in New Issue
Block a user