From 355271be9363eb27c47f039da2054518ed4115c9 Mon Sep 17 00:00:00 2001 From: Saurab-Shrestha Date: Wed, 3 Apr 2024 14:11:08 +0545 Subject: [PATCH] Added table for chat history --- private_gpt/server/chat/chat_service.py | 9 +- .../server/completions/completions_router.py | 90 ++++++++-- .../users/api/v1/routers/chat_histories.py | 154 ++++++++++++++++++ private_gpt/users/crud/__init__.py | 3 +- private_gpt/users/crud/chathistory_crud.py | 50 ++++++ private_gpt/users/models/__init__.py | 3 +- private_gpt/users/models/chat_history.py | 39 +++++ private_gpt/users/models/user.py | 1 + private_gpt/users/schemas/__init__.py | 1 + private_gpt/users/schemas/chat_history.py | 30 ++++ private_gpt/users/schemas/user.py | 1 - 11 files changed, 357 insertions(+), 24 deletions(-) create mode 100644 private_gpt/users/api/v1/routers/chat_histories.py create mode 100644 private_gpt/users/crud/chathistory_crud.py create mode 100644 private_gpt/users/models/chat_history.py create mode 100644 private_gpt/users/schemas/chat_history.py diff --git a/private_gpt/server/chat/chat_service.py b/private_gpt/server/chat/chat_service.py index 5369200b..2e15f712 100644 --- a/private_gpt/server/chat/chat_service.py +++ b/private_gpt/server/chat/chat_service.py @@ -129,9 +129,7 @@ class ChatService: else None ) system_prompt = ( - chat_engine_input.system_message.content - if chat_engine_input.system_message - else None + "You can only answer questions about the provided context. If you know the answer but it is not based in the provided context, don't provide the answer, just state the answer is not in the context provided." ) chat_history = ( chat_engine_input.chat_history if chat_engine_input.chat_history else None @@ -165,14 +163,11 @@ class ChatService: else None ) system_prompt = ( - chat_engine_input.system_message.content - if chat_engine_input.system_message - else None + "You can only answer questions about the provided context. If you know the answer but it is not based in the provided context, don't provide the answer, just state the answer is not in the context provided." ) chat_history = ( chat_engine_input.chat_history if chat_engine_input.chat_history else None ) - chat_engine = self._chat_engine( system_prompt=system_prompt, use_context=use_context, diff --git a/private_gpt/server/completions/completions_router.py b/private_gpt/server/completions/completions_router.py index 45e3e302..f560b953 100644 --- a/private_gpt/server/completions/completions_router.py +++ b/private_gpt/server/completions/completions_router.py @@ -1,3 +1,6 @@ +from private_gpt.users import crud, models, schemas +import itertools +from llama_index.llms import ChatMessage, ChatResponse, MessageRole from fastapi import APIRouter, Depends, Request, Security, HTTPException, status from private_gpt.server.ingest.ingest_service import IngestService from pydantic import BaseModel @@ -17,11 +20,13 @@ from private_gpt.open_ai.openai_models import ( from private_gpt.server.chat.chat_router import ChatBody, chat_completion from private_gpt.server.utils.auth import authenticated from private_gpt.users.api import deps -from private_gpt.users import crud, models, schemas +from pydantic import Optional + completions_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)]) class CompletionsBody(BaseModel): + conversation_id: Optional[int] prompt: str system_prompt: str | None = None use_context: bool = False @@ -145,25 +150,82 @@ async def prompt_completion( detail="Internal Server Error", ) - messages = [OpenAIMessage(content=body.prompt, role="user")] + if body.conversation_id: + chat_history = crud.chat.get_by_id(db, id=body.conversation_id) + if chat_history is None or chat_history.user_id != current_user.id: + raise HTTPException( + status_code=404, detail="Chat history not found") + else: + chat_create_in = schemas.ChatCreate(user_id=current_user.id) + chat_history = crud.chat.create(db=db, obj_in=chat_create_in) + + _history = chat_history.messages or [] + + def build_history() -> list[ChatMessage]: + history_messages: list[ChatMessage] = [] + for interaction in _history: + user_message = interaction.get("user", "") + ai_message = interaction.get("ai", "") + if user_message: + history_messages.append( + ChatMessage( + content=user_message, + role=MessageRole.USER + ) + ) + if ai_message: + history_messages.append( + ChatMessage( + content=ai_message, + role=MessageRole.ASSISTANT + ) + ) + + # max 20 messages to try to avoid context overflow + return history_messages[:20] + + # Prepare new messages + new_messages = [] + + if body.prompt: + new_messages.append(OpenAIMessage(content=body.prompt, role="user")) if body.system_prompt: - messages.insert(0, OpenAIMessage( + new_messages.insert(0, OpenAIMessage( content=body.system_prompt, role="system")) + # Update chat history with new user messages + if new_messages: + new_message = ChatMessage(content=new_messages, role=MessageRole.USER) + _history.append(new_message.dict()) + + # Process chat completion chat_body = ChatBody( - messages=messages, + messages=build_history(), use_context=body.use_context, stream=body.stream, include_sources=body.include_sources, context_filter=body.context_filter, ) - log_audit( - model='Chat', - action='Chat', - details={ - "query": body.prompt, - 'user': current_user.username, - }, - user_id=current_user.id - ) - return await chat_completion(request, chat_body) + + ai_response = await chat_completion(request, chat_body) + + # Update chat history with AI response + if ai_response.messages: + ai_message = OpenAIMessage( + content=ai_response.messages, role="assistant") + _history.append(ai_message.dict()) + + # Update chat history in the database + chat_obj_in = schemas.ChatUpdate(messages=build_history()) + crud.chat.update_messages(db, db_obj=chat_history, obj_in=chat_obj_in) + + return ai_response + # log_audit( + # model='Chat', + # action='Chat', + # details={ + # "query": body.prompt, + # 'user': current_user.username, + # }, + # user_id=current_user.id + # ) diff --git a/private_gpt/users/api/v1/routers/chat_histories.py b/private_gpt/users/api/v1/routers/chat_histories.py new file mode 100644 index 00000000..b0e13a2a --- /dev/null +++ b/private_gpt/users/api/v1/routers/chat_histories.py @@ -0,0 +1,154 @@ +import logging +import traceback + +from sqlalchemy.orm import Session +from fastapi.responses import JSONResponse +from fastapi.encoders import jsonable_encoder +from fastapi import APIRouter, Depends, HTTPException, status, Security + +from private_gpt.users.api import deps +from private_gpt.users.constants.role import Role +from private_gpt.users import crud, models, schemas + +logger = logging.getLogger(__name__) +router = APIRouter(prefix="/c", tags=["Chat Histories"]) + + +@router.get("", response_model=list[schemas.ChatBase]) +def list_chat_histories( + db: Session = Depends(deps.get_db), + skip: int = 0, + limit: int = 100, + current_user: models.User = Security( + deps.get_current_user, + ), +) -> list[schemas.ChatBase]: + """ + Retrieve a list of chat histories with pagination support. + """ + try: + chat_histories = crud.chat.get_multi( + db, skip=skip, limit=limit, user_id=current_user.id) + return chat_histories + except Exception as e: + print(traceback.format_exc()) + logger.error(f"Error listing chat histories: {str(e)}") + raise HTTPException( + status_code=500, + detail="Internal Server Error", + ) + + +@router.post("/create", response_model=schemas.ChatBase) +def create_chat_history( + chat_history_in: schemas.ChatCreate, + db: Session = Depends(deps.get_db), + current_user: models.User = Security( + deps.get_current_user, + ), +) -> schemas.ChatBase: + """ + Create a new chat history + """ + try: + chat_history = crud.chat.create( + db=db, obj_in=chat_history_in, user_id=current_user.id) + return chat_history + except Exception as e: + print(traceback.format_exc()) + logger.error(f"Error creating chat history: {str(e)}") + raise HTTPException( + status_code=500, + detail="Internal Server Error", + ) + + +@router.get("/{chat_history_id}", response_model=schemas.ChatMessages) +def read_chat_history( + chat_history_id: int, + db: Session = Depends(deps.get_db), + current_user: models.User = Security( + deps.get_current_user, + ), +) -> schemas.ChatMessages: + """ + Read a chat history by ID + """ + try: + chat_history = crud.chat.get_by_id(db, id=chat_history_id) + if chat_history is None or chat_history.user_id != current_user.id: + raise HTTPException( + status_code=404, detail="Chat history not found") + return chat_history + except Exception as e: + print(traceback.format_exc()) + logger.error(f"Error reading chat history: {str(e)}") + raise HTTPException( + status_code=500, + detail="Internal Server Error", + ) + + +@router.post("/conversation", response_model=schemas.ChatHistory) +def conversation( + chat_history_in: schemas.ChatUpdate, + db: Session = Depends(deps.get_db), + current_user: models.User = Security( + deps.get_current_user, + ), +) -> schemas.ChatHistory: + """ + Update a chat history by ID + """ + try: + chat_history = crud.chat.get_by_id( + db, id=chat_history_in.conversation_id) + if chat_history is None or chat_history.user_id != current_user.id: + raise HTTPException( + status_code=404, detail="Chat history not found") + + updated_chat_history = crud.chat.update_messages( + db=db, db_obj=chat_history, obj_in=chat_history_in) + + return updated_chat_history + except Exception as e: + print(traceback.format_exc()) + logger.error(f"Error updating chat history: {str(e)}") + raise HTTPException( + status_code=500, + detail="Internal Server Error", + ) + + +@router.post("/delete") +def delete_chat_history( + chat_history_in: schemas.ChatDelete, + db: Session = Depends(deps.get_db), + current_user: models.User = Security( + deps.get_current_user, + ), +): + """ + Delete a chat history by ID + """ + try: + chat_history_id = chat_history_in.id + chat_history = crud.chat.get(db, id=chat_history_id) + if chat_history is None or chat_history.user_id != current_user.id: + raise HTTPException( + status_code=404, detail="Chat history not found") + + crud.chat.remove(db=db, id=chat_history_id) + return JSONResponse( + status_code=status.HTTP_200_OK, + content={ + "message": "Chat history deleted successfully", + }, + ) + except Exception as e: + print(traceback.format_exc()) + logger.error(f"Error deleting chat history: {str(e)}") + raise HTTPException( + status_code=500, + detail="Internal Server Error", + ) diff --git a/private_gpt/users/crud/__init__.py b/private_gpt/users/crud/__init__.py index fd7663a9..b87b967a 100644 --- a/private_gpt/users/crud/__init__.py +++ b/private_gpt/users/crud/__init__.py @@ -5,4 +5,5 @@ from .company_crud import company from .subscription_crud import subscription from .document_crud import documents from .department_crud import department -from .audit_crud import audit \ No newline at end of file +from .audit_crud import audit +from .chathistory_crud import chat \ No newline at end of file diff --git a/private_gpt/users/crud/chathistory_crud.py b/private_gpt/users/crud/chathistory_crud.py new file mode 100644 index 00000000..7db220db --- /dev/null +++ b/private_gpt/users/crud/chathistory_crud.py @@ -0,0 +1,50 @@ +from typing import Optional, List, Union, Dict, Any +from fastapi import HTTPException, status +from sqlalchemy.orm import Session +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm.util import object_mapper + +from private_gpt.users.crud.base import CRUDBase +from private_gpt.users.models.chat_history import ChatHistory +from private_gpt.users.schemas.chat_history import ChatCreate, ChatUpdate + + +class CRUDChat(CRUDBase[ChatHistory, ChatCreate, ChatUpdate]): + def get_by_id(self, db: Session, *, id: int) -> Optional[ChatHistory]: + return db.query(self.model).filter(ChatHistory.conversation_id == id).first() + + def update_messages( + self, + db: Session, + *, + db_obj: ChatHistory, + obj_in: Union[ChatUpdate, Dict[str, Any]] + ) -> ChatHistory: + try: + obj_data = object_mapper(db_obj).data + if isinstance(obj_in, dict): + update_data = obj_in + else: + update_data = obj_in.dict(exclude_unset=True) + + # Update the `messages` field by appending new messages + existing_messages = obj_data.get("messages", []) + new_messages = update_data.get("messages", []) + obj_data["messages"] = existing_messages + new_messages + + for field, value in obj_data.items(): + setattr(db_obj, field, value) + + db.add(db_obj) + db.commit() + db.refresh(db_obj) + return db_obj + except IntegrityError as e: + db.rollback() + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Integrity Error: {str(e)}", + ) + + +chat = CRUDChat(ChatHistory) diff --git a/private_gpt/users/models/__init__.py b/private_gpt/users/models/__init__.py index 7dfd1e98..d3d3ab14 100644 --- a/private_gpt/users/models/__init__.py +++ b/private_gpt/users/models/__init__.py @@ -6,4 +6,5 @@ from .document import Document from .subscription import Subscription from .department import Department from .audit import Audit -from .document_department import document_department_association \ No newline at end of file +from .document_department import document_department_association +from .chat_history import ChatHistory \ No newline at end of file diff --git a/private_gpt/users/models/chat_history.py b/private_gpt/users/models/chat_history.py new file mode 100644 index 00000000..1b468f79 --- /dev/null +++ b/private_gpt/users/models/chat_history.py @@ -0,0 +1,39 @@ +from datetime import datetime + +from sqlalchemy.orm import relationship +from sqlalchemy import Column, Integer, String, DateTime, ForeignKey, JSON + +from private_gpt.users.db.base_class import Base + + +class ChatHistory(Base): + """Models a chat history table""" + + __tablename__ = "chat_history" + + id = Column(Integer, nullable=False, primary_key=True) + title = Column(String(255), nullable=False) + messages = Column(JSON, nullable=False) + created_at = Column(DateTime, default=datetime.now) + user_id = Column(Integer, ForeignKey("users.id")) + user = relationship("User", back_populates="chat_histories") + + def __init__(self, messages, user_id, *args, **kwargs): + super().__init__(*args, **kwargs) + self.messages = messages + self.user_id = user_id + self.title = self.generate_title() + + def generate_title(self): + if self.messages: + first_user_message = next(+ + (msg["message"] + for msg in self.messages if msg["sender"] == "user"), None + ) + if first_user_message: + return first_user_message[:30] + return "Untitled Chat" + + def __repr__(self): + """Returns string representation of model instance""" + return f"" diff --git a/private_gpt/users/models/user.py b/private_gpt/users/models/user.py index 5b527208..dd4be3db 100644 --- a/private_gpt/users/models/user.py +++ b/private_gpt/users/models/user.py @@ -51,6 +51,7 @@ class User(Base): Integer, ForeignKey("departments.id"), nullable=False) department = relationship("Department", back_populates="users") + chat_histories = relationship("ChatHistory", back_populates="user") __table_args__ = ( UniqueConstraint('username', name='unique_username_no_spacing'), diff --git a/private_gpt/users/schemas/__init__.py b/private_gpt/users/schemas/__init__.py index 487e92c5..d59f0a15 100644 --- a/private_gpt/users/schemas/__init__.py +++ b/private_gpt/users/schemas/__init__.py @@ -7,3 +7,4 @@ from .company import Company, CompanyBase, CompanyCreate, CompanyUpdate from .documents import Document, DocumentCreate, DocumentsBase, DocumentUpdate, DocumentList, DepartmentList, DocumentEnable, DocumentDepartmentUpdate, DocumentCheckerUpdate, DocumentMakerCreate, DocumentDepartmentList, DocumentView, DocumentVerify, DocumentFilter from .department import Department, DepartmentCreate, DepartmentUpdate, DepartmentAdminCreate, DepartmentDelete from .audit import AuditBase, AuditCreate, AuditUpdate, Audit, GetAudit +from .chat_history import Chat, ChatBase, ChatCreate, ChatDelete, ChatUpdate, ChatMessages diff --git a/private_gpt/users/schemas/chat_history.py b/private_gpt/users/schemas/chat_history.py new file mode 100644 index 00000000..236518a1 --- /dev/null +++ b/private_gpt/users/schemas/chat_history.py @@ -0,0 +1,30 @@ +from typing import Optional, Dict, Any +from pydantic import BaseModel +from datetime import datetime + + +class ChatBase(BaseModel): + title: Optional[str] + +class ChatCreate(ChatBase): + user_id: int + messages: Optional[Dict[str, Any]] + +class ChatUpdate(ChatBase): + conversation_id: int + messages: Dict + + +class ChatDelete(BaseModel): + conversation_id: int + +class ChatMessages(BaseModel): + messages: Dict[str, Any] + +class Chat(ChatBase): + conversation_id: int + created_at: datetime + user_id: int + + class Config: + orm_mode = True diff --git a/private_gpt/users/schemas/user.py b/private_gpt/users/schemas/user.py index 4d583dbe..52d0f3e5 100644 --- a/private_gpt/users/schemas/user.py +++ b/private_gpt/users/schemas/user.py @@ -30,7 +30,6 @@ class UserUpdate(BaseModel): last_login: Optional[datetime] = None - class UserLoginSchema(BaseModel): email: EmailStr = Field(alias="email") password: str