Added table for chat history

This commit is contained in:
Saurab-Shrestha 2024-04-03 14:11:08 +05:45
parent 542ed0ef4e
commit 355271be93
11 changed files with 357 additions and 24 deletions

View File

@ -129,9 +129,7 @@ class ChatService:
else None
)
system_prompt = (
chat_engine_input.system_message.content
if chat_engine_input.system_message
else None
"You can only answer questions about the provided context. If you know the answer but it is not based in the provided context, don't provide the answer, just state the answer is not in the context provided."
)
chat_history = (
chat_engine_input.chat_history if chat_engine_input.chat_history else None
@ -165,14 +163,11 @@ class ChatService:
else None
)
system_prompt = (
chat_engine_input.system_message.content
if chat_engine_input.system_message
else None
"You can only answer questions about the provided context. If you know the answer but it is not based in the provided context, don't provide the answer, just state the answer is not in the context provided."
)
chat_history = (
chat_engine_input.chat_history if chat_engine_input.chat_history else None
)
chat_engine = self._chat_engine(
system_prompt=system_prompt,
use_context=use_context,

View File

@ -1,3 +1,6 @@
from private_gpt.users import crud, models, schemas
import itertools
from llama_index.llms import ChatMessage, ChatResponse, MessageRole
from fastapi import APIRouter, Depends, Request, Security, HTTPException, status
from private_gpt.server.ingest.ingest_service import IngestService
from pydantic import BaseModel
@ -17,11 +20,13 @@ from private_gpt.open_ai.openai_models import (
from private_gpt.server.chat.chat_router import ChatBody, chat_completion
from private_gpt.server.utils.auth import authenticated
from private_gpt.users.api import deps
from private_gpt.users import crud, models, schemas
from pydantic import Optional
completions_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)])
class CompletionsBody(BaseModel):
conversation_id: Optional[int]
prompt: str
system_prompt: str | None = None
use_context: bool = False
@ -145,25 +150,82 @@ async def prompt_completion(
detail="Internal Server Error",
)
messages = [OpenAIMessage(content=body.prompt, role="user")]
if body.conversation_id:
chat_history = crud.chat.get_by_id(db, id=body.conversation_id)
if chat_history is None or chat_history.user_id != current_user.id:
raise HTTPException(
status_code=404, detail="Chat history not found")
else:
chat_create_in = schemas.ChatCreate(user_id=current_user.id)
chat_history = crud.chat.create(db=db, obj_in=chat_create_in)
_history = chat_history.messages or []
def build_history() -> list[ChatMessage]:
history_messages: list[ChatMessage] = []
for interaction in _history:
user_message = interaction.get("user", "")
ai_message = interaction.get("ai", "")
if user_message:
history_messages.append(
ChatMessage(
content=user_message,
role=MessageRole.USER
)
)
if ai_message:
history_messages.append(
ChatMessage(
content=ai_message,
role=MessageRole.ASSISTANT
)
)
# max 20 messages to try to avoid context overflow
return history_messages[:20]
# Prepare new messages
new_messages = []
if body.prompt:
new_messages.append(OpenAIMessage(content=body.prompt, role="user"))
if body.system_prompt:
messages.insert(0, OpenAIMessage(
new_messages.insert(0, OpenAIMessage(
content=body.system_prompt, role="system"))
# Update chat history with new user messages
if new_messages:
new_message = ChatMessage(content=new_messages, role=MessageRole.USER)
_history.append(new_message.dict())
# Process chat completion
chat_body = ChatBody(
messages=messages,
messages=build_history(),
use_context=body.use_context,
stream=body.stream,
include_sources=body.include_sources,
context_filter=body.context_filter,
)
log_audit(
model='Chat',
action='Chat',
details={
"query": body.prompt,
'user': current_user.username,
},
user_id=current_user.id
)
return await chat_completion(request, chat_body)
ai_response = await chat_completion(request, chat_body)
# Update chat history with AI response
if ai_response.messages:
ai_message = OpenAIMessage(
content=ai_response.messages, role="assistant")
_history.append(ai_message.dict())
# Update chat history in the database
chat_obj_in = schemas.ChatUpdate(messages=build_history())
crud.chat.update_messages(db, db_obj=chat_history, obj_in=chat_obj_in)
return ai_response
# log_audit(
# model='Chat',
# action='Chat',
# details={
# "query": body.prompt,
# 'user': current_user.username,
# },
# user_id=current_user.id
# )

View File

@ -0,0 +1,154 @@
import logging
import traceback
from sqlalchemy.orm import Session
from fastapi.responses import JSONResponse
from fastapi.encoders import jsonable_encoder
from fastapi import APIRouter, Depends, HTTPException, status, Security
from private_gpt.users.api import deps
from private_gpt.users.constants.role import Role
from private_gpt.users import crud, models, schemas
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/c", tags=["Chat Histories"])
@router.get("", response_model=list[schemas.ChatBase])
def list_chat_histories(
db: Session = Depends(deps.get_db),
skip: int = 0,
limit: int = 100,
current_user: models.User = Security(
deps.get_current_user,
),
) -> list[schemas.ChatBase]:
"""
Retrieve a list of chat histories with pagination support.
"""
try:
chat_histories = crud.chat.get_multi(
db, skip=skip, limit=limit, user_id=current_user.id)
return chat_histories
except Exception as e:
print(traceback.format_exc())
logger.error(f"Error listing chat histories: {str(e)}")
raise HTTPException(
status_code=500,
detail="Internal Server Error",
)
@router.post("/create", response_model=schemas.ChatBase)
def create_chat_history(
chat_history_in: schemas.ChatCreate,
db: Session = Depends(deps.get_db),
current_user: models.User = Security(
deps.get_current_user,
),
) -> schemas.ChatBase:
"""
Create a new chat history
"""
try:
chat_history = crud.chat.create(
db=db, obj_in=chat_history_in, user_id=current_user.id)
return chat_history
except Exception as e:
print(traceback.format_exc())
logger.error(f"Error creating chat history: {str(e)}")
raise HTTPException(
status_code=500,
detail="Internal Server Error",
)
@router.get("/{chat_history_id}", response_model=schemas.ChatMessages)
def read_chat_history(
chat_history_id: int,
db: Session = Depends(deps.get_db),
current_user: models.User = Security(
deps.get_current_user,
),
) -> schemas.ChatMessages:
"""
Read a chat history by ID
"""
try:
chat_history = crud.chat.get_by_id(db, id=chat_history_id)
if chat_history is None or chat_history.user_id != current_user.id:
raise HTTPException(
status_code=404, detail="Chat history not found")
return chat_history
except Exception as e:
print(traceback.format_exc())
logger.error(f"Error reading chat history: {str(e)}")
raise HTTPException(
status_code=500,
detail="Internal Server Error",
)
@router.post("/conversation", response_model=schemas.ChatHistory)
def conversation(
chat_history_in: schemas.ChatUpdate,
db: Session = Depends(deps.get_db),
current_user: models.User = Security(
deps.get_current_user,
),
) -> schemas.ChatHistory:
"""
Update a chat history by ID
"""
try:
chat_history = crud.chat.get_by_id(
db, id=chat_history_in.conversation_id)
if chat_history is None or chat_history.user_id != current_user.id:
raise HTTPException(
status_code=404, detail="Chat history not found")
updated_chat_history = crud.chat.update_messages(
db=db, db_obj=chat_history, obj_in=chat_history_in)
return updated_chat_history
except Exception as e:
print(traceback.format_exc())
logger.error(f"Error updating chat history: {str(e)}")
raise HTTPException(
status_code=500,
detail="Internal Server Error",
)
@router.post("/delete")
def delete_chat_history(
chat_history_in: schemas.ChatDelete,
db: Session = Depends(deps.get_db),
current_user: models.User = Security(
deps.get_current_user,
),
):
"""
Delete a chat history by ID
"""
try:
chat_history_id = chat_history_in.id
chat_history = crud.chat.get(db, id=chat_history_id)
if chat_history is None or chat_history.user_id != current_user.id:
raise HTTPException(
status_code=404, detail="Chat history not found")
crud.chat.remove(db=db, id=chat_history_id)
return JSONResponse(
status_code=status.HTTP_200_OK,
content={
"message": "Chat history deleted successfully",
},
)
except Exception as e:
print(traceback.format_exc())
logger.error(f"Error deleting chat history: {str(e)}")
raise HTTPException(
status_code=500,
detail="Internal Server Error",
)

View File

@ -5,4 +5,5 @@ from .company_crud import company
from .subscription_crud import subscription
from .document_crud import documents
from .department_crud import department
from .audit_crud import audit
from .audit_crud import audit
from .chathistory_crud import chat

View File

@ -0,0 +1,50 @@
from typing import Optional, List, Union, Dict, Any
from fastapi import HTTPException, status
from sqlalchemy.orm import Session
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm.util import object_mapper
from private_gpt.users.crud.base import CRUDBase
from private_gpt.users.models.chat_history import ChatHistory
from private_gpt.users.schemas.chat_history import ChatCreate, ChatUpdate
class CRUDChat(CRUDBase[ChatHistory, ChatCreate, ChatUpdate]):
def get_by_id(self, db: Session, *, id: int) -> Optional[ChatHistory]:
return db.query(self.model).filter(ChatHistory.conversation_id == id).first()
def update_messages(
self,
db: Session,
*,
db_obj: ChatHistory,
obj_in: Union[ChatUpdate, Dict[str, Any]]
) -> ChatHistory:
try:
obj_data = object_mapper(db_obj).data
if isinstance(obj_in, dict):
update_data = obj_in
else:
update_data = obj_in.dict(exclude_unset=True)
# Update the `messages` field by appending new messages
existing_messages = obj_data.get("messages", [])
new_messages = update_data.get("messages", [])
obj_data["messages"] = existing_messages + new_messages
for field, value in obj_data.items():
setattr(db_obj, field, value)
db.add(db_obj)
db.commit()
db.refresh(db_obj)
return db_obj
except IntegrityError as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Integrity Error: {str(e)}",
)
chat = CRUDChat(ChatHistory)

View File

@ -6,4 +6,5 @@ from .document import Document
from .subscription import Subscription
from .department import Department
from .audit import Audit
from .document_department import document_department_association
from .document_department import document_department_association
from .chat_history import ChatHistory

View File

@ -0,0 +1,39 @@
from datetime import datetime
from sqlalchemy.orm import relationship
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey, JSON
from private_gpt.users.db.base_class import Base
class ChatHistory(Base):
"""Models a chat history table"""
__tablename__ = "chat_history"
id = Column(Integer, nullable=False, primary_key=True)
title = Column(String(255), nullable=False)
messages = Column(JSON, nullable=False)
created_at = Column(DateTime, default=datetime.now)
user_id = Column(Integer, ForeignKey("users.id"))
user = relationship("User", back_populates="chat_histories")
def __init__(self, messages, user_id, *args, **kwargs):
super().__init__(*args, **kwargs)
self.messages = messages
self.user_id = user_id
self.title = self.generate_title()
def generate_title(self):
if self.messages:
first_user_message = next(+
(msg["message"]
for msg in self.messages if msg["sender"] == "user"), None
)
if first_user_message:
return first_user_message[:30]
return "Untitled Chat"
def __repr__(self):
"""Returns string representation of model instance"""
return f"<ChatHistory {self.id!r}>"

View File

@ -51,6 +51,7 @@ class User(Base):
Integer, ForeignKey("departments.id"), nullable=False)
department = relationship("Department", back_populates="users")
chat_histories = relationship("ChatHistory", back_populates="user")
__table_args__ = (
UniqueConstraint('username', name='unique_username_no_spacing'),

View File

@ -7,3 +7,4 @@ from .company import Company, CompanyBase, CompanyCreate, CompanyUpdate
from .documents import Document, DocumentCreate, DocumentsBase, DocumentUpdate, DocumentList, DepartmentList, DocumentEnable, DocumentDepartmentUpdate, DocumentCheckerUpdate, DocumentMakerCreate, DocumentDepartmentList, DocumentView, DocumentVerify, DocumentFilter
from .department import Department, DepartmentCreate, DepartmentUpdate, DepartmentAdminCreate, DepartmentDelete
from .audit import AuditBase, AuditCreate, AuditUpdate, Audit, GetAudit
from .chat_history import Chat, ChatBase, ChatCreate, ChatDelete, ChatUpdate, ChatMessages

View File

@ -0,0 +1,30 @@
from typing import Optional, Dict, Any
from pydantic import BaseModel
from datetime import datetime
class ChatBase(BaseModel):
title: Optional[str]
class ChatCreate(ChatBase):
user_id: int
messages: Optional[Dict[str, Any]]
class ChatUpdate(ChatBase):
conversation_id: int
messages: Dict
class ChatDelete(BaseModel):
conversation_id: int
class ChatMessages(BaseModel):
messages: Dict[str, Any]
class Chat(ChatBase):
conversation_id: int
created_at: datetime
user_id: int
class Config:
orm_mode = True

View File

@ -30,7 +30,6 @@ class UserUpdate(BaseModel):
last_login: Optional[datetime] = None
class UserLoginSchema(BaseModel):
email: EmailStr = Field(alias="email")
password: str