Updated with department model

This commit is contained in:
Saurab-Shrestha 2024-02-22 07:11:50 +05:45
parent 500d4a1494
commit 062a0ae7da
12 changed files with 171 additions and 79 deletions

View File

@ -0,0 +1,34 @@
"""update department model
Revision ID: 36beb9b73c64
Revises: 0aeaf9df35a6
Create Date: 2024-02-21 15:12:07.840057
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '36beb9b73c64'
down_revision: Union[str, None] = '0aeaf9df35a6'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('departments', sa.Column('total_users', sa.Integer(), nullable=True))
op.add_column('departments', sa.Column('total_documents', sa.Integer(), nullable=True))
# op.create_unique_constraint('unique_user_role', 'user_roles', ['user_id', 'role_id', 'company_id'])
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
# op.drop_constraint('unique_user_role', 'user_roles', type_='unique')
op.drop_column('departments', 'total_documents')
op.drop_column('departments', 'total_users')
# ### end Alembic commands ###

View File

@ -16,8 +16,8 @@ from private_gpt.open_ai.openai_models import (
) )
from private_gpt.server.chat.chat_router import ChatBody, chat_completion from private_gpt.server.chat.chat_router import ChatBody, chat_completion
from private_gpt.server.utils.auth import authenticated from private_gpt.server.utils.auth import authenticated
from private_gpt.users import crud, models, schemas
from private_gpt.users.api import deps from private_gpt.users.api import deps
from private_gpt.users import crud, models, schemas
completions_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)]) completions_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)])
@ -44,53 +44,53 @@ class CompletionsBody(BaseModel):
} }
@completions_router.post( # @completions_router.post(
"/completions", # "/completions",
response_model=None, # response_model=None,
summary="Completion", # summary="Completion",
responses={200: {"model": OpenAICompletion}}, # responses={200: {"model": OpenAICompletion}},
tags=["Contextual Completions"], # tags=["Contextual Completions"],
) # )
def prompt_completion( # def prompt_completion(
request: Request, body: CompletionsBody # request: Request, body: CompletionsBody
) -> OpenAICompletion | StreamingResponse: # ) -> OpenAICompletion | StreamingResponse:
"""We recommend most users use our Chat completions API. # """We recommend most users use our Chat completions API.
Given a prompt, the model will return one predicted completion. # Given a prompt, the model will return one predicted completion.
Optionally include a `system_prompt` to influence the way the LLM answers. # Optionally include a `system_prompt` to influence the way the LLM answers.
If `use_context` # If `use_context`
is set to `true`, the model will use context coming from the ingested documents # is set to `true`, the model will use context coming from the ingested documents
to create the response. The documents being used can be filtered using the # to create the response. The documents being used can be filtered using the
`context_filter` and passing the document IDs to be used. Ingested documents IDs # `context_filter` and passing the document IDs to be used. Ingested documents IDs
can be found using `/ingest/list` endpoint. If you want all ingested documents to # can be found using `/ingest/list` endpoint. If you want all ingested documents to
be used, remove `context_filter` altogether. # be used, remove `context_filter` altogether.
When using `'include_sources': true`, the API will return the source Chunks used # When using `'include_sources': true`, the API will return the source Chunks used
to create the response, which come from the context provided. # to create the response, which come from the context provided.
When using `'stream': true`, the API will return data chunks following [OpenAI's # When using `'stream': true`, the API will return data chunks following [OpenAI's
streaming model](https://platform.openai.com/docs/api-reference/chat/streaming): # streaming model](https://platform.openai.com/docs/api-reference/chat/streaming):
``` # ```
{"id":"12345","object":"completion.chunk","created":1694268190, # {"id":"12345","object":"completion.chunk","created":1694268190,
"model":"private-gpt","choices":[{"index":0,"delta":{"content":"Hello"}, # "model":"private-gpt","choices":[{"index":0,"delta":{"content":"Hello"},
"finish_reason":null}]} # "finish_reason":null}]}
``` # ```
""" # """
messages = [OpenAIMessage(content=body.prompt, role="user")] # messages = [OpenAIMessage(content=body.prompt, role="user")]
# If system prompt is passed, create a fake message with the system prompt. # # If system prompt is passed, create a fake message with the system prompt.
if body.system_prompt: # if body.system_prompt:
messages.insert(0, OpenAIMessage(content=body.system_prompt, role="system")) # messages.insert(0, OpenAIMessage(content=body.system_prompt, role="system"))
chat_body = ChatBody( # chat_body = ChatBody(
messages=messages, # messages=messages,
use_context=body.use_context, # use_context=body.use_context,
stream=body.stream, # stream=body.stream,
include_sources=body.include_sources, # include_sources=body.include_sources,
context_filter=body.context_filter, # context_filter=body.context_filter,
) # )
return chat_completion(request, chat_body) # return chat_completion(request, chat_body)
@completions_router.post( @completions_router.post(
@ -101,7 +101,7 @@ def prompt_completion(
tags=["Contextual Completions"], tags=["Contextual Completions"],
) )
def prompt_completion( def prompt_completion(
request: Request, request: Request,
body: CompletionsBody, body: CompletionsBody,
db: Session = Depends(deps.get_db), db: Session = Depends(deps.get_db),
current_user: models.User = Security( current_user: models.User = Security(
@ -111,15 +111,17 @@ def prompt_completion(
try: try:
service = request.state.injector.get(IngestService) service = request.state.injector.get(IngestService)
department = crud.department.get_by_id(db, id=current_user.department_id) department = crud.department.get_by_id(
db, id=current_user.department_id)
if not department: if not department:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, raise HTTPException(status_code=status.HTTP_404_NOT_FOUND,
detail=f"No department assigned to you") detail=f"No department assigned to you")
documents = crud.documents.get_multi_documents(db, department_id=department.id) documents = crud.documents.get_multi_documents(
db, department_id=department.id)
if not documents: if not documents:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, raise HTTPException(status_code=status.HTTP_404_NOT_FOUND,
detail=f"No documents uploaded for your department.") detail=f"No documents uploaded for your department.")
docs_list = [document.filename for document in documents] docs_list = [document.filename for document in documents]
docs_ids = [] docs_ids = []
for filename in docs_list: for filename in docs_list:
doc_id = service.get_doc_ids_by_filename(filename) doc_id = service.get_doc_ids_by_filename(filename)

View File

@ -34,13 +34,13 @@ def register_user(
""" """
Register a new user in the database. Register a new user in the database.
""" """
logging.info(f"User : {email} Password: {password} company_id: {company.id} deparment_id: {department.id}") logging.info(f"User : {email} Password: {password} company_id: {company.id} deparment_id: {department}")
user_in = schemas.UserCreate( user_in = schemas.UserCreate(
email=email, email=email,
password=password, password=password,
fullname=fullname, fullname=fullname,
company_id=company.id, company_id=company.id,
department_id=department.id, department_id=department,
) )
try: try:
send_registration_email(fullname, email, password) send_registration_email(fullname, email, password)
@ -204,8 +204,8 @@ def register(
# password: str = Body(...), # password: str = Body(...),
company_id: int = Body(None, title="Company ID", company_id: int = Body(None, title="Company ID",
description="Company ID for the user (if applicable)"), description="Company ID for the user (if applicable)"),
department_name: str = Body(None, title="Department Name", department_id: int = Body(None, title="Department Id",
description="Department name for the user (if applicable)"), description="Department Id for the user (if applicable)"),
role_name: str = Body(None, title="Role Name", role_name: str = Body(None, title="Role Name",
description="User role name (if applicable)"), description="User role name (if applicable)"),
current_user: models.User = Security( current_user: models.User = Security(
@ -232,17 +232,15 @@ def register(
status_code=404, status_code=404,
detail="Company not found.", detail="Company not found.",
) )
if department_name:
department = crud.department.get_by_department_name( if not department_id:
db=db, name=department_name) raise HTTPException(
if not department: status_code=404,
raise HTTPException( detail="Department not found.",
status_code=404, )
detail="Department not found.", logging.info(f"Department is {department_id}")
)
logging.info(f"Department is {department}")
user = register_user( user = register_user(
db, email, fullname, random_password, company, department db, email, fullname, random_password, company, department_id
) )
user_role_name = role_name or Role.GUEST["name"] user_role_name = role_name or Role.GUEST["name"]
user_role = create_user_role(db, user, user_role_name, company) user_role = create_user_role(db, user, user_role_name, company)

View File

@ -56,7 +56,7 @@ def create_deparment(
) )
@router.get("/{department_id}", response_model=schemas.Department) @router.post("/read", response_model=schemas.Department)
def read_department( def read_department(
department_id: int, department_id: int,
db: Session = Depends(deps.get_db), db: Session = Depends(deps.get_db),
@ -74,9 +74,8 @@ def read_department(
return department return department
@router.put("/{department_id}", response_model=schemas.Department) @router.post("/update", response_model=schemas.Department)
def update_department( def update_department(
department_id: int,
department_in: schemas.DepartmentUpdate, department_in: schemas.DepartmentUpdate,
db: Session = Depends(deps.get_db), db: Session = Depends(deps.get_db),
current_user: models.User = Security( current_user: models.User = Security(
@ -87,7 +86,7 @@ def update_department(
""" """
Update a Department by ID Update a Department by ID
""" """
department = crud.department.get_by_id(db, id=department_id) department = crud.department.get_by_id(db, id=department_in.id)
if department is None: if department is None:
raise HTTPException(status_code=404, detail="department not found") raise HTTPException(status_code=404, detail="department not found")
@ -103,9 +102,9 @@ def update_department(
) )
@router.delete("/{department_id}", response_model=schemas.Department) @router.post("/delete", response_model=schemas.Department)
def delete_department( def delete_department(
department_id: int, department_in: schemas.DepartmentDelete,
db: Session = Depends(deps.get_db), db: Session = Depends(deps.get_db),
current_user: models.User = Security( current_user: models.User = Security(
deps.get_current_user, deps.get_current_user,
@ -115,6 +114,10 @@ def delete_department(
""" """
Delete a Department by ID Delete a Department by ID
""" """
department_id = department_in.id
department = crud.department.get(db, id=department_id)
if department is None:
raise HTTPException(status_code=404, detail="User not found")
department = crud.department.remove(db=db, id=department_id) department = crud.department.remove(db=db, id=department_id)
if department is None: if department is None:

View File

@ -15,6 +15,7 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix='/documents', tags=['Documents']) router = APIRouter(prefix='/documents', tags=['Documents'])
@router.get("", response_model=List[schemas.Document]) @router.get("", response_model=List[schemas.Document])
def list_files( def list_files(
request: Request, request: Request,
@ -26,18 +27,36 @@ def list_files(
scopes=[Role.SUPER_ADMIN["name"]], scopes=[Role.SUPER_ADMIN["name"]],
) )
): ):
def get_department_name(db, id):
dep = crud.department.get_by_id(db=db, id=id)
return dep.name
def get_username(db, id):
user = crud.user.get_by_id(db=db, id=id)
return user.fullname
try: try:
docs = crud.documents.get_multi(db, skip=skip, limit=limit) docs = crud.documents.get_multi(db, skip=skip, limit=limit)
docs = [
schemas.Document(
id=doc.id,
filename=doc.filename,
uploaded_at=doc.uploaded_at,
uploaded_by=get_username(db, doc.uploaded_by),
department=get_department_name(db, doc.department_id)
)
for doc in docs
]
return docs return docs
except Exception as e: except Exception as e:
print(traceback.format_exc()) print(traceback.format_exc())
logger.error(f"There was an error listing the file(s).") logger.error(f"There was an error listing the file(s).")
raise HTTPException( raise HTTPException(
status_code=500, status_code=500,
detail="Internal Server Error: Unable to ingest file.", detail="Internal Server Error",
) )
@router.get('{department_id}', response_model=List[schemas.Document]) @router.get('{department_id}', response_model=List[schemas.Document])
def list_files_by_department( def list_files_by_department(
request: Request, request: Request,
@ -62,7 +81,7 @@ def list_files_by_department(
logger.error(f"There was an error listing the file(s).") logger.error(f"There was an error listing the file(s).")
raise HTTPException( raise HTTPException(
status_code=500, status_code=500,
detail="Internal Server Error: Unable to ingest file.", detail="Internal Server Error.",
) )
@ -90,5 +109,5 @@ def list_files_by_department(
logger.error(f"There was an error listing the file(s).") logger.error(f"There was an error listing the file(s).")
raise HTTPException( raise HTTPException(
status_code=500, status_code=500,
detail="Internal Server Error: Unable to ingest file.", detail="Internal Server Error.",
) )

View File

@ -6,7 +6,7 @@ from typing import Optional
class CRUDDepartments(CRUDBase[Department, DepartmentCreate, DepartmentUpdate]): class CRUDDepartments(CRUDBase[Department, DepartmentCreate, DepartmentUpdate]):
def get_by_id(self, db: Session, *, id: str) -> Optional[Department]: def get_by_id(self, db: Session, *, id: int) -> Optional[Department]:
return db.query(self.model).filter(Department.id == id).first() return db.query(self.model).filter(Department.id == id).first()
def get_by_department_name(self, db: Session, *, name: str) -> Optional[Department]: def get_by_department_name(self, db: Session, *, name: str) -> Optional[Department]:

View File

@ -6,7 +6,7 @@ from typing import Optional, List
class CRUDDocuments(CRUDBase[Document, DocumentCreate, DocumentUpdate]): class CRUDDocuments(CRUDBase[Document, DocumentCreate, DocumentUpdate]):
def get_by_id(self, db: Session, *, id: str) -> Optional[Document]: def get_by_id(self, db: Session, *, id: int) -> Optional[Document]:
return db.query(self.model).filter(Document.id == id).first() return db.query(self.model).filter(Document.id == id).first()
def get_by_filename(self, db: Session, *, file_name: str) -> Optional[Document]: def get_by_filename(self, db: Session, *, file_name: str) -> Optional[Document]:

View File

@ -104,4 +104,8 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
.limit(limit) .limit(limit)
.all() .all()
) )
def get_by_id(self, db: Session, *, id: int) -> Optional[User]:
return db.query(self.model).filter(User.id == id).first()
user = CRUDUser(User) user = CRUDUser(User)

View File

@ -1,8 +1,10 @@
from sqlalchemy import ForeignKey from sqlalchemy import ForeignKey, event
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship, Session
from sqlalchemy import Column, Integer, String from sqlalchemy import Column, Integer, String
from private_gpt.users.db.base_class import Base from private_gpt.users.db.base_class import Base
from private_gpt.users.models.document import Document
from private_gpt.users.models.user import User
class Department(Base): class Department(Base):
@ -17,4 +19,27 @@ class Department(Base):
company = relationship("Company", back_populates="departments") company = relationship("Company", back_populates="departments")
users = relationship("User", back_populates="department") users = relationship("User", back_populates="department")
documents = relationship("Document", back_populates="department") documents = relationship("Document", back_populates="department")
total_users = Column(Integer, default=0)
total_documents = Column(Integer, default=0)
def update_total_users(mapper, connection, target):
session = Session(bind=connection)
target.total_users = session.query(User).filter_by(
department_id=target.id).count()
def update_total_documents(mapper, connection, target):
session = Session(bind=connection)
target.total_documents = session.query(
Document).filter_by(department_id=target.id).count()
# Attach event listeners to Department model
event.listen(Department, 'after_insert', update_total_users)
event.listen(Department, 'after_update', update_total_users)
event.listen(Department, 'after_insert', update_total_documents)
event.listen(Department, 'after_update', update_total_documents)

View File

@ -5,4 +5,4 @@ from .user_role import UserRole, UserRoleCreate, UserRoleInDB, UserRoleUpdate
from .subscription import Subscription, SubscriptionBase, SubscriptionCreate, SubscriptionUpdate from .subscription import Subscription, SubscriptionBase, SubscriptionCreate, SubscriptionUpdate
from .company import Company, CompanyBase, CompanyCreate, CompanyUpdate from .company import Company, CompanyBase, CompanyCreate, CompanyUpdate
from .documents import Document, DocumentCreate, DocumentsBase, DocumentUpdate, DocumentList from .documents import Document, DocumentCreate, DocumentsBase, DocumentUpdate, DocumentList
from .department import Department, DepartmentCreate, DepartmentUpdate, DepartmentAdminCreate from .department import Department, DepartmentCreate, DepartmentUpdate, DepartmentAdminCreate, DepartmentDelete

View File

@ -1,4 +1,4 @@
from typing import List from typing import List, Optional
from pydantic import BaseModel from pydantic import BaseModel
@ -11,12 +11,18 @@ class DepartmentCreate(DepartmentBase):
class DepartmentUpdate(DepartmentBase): class DepartmentUpdate(DepartmentBase):
pass id: int
class DepartmentDelete(BaseModel):
id: int
class DepartmentInDB(DepartmentBase): class DepartmentInDB(DepartmentBase):
id: int id: int
company_id: int company_id: int
total_users: Optional[int]
total_documents: Optional[int]
class Config: class Config:
orm_mode = True orm_mode = True
@ -29,3 +35,4 @@ class DepartmentAdminCreate(DepartmentBase):
class Department(DepartmentInDB): class Department(DepartmentInDB):
pass pass

View File

@ -23,9 +23,9 @@ class DocumentList(DocumentsBase):
class Document(DocumentsBase): class Document(DocumentsBase):
id: int id: int
uploaded_by: int uploaded_by: str
uploaded_at: datetime uploaded_at: datetime
department_id: int department: str
class Config: class Config:
orm_mode = True orm_mode = True