Added pagination in chat history

This commit is contained in:
Saurab-Shrestha9639*969**9858//852
2024-04-30 09:52:27 +05:45
parent f9a454861d
commit 461d5afac0
5 changed files with 9 additions and 60 deletions

View File

@@ -4,8 +4,7 @@ services:
private-gpt:
build:
dockerfile: Dockerfile.external
entrypoint: ./docker-entrypoint.sh
# entrypoint: ./docker-entrypoint.sh
env_file:
- .env
volumes:

View File

@@ -1,4 +1,4 @@
#!/bin/bash
#!/bin/sh
# Initialize alembic ini

View File

@@ -4,6 +4,7 @@ import uuid
from sqlalchemy.orm import Session
from fastapi.responses import JSONResponse
from fastapi import APIRouter, Depends, HTTPException, status, Security
from fastapi_pagination import Page, paginate
from private_gpt.users.api import deps
from private_gpt.users import crud, models, schemas
@@ -12,7 +13,7 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix="/c", tags=["Chat Histories"])
@router.get("", response_model=list[schemas.ChatHistory])
@router.get("", response_model=Page[schemas.ChatHistory])
def list_chat_histories(
db: Session = Depends(deps.get_db),
skip: int = 0,
@@ -20,14 +21,14 @@ def list_chat_histories(
current_user: models.User = Security(
deps.get_current_user,
),
) -> list[schemas.ChatHistory]:
) -> Page[schemas.ChatHistory]:
"""
Retrieve a list of chat histories with pagination support.
"""
try:
chat_histories = crud.chat.get_chat_history(
db, user_id=current_user.id, skip=skip, limit=limit)
return chat_histories
return paginate(chat_histories)
except Exception as e:
print(traceback.format_exc())
logger.error(f"Error listing chat histories: {str(e)}")
@@ -63,14 +64,14 @@ def create_chat_history(
)
@router.get("/{conversation_id}", response_model=schemas.ChatHistory)
@router.get("/{conversation_id}", response_model=Page[schemas.ChatHistory])
def read_chat_history(
conversation_id: uuid.UUID,
db: Session = Depends(deps.get_db),
current_user: models.User = Security(
deps.get_current_user,
),
) -> schemas.ChatHistory:
) -> Page[schemas.ChatHistory]:
"""
Read a chat history by ID
"""
@@ -79,7 +80,7 @@ def read_chat_history(
if chat_history is None or chat_history.user_id != current_user.id:
raise HTTPException(
status_code=404, detail="Chat history not found")
return chat_history
return paginate(chat_history)
except Exception as e:
print(traceback.format_exc())
logger.error(f"Error reading chat history: {str(e)}")

View File

@@ -1,50 +0,0 @@
from typing import Optional, List, Union, Dict, Any
from fastapi import HTTPException, status
from sqlalchemy.orm import Session
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm.util import object_mapper
from private_gpt.users.crud.base import CRUDBase
from private_gpt.users.models.chat import ChatHistory
from private_gpt.users.schemas.chat import ChatHistoryCreate, ChatHistoryUpdate
class CRUDChat(CRUDBase[ChatHistory, ChatHistoryCreate, ChatHistoryUpdate]):
def get_by_id(self, db: Session, *, id: int) -> Optional[ChatHistory]:
return db.query(self.model).filter(ChatHistory.conversation_id == id).first()
def update_messages(
self,
db: Session,
*,
db_obj: ChatHistory,
obj_in: Union[ChatHistoryUpdate, Dict[str, Any]]
) -> ChatHistory:
try:
obj_data = object_mapper(db_obj).data
if isinstance(obj_in, dict):
update_data = obj_in
else:
update_data = obj_in.dict(exclude_unset=True)
# Update the `messages` field by appending new messages
existing_messages = obj_data.get("messages", [])
new_messages = update_data.get("messages", [])
obj_data["messages"] = existing_messages + new_messages
for field, value in obj_data.items():
setattr(db_obj, field, value)
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
except IntegrityError as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Integrity Error: {str(e)}",
)
chat = CRUDChat(ChatHistory)

View File

@@ -7,7 +7,6 @@ from private_gpt.users.models.department import Department
from private_gpt.users.models.document_department import document_department_association
from private_gpt.users.crud.base import CRUDBase
from typing import Optional, List
from fastapi_pagination import Page, paginate
class CRUDDocuments(CRUDBase[Document, DocumentCreate, DocumentUpdate]):