mirror of
https://github.com/imartinez/privateGPT.git
synced 2025-08-13 21:25:56 +00:00
Added table for chat history
This commit is contained in:
parent
542ed0ef4e
commit
355271be93
@ -129,9 +129,7 @@ class ChatService:
|
||||
else None
|
||||
)
|
||||
system_prompt = (
|
||||
chat_engine_input.system_message.content
|
||||
if chat_engine_input.system_message
|
||||
else None
|
||||
"You can only answer questions about the provided context. If you know the answer but it is not based in the provided context, don't provide the answer, just state the answer is not in the context provided."
|
||||
)
|
||||
chat_history = (
|
||||
chat_engine_input.chat_history if chat_engine_input.chat_history else None
|
||||
@ -165,14 +163,11 @@ class ChatService:
|
||||
else None
|
||||
)
|
||||
system_prompt = (
|
||||
chat_engine_input.system_message.content
|
||||
if chat_engine_input.system_message
|
||||
else None
|
||||
"You can only answer questions about the provided context. If you know the answer but it is not based in the provided context, don't provide the answer, just state the answer is not in the context provided."
|
||||
)
|
||||
chat_history = (
|
||||
chat_engine_input.chat_history if chat_engine_input.chat_history else None
|
||||
)
|
||||
|
||||
chat_engine = self._chat_engine(
|
||||
system_prompt=system_prompt,
|
||||
use_context=use_context,
|
||||
|
@ -1,3 +1,6 @@
|
||||
from private_gpt.users import crud, models, schemas
|
||||
import itertools
|
||||
from llama_index.llms import ChatMessage, ChatResponse, MessageRole
|
||||
from fastapi import APIRouter, Depends, Request, Security, HTTPException, status
|
||||
from private_gpt.server.ingest.ingest_service import IngestService
|
||||
from pydantic import BaseModel
|
||||
@ -17,11 +20,13 @@ from private_gpt.open_ai.openai_models import (
|
||||
from private_gpt.server.chat.chat_router import ChatBody, chat_completion
|
||||
from private_gpt.server.utils.auth import authenticated
|
||||
from private_gpt.users.api import deps
|
||||
from private_gpt.users import crud, models, schemas
|
||||
from pydantic import Optional
|
||||
|
||||
completions_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)])
|
||||
|
||||
|
||||
class CompletionsBody(BaseModel):
|
||||
conversation_id: Optional[int]
|
||||
prompt: str
|
||||
system_prompt: str | None = None
|
||||
use_context: bool = False
|
||||
@ -145,25 +150,82 @@ async def prompt_completion(
|
||||
detail="Internal Server Error",
|
||||
)
|
||||
|
||||
messages = [OpenAIMessage(content=body.prompt, role="user")]
|
||||
if body.conversation_id:
|
||||
chat_history = crud.chat.get_by_id(db, id=body.conversation_id)
|
||||
if chat_history is None or chat_history.user_id != current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Chat history not found")
|
||||
else:
|
||||
chat_create_in = schemas.ChatCreate(user_id=current_user.id)
|
||||
chat_history = crud.chat.create(db=db, obj_in=chat_create_in)
|
||||
|
||||
_history = chat_history.messages or []
|
||||
|
||||
def build_history() -> list[ChatMessage]:
|
||||
history_messages: list[ChatMessage] = []
|
||||
for interaction in _history:
|
||||
user_message = interaction.get("user", "")
|
||||
ai_message = interaction.get("ai", "")
|
||||
if user_message:
|
||||
history_messages.append(
|
||||
ChatMessage(
|
||||
content=user_message,
|
||||
role=MessageRole.USER
|
||||
)
|
||||
)
|
||||
if ai_message:
|
||||
history_messages.append(
|
||||
ChatMessage(
|
||||
content=ai_message,
|
||||
role=MessageRole.ASSISTANT
|
||||
)
|
||||
)
|
||||
|
||||
# max 20 messages to try to avoid context overflow
|
||||
return history_messages[:20]
|
||||
|
||||
# Prepare new messages
|
||||
new_messages = []
|
||||
|
||||
if body.prompt:
|
||||
new_messages.append(OpenAIMessage(content=body.prompt, role="user"))
|
||||
if body.system_prompt:
|
||||
messages.insert(0, OpenAIMessage(
|
||||
new_messages.insert(0, OpenAIMessage(
|
||||
content=body.system_prompt, role="system"))
|
||||
|
||||
# Update chat history with new user messages
|
||||
if new_messages:
|
||||
new_message = ChatMessage(content=new_messages, role=MessageRole.USER)
|
||||
_history.append(new_message.dict())
|
||||
|
||||
# Process chat completion
|
||||
chat_body = ChatBody(
|
||||
messages=messages,
|
||||
messages=build_history(),
|
||||
use_context=body.use_context,
|
||||
stream=body.stream,
|
||||
include_sources=body.include_sources,
|
||||
context_filter=body.context_filter,
|
||||
)
|
||||
log_audit(
|
||||
model='Chat',
|
||||
action='Chat',
|
||||
details={
|
||||
"query": body.prompt,
|
||||
'user': current_user.username,
|
||||
},
|
||||
user_id=current_user.id
|
||||
)
|
||||
return await chat_completion(request, chat_body)
|
||||
|
||||
ai_response = await chat_completion(request, chat_body)
|
||||
|
||||
# Update chat history with AI response
|
||||
if ai_response.messages:
|
||||
ai_message = OpenAIMessage(
|
||||
content=ai_response.messages, role="assistant")
|
||||
_history.append(ai_message.dict())
|
||||
|
||||
# Update chat history in the database
|
||||
chat_obj_in = schemas.ChatUpdate(messages=build_history())
|
||||
crud.chat.update_messages(db, db_obj=chat_history, obj_in=chat_obj_in)
|
||||
|
||||
return ai_response
|
||||
# log_audit(
|
||||
# model='Chat',
|
||||
# action='Chat',
|
||||
# details={
|
||||
# "query": body.prompt,
|
||||
# 'user': current_user.username,
|
||||
# },
|
||||
# user_id=current_user.id
|
||||
# )
|
||||
|
154
private_gpt/users/api/v1/routers/chat_histories.py
Normal file
154
private_gpt/users/api/v1/routers/chat_histories.py
Normal file
@ -0,0 +1,154 @@
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Security
|
||||
|
||||
from private_gpt.users.api import deps
|
||||
from private_gpt.users.constants.role import Role
|
||||
from private_gpt.users import crud, models, schemas
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/c", tags=["Chat Histories"])
|
||||
|
||||
|
||||
@router.get("", response_model=list[schemas.ChatBase])
|
||||
def list_chat_histories(
|
||||
db: Session = Depends(deps.get_db),
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
current_user: models.User = Security(
|
||||
deps.get_current_user,
|
||||
),
|
||||
) -> list[schemas.ChatBase]:
|
||||
"""
|
||||
Retrieve a list of chat histories with pagination support.
|
||||
"""
|
||||
try:
|
||||
chat_histories = crud.chat.get_multi(
|
||||
db, skip=skip, limit=limit, user_id=current_user.id)
|
||||
return chat_histories
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
logger.error(f"Error listing chat histories: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Internal Server Error",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/create", response_model=schemas.ChatBase)
|
||||
def create_chat_history(
|
||||
chat_history_in: schemas.ChatCreate,
|
||||
db: Session = Depends(deps.get_db),
|
||||
current_user: models.User = Security(
|
||||
deps.get_current_user,
|
||||
),
|
||||
) -> schemas.ChatBase:
|
||||
"""
|
||||
Create a new chat history
|
||||
"""
|
||||
try:
|
||||
chat_history = crud.chat.create(
|
||||
db=db, obj_in=chat_history_in, user_id=current_user.id)
|
||||
return chat_history
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
logger.error(f"Error creating chat history: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Internal Server Error",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{chat_history_id}", response_model=schemas.ChatMessages)
|
||||
def read_chat_history(
|
||||
chat_history_id: int,
|
||||
db: Session = Depends(deps.get_db),
|
||||
current_user: models.User = Security(
|
||||
deps.get_current_user,
|
||||
),
|
||||
) -> schemas.ChatMessages:
|
||||
"""
|
||||
Read a chat history by ID
|
||||
"""
|
||||
try:
|
||||
chat_history = crud.chat.get_by_id(db, id=chat_history_id)
|
||||
if chat_history is None or chat_history.user_id != current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Chat history not found")
|
||||
return chat_history
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
logger.error(f"Error reading chat history: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Internal Server Error",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/conversation", response_model=schemas.ChatHistory)
|
||||
def conversation(
|
||||
chat_history_in: schemas.ChatUpdate,
|
||||
db: Session = Depends(deps.get_db),
|
||||
current_user: models.User = Security(
|
||||
deps.get_current_user,
|
||||
),
|
||||
) -> schemas.ChatHistory:
|
||||
"""
|
||||
Update a chat history by ID
|
||||
"""
|
||||
try:
|
||||
chat_history = crud.chat.get_by_id(
|
||||
db, id=chat_history_in.conversation_id)
|
||||
if chat_history is None or chat_history.user_id != current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Chat history not found")
|
||||
|
||||
updated_chat_history = crud.chat.update_messages(
|
||||
db=db, db_obj=chat_history, obj_in=chat_history_in)
|
||||
|
||||
return updated_chat_history
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
logger.error(f"Error updating chat history: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Internal Server Error",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/delete")
|
||||
def delete_chat_history(
|
||||
chat_history_in: schemas.ChatDelete,
|
||||
db: Session = Depends(deps.get_db),
|
||||
current_user: models.User = Security(
|
||||
deps.get_current_user,
|
||||
),
|
||||
):
|
||||
"""
|
||||
Delete a chat history by ID
|
||||
"""
|
||||
try:
|
||||
chat_history_id = chat_history_in.id
|
||||
chat_history = crud.chat.get(db, id=chat_history_id)
|
||||
if chat_history is None or chat_history.user_id != current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Chat history not found")
|
||||
|
||||
crud.chat.remove(db=db, id=chat_history_id)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content={
|
||||
"message": "Chat history deleted successfully",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
logger.error(f"Error deleting chat history: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Internal Server Error",
|
||||
)
|
@ -5,4 +5,5 @@ from .company_crud import company
|
||||
from .subscription_crud import subscription
|
||||
from .document_crud import documents
|
||||
from .department_crud import department
|
||||
from .audit_crud import audit
|
||||
from .audit_crud import audit
|
||||
from .chathistory_crud import chat
|
50
private_gpt/users/crud/chathistory_crud.py
Normal file
50
private_gpt/users/crud/chathistory_crud.py
Normal file
@ -0,0 +1,50 @@
|
||||
from typing import Optional, List, Union, Dict, Any
|
||||
from fastapi import HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm.util import object_mapper
|
||||
|
||||
from private_gpt.users.crud.base import CRUDBase
|
||||
from private_gpt.users.models.chat_history import ChatHistory
|
||||
from private_gpt.users.schemas.chat_history import ChatCreate, ChatUpdate
|
||||
|
||||
|
||||
class CRUDChat(CRUDBase[ChatHistory, ChatCreate, ChatUpdate]):
|
||||
def get_by_id(self, db: Session, *, id: int) -> Optional[ChatHistory]:
|
||||
return db.query(self.model).filter(ChatHistory.conversation_id == id).first()
|
||||
|
||||
def update_messages(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
db_obj: ChatHistory,
|
||||
obj_in: Union[ChatUpdate, Dict[str, Any]]
|
||||
) -> ChatHistory:
|
||||
try:
|
||||
obj_data = object_mapper(db_obj).data
|
||||
if isinstance(obj_in, dict):
|
||||
update_data = obj_in
|
||||
else:
|
||||
update_data = obj_in.dict(exclude_unset=True)
|
||||
|
||||
# Update the `messages` field by appending new messages
|
||||
existing_messages = obj_data.get("messages", [])
|
||||
new_messages = update_data.get("messages", [])
|
||||
obj_data["messages"] = existing_messages + new_messages
|
||||
|
||||
for field, value in obj_data.items():
|
||||
setattr(db_obj, field, value)
|
||||
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Integrity Error: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
chat = CRUDChat(ChatHistory)
|
@ -6,4 +6,5 @@ from .document import Document
|
||||
from .subscription import Subscription
|
||||
from .department import Department
|
||||
from .audit import Audit
|
||||
from .document_department import document_department_association
|
||||
from .document_department import document_department_association
|
||||
from .chat_history import ChatHistory
|
39
private_gpt/users/models/chat_history.py
Normal file
39
private_gpt/users/models/chat_history.py
Normal file
@ -0,0 +1,39 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey, JSON
|
||||
|
||||
from private_gpt.users.db.base_class import Base
|
||||
|
||||
|
||||
class ChatHistory(Base):
|
||||
"""Models a chat history table"""
|
||||
|
||||
__tablename__ = "chat_history"
|
||||
|
||||
id = Column(Integer, nullable=False, primary_key=True)
|
||||
title = Column(String(255), nullable=False)
|
||||
messages = Column(JSON, nullable=False)
|
||||
created_at = Column(DateTime, default=datetime.now)
|
||||
user_id = Column(Integer, ForeignKey("users.id"))
|
||||
user = relationship("User", back_populates="chat_histories")
|
||||
|
||||
def __init__(self, messages, user_id, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.messages = messages
|
||||
self.user_id = user_id
|
||||
self.title = self.generate_title()
|
||||
|
||||
def generate_title(self):
|
||||
if self.messages:
|
||||
first_user_message = next(+
|
||||
(msg["message"]
|
||||
for msg in self.messages if msg["sender"] == "user"), None
|
||||
)
|
||||
if first_user_message:
|
||||
return first_user_message[:30]
|
||||
return "Untitled Chat"
|
||||
|
||||
def __repr__(self):
|
||||
"""Returns string representation of model instance"""
|
||||
return f"<ChatHistory {self.id!r}>"
|
@ -51,6 +51,7 @@ class User(Base):
|
||||
Integer, ForeignKey("departments.id"), nullable=False)
|
||||
|
||||
department = relationship("Department", back_populates="users")
|
||||
chat_histories = relationship("ChatHistory", back_populates="user")
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint('username', name='unique_username_no_spacing'),
|
||||
|
@ -7,3 +7,4 @@ from .company import Company, CompanyBase, CompanyCreate, CompanyUpdate
|
||||
from .documents import Document, DocumentCreate, DocumentsBase, DocumentUpdate, DocumentList, DepartmentList, DocumentEnable, DocumentDepartmentUpdate, DocumentCheckerUpdate, DocumentMakerCreate, DocumentDepartmentList, DocumentView, DocumentVerify, DocumentFilter
|
||||
from .department import Department, DepartmentCreate, DepartmentUpdate, DepartmentAdminCreate, DepartmentDelete
|
||||
from .audit import AuditBase, AuditCreate, AuditUpdate, Audit, GetAudit
|
||||
from .chat_history import Chat, ChatBase, ChatCreate, ChatDelete, ChatUpdate, ChatMessages
|
||||
|
30
private_gpt/users/schemas/chat_history.py
Normal file
30
private_gpt/users/schemas/chat_history.py
Normal file
@ -0,0 +1,30 @@
|
||||
from typing import Optional, Dict, Any
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class ChatBase(BaseModel):
|
||||
title: Optional[str]
|
||||
|
||||
class ChatCreate(ChatBase):
|
||||
user_id: int
|
||||
messages: Optional[Dict[str, Any]]
|
||||
|
||||
class ChatUpdate(ChatBase):
|
||||
conversation_id: int
|
||||
messages: Dict
|
||||
|
||||
|
||||
class ChatDelete(BaseModel):
|
||||
conversation_id: int
|
||||
|
||||
class ChatMessages(BaseModel):
|
||||
messages: Dict[str, Any]
|
||||
|
||||
class Chat(ChatBase):
|
||||
conversation_id: int
|
||||
created_at: datetime
|
||||
user_id: int
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
@ -30,7 +30,6 @@ class UserUpdate(BaseModel):
|
||||
last_login: Optional[datetime] = None
|
||||
|
||||
|
||||
|
||||
class UserLoginSchema(BaseModel):
|
||||
email: EmailStr = Field(alias="email")
|
||||
password: str
|
||||
|
Loading…
Reference in New Issue
Block a user