mirror of
https://github.com/imartinez/privateGPT.git
synced 2025-09-15 22:59:53 +00:00
Updated with new models and auth
This commit is contained in:
@@ -14,7 +14,8 @@ from private_gpt.users.models.role import Role
|
||||
from private_gpt.users.models.user_role import UserRole
|
||||
from private_gpt.users.models.subscription import Subscription
|
||||
from private_gpt.users.models.company import Company
|
||||
from private_gpt.users.models.documents import Documents
|
||||
from private_gpt.users.models.document import Document
|
||||
from private_gpt.users.models.department import Department
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
|
@@ -1,8 +1,8 @@
|
||||
"""Create models
|
||||
"""Create models
|
||||
|
||||
Revision ID: dcf96cb11a85
|
||||
Revision ID: 0aeaf9df35a6
|
||||
Revises:
|
||||
Create Date: 2024-02-14 16:30:51.094285
|
||||
Create Date: 2024-02-20 19:16:15.608391
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
@@ -12,7 +12,7 @@ import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'dcf96cb11a85'
|
||||
revision: str = '0aeaf9df35a6'
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
@@ -35,6 +35,15 @@ def upgrade() -> None:
|
||||
)
|
||||
op.create_index(op.f('ix_roles_id'), 'roles', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_roles_name'), 'roles', ['name'], unique=False)
|
||||
op.create_table('departments',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('name', sa.String(), nullable=True),
|
||||
sa.Column('company_id', sa.Integer(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['company_id'], ['companies.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_departments_id'), 'departments', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_departments_name'), 'departments', ['name'], unique=True)
|
||||
op.create_table('subscriptions',
|
||||
sa.Column('sub_id', sa.Integer(), nullable=False),
|
||||
sa.Column('company_id', sa.Integer(), nullable=True),
|
||||
@@ -54,7 +63,9 @@ def upgrade() -> None:
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('company_id', sa.Integer(), nullable=True),
|
||||
sa.Column('department_id', sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['company_id'], ['companies.id'], ),
|
||||
sa.ForeignKeyConstraint(['department_id'], ['departments.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('email'),
|
||||
sa.UniqueConstraint('fullname'),
|
||||
@@ -65,6 +76,8 @@ def upgrade() -> None:
|
||||
sa.Column('filename', sa.String(length=225), nullable=False),
|
||||
sa.Column('uploaded_by', sa.Integer(), nullable=False),
|
||||
sa.Column('uploaded_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('department_id', sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['department_id'], ['departments.id'], ),
|
||||
sa.ForeignKeyConstraint(['uploaded_by'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('filename')
|
||||
@@ -91,6 +104,9 @@ def downgrade() -> None:
|
||||
op.drop_table('users')
|
||||
op.drop_index(op.f('ix_subscriptions_sub_id'), table_name='subscriptions')
|
||||
op.drop_table('subscriptions')
|
||||
op.drop_index(op.f('ix_departments_name'), table_name='departments')
|
||||
op.drop_index(op.f('ix_departments_id'), table_name='departments')
|
||||
op.drop_table('departments')
|
||||
op.drop_index(op.f('ix_roles_name'), table_name='roles')
|
||||
op.drop_index(op.f('ix_roles_id'), table_name='roles')
|
||||
op.drop_table('roles')
|
@@ -34,7 +34,7 @@ def create_app(root_injector: Injector) -> FastAPI:
|
||||
app.include_router(health_router)
|
||||
|
||||
app.include_router(api_router)
|
||||
app.include_router(home_router)
|
||||
# app.include_router(home_router)
|
||||
app.include_router(pdf_router)
|
||||
settings = root_injector.get(Settings)
|
||||
if settings.server.cors.enabled:
|
||||
|
@@ -1,4 +1,4 @@
|
||||
from private_gpt.users.api.v1.routers import auth, roles, user_roles, users, subscriptions, companies
|
||||
from private_gpt.users.api.v1.routers import auth, roles, user_roles, users, subscriptions, companies, departments
|
||||
from fastapi import APIRouter
|
||||
|
||||
api_router = APIRouter(prefix="/v1")
|
||||
@@ -9,4 +9,5 @@ api_router.include_router(roles.router)
|
||||
api_router.include_router(user_roles.router)
|
||||
api_router.include_router(companies.router)
|
||||
api_router.include_router(subscriptions.router)
|
||||
api_router.include_router(departments.router)
|
||||
|
||||
|
@@ -1,3 +1,4 @@
|
||||
import traceback
|
||||
from typing import Any, Optional
|
||||
from datetime import timedelta, datetime
|
||||
|
||||
@@ -12,9 +13,13 @@ from private_gpt.users.constants.role import Role
|
||||
from private_gpt.users.core.config import settings
|
||||
from private_gpt.users import crud, models, schemas
|
||||
from private_gpt.users.utils import send_registration_email, Ldap
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
LDAP_SERVER = settings.LDAP_SERVER
|
||||
LDAP_ENABLE = settings.LDAP_ENABLE
|
||||
# LDAP_ENABLE = settings.LDAP_ENABLE
|
||||
LDAP_ENABLE = False
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
@@ -24,16 +29,23 @@ def register_user(
|
||||
fullname: str,
|
||||
password: str,
|
||||
company: Optional[models.Company] = None,
|
||||
department: Optional[models.Department] = None,
|
||||
) -> models.User:
|
||||
"""
|
||||
Register a new user in the database.
|
||||
"""
|
||||
print(f"{email} {fullname} {password} {company.id}")
|
||||
user_in = schemas.UserCreate(email=email, password=password, fullname=fullname, company_id=company.id)
|
||||
logging.info(f"User : {email} Password: {password} company_id: {company.id} deparment_id: {department.id}")
|
||||
user_in = schemas.UserCreate(
|
||||
email=email,
|
||||
password=password,
|
||||
fullname=fullname,
|
||||
company_id=company.id,
|
||||
department_id=department.id,
|
||||
)
|
||||
try:
|
||||
send_registration_email(fullname, email, password)
|
||||
except Exception as e:
|
||||
print(f"Failed to send registration email: {str(e)}")
|
||||
logging.info(f"Failed to send registration email: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to send registration email.")
|
||||
return crud.user.create(db, obj_in=user_in)
|
||||
@@ -72,6 +84,7 @@ def create_token_payload(user: models.User, user_role: models.UserRole) -> dict:
|
||||
"role": user_role.role.name,
|
||||
"username": str(user.fullname),
|
||||
"company_id": user_role.company.id if user_role.company else None,
|
||||
"department_id": user.department_id
|
||||
}
|
||||
|
||||
def ad_user_register(
|
||||
@@ -82,17 +95,15 @@ def ad_user_register(
|
||||
|
||||
) -> models.User:
|
||||
"""
|
||||
Register a new user in the database.
|
||||
Register a new user in the database. Company id is directly given here.
|
||||
"""
|
||||
user_in = schemas.UserCreate(email=email, password=password, fullname=fullname, company_id=1)
|
||||
print("user is: ", user_in)
|
||||
user = crud.user.create(db, obj_in=user_in)
|
||||
print("AD user created......................................................................")
|
||||
user_role_name = Role.GUEST["name"]
|
||||
company = crud.company.get(db, 1)
|
||||
|
||||
user_role = create_user_role(db, user, user_role_name, company)
|
||||
print("AD user role created----------------------------------------------------------------")
|
||||
return user
|
||||
|
||||
|
||||
@@ -104,17 +115,17 @@ def login_access_token(
|
||||
"""
|
||||
OAuth2 compatible token login, get an access token for future requests
|
||||
"""
|
||||
# if LDAP_ENABLE:
|
||||
# existing_user = crud.user.get_by_email(db, email=form_data.username)
|
||||
if LDAP_ENABLE:
|
||||
existing_user = crud.user.get_by_email(db, email=form_data.username)
|
||||
|
||||
# if existing_user:
|
||||
# if existing_user.user_role.role.name == "SUPER_ADMIN":
|
||||
# pass
|
||||
# else:
|
||||
# ldap = ldap_login(db=db, username=form_data.username, password=form_data.password)
|
||||
# else:
|
||||
# ldap = ldap_login(db=db, username=form_data.username, password=form_data.password)
|
||||
# ad_user_register(db=db, email=form_data.username,fullname=ldap, password=form_data.password)
|
||||
if existing_user:
|
||||
if existing_user.user_role.role.name == "SUPER_ADMIN":
|
||||
pass
|
||||
else:
|
||||
ldap = ldap_login(db=db, username=form_data.username, password=form_data.password)
|
||||
else:
|
||||
ldap = ldap_login(db=db, username=form_data.username, password=form_data.password)
|
||||
ad_user_register(db=db, email=form_data.username,fullname=ldap, password=form_data.password)
|
||||
|
||||
user = crud.user.authenticate(
|
||||
db, email=form_data.username, password=form_data.password
|
||||
@@ -130,9 +141,6 @@ def login_access_token(
|
||||
minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES
|
||||
)
|
||||
user_in = schemas.UserUpdate(
|
||||
email=user.email,
|
||||
fullname=user.fullname,
|
||||
company_id=user.user_role.company_id,
|
||||
last_login=datetime.now()
|
||||
)
|
||||
user = crud.user.update(db, db_obj=user, obj_in=user_in)
|
||||
@@ -148,6 +156,7 @@ def login_access_token(
|
||||
"username": str(user.fullname),
|
||||
"role": role,
|
||||
"company_id": company_id,
|
||||
"department_id": str(user.department_id),
|
||||
}
|
||||
|
||||
response_dict = {
|
||||
@@ -192,8 +201,11 @@ def register(
|
||||
db: Session = Depends(deps.get_db),
|
||||
email: str = Body(...),
|
||||
fullname: str = Body(...),
|
||||
# password: str = Body(...),
|
||||
company_id: int = Body(None, title="Company ID",
|
||||
description="Company ID for the user (if applicable)"),
|
||||
department_name: str = Body(None, title="Department Name",
|
||||
description="Department name for the user (if applicable)"),
|
||||
role_name: str = Body(None, title="Role Name",
|
||||
description="User role name (if applicable)"),
|
||||
current_user: models.User = Security(
|
||||
@@ -212,29 +224,34 @@ def register(
|
||||
detail="The user with this email already exists!",
|
||||
)
|
||||
random_password = security.generate_random_password()
|
||||
# random_password = password
|
||||
try:
|
||||
if company_id:
|
||||
# Registering user with a specific company
|
||||
company = crud.company.get(db, company_id)
|
||||
if not company:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Company not found.",
|
||||
)
|
||||
user = register_user(db, email, fullname, random_password, company)
|
||||
if department_name:
|
||||
department = crud.department.get_by_department_name(
|
||||
db=db, name=department_name)
|
||||
if not department:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Department not found.",
|
||||
)
|
||||
logging.info(f"Department is {department}")
|
||||
user = register_user(
|
||||
db, email, fullname, random_password, company, department
|
||||
)
|
||||
user_role_name = role_name or Role.GUEST["name"]
|
||||
user_role = create_user_role(db, user, user_role_name, company)
|
||||
|
||||
else:
|
||||
user = register_user(db, email, fullname, random_password, None)
|
||||
user_role_name = role_name or Role.ADMIN["name"]
|
||||
user_role = create_user_role(db, user, user_role_name, None)
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Unable to create account.",
|
||||
)
|
||||
status_code=500,
|
||||
detail="Unable to create account.",
|
||||
)
|
||||
token_payload = create_token_payload(user, user_role)
|
||||
response_dict = {
|
||||
"access_token": security.create_access_token(token_payload, expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)),
|
||||
@@ -242,6 +259,4 @@ def register(
|
||||
"token_type": "bearer",
|
||||
"password": random_password,
|
||||
}
|
||||
print("RESPONSE DICT: ", response_dict)
|
||||
return JSONResponse(content=response_dict, status_code=status.HTTP_201_CREATED)
|
||||
|
||||
|
@@ -24,104 +24,104 @@ def list_deparments(
|
||||
),
|
||||
) -> List[schemas.Department]:
|
||||
"""
|
||||
Retrieve a list of companies with pagination support.
|
||||
Retrieve a list of department with pagination support.
|
||||
"""
|
||||
deparments = crud.deparment.get_multi(db, skip=skip, limit=limit)
|
||||
deparments = crud.department.get_multi(db, skip=skip, limit=limit)
|
||||
return deparments
|
||||
|
||||
|
||||
@router.post("/create", response_model=schemas.Department)
|
||||
def create_deparment(
|
||||
company_in: schemas.DepartmentCreate,
|
||||
department_in: schemas.DepartmentCreate,
|
||||
db: Session = Depends(deps.get_db),
|
||||
current_user: models.User = Security(
|
||||
deps.get_current_user,
|
||||
scopes=[Role.SUPER_ADMIN["name"]],
|
||||
),
|
||||
) -> schemas.Company:
|
||||
) -> schemas.Department:
|
||||
"""
|
||||
Create a new company
|
||||
Create a new department
|
||||
"""
|
||||
deparment = crud.deparment.create(db=db, obj_in=company_in)
|
||||
deparment = jsonable_encoder(deparment)
|
||||
deparment = crud.department.create(db=db, obj_in=department_in)
|
||||
department = jsonable_encoder(department)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
content={
|
||||
"message": "Department created successfully",
|
||||
"department": deparment
|
||||
"department": department
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{deparment_id}", response_model=schemas.Department)
|
||||
def read_company(
|
||||
deparment_id: int,
|
||||
@router.get("/{department_id}", response_model=schemas.Department)
|
||||
def read_department(
|
||||
department_id: int,
|
||||
db: Session = Depends(deps.get_db),
|
||||
current_user: models.User = Security(
|
||||
deps.get_current_user,
|
||||
scopes=[Role.SUPER_ADMIN["name"]],
|
||||
),
|
||||
) -> schemas.Company:
|
||||
) -> schemas.Department:
|
||||
"""
|
||||
Read a company by ID
|
||||
Read a Department by ID
|
||||
"""
|
||||
deparment = crud.deparment.get_by_id(db, id=deparment_id)
|
||||
if deparment is None:
|
||||
raise HTTPException(status_code=404, detail="Deparment not found")
|
||||
return deparment
|
||||
department = crud.department.get_by_id(db, id=department_id)
|
||||
if department is None:
|
||||
raise HTTPException(status_code=404, detail="department not found")
|
||||
return department
|
||||
|
||||
|
||||
@router.put("/{deparment_id}", response_model=schemas.Department)
|
||||
def update_company(
|
||||
deparment_id: int,
|
||||
deparment_in: schemas.DepartmentUpdate,
|
||||
@router.put("/{department_id}", response_model=schemas.Department)
|
||||
def update_department(
|
||||
department_id: int,
|
||||
department_in: schemas.DepartmentUpdate,
|
||||
db: Session = Depends(deps.get_db),
|
||||
current_user: models.User = Security(
|
||||
deps.get_current_user,
|
||||
scopes=[Role.SUPER_ADMIN["name"]],
|
||||
),
|
||||
) -> schemas.Company:
|
||||
) -> schemas.Department:
|
||||
"""
|
||||
Update a company by ID
|
||||
Update a Department by ID
|
||||
"""
|
||||
deparment = crud.deparment.get_by_id(db, id=deparment_id)
|
||||
if deparment is None:
|
||||
raise HTTPException(status_code=404, detail="Deparment not found")
|
||||
department = crud.department.get_by_id(db, id=department_id)
|
||||
if department is None:
|
||||
raise HTTPException(status_code=404, detail="department not found")
|
||||
|
||||
updated_deparment = crud.deparment.update(
|
||||
db=db, db_obj=deparment, obj_in=deparment_in)
|
||||
updated_deparment = jsonable_encoder(updated_deparment)
|
||||
updated_department = crud.department.update(
|
||||
db=db, db_obj=department, obj_in=department_in)
|
||||
updated_department = jsonable_encoder(updated_department)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content={
|
||||
"message": f"{deparment_in} Deparment updated successfully",
|
||||
"deparment": updated_deparment
|
||||
"message": f"{department_in} department updated successfully",
|
||||
"department": updated_department
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{deparment_id}", response_model=schemas.Department)
|
||||
def delete_company(
|
||||
deparment_id: int,
|
||||
@router.delete("/{department_id}", response_model=schemas.Department)
|
||||
def delete_department(
|
||||
department_id: int,
|
||||
db: Session = Depends(deps.get_db),
|
||||
current_user: models.User = Security(
|
||||
deps.get_current_user,
|
||||
scopes=[Role.SUPER_ADMIN["name"]],
|
||||
),
|
||||
) -> schemas.Company:
|
||||
) -> schemas.Department:
|
||||
"""
|
||||
Delete a company by ID
|
||||
Delete a Department by ID
|
||||
"""
|
||||
|
||||
deparment = crud.deparment.remove(db=db, id=deparment_id)
|
||||
if deparment is None:
|
||||
raise HTTPException(status_code=404, detail="Deparment not found")
|
||||
deparment = jsonable_encoder(deparment)
|
||||
department = crud.department.remove(db=db, id=department_id)
|
||||
if department is None:
|
||||
raise HTTPException(status_code=404, detail="department not found")
|
||||
department = jsonable_encoder(department)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content={
|
||||
"message": "Deparment deleted successfully",
|
||||
"deparment": deparment
|
||||
"message": "Department deleted successfully",
|
||||
"deparment": department,
|
||||
},
|
||||
)
|
||||
|
@@ -4,4 +4,4 @@ from .user_role_crud import user_role
|
||||
from .company_crud import company
|
||||
from .subscription_crud import subscription
|
||||
from .document_crud import documents
|
||||
from .department_crud import deparment
|
||||
from .department_crud import department
|
||||
|
@@ -9,8 +9,7 @@ class CRUDDepartments(CRUDBase[Department, DepartmentCreate, DepartmentUpdate]):
|
||||
def get_by_id(self, db: Session, *, id: str) -> Optional[Department]:
|
||||
return db.query(self.model).filter(Department.id == id).first()
|
||||
|
||||
def get_by_deparment_name(self, db: Session, *, name: str) -> Optional[Department]:
|
||||
def get_by_department_name(self, db: Session, *, name: str) -> Optional[Department]:
|
||||
return db.query(self.model).filter(Department.name == name).first()
|
||||
|
||||
|
||||
deparment = CRUDDepartments(Department)
|
||||
department = CRUDDepartments(Department)
|
||||
|
@@ -1,16 +1,16 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from private_gpt.users.schemas.documents import DocumentCreate, DocumentUpdate
|
||||
from private_gpt.users.models.documents import Documents
|
||||
from private_gpt.users.models.document import Document
|
||||
from private_gpt.users.crud.base import CRUDBase
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class CRUDDocuments(CRUDBase[Documents, DocumentCreate, DocumentUpdate]):
|
||||
def get_by_id(self, db: Session, *, id: str) -> Optional[Documents]:
|
||||
return db.query(self.model).filter(Documents.id == id).first()
|
||||
class CRUDDocuments(CRUDBase[Document, DocumentCreate, DocumentUpdate]):
|
||||
def get_by_id(self, db: Session, *, id: str) -> Optional[Document]:
|
||||
return db.query(self.model).filter(Document.id == id).first()
|
||||
|
||||
def get_by_filename(self, db: Session, *, file_name: str) -> Optional[Documents]:
|
||||
return db.query(self.model).filter(Documents.filename == file_name).first()
|
||||
def get_by_filename(self, db: Session, *, file_name: str) -> Optional[Document]:
|
||||
return db.query(self.model).filter(Document.filename == file_name).first()
|
||||
|
||||
|
||||
documents = CRUDDocuments(Documents)
|
||||
documents = CRUDDocuments(Document)
|
||||
|
@@ -19,6 +19,8 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
email=obj_in.email,
|
||||
hashed_password=get_password_hash(obj_in.password),
|
||||
fullname=obj_in.fullname,
|
||||
company_id=obj_in.company_id,
|
||||
department_id=obj_in.department_id,
|
||||
)
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
@@ -92,4 +94,14 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
def get_by_name(self, db: Session, *, name: str) -> Optional[User]:
|
||||
return db.query(self.model).filter(User.fullname == name).first()
|
||||
|
||||
def get_by_department_id(
|
||||
self, db: Session, *, department_id: int, skip: int = 0, limit: int = 100
|
||||
) -> List[User]:
|
||||
return (
|
||||
db.query(self.model)
|
||||
.filter(User.department_id == department_id)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
user = CRUDUser(User)
|
||||
|
@@ -2,6 +2,6 @@ from .user import User
|
||||
from .company import Company
|
||||
from .user_role import UserRole
|
||||
from .role import Role
|
||||
from .documents import Documents
|
||||
from .document import Document
|
||||
from .subscription import Subscription
|
||||
from .department import Department
|
@@ -15,3 +15,6 @@ class Department(Base):
|
||||
|
||||
company_id = Column(Integer, ForeignKey('companies.id'))
|
||||
company = relationship("Company", back_populates="departments")
|
||||
|
||||
users = relationship("User", back_populates="department")
|
||||
documents = relationship("Document", back_populates="department")
|
@@ -4,9 +4,10 @@ from sqlalchemy.orm import relationship
|
||||
from sqlalchemy import Column, Integer, String, ForeignKey, DateTime
|
||||
|
||||
|
||||
class Documents(Base):
|
||||
class Document(Base):
|
||||
"""Models a document table"""
|
||||
__tablename__ = "document"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
filename = Column(String(225), nullable=False, unique=True)
|
||||
uploaded_by = Column(
|
||||
@@ -21,8 +22,9 @@ class Documents(Base):
|
||||
)
|
||||
uploaded_by_user = relationship(
|
||||
"User", back_populates="uploaded_documents")
|
||||
department_id = Column(Integer, ForeignKey("departments.id"), nullable=True)
|
||||
|
||||
department_id = Column(Integer, ForeignKey(
|
||||
"departments.id"), nullable=False)
|
||||
uploaded_by_user = relationship(
|
||||
"User", back_populates="uploaded_documents")
|
||||
department = relationship("Department", back_populates="documents")
|
||||
department = relationship("Department", back_populates="documents")
|
@@ -36,12 +36,14 @@ class User(Base):
|
||||
|
||||
company_id = Column(Integer, ForeignKey("companies.id"), nullable=True)
|
||||
company = relationship("Company", back_populates="users")
|
||||
uploaded_documents = relationship("Documents", back_populates="uploaded_by_user")
|
||||
|
||||
uploaded_documents = relationship("Document", back_populates="uploaded_by_user")
|
||||
|
||||
user_role = relationship(
|
||||
"UserRole", back_populates="user", uselist=False, cascade="all, delete-orphan")
|
||||
|
||||
department_id = Column(Integer, ForeignKey(
|
||||
"departments.id"), nullable=True)
|
||||
"departments.id"), nullable=False)
|
||||
department = relationship("Department", back_populates="users")
|
||||
|
||||
def __repr__(self):
|
||||
|
@@ -1,9 +1,6 @@
|
||||
from typing import List
|
||||
from pydantic import BaseModel
|
||||
|
||||
from private_gpt.users.schemas import Department, User
|
||||
|
||||
|
||||
class CompanyBase(BaseModel):
|
||||
name: str
|
||||
|
||||
@@ -24,7 +21,4 @@ class CompanyInDB(CompanyBase):
|
||||
|
||||
|
||||
class Company(CompanyInDB):
|
||||
subscriptions: List[str] = []
|
||||
users: List[User] = []
|
||||
user_roles: List[str] = []
|
||||
departments: List[Department] = []
|
||||
pass
|
@@ -10,7 +10,8 @@ from private_gpt.users.schemas.company import Company
|
||||
class UserBaseSchema(BaseModel):
|
||||
email: EmailStr
|
||||
fullname: str
|
||||
company_id: Optional[int]
|
||||
company_id: int
|
||||
department_id: int
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
@@ -24,7 +25,7 @@ class UsernameUpdate(BaseModel):
|
||||
fullname: str
|
||||
|
||||
|
||||
class UserUpdate(UserBaseSchema):
|
||||
class UserUpdate(BaseModel):
|
||||
last_login: Optional[datetime] = None
|
||||
|
||||
|
||||
@@ -69,4 +70,4 @@ class UserAdminUpdate(BaseModel):
|
||||
id: int
|
||||
fullname: str
|
||||
role: str
|
||||
department_id: Optional[int] = None
|
||||
department_id: int
|
||||
|
Reference in New Issue
Block a user