From 0b79c23f68a9ae29f29298be9dcc649a94e662fd Mon Sep 17 00:00:00 2001 From: Saurab-Shrestha Date: Sun, 21 Jan 2024 12:02:36 +0545 Subject: [PATCH] Added routes for creating users through admin --- .env | 8 +- ...reated_a_relationships_based_on_company.py | 38 ++++ ..._added_relation_company_with_user_roles.py | 32 ++++ ...ebc068de_created_company_name_as_unique.py | 46 +++++ bash.exe.stackdump | 19 ++ private_gpt/launcher.py | 12 +- private_gpt/users/api/deps.py | 31 ++- private_gpt/users/api/v1/routers/auth.py | 181 +++++++++++++----- .../users/api/v1/routers/user_roles.py | 8 +- private_gpt/users/api/v1/routers/users.py | 26 ++- private_gpt/users/crud/company_crud.py | 5 +- private_gpt/users/crud/role_crud.py | 2 +- private_gpt/users/crud/user_crud.py | 10 + private_gpt/users/models/__init__.py | 5 +- private_gpt/users/models/company.py | 9 +- private_gpt/users/models/user.py | 5 +- private_gpt/users/models/user_role.py | 16 +- private_gpt/users/schemas/company.py | 2 +- private_gpt/users/schemas/token.py | 1 + private_gpt/users/schemas/user.py | 3 +- private_gpt/users/schemas/user_role.py | 4 +- 21 files changed, 377 insertions(+), 86 deletions(-) create mode 100644 alembic/versions/6f93f0d1defb_created_a_relationships_based_on_company.py create mode 100644 alembic/versions/cccea6c7d70d_added_relation_company_with_user_roles.py create mode 100644 alembic/versions/f93bebc068de_created_company_name_as_unique.py create mode 100644 bash.exe.stackdump diff --git a/.env b/.env index 6b411c22..75bcc113 100644 --- a/.env +++ b/.env @@ -15,8 +15,8 @@ ACCESS_TOKEN_EXPIRE_MINUTES=60 REFRESH_TOKEN_EXPIRE_MINUTES = 120 # 7 days -SMTP_SERVER=smtp-mail.outlook.com +SMTP_SERVER=smtp.gmail.com SMTP_PORT=587 -SMTP_SENDER_EMAIL=saurabstha7@outlook.com -SMTP_USERNAME=saurabstha7@outlook.com -SMTP_PASSWORD=avantador123 \ No newline at end of file +SMTP_SENDER_EMAIL=shresthasaurab030@outlook.com +SMTP_USERNAME=shresthasaurab030 +SMTP_PASSWORD=huurxwxeorxjorzw \ No newline at end of file diff --git a/alembic/versions/6f93f0d1defb_created_a_relationships_based_on_company.py b/alembic/versions/6f93f0d1defb_created_a_relationships_based_on_company.py new file mode 100644 index 00000000..a607ec0f --- /dev/null +++ b/alembic/versions/6f93f0d1defb_created_a_relationships_based_on_company.py @@ -0,0 +1,38 @@ +"""Created a relationships based on company + +Revision ID: 6f93f0d1defb +Revises: 6f3cc13e1339 +Create Date: 2024-01-20 15:46:13.093101 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '6f93f0d1defb' +down_revision: Union[str, None] = '6f3cc13e1339' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('user_roles', sa.Column('company_id', sa.Integer(), nullable=True)) + # op.create_unique_constraint('unique_user_role', 'user_roles', ['user_id', 'role_id', 'company_id']) + op.create_foreign_key(None, 'user_roles', 'companies', ['company_id'], ['id']) + op.add_column('users', sa.Column('company_id', sa.Integer(), nullable=True)) + op.create_foreign_key(None, 'users', 'companies', ['company_id'], ['id']) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint(None, 'users', type_='foreignkey') + op.drop_column('users', 'company_id') + op.drop_constraint(None, 'user_roles', type_='foreignkey') + # op.drop_constraint('unique_user_role', 'user_roles', type_='unique') + op.drop_column('user_roles', 'company_id') + # ### end Alembic commands ### diff --git a/alembic/versions/cccea6c7d70d_added_relation_company_with_user_roles.py b/alembic/versions/cccea6c7d70d_added_relation_company_with_user_roles.py new file mode 100644 index 00000000..afbe9d0d --- /dev/null +++ b/alembic/versions/cccea6c7d70d_added_relation_company_with_user_roles.py @@ -0,0 +1,32 @@ +"""Added relation company with user roles + +Revision ID: cccea6c7d70d +Revises: f93bebc068de +Create Date: 2024-01-20 17:17:36.178991 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'cccea6c7d70d' +down_revision: Union[str, None] = 'f93bebc068de' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + # op.create_unique_constraint('unique_user_role', 'user_roles', ['user_id', 'role_id', 'company_id']) + pass + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + # op.drop_constraint('unique_user_role', 'user_roles', type_='unique') + pass + # ### end Alembic commands ### diff --git a/alembic/versions/f93bebc068de_created_company_name_as_unique.py b/alembic/versions/f93bebc068de_created_company_name_as_unique.py new file mode 100644 index 00000000..986153b0 --- /dev/null +++ b/alembic/versions/f93bebc068de_created_company_name_as_unique.py @@ -0,0 +1,46 @@ +"""Created company name as unique + +Revision ID: f93bebc068de +Revises: 6f93f0d1defb +Create Date: 2024-01-20 17:05:36.133343 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'f93bebc068de' +down_revision: Union[str, None] = '6f93f0d1defb' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index('ix_companies_name', table_name='companies') + op.create_index(op.f('ix_companies_name'), 'companies', ['name'], unique=True) + op.alter_column('user_roles', 'company_id', + existing_type=sa.INTEGER(), + nullable=True) + # op.create_unique_constraint('unique_user_role', 'user_roles', ['user_id', 'role_id', 'company_id']) + op.alter_column('users', 'company_id', + existing_type=sa.INTEGER(), + nullable=True) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column('users', 'company_id', + existing_type=sa.INTEGER(), + nullable=True) + # op.drop_constraint('unique_user_role', 'user_roles', type_='unique') + op.alter_column('user_roles', 'company_id', + existing_type=sa.INTEGER(), + nullable=True) + op.drop_index(op.f('ix_companies_name'), table_name='companies') + op.create_index('ix_companies_name', 'companies', ['name'], unique=False) + # ### end Alembic commands ### diff --git a/bash.exe.stackdump b/bash.exe.stackdump new file mode 100644 index 00000000..a511e211 --- /dev/null +++ b/bash.exe.stackdump @@ -0,0 +1,19 @@ +Stack trace: +Frame Function Args +000005FF350 00210062B0E (00210298702, 00210275E3E, 0000000005E, 000005FAEB0) +000005FF350 0021004846A (00000000000, 00000000000, 7FF800000000, 00000001000) +000005FF350 002100484A2 (00000000000, 0000000005A, 0000000005E, 00000000000) +000005FF350 002100DA798 (00000000000, 00200000000, 002102759F2, 000005FBFFC) +000005FF350 00210133477 (00000000000, 0021022B110, 0021022B100, 000005FDDE0) +000005FF350 002100488B4 (00210317960, 000005FDDE0, 00000000000, 00000000000) +000005FF350 0021004A01F (0007FFE0384, 00000000000, 00000000000, 00000000000) +000005FF350 002100DB7D8 (00000000000, 00000000000, 00000000000, 00000000000) +000005FF5F0 7FF8C2889A1D (00210040000, 00000000001, 00000000000, 000005FF538) +000005FF5F0 7FF8C28DC2C7 (7FF8C28ADB00, 000000B3F01, 7FF800000001, 00000000001) +000005FF5F0 7FF8C28DC05A (000000B3F70, 000005FF5F0, 000000B48C0, 00000070000) +000005FF5F0 7FF8C28DC0E0 (7FF8C2995A10, 00000000000, 00000000010, 000005FF690) +00000000000 7FF8C2943C42 (00000000000, 00000000000, 00000000001, 00000000000) +00000000000 7FF8C28E4DBB (7FF8C2870000, 00000000000, 000003C5000, 00000000000) +00000000000 7FF8C28E4C43 (00000000000, 00000000000, 00000000000, 00000000000) +00000000000 7FF8C28E4BEE (00000000000, 00000000000, 00000000000, 00000000000) +End of stack trace diff --git a/private_gpt/launcher.py b/private_gpt/launcher.py index 5ae70e2c..e84216d0 100644 --- a/private_gpt/launcher.py +++ b/private_gpt/launcher.py @@ -26,12 +26,12 @@ def create_app(root_injector: Injector) -> FastAPI: app = FastAPI(dependencies=[Depends(bind_injector_to_request)]) - # app.include_router(completions_router) - # app.include_router(chat_router) - # app.include_router(chunks_router) - # app.include_router(ingest_router) - # app.include_router(embeddings_router) - # app.include_router(health_router) + app.include_router(completions_router) + app.include_router(chat_router) + app.include_router(chunks_router) + app.include_router(ingest_router) + app.include_router(embeddings_router) + app.include_router(health_router) app.include_router(api_router) diff --git a/private_gpt/users/api/deps.py b/private_gpt/users/api/deps.py index 3614f9e3..0e255164 100644 --- a/private_gpt/users/api/deps.py +++ b/private_gpt/users/api/deps.py @@ -14,6 +14,7 @@ from private_gpt.users.core.security import ( from fastapi import Depends, HTTPException, Security, status from jose import jwt from pydantic import ValidationError +from private_gpt.users.constants.role import Role from private_gpt.users.schemas.token import TokenPayload from sqlalchemy.orm import Session @@ -94,4 +95,32 @@ async def get_current_active_user( ) -> models.User: if not crud.user.is_active(current_user): raise HTTPException(status_code=400, detail="Inactive user") - return current_user \ No newline at end of file + return current_user + + + +async def is_company_admin(current_user: models.User = Depends(get_current_user)): + if current_user.role == Role.ADMIN["name"]: + return current_user + else: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You don't have the necessary permissions to perform this action", + ) + + +async def is_super_admin(current_user: models.User = Depends(get_current_user)): + if current_user.role == Role.SUPER_ADMIN["name"]: + return current_user + else: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You don't have the necessary permissions to perform this action", + ) + + +async def get_company_name(company_id: int, db: Session = Depends(get_db)) -> str: + company = crud.company.get(db=db, id=company_id) + if not company: + raise HTTPException(status_code=404, detail="Company not found") + return company.name diff --git a/private_gpt/users/api/v1/routers/auth.py b/private_gpt/users/api/v1/routers/auth.py index cb76732b..ac68de07 100644 --- a/private_gpt/users/api/v1/routers/auth.py +++ b/private_gpt/users/api/v1/routers/auth.py @@ -1,10 +1,12 @@ -from typing import Any +from typing import Any, Optional from datetime import timedelta, datetime -from pydantic.networks import EmailStr from sqlalchemy.orm import Session -from fastapi import APIRouter, Body, Depends, HTTPException, Security +from pydantic.networks import EmailStr +from fastapi.responses import JSONResponse +from fastapi.encoders import jsonable_encoder from fastapi.security import OAuth2PasswordRequestForm +from fastapi import APIRouter, Body, Depends, HTTPException, Security, Path, status from private_gpt.users.api import deps from private_gpt.users.core import security @@ -16,6 +18,45 @@ from private_gpt.users.utils import send_registration_email router = APIRouter(prefix="/auth", tags=["auth"]) +def register_user( + db: Session, + email: str, + fullname: str, + password: str, + company: Optional[models.Company] = None, +) -> models.User: + """ + Register a new user in the database. + """ + print(f"{email} {fullname} {password} {company.id}") + user_in = schemas.UserCreate(email=email, password=password, fullname=fullname, company_id=company.id) + send_registration_email(fullname, email, password) + return crud.user.create(db, obj_in=user_in) + + +def create_user_role( + db: Session, + user: models.User, + role_name: str, + company: Optional[models.Company] = None, +) -> models.UserRole: + """ + Create a user role in the database. + """ + role = crud.role.get_by_name(db, name=role_name) + user_role_in = schemas.UserRoleCreate(user_id=user.id, role_id=role.id, company_id=company.id if company else None) + return crud.user_role.create(db, obj_in=user_role_in) + + +def create_token_payload(user: models.User, user_role: models.UserRole) -> dict: + """ + Create a token payload for authentication. + """ + return { + "id": str(user.id), + "role": user_role.role.name, + "company_id": user_role.company.id if user_role.company else None, + } @router.post("/login", response_model=schemas.TokenSchema) def login_access_token( @@ -32,9 +73,7 @@ def login_access_token( raise HTTPException( status_code=400, detail="Incorrect email or password" ) - # elif not crud.user.is_active(user): - # raise HTTPException(status_code=400, detail="Inactive user") - + access_token_expires = timedelta( minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES ) @@ -42,21 +81,26 @@ def login_access_token( minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES ) user_in = schemas.UserUpdate( - email = user.email, - fullname = user.fullname, + email=user.email, + fullname=user.fullname, + company_id=user.company_id, last_login=datetime.now() ) + user = crud.user.update(db, db_obj=user, obj_in=user_in) - if not user.user_role: - role = "GUEST" - else: + + if user.user_role: role = user.user_role.role.name - + if user.user_role.company_id: + company_id = user.user_role.company_id + else: company_id = None + token_payload = { "id": str(user.id), "role": role, + "company_id": company_id, } - + return { "access_token": security.create_access_token( token_payload, expires_delta=access_token_expires @@ -68,20 +112,21 @@ def login_access_token( } -@router.post("/register", response_model=schemas.TokenSchema) -def register( + +@router.post("/{company_name}/register", response_model=schemas.User) +def register_for_company( *, db: Session = Depends(deps.get_db), - email: EmailStr = Body(...), + email: str = Body(...), fullname: str = Body(...), - role: str = Body(Default="GUEST"), + company_name: Optional[str] = Path(..., title="Company Name", description="Only for company admin"), current_user: models.User = Security( deps.get_current_user, - scopes=[Role.ADMIN["name"], Role.SUPER_ADMIN["name"]], + scopes=[Role.SUPER_ADMIN["name"], Role.ADMIN['name']], ), ) -> Any: """ - Register new user. + Register new user for a specific company. """ user = crud.user.get_by_email(db, email=email) if user: @@ -89,42 +134,74 @@ def register( status_code=409, detail="The user with this username already exists in the system", ) - random_password = security.generate_random_password() - user_in = schemas.UserCreate( - email=email, - password=random_password, - fullname=fullname, - ) - user = crud.user.create(db, obj_in=user_in) - send_registration_email(fullname, email, random_password) - role_db = crud.role.get_by_name(db, name=role) - if not role_db: + if current_user.user_role.role.name not in {Role.ADMIN["name"], Role.SUPER_ADMIN["name"]}: raise HTTPException( - status_code=404, - detail=f"Role '{role}' not found", + status_code=403, + detail="You do not have permission to register users for a company.", ) - user_role_in = schemas.UserRoleCreate( - user_id=user.id, - role_id=role_db.id + + company = crud.company.get_by_company_name(db, company_name=company_name) + print(f"Company is : {company.id}") + if not (current_user.user_role.role.name == Role.ADMIN["name"] and current_user.user_role.company_id == company.id): + raise HTTPException( + status_code=403, + detail="You are not the admin of the specified company.", + ) + + random_password = security.generate_random_password() + user = register_user(db, email, fullname, random_password, company) + user_role = create_user_role(db, user, Role.GUEST["name"], company) + + token_payload = create_token_payload(user, user_role) + return JSONResponse( + status_code=status.HTTP_201_CREATED, + content={"message": "User registered successfully.\n\n Check respective user email for login credentials", "user": jsonable_encoder(user)}, ) - user_role = crud.user_role.create(db, obj_in=user_role_in) - access_token_expires = timedelta( - minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES - ) - refresh_token_expires = timedelta( - minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES - ) - token_payload = { - "id": str(user.id), - "role": user_role.role.name, - } + + +@router.post("/register", response_model=schemas.TokenSchema) +def register_without_company_assignment( + *, + db: Session = Depends(deps.get_db), + email: str = Body(...), + fullname: str = Body(...), + company_id: int = Body(None, title="Company ID", description="Company ID for the user (if applicable)"), + current_user: models.User = Security( + deps.get_current_user, + scopes=[Role.SUPER_ADMIN["name"], Role.ADMIN['name']], + ), +) -> Any: + """ + Register new user with company assignment. + """ + user = crud.user.get_by_email(db, email=email) + if user: + raise HTTPException( + status_code=409, + detail="The user with this username already exists in the system", + ) + + if current_user.user_role.role.name != Role.SUPER_ADMIN["name"]: + raise HTTPException( + status_code=403, + detail="You do not have permission to register users without a company.", + ) + + if company_id is None: + raise HTTPException( + status_code=400, + detail="Company ID is required for registering a user without a specific company.", + ) + + random_password = security.generate_random_password() + company = crud.company.get(db, company_id) + user = register_user(db, email, fullname, random_password, company) + user_role = create_user_role(db, user, Role.ADMIN["name"], company) + + token_payload = create_token_payload(user, user_role) return { - "access_token": security.create_access_token( - token_payload, expires_delta=access_token_expires - ), - "refresh_token": security.create_refresh_token( - token_payload, expires_delta=refresh_token_expires - ), + "access_token": security.create_access_token(token_payload, expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)), + "refresh_token": security.create_refresh_token(token_payload, expires_delta=timedelta(minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES)), "token_type": "bearer", - } + } \ No newline at end of file diff --git a/private_gpt/users/api/v1/routers/user_roles.py b/private_gpt/users/api/v1/routers/user_roles.py index e0f21d36..2474a7d9 100644 --- a/private_gpt/users/api/v1/routers/user_roles.py +++ b/private_gpt/users/api/v1/routers/user_roles.py @@ -11,7 +11,6 @@ from private_gpt.users import crud, models, schemas router = APIRouter(prefix="/user-roles", tags=["user-roles"]) - @router.post("", response_model=schemas.UserRole) def assign_user_role( *, @@ -69,3 +68,10 @@ def update_user_role( status_code=status.HTTP_200_OK, content={"message": "User role updated successfully", "user_role": jsonable_encoder(user_role)}, ) + + + + +company_router = APIRouter(prefix="/user-roles", tags=["user-roles"]) + + diff --git a/private_gpt/users/api/v1/routers/users.py b/private_gpt/users/api/v1/routers/users.py index 5b8ef150..0cb99631 100644 --- a/private_gpt/users/api/v1/routers/users.py +++ b/private_gpt/users/api/v1/routers/users.py @@ -1,10 +1,10 @@ -from typing import Any, List +from typing import Any, List, Optional from sqlalchemy.orm import Session from pydantic.networks import EmailStr from fastapi.responses import JSONResponse from fastapi.encoders import jsonable_encoder -from fastapi import APIRouter, Body, Depends, HTTPException, Security, status +from fastapi import APIRouter, Body, Depends, HTTPException, Security, status, Path from private_gpt.users.api import deps from private_gpt.users.constants.role import Role @@ -24,13 +24,30 @@ def read_users( scopes=[Role.ADMIN["name"], Role.SUPER_ADMIN["name"]], ), ) -> Any: - """ + """ Retrieve all users. """ users = crud.user.get_multi(db, skip=skip, limit=limit) return users +@router.get("/{company_name}") +def read_users_by_company( + company_name: Optional[str] = Path(..., title="Company Name", description="Only for company admin"), + db: Session = Depends(deps.get_db), + current_user: models.User = Security( + deps.get_current_user, + scopes=[Role.ADMIN["name"], Role.SUPER_ADMIN["name"]], + ), +): + """ + Retrieve all users of that company only + """ + company = crud.company.get_by_company_name(db, company_name=company_name) + users = crud.user.get_multi_by_company_id(db, company_id=company.id) + return users + + @router.post("", response_model=schemas.User) def create_user( *, @@ -185,4 +202,5 @@ def update_user( return JSONResponse( status_code=status.HTTP_200_OK, content={"message": "User updated successfully", "user": jsonable_encoder(user_data)}, - ) \ No newline at end of file + ) + diff --git a/private_gpt/users/crud/company_crud.py b/private_gpt/users/crud/company_crud.py index 53f145ce..160c9cea 100644 --- a/private_gpt/users/crud/company_crud.py +++ b/private_gpt/users/crud/company_crud.py @@ -9,5 +9,8 @@ from sqlalchemy.orm import Session class CRUDCompany(CRUDBase[Company, CompanyCreate, CompanyUpdate]): def get_by_id(self, db: Session, *, id: str) -> Optional[Company]: return db.query(self.model).filter(Company.id == id).first() - + + def get_by_company_name(self, db: Session, *, company_name: str) -> Optional[Company]: + return db.query(self.model).filter(Company.name == company_name).first() + company = CRUDCompany(Company) \ No newline at end of file diff --git a/private_gpt/users/crud/role_crud.py b/private_gpt/users/crud/role_crud.py index ddd94bdc..3485c933 100644 --- a/private_gpt/users/crud/role_crud.py +++ b/private_gpt/users/crud/role_crud.py @@ -9,5 +9,5 @@ from sqlalchemy.orm import Session class CRUDRole(CRUDBase[Role, RoleCreate, RoleUpdate]): def get_by_name(self, db: Session, *, name: str) -> Optional[Role]: return db.query(self.model).filter(Role.name == name).first() - + role = CRUDRole(Role) \ No newline at end of file diff --git a/private_gpt/users/crud/user_crud.py b/private_gpt/users/crud/user_crud.py index ecd81c4c..3d333f40 100644 --- a/private_gpt/users/crud/user_crud.py +++ b/private_gpt/users/crud/user_crud.py @@ -73,5 +73,15 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]): .all() ) + def get_multi_by_company_id( + self, db: Session, *, company_id: str, skip: int = 0, limit: int = 100 + ) -> List[User]: + return ( + db.query(self.model) + .filter(User.company_id == company_id) + .offset(skip) + .limit(limit) + .all() + ) user = CRUDUser(User) \ No newline at end of file diff --git a/private_gpt/users/models/__init__.py b/private_gpt/users/models/__init__.py index b7bb9be8..a5e9e7c8 100644 --- a/private_gpt/users/models/__init__.py +++ b/private_gpt/users/models/__init__.py @@ -1 +1,4 @@ -from .user import User \ No newline at end of file +from .user import User +from .company import Company +from .user_role import UserRole +from .role import Role \ No newline at end of file diff --git a/private_gpt/users/models/company.py b/private_gpt/users/models/company.py index e8749474..3c9adf41 100644 --- a/private_gpt/users/models/company.py +++ b/private_gpt/users/models/company.py @@ -1,6 +1,8 @@ +from typing import List from sqlalchemy import Column, Integer, String from sqlalchemy.orm import relationship from private_gpt.users.db.base_class import Base +from private_gpt.users.schemas.user import User class Company(Base): """Models a Company table.""" @@ -8,7 +10,8 @@ class Company(Base): __tablename__ = "companies" id = Column(Integer, primary_key=True, index=True) - name = Column(String, index=True) - subscriptions = relationship("Subscription", back_populates="company") + name = Column(String, index=True, unique=True) - \ No newline at end of file + subscriptions = relationship("Subscription", back_populates="company") + users = relationship("User", back_populates="company") + user_roles = relationship("UserRole", back_populates="company") \ No newline at end of file diff --git a/private_gpt/users/models/user.py b/private_gpt/users/models/user.py index fd2c2f4c..655823b2 100644 --- a/private_gpt/users/models/user.py +++ b/private_gpt/users/models/user.py @@ -1,6 +1,5 @@ import datetime from sqlalchemy import ( - LargeBinary, Column, String, Integer, @@ -35,10 +34,10 @@ class User(Base): onupdate=datetime.datetime.utcnow, ) - # account_id = Column(Integer, ForeignKey("accounts.id"), nullable=True) + company_id = Column(Integer, ForeignKey("companies.id"), nullable=True) + company = relationship("Company", back_populates="users") user_role = relationship("UserRole", back_populates="user", uselist=False) - # account = relationship("Account", back_populates="users") def __repr__(self): """Returns string representation of model instance""" diff --git a/private_gpt/users/models/user_role.py b/private_gpt/users/models/user_role.py index 9aef0424..763d9972 100644 --- a/private_gpt/users/models/user_role.py +++ b/private_gpt/users/models/user_role.py @@ -5,7 +5,7 @@ from sqlalchemy.orm import relationship class UserRole(Base): __tablename__ = "user_roles" user_id = Column( - Integer, + Integer, ForeignKey("users.id"), primary_key=True, nullable=False, @@ -16,10 +16,16 @@ class UserRole(Base): primary_key=True, nullable=False, ) - + company_id = Column( + Integer, + ForeignKey("companies.id"), + primary_key=True, + nullable=True, + ) role = relationship("Role") user = relationship("User", back_populates="user_role", uselist=False) - + company = relationship("Company", back_populates="user_roles") + __table_args__ = ( - UniqueConstraint("user_id", "role_id", name="unique_user_role"), - ) \ No newline at end of file + UniqueConstraint("user_id", "role_id", "company_id", name="unique_user_role"), + ) diff --git a/private_gpt/users/schemas/company.py b/private_gpt/users/schemas/company.py index 69c21063..6c1d7b3b 100644 --- a/private_gpt/users/schemas/company.py +++ b/private_gpt/users/schemas/company.py @@ -14,6 +14,6 @@ class CompanyUpdate(CompanyBase): class Company(CompanyBase): id: int - + class Config: orm_mode = True \ No newline at end of file diff --git a/private_gpt/users/schemas/token.py b/private_gpt/users/schemas/token.py index 7da7d2db..b8e44d4a 100644 --- a/private_gpt/users/schemas/token.py +++ b/private_gpt/users/schemas/token.py @@ -11,6 +11,7 @@ class TokenSchema(BaseModel): class TokenPayload(BaseModel): id: int role: str = None + company: str = None class Config: arbitrary_types_allowed = True \ No newline at end of file diff --git a/private_gpt/users/schemas/user.py b/private_gpt/users/schemas/user.py index e775b89b..b335bbf5 100644 --- a/private_gpt/users/schemas/user.py +++ b/private_gpt/users/schemas/user.py @@ -4,11 +4,12 @@ from typing import Optional from pydantic import BaseModel, Field, EmailStr from private_gpt.users.schemas.user_role import UserRole - +from private_gpt.users.schemas.company import Company class UserBaseSchema(BaseModel): email: EmailStr fullname: str + company_id: Optional[int] class Config: arbitrary_types_allowed = True diff --git a/private_gpt/users/schemas/user_role.py b/private_gpt/users/schemas/user_role.py index 091a2747..65894f50 100644 --- a/private_gpt/users/schemas/user_role.py +++ b/private_gpt/users/schemas/user_role.py @@ -7,6 +7,7 @@ from pydantic import BaseModel class UserRoleBase(BaseModel): user_id: Optional[int] role_id: Optional[int] + company_id: Optional[int] class Config: arbitrary_types_allowed = True @@ -31,11 +32,10 @@ class UserRoleInDBBase(UserRoleBase): arbitrary_types_allowed = True - # Additional properties to return via API class UserRole(UserRoleInDBBase): pass class UserRoleInDB(UserRoleInDBBase): - pass \ No newline at end of file + pass