Added pagination in chat history

This commit is contained in:
Saurab-Shrestha9639*969**9858//852
2024-04-30 09:52:27 +05:45
parent f9a454861d
commit 461d5afac0
5 changed files with 9 additions and 60 deletions

View File

@@ -4,8 +4,7 @@ services:
private-gpt: private-gpt:
build: build:
dockerfile: Dockerfile.external dockerfile: Dockerfile.external
entrypoint: ./docker-entrypoint.sh # entrypoint: ./docker-entrypoint.sh
env_file: env_file:
- .env - .env
volumes: volumes:

View File

@@ -1,4 +1,4 @@
#!/bin/bash #!/bin/sh
# Initialize alembic ini # Initialize alembic ini

View File

@@ -4,6 +4,7 @@ import uuid
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from fastapi import APIRouter, Depends, HTTPException, status, Security from fastapi import APIRouter, Depends, HTTPException, status, Security
from fastapi_pagination import Page, paginate
from private_gpt.users.api import deps from private_gpt.users.api import deps
from private_gpt.users import crud, models, schemas from private_gpt.users import crud, models, schemas
@@ -12,7 +13,7 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix="/c", tags=["Chat Histories"]) router = APIRouter(prefix="/c", tags=["Chat Histories"])
@router.get("", response_model=list[schemas.ChatHistory]) @router.get("", response_model=Page[schemas.ChatHistory])
def list_chat_histories( def list_chat_histories(
db: Session = Depends(deps.get_db), db: Session = Depends(deps.get_db),
skip: int = 0, skip: int = 0,
@@ -20,14 +21,14 @@ def list_chat_histories(
current_user: models.User = Security( current_user: models.User = Security(
deps.get_current_user, deps.get_current_user,
), ),
) -> list[schemas.ChatHistory]: ) -> Page[schemas.ChatHistory]:
""" """
Retrieve a list of chat histories with pagination support. Retrieve a list of chat histories with pagination support.
""" """
try: try:
chat_histories = crud.chat.get_chat_history( chat_histories = crud.chat.get_chat_history(
db, user_id=current_user.id, skip=skip, limit=limit) db, user_id=current_user.id, skip=skip, limit=limit)
return chat_histories return paginate(chat_histories)
except Exception as e: except Exception as e:
print(traceback.format_exc()) print(traceback.format_exc())
logger.error(f"Error listing chat histories: {str(e)}") logger.error(f"Error listing chat histories: {str(e)}")
@@ -63,14 +64,14 @@ def create_chat_history(
) )
@router.get("/{conversation_id}", response_model=schemas.ChatHistory) @router.get("/{conversation_id}", response_model=Page[schemas.ChatHistory])
def read_chat_history( def read_chat_history(
conversation_id: uuid.UUID, conversation_id: uuid.UUID,
db: Session = Depends(deps.get_db), db: Session = Depends(deps.get_db),
current_user: models.User = Security( current_user: models.User = Security(
deps.get_current_user, deps.get_current_user,
), ),
) -> schemas.ChatHistory: ) -> Page[schemas.ChatHistory]:
""" """
Read a chat history by ID Read a chat history by ID
""" """
@@ -79,7 +80,7 @@ def read_chat_history(
if chat_history is None or chat_history.user_id != current_user.id: if chat_history is None or chat_history.user_id != current_user.id:
raise HTTPException( raise HTTPException(
status_code=404, detail="Chat history not found") status_code=404, detail="Chat history not found")
return chat_history return paginate(chat_history)
except Exception as e: except Exception as e:
print(traceback.format_exc()) print(traceback.format_exc())
logger.error(f"Error reading chat history: {str(e)}") logger.error(f"Error reading chat history: {str(e)}")

View File

@@ -1,50 +0,0 @@
from typing import Optional, List, Union, Dict, Any
from fastapi import HTTPException, status
from sqlalchemy.orm import Session
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm.util import object_mapper
from private_gpt.users.crud.base import CRUDBase
from private_gpt.users.models.chat import ChatHistory
from private_gpt.users.schemas.chat import ChatHistoryCreate, ChatHistoryUpdate
class CRUDChat(CRUDBase[ChatHistory, ChatHistoryCreate, ChatHistoryUpdate]):
def get_by_id(self, db: Session, *, id: int) -> Optional[ChatHistory]:
return db.query(self.model).filter(ChatHistory.conversation_id == id).first()
def update_messages(
self,
db: Session,
*,
db_obj: ChatHistory,
obj_in: Union[ChatHistoryUpdate, Dict[str, Any]]
) -> ChatHistory:
try:
obj_data = object_mapper(db_obj).data
if isinstance(obj_in, dict):
update_data = obj_in
else:
update_data = obj_in.dict(exclude_unset=True)
# Update the `messages` field by appending new messages
existing_messages = obj_data.get("messages", [])
new_messages = update_data.get("messages", [])
obj_data["messages"] = existing_messages + new_messages
for field, value in obj_data.items():
setattr(db_obj, field, value)
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
except IntegrityError as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Integrity Error: {str(e)}",
)
chat = CRUDChat(ChatHistory)

View File

@@ -7,7 +7,6 @@ from private_gpt.users.models.department import Department
from private_gpt.users.models.document_department import document_department_association from private_gpt.users.models.document_department import document_department_association
from private_gpt.users.crud.base import CRUDBase from private_gpt.users.crud.base import CRUDBase
from typing import Optional, List from typing import Optional, List
from fastapi_pagination import Page, paginate
class CRUDDocuments(CRUDBase[Document, DocumentCreate, DocumentUpdate]): class CRUDDocuments(CRUDBase[Document, DocumentCreate, DocumentUpdate]):