mirror of
https://github.com/imartinez/privateGPT.git
synced 2025-08-01 23:47:54 +00:00
Added apis for documents update and user checker mode
This commit is contained in:
parent
f011bb6a7a
commit
e2bad96854
32
alembic/versions/59b6ae907209_added_checker.py
Normal file
32
alembic/versions/59b6ae907209_added_checker.py
Normal file
@ -0,0 +1,32 @@
|
||||
"""Added checker
|
||||
|
||||
Revision ID: 59b6ae907209
|
||||
Revises: ee8ae7222697
|
||||
Create Date: 2024-03-07 16:48:21.115238
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '59b6ae907209'
|
||||
down_revision: Union[str, None] = 'ee8ae7222697'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
# op.create_unique_constraint('unique_user_role', 'user_roles', ['user_id', 'role_id', 'company_id'])
|
||||
op.add_column('users', sa.Column('checker', sa.Boolean(), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column('users', 'checker')
|
||||
# op.drop_constraint('unique_user_role', 'user_roles', type_='unique')
|
||||
# ### end Alembic commands ###
|
@ -1,8 +1,8 @@
|
||||
"""Documents association
|
||||
"""added is_enabled in documents
|
||||
|
||||
Revision ID: 2f490371bf6c
|
||||
Revision ID: ee8ae7222697
|
||||
Revises: f2978211af18
|
||||
Create Date: 2024-03-06 17:17:54.701414
|
||||
Create Date: 2024-03-07 15:34:22.365353
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
@ -12,7 +12,7 @@ import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '2f490371bf6c'
|
||||
revision: str = 'ee8ae7222697'
|
||||
down_revision: Union[str, None] = 'f2978211af18'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
@ -20,6 +20,7 @@ depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('document', sa.Column('is_enabled', sa.Boolean(), nullable=True))
|
||||
# op.create_unique_constraint('unique_user_role', 'user_roles', ['user_id', 'role_id', 'company_id'])
|
||||
# ### end Alembic commands ###
|
||||
|
||||
@ -27,4 +28,5 @@ def upgrade() -> None:
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
# op.drop_constraint('unique_user_role', 'user_roles', type_='unique')
|
||||
op.drop_column('document', 'is_enabled')
|
||||
# ### end Alembic commands ###
|
@ -125,7 +125,7 @@ async def process_both(
|
||||
@pdf_router.post("/pdf_ocr")
|
||||
async def get_pdf_ocr_wrapper(
|
||||
request: Request,
|
||||
departments: schemas.DepartmentList = Depends(),
|
||||
departments: schemas.DocumentDepartmentList = Depends(),
|
||||
db: Session = Depends(deps.get_db),
|
||||
log_audit: models.Audit = Depends(deps.get_audit_logger),
|
||||
file: UploadFile = File(...),
|
||||
@ -140,7 +140,7 @@ async def get_pdf_ocr_wrapper(
|
||||
@pdf_router.post("/both")
|
||||
async def get_both_wrapper(
|
||||
request: Request,
|
||||
departments: schemas.DepartmentList = Depends(),
|
||||
departments: schemas.DocumentDepartmentList = Depends(),
|
||||
db: Session = Depends(deps.get_db),
|
||||
log_audit: models.Audit = Depends(deps.get_audit_logger),
|
||||
file: UploadFile = File(...),
|
||||
|
@ -110,33 +110,33 @@ async def prompt_completion(
|
||||
),
|
||||
) -> OpenAICompletion | StreamingResponse:
|
||||
|
||||
# service = request.state.injector.get(IngestService)
|
||||
# try:
|
||||
service = request.state.injector.get(IngestService)
|
||||
try:
|
||||
department = crud.department.get_by_id(
|
||||
db, id=current_user.department_id)
|
||||
if not department:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"No department assigned to you")
|
||||
documents = crud.documents.get_enabled_documents_by_departments(
|
||||
db, department_id=department.id)
|
||||
if not documents:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"No documents uploaded for your department.")
|
||||
docs_list = [document.filename for document in documents]
|
||||
print("DOCUMENTS ASSIGNED TO THIS DEPARTMENTS: ", docs_list)
|
||||
docs_ids = []
|
||||
for filename in docs_list:
|
||||
doc_id = service.get_doc_ids_by_filename(filename)
|
||||
docs_ids.extend(doc_id)
|
||||
body.context_filter = {"docs_ids": docs_ids}
|
||||
|
||||
# department = crud.department.get_by_id(
|
||||
# db, id=current_user.department_id)
|
||||
# if not department:
|
||||
# raise HTTPException(status_code=status.HTTP_404_NOT_FOUND,
|
||||
# detail=f"No department assigned to you")
|
||||
# documents = crud.documents.get_multi_documents(
|
||||
# db, department_id=department.id)
|
||||
# if not documents:
|
||||
# raise HTTPException(status_code=status.HTTP_404_NOT_FOUND,
|
||||
# detail=f"No documents uploaded for your department.")
|
||||
# docs_list = [document.filename for document in documents]
|
||||
# docs_ids = []
|
||||
# for filename in docs_list:
|
||||
# doc_id = service.get_doc_ids_by_filename(filename)
|
||||
# docs_ids.extend(doc_id)
|
||||
# body.context_filter = {"docs_ids": docs_ids}
|
||||
|
||||
# except Exception as e:
|
||||
# print(traceback.format_exc())
|
||||
# logger.error(f"There was an error: {str(e)}")
|
||||
# raise HTTPException(
|
||||
# status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
# detail="Internal Server Error",
|
||||
# )
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
logger.error(f"There was an error: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Internal Server Error",
|
||||
)
|
||||
|
||||
messages = [OpenAIMessage(content=body.prompt, role="user")]
|
||||
if body.system_prompt:
|
||||
|
@ -184,7 +184,7 @@ async def create_documents(
|
||||
db: Session,
|
||||
file_name: str = None,
|
||||
current_user: models.User = None,
|
||||
departments: schemas.DepartmentList = Depends(),
|
||||
departments: schemas.DocumentDepartmentList = Depends(),
|
||||
log_audit: models.Audit = None,
|
||||
):
|
||||
"""
|
||||
@ -226,7 +226,7 @@ async def common_ingest_logic(
|
||||
ocr_file,
|
||||
original_file: str = None,
|
||||
current_user: models.User = None,
|
||||
departments: schemas.DepartmentList = Depends(),
|
||||
departments: schemas.DocumentDepartmentList = Depends(),
|
||||
log_audit: models.Audit = None,
|
||||
):
|
||||
service = request.state.injector.get(IngestService)
|
||||
@ -245,8 +245,7 @@ async def common_ingest_logic(
|
||||
with open(upload_path, "wb") as f:
|
||||
f.write(file.read())
|
||||
file.seek(0)
|
||||
ingested_documents = service.ingest_bin_data(file_name, file)
|
||||
|
||||
ingested_documents = service.ingest_bin_data(file_name, file)
|
||||
# Handling Original File
|
||||
if original_file:
|
||||
try:
|
||||
@ -297,7 +296,7 @@ async def common_ingest_logic(
|
||||
@ingest_router.post("/ingest/file", response_model=IngestResponse, tags=["Ingestion"])
|
||||
async def ingest_file(
|
||||
request: Request,
|
||||
departments: schemas.DepartmentList = Depends(),
|
||||
departments: schemas.DocumentDepartmentList = Depends(),
|
||||
file: UploadFile = File(...),
|
||||
log_audit: models.Audit = Depends(deps.get_audit_logger),
|
||||
db: Session = Depends(deps.get_db),
|
||||
|
@ -27,10 +27,6 @@ def list_files(
|
||||
scopes=[Role.ADMIN["name"], Role.SUPER_ADMIN["name"], Role.OPERATOR["name"]],
|
||||
)
|
||||
):
|
||||
def get_department_name(db, id):
|
||||
dep = crud.department.get_by_id(db=db, id=id)
|
||||
return dep.name
|
||||
|
||||
def get_username(db, id):
|
||||
user = crud.user.get_by_id(db=db, id=id)
|
||||
return user.fullname
|
||||
@ -43,18 +39,21 @@ def list_files(
|
||||
docs = crud.documents.get_multi_documents(
|
||||
db, department_id=current_user.department_id, skip=skip, limit=limit)
|
||||
|
||||
docs = [
|
||||
documents = [
|
||||
schemas.Document(
|
||||
id=doc.id,
|
||||
filename=doc.filename,
|
||||
uploaded_at=doc.uploaded_at,
|
||||
uploaded_by=get_username(db, doc.uploaded_by),
|
||||
# department=get_department_name(db, doc.department_id)
|
||||
department="deparments"
|
||||
uploaded_at=doc.uploaded_at,
|
||||
is_enabled=doc.is_enabled,
|
||||
departments=[
|
||||
schemas.DepartmentList(id=dep.id, name=dep.name)
|
||||
for dep in doc.departments
|
||||
]
|
||||
)
|
||||
for doc in docs
|
||||
]
|
||||
return docs
|
||||
return documents
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
logger.error(f"There was an error listing the file(s).")
|
||||
@ -64,8 +63,7 @@ def list_files(
|
||||
)
|
||||
|
||||
|
||||
|
||||
@router.get('{department_id}', response_model=List[schemas.Document])
|
||||
@router.get('{department_id}', response_model=List[schemas.DocumentList])
|
||||
def list_files_by_department(
|
||||
request: Request,
|
||||
department_id: int,
|
||||
@ -81,7 +79,7 @@ def list_files_by_department(
|
||||
Listing the documents by the department id
|
||||
'''
|
||||
try:
|
||||
docs = crud.documents.get_multi_documents(
|
||||
docs = crud.documents.get_documents_by_departments(
|
||||
db, department_id=department_id, skip=skip, limit=limit)
|
||||
return docs
|
||||
except Exception as e:
|
||||
@ -101,7 +99,7 @@ def list_files_by_department(
|
||||
limit: int = 100,
|
||||
current_user: models.User = Security(
|
||||
deps.get_current_user,
|
||||
scopes=[Role.ADMIN["name"]],
|
||||
scopes=[Role.ADMIN["name"], Role.SUPER_ADMIN["name"], Role.OPERATOR["name"]],
|
||||
)
|
||||
):
|
||||
'''
|
||||
@ -109,7 +107,7 @@ def list_files_by_department(
|
||||
'''
|
||||
try:
|
||||
department_id = current_user.department_id
|
||||
docs = crud.documents.get_multi_documents(
|
||||
docs = crud.documents.get_documents_by_departments(
|
||||
db, department_id=department_id, skip=skip, limit=limit)
|
||||
return docs
|
||||
except Exception as e:
|
||||
@ -119,3 +117,88 @@ def list_files_by_department(
|
||||
status_code=500,
|
||||
detail="Internal Server Error.",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@router.post('/update', response_model=schemas.DocumentEnable)
|
||||
def update_document(
|
||||
request: Request,
|
||||
document_in: schemas.DocumentEnable ,
|
||||
db: Session = Depends(deps.get_db),
|
||||
log_audit: models.Audit = Depends(deps.get_audit_logger),
|
||||
current_user: models.User = Security(
|
||||
deps.get_current_user,
|
||||
scopes=[Role.SUPER_ADMIN["name"], Role.OPERATOR["name"]],
|
||||
)
|
||||
):
|
||||
try:
|
||||
document = crud.documents.get_by_filename(
|
||||
db, file_name=document_in.filename)
|
||||
if not document:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Document with this filename doesn't exist!",
|
||||
)
|
||||
docs = crud.documents.update(db=db, db_obj=document, obj_in=document_in)
|
||||
log_audit(
|
||||
model='Document',
|
||||
action='update',
|
||||
details={
|
||||
'detail': f'{document_in.filename} status changed to {document_in.is_enabled} from {document.is_enabled}'
|
||||
},
|
||||
user_id=current_user.id
|
||||
)
|
||||
return docs
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
logger.error(f"There was an error listing the file(s).")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Internal Server Error.",
|
||||
)
|
||||
|
||||
|
||||
@router.post('/department_update', response_model=schemas.DocumentList)
|
||||
def update_department(
|
||||
request: Request,
|
||||
document_in: schemas.DocumentDepartmentUpdate ,
|
||||
db: Session = Depends(deps.get_db),
|
||||
log_audit: models.Audit = Depends(deps.get_audit_logger),
|
||||
current_user: models.User = Security(
|
||||
deps.get_current_user,
|
||||
scopes=[Role.SUPER_ADMIN["name"], Role.OPERATOR["name"]],
|
||||
)
|
||||
):
|
||||
"""
|
||||
Update the department list for the documents
|
||||
"""
|
||||
try:
|
||||
document = crud.documents.get_by_filename(
|
||||
db, file_name=document_in.filename)
|
||||
old_departments = document.departments
|
||||
if not document:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Document with this filename doesn't exist!",
|
||||
)
|
||||
department_ids = [int(number) for number in document_in.departments]
|
||||
print("Department update: ", document_in, department_ids)
|
||||
for department_id in department_ids:
|
||||
db.execute(models.document_department_association.insert().values(document_id=document.id, department_id=department_id))
|
||||
log_audit(
|
||||
model='Document',
|
||||
action='update',
|
||||
details={
|
||||
'detail': f'{document_in.filename} assigned to {department_ids} from {old_departments}'
|
||||
},
|
||||
user_id=current_user.id
|
||||
)
|
||||
return document
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
logger.error(f"There was an error listing the file(s).")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Internal Server Error.",
|
||||
)
|
||||
|
||||
|
@ -1,6 +1,9 @@
|
||||
from sqlalchemy import or_, and_
|
||||
from sqlalchemy.orm import Session
|
||||
from private_gpt.users.schemas.documents import DocumentCreate, DocumentUpdate
|
||||
from private_gpt.users.models.document import Document
|
||||
from private_gpt.users.models.department import Department
|
||||
from private_gpt.users.models.document_department import document_department_association
|
||||
from private_gpt.users.crud.base import CRUDBase
|
||||
from typing import Optional, List
|
||||
|
||||
@ -23,4 +26,44 @@ class CRUDDocuments(CRUDBase[Document, DocumentCreate, DocumentUpdate]):
|
||||
.all()
|
||||
)
|
||||
|
||||
def get_documents_by_departments(
|
||||
self, db: Session, *, department_id: int, skip: int = 0, limit: int = 100
|
||||
) -> List[Document]:
|
||||
return (
|
||||
db.query(self.model)
|
||||
.join(document_department_association)
|
||||
.join(Department)
|
||||
.filter(document_department_association.c.department_id == department_id)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
def get_enabled_documents_by_departments(
|
||||
self, db: Session, *, department_id: int, skip: int = 0, limit: int = 100
|
||||
) -> List[Document]:
|
||||
all_department_id = 4 # department ID for "ALL" is 4
|
||||
|
||||
return (
|
||||
db.query(self.model)
|
||||
.join(document_department_association)
|
||||
.join(Department)
|
||||
.filter(
|
||||
or_(
|
||||
and_(
|
||||
document_department_association.c.department_id == department_id,
|
||||
Document.is_enabled == True,
|
||||
),
|
||||
and_(
|
||||
document_department_association.c.department_id == all_department_id,
|
||||
Document.is_enabled == True,
|
||||
),
|
||||
)
|
||||
)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
documents = CRUDDocuments(Document)
|
||||
|
@ -1,5 +1,5 @@
|
||||
from datetime import datetime
|
||||
from sqlalchemy import event, select, func, update
|
||||
from sqlalchemy import Boolean, event, select, func, update
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy import Column, Integer, String, ForeignKey, DateTime, Table
|
||||
from private_gpt.users.models.department import Department
|
||||
@ -26,7 +26,7 @@ class Document(Base):
|
||||
)
|
||||
uploaded_by_user = relationship(
|
||||
"User", back_populates="uploaded_documents")
|
||||
|
||||
is_enabled = Column(Boolean, default=True)
|
||||
# Use document_department_association as the secondary for the relationship
|
||||
departments = relationship(
|
||||
"Department",
|
||||
|
@ -37,6 +37,7 @@ class User(Base):
|
||||
)
|
||||
|
||||
password_created = Column(DateTime, nullable=True)
|
||||
checker = Column(Boolean, default=False)
|
||||
|
||||
company_id = Column(Integer, ForeignKey("companies.id"), nullable=True)
|
||||
company = relationship("Company", back_populates="users")
|
||||
|
@ -1,9 +1,9 @@
|
||||
from .role import Role, RoleCreate, RoleInDB, RoleUpdate
|
||||
from .token import TokenSchema, TokenPayload
|
||||
from .user import User, UserCreate, UserInDB, UserUpdate, UserBaseSchema, Profile, UsernameUpdate, DeleteUser, UserAdminUpdate, UserAdmin, PasswordUpdate
|
||||
from .user import User, UserCreate, UserInDB, UserUpdate, UserBaseSchema, Profile, UsernameUpdate, DeleteUser, UserAdminUpdate, UserAdmin, PasswordUpdate, SuperMakerUpdate
|
||||
from .user_role import UserRole, UserRoleCreate, UserRoleInDB, UserRoleUpdate
|
||||
from .subscription import Subscription, SubscriptionBase, SubscriptionCreate, SubscriptionUpdate
|
||||
from .company import Company, CompanyBase, CompanyCreate, CompanyUpdate
|
||||
from .documents import Document, DocumentCreate, DocumentsBase, DocumentUpdate, DocumentList
|
||||
from .department import Department, DepartmentCreate, DepartmentUpdate, DepartmentAdminCreate, DepartmentDelete, DepartmentList
|
||||
from .documents import Document, DocumentCreate, DocumentsBase, DocumentUpdate, DocumentList, DepartmentList, DocumentEnable, DocumentDepartmentUpdate
|
||||
from .department import Department, DepartmentCreate, DepartmentUpdate, DepartmentAdminCreate, DepartmentDelete, DocumentDepartmentList
|
||||
from .audit import AuditBase, AuditCreate, AuditUpdate, Audit, GetAudit
|
@ -51,5 +51,5 @@ class Department(DepartmentBase):
|
||||
orm_mode = True
|
||||
|
||||
|
||||
class DepartmentList(BaseModel):
|
||||
class DocumentDepartmentList(BaseModel):
|
||||
departments_ids: str = Form(...)
|
@ -15,17 +15,24 @@ class DocumentCreate(DocumentsBase):
|
||||
class DocumentUpdate(DocumentsBase):
|
||||
pass
|
||||
|
||||
class DocumentEnable(DocumentsBase):
|
||||
is_enabled: bool
|
||||
|
||||
class DocumentDepartmentUpdate(DocumentsBase):
|
||||
departments: List[int] = []
|
||||
|
||||
class DocumentList(DocumentsBase):
|
||||
id: int
|
||||
is_enabled: bool
|
||||
uploaded_by: int
|
||||
uploaded_at: datetime
|
||||
departments: List[DepartmentList] = []
|
||||
|
||||
class Document(DocumentsBase):
|
||||
class Document(BaseModel):
|
||||
id: int
|
||||
is_enabled: bool
|
||||
filename: str
|
||||
uploaded_by: int
|
||||
uploaded_by: str
|
||||
uploaded_at: datetime
|
||||
departments: List[DepartmentList] = []
|
||||
|
||||
|
@ -80,3 +80,6 @@ class UserAdmin(BaseModel):
|
||||
|
||||
class PasswordUpdate(BaseModel):
|
||||
password_created: Optional[datetime] = None
|
||||
|
||||
class SuperMakerUpdate(BaseModel):
|
||||
checker: bool
|
Loading…
Reference in New Issue
Block a user