mirror of
https://github.com/imartinez/privateGPT.git
synced 2025-07-12 14:48:00 +00:00
Updated with new api for chat with context filtering for based on files wrt to departments
This commit is contained in:
parent
1f5c0d5d7b
commit
500d4a1494
@ -1,5 +1,12 @@
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi import APIRouter, Depends, Request, Security, HTTPException, status
|
||||
from private_gpt.server.ingest.ingest_service import IngestService
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
import traceback
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from private_gpt.open_ai.extensions.context_filter import ContextFilter
|
||||
@ -9,7 +16,8 @@ from private_gpt.open_ai.openai_models import (
|
||||
)
|
||||
from private_gpt.server.chat.chat_router import ChatBody, chat_completion
|
||||
from private_gpt.server.utils.auth import authenticated
|
||||
|
||||
from private_gpt.users import crud, models, schemas
|
||||
from private_gpt.users.api import deps
|
||||
completions_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)])
|
||||
|
||||
|
||||
@ -83,3 +91,59 @@ def prompt_completion(
|
||||
context_filter=body.context_filter,
|
||||
)
|
||||
return chat_completion(request, chat_body)
|
||||
|
||||
|
||||
@completions_router.post(
|
||||
"/chat",
|
||||
response_model=None,
|
||||
summary="Completion",
|
||||
responses={200: {"model": OpenAICompletion}},
|
||||
tags=["Contextual Completions"],
|
||||
)
|
||||
def prompt_completion(
|
||||
request: Request,
|
||||
body: CompletionsBody,
|
||||
db: Session = Depends(deps.get_db),
|
||||
current_user: models.User = Security(
|
||||
deps.get_current_user,
|
||||
),
|
||||
) -> OpenAICompletion | StreamingResponse:
|
||||
try:
|
||||
service = request.state.injector.get(IngestService)
|
||||
|
||||
department = crud.department.get_by_id(db, id=current_user.department_id)
|
||||
if not department:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"No department assigned to you")
|
||||
documents = crud.documents.get_multi_documents(db, department_id=department.id)
|
||||
if not documents:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"No documents uploaded for your department.")
|
||||
docs_list = [document.filename for document in documents]
|
||||
docs_ids = []
|
||||
for filename in docs_list:
|
||||
doc_id = service.get_doc_ids_by_filename(filename)
|
||||
docs_ids.extend(doc_id)
|
||||
body.context_filter = {"docs_ids": docs_ids}
|
||||
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
logger.error(f"There was an error: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Internal Server Error",
|
||||
)
|
||||
|
||||
messages = [OpenAIMessage(content=body.prompt, role="user")]
|
||||
if body.system_prompt:
|
||||
messages.insert(0, OpenAIMessage(
|
||||
content=body.system_prompt, role="system"))
|
||||
|
||||
chat_body = ChatBody(
|
||||
messages=messages,
|
||||
use_context=body.use_context,
|
||||
stream=body.stream,
|
||||
include_sources=body.include_sources,
|
||||
context_filter=body.context_filter,
|
||||
)
|
||||
return chat_completion(request, chat_body)
|
||||
|
@ -1,5 +1,7 @@
|
||||
import os
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, List
|
||||
|
||||
@ -8,7 +10,6 @@ from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile, File
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from private_gpt.home import Home
|
||||
from private_gpt.users import crud, models, schemas
|
||||
from private_gpt.users.api import deps
|
||||
from private_gpt.users.constants.role import Role
|
||||
@ -194,7 +195,7 @@ def ingest_file(
|
||||
)
|
||||
|
||||
try:
|
||||
docs_in = schemas.DocumentCreate(filename=file.filename, uploaded_by=current_user.id)
|
||||
docs_in = schemas.DocumentCreate(filename=file.filename, uploaded_by=current_user.id, department_id=current_user.department_id)
|
||||
crud.documents.create(db=db, obj_in=docs_in)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
@ -216,7 +217,6 @@ def ingest_file(
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"There was an error uploading the file(s): {str(e)}")
|
||||
print("ERROR: ", e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Internal Server Error: Unable to ingest file.",
|
||||
@ -250,13 +250,11 @@ async def common_ingest_logic(
|
||||
)
|
||||
|
||||
docs_in = schemas.DocumentCreate(
|
||||
filename=file_name, uploaded_by=current_user.id)
|
||||
filename=file_name, uploaded_by=current_user.id, department_id=current_user.department_id)
|
||||
crud.documents.create(db=db, obj_in=docs_in)
|
||||
|
||||
with open(upload_path, "wb") as f:
|
||||
f.write(file.read())
|
||||
|
||||
# Ingest binary data
|
||||
file.seek(0) # Move the file pointer back to the beginning
|
||||
ingested_documents = service.ingest_bin_data(file_name, file)
|
||||
|
||||
@ -270,35 +268,7 @@ async def common_ingest_logic(
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"There was an error uploading the file(s): {str(e)}")
|
||||
print("ERROR: ", e)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Internal Server Error: Unable to ingest file.",
|
||||
)
|
||||
|
||||
from private_gpt.users.schemas import Document
|
||||
|
||||
@ingest_router.get("/ingest/list_files", response_model=List[schemas.Document], tags=["Ingestion"])
|
||||
def list_files(
|
||||
request: Request,
|
||||
db: Session = Depends(deps.get_db),
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
current_user: models.User = Security(
|
||||
deps.get_current_user,
|
||||
scopes=[Role.ADMIN["name"], Role.SUPER_ADMIN["name"]],
|
||||
|
||||
)
|
||||
):
|
||||
try:
|
||||
docs = crud.documents.get_multi(db, skip=skip, limit=limit)
|
||||
return docs
|
||||
except Exception as e:
|
||||
logger.error(f"There was an error uploading the file(s): {str(e)}")
|
||||
print("ERROR: ", e)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Internal Server Error: Unable to ingest file.",
|
||||
)
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
from private_gpt.users.api.v1.routers import auth, roles, user_roles, users, subscriptions, companies, departments
|
||||
from private_gpt.users.api.v1.routers import auth, roles, user_roles, users, subscriptions, companies, departments, documents
|
||||
from fastapi import APIRouter
|
||||
|
||||
api_router = APIRouter(prefix="/v1")
|
||||
@ -10,4 +10,5 @@ api_router.include_router(user_roles.router)
|
||||
api_router.include_router(companies.router)
|
||||
api_router.include_router(subscriptions.router)
|
||||
api_router.include_router(departments.router)
|
||||
api_router.include_router(documents.router)
|
||||
|
||||
|
@ -10,7 +10,7 @@ from private_gpt.users.constants.role import Role
|
||||
from private_gpt.users import crud, models, schemas
|
||||
|
||||
|
||||
router = APIRouter(prefix="/departments", tags=["Deparments"])
|
||||
router = APIRouter(prefix="/departments", tags=["Departments"])
|
||||
|
||||
|
||||
@router.get("", response_model=List[schemas.Department])
|
||||
@ -42,7 +42,9 @@ def create_deparment(
|
||||
"""
|
||||
Create a new department
|
||||
"""
|
||||
deparment = crud.department.create(db=db, obj_in=department_in)
|
||||
company_id = current_user.company_id
|
||||
department_create_in = schemas.DepartmentAdminCreate(name=department_in.name, company_id=company_id)
|
||||
department = crud.department.create(db=db, obj_in=department_create_in)
|
||||
department = jsonable_encoder(department)
|
||||
|
||||
return JSONResponse(
|
||||
|
@ -0,0 +1,94 @@
|
||||
import traceback
|
||||
import logging
|
||||
from typing import Any, List
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Security, Request
|
||||
|
||||
from private_gpt.users.api import deps
|
||||
from private_gpt.users.constants.role import Role
|
||||
from private_gpt.users import crud, models, schemas
|
||||
from private_gpt.users.schemas import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix='/documents', tags=['Documents'])
|
||||
|
||||
@router.get("", response_model=List[schemas.Document])
|
||||
def list_files(
|
||||
request: Request,
|
||||
db: Session = Depends(deps.get_db),
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
current_user: models.User = Security(
|
||||
deps.get_current_user,
|
||||
scopes=[Role.SUPER_ADMIN["name"]],
|
||||
)
|
||||
):
|
||||
try:
|
||||
docs = crud.documents.get_multi(db, skip=skip, limit=limit)
|
||||
return docs
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
logger.error(f"There was an error listing the file(s).")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Internal Server Error: Unable to ingest file.",
|
||||
)
|
||||
|
||||
|
||||
@router.get('{department_id}', response_model=List[schemas.Document])
|
||||
def list_files_by_department(
|
||||
request: Request,
|
||||
department_id: int,
|
||||
db: Session = Depends(deps.get_db),
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
current_user: models.User = Security(
|
||||
deps.get_current_user,
|
||||
scopes=[Role.SUPER_ADMIN["name"]],
|
||||
)
|
||||
):
|
||||
'''
|
||||
Listing the documents by the department id
|
||||
'''
|
||||
try:
|
||||
docs = crud.documents.get_multi_documents(
|
||||
db, department_id=department_id, skip=skip, limit=limit)
|
||||
return docs
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
logger.error(f"There was an error listing the file(s).")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Internal Server Error: Unable to ingest file.",
|
||||
)
|
||||
|
||||
|
||||
@router.get('/files', response_model=List[schemas.DocumentList])
|
||||
def list_files_by_department(
|
||||
request: Request,
|
||||
db: Session = Depends(deps.get_db),
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
current_user: models.User = Security(
|
||||
deps.get_current_user,
|
||||
scopes=[Role.ADMIN["name"]],
|
||||
)
|
||||
):
|
||||
'''
|
||||
Listing the documents by the ADMIN of the Department
|
||||
'''
|
||||
try:
|
||||
department_id = current_user.department_id
|
||||
docs = crud.documents.get_multi_documents(
|
||||
db, department_id=department_id, skip=skip, limit=limit)
|
||||
return docs
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
logger.error(f"There was an error listing the file(s).")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Internal Server Error: Unable to ingest file.",
|
||||
)
|
@ -95,7 +95,8 @@ def update_username(
|
||||
user_data = schemas.UserBaseSchema(
|
||||
email=user.email,
|
||||
fullname=user.fullname,
|
||||
company_id=user.company_id
|
||||
company_id=user.company_id,
|
||||
department_id=user.department_id,
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
@ -114,11 +115,11 @@ def read_user_me(
|
||||
Get current user.
|
||||
"""
|
||||
role = current_user.user_role.role.name if current_user.user_role else None
|
||||
print("THe role is: ", role)
|
||||
user_data = schemas.Profile(
|
||||
email=current_user.email,
|
||||
fullname=current_user.fullname,
|
||||
company_id = current_user.company_id,
|
||||
department_id=current_user.department_id,
|
||||
role =role
|
||||
)
|
||||
return JSONResponse(
|
||||
@ -151,6 +152,7 @@ def change_password(
|
||||
email=current_user.email,
|
||||
fullname=current_user.fullname,
|
||||
company_id= current_user.company_id,
|
||||
department_id=current_user.department_id,
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
@ -205,6 +207,8 @@ def update_user(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
fullname=user.fullname,
|
||||
company_id=user.company_id,
|
||||
department_id=user.department_id,
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
@ -254,6 +258,7 @@ def admin_change_password(
|
||||
email=user.email,
|
||||
fullname=user.fullname,
|
||||
company_id=user.company_id,
|
||||
department_id=user.department_id,
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
@ -331,8 +336,7 @@ def admin_update_user(
|
||||
role = crud.user_role.update(db, db_obj=user_role, obj_in=role_in)
|
||||
|
||||
user_in = schemas.UserUpdate(fullname=user_update.fullname,
|
||||
email=existing_user.email, company_id=existing_user.user_role.company_id)
|
||||
print("User in: ", user_in)
|
||||
email=existing_user.email, company_id=existing_user.user_role.company_id, department_id=user_update.department_id)
|
||||
user = crud.user.update(db, db_obj=existing_user, obj_in=user_in)
|
||||
|
||||
return JSONResponse(
|
||||
|
@ -1,5 +1,5 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from private_gpt.users.schemas.deparment import DepartmentCreate, DepartmentUpdate
|
||||
from private_gpt.users.schemas.department import DepartmentCreate, DepartmentUpdate
|
||||
from private_gpt.users.models.department import Department
|
||||
from private_gpt.users.crud.base import CRUDBase
|
||||
from typing import Optional
|
||||
|
@ -2,7 +2,7 @@ from sqlalchemy.orm import Session
|
||||
from private_gpt.users.schemas.documents import DocumentCreate, DocumentUpdate
|
||||
from private_gpt.users.models.document import Document
|
||||
from private_gpt.users.crud.base import CRUDBase
|
||||
from typing import Optional
|
||||
from typing import Optional, List
|
||||
|
||||
|
||||
class CRUDDocuments(CRUDBase[Document, DocumentCreate, DocumentUpdate]):
|
||||
@ -11,6 +11,16 @@ class CRUDDocuments(CRUDBase[Document, DocumentCreate, DocumentUpdate]):
|
||||
|
||||
def get_by_filename(self, db: Session, *, file_name: str) -> Optional[Document]:
|
||||
return db.query(self.model).filter(Document.filename == file_name).first()
|
||||
|
||||
|
||||
def get_multi_documents(
|
||||
self, db: Session, *,department_id: int, skip: int = 0, limit: int = 100
|
||||
) -> List[Document]:
|
||||
return (
|
||||
db.query(self.model)
|
||||
.filter(Document.department_id == department_id)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
documents = CRUDDocuments(Document)
|
||||
|
@ -4,5 +4,5 @@ from .user import User, UserCreate, UserInDB, UserUpdate, UserBaseSchema, Profil
|
||||
from .user_role import UserRole, UserRoleCreate, UserRoleInDB, UserRoleUpdate
|
||||
from .subscription import Subscription, SubscriptionBase, SubscriptionCreate, SubscriptionUpdate
|
||||
from .company import Company, CompanyBase, CompanyCreate, CompanyUpdate
|
||||
from .documents import Document, DocumentCreate, DocumentsBase, DocumentUpdate
|
||||
from .deparment import Department, DepartmentCreate, DepartmentUpdate
|
||||
from .documents import Document, DocumentCreate, DocumentsBase, DocumentUpdate, DocumentList
|
||||
from .department import Department, DepartmentCreate, DepartmentUpdate, DepartmentAdminCreate
|
@ -21,6 +21,11 @@ class DepartmentInDB(DepartmentBase):
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
class DepartmentAdminCreate(DepartmentBase):
|
||||
company_id: int
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
class Department(DepartmentInDB):
|
||||
pass
|
@ -9,16 +9,23 @@ class DocumentsBase(BaseModel):
|
||||
|
||||
class DocumentCreate(DocumentsBase):
|
||||
uploaded_by: int
|
||||
department_id: int
|
||||
|
||||
|
||||
class DocumentUpdate(DocumentsBase):
|
||||
pass
|
||||
|
||||
class DocumentList(DocumentsBase):
|
||||
id: int
|
||||
uploaded_by: int
|
||||
uploaded_at: datetime
|
||||
|
||||
|
||||
class Document(DocumentsBase):
|
||||
id: int
|
||||
uploaded_by: int
|
||||
uploaded_at: datetime
|
||||
department_id: int
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
Loading…
Reference in New Issue
Block a user