mirror of
https://github.com/imartinez/privateGPT.git
synced 2025-06-21 21:19:42 +00:00
Added routes for creating users through admin
This commit is contained in:
parent
8b396ac6fb
commit
0b79c23f68
8
.env
8
.env
@ -15,8 +15,8 @@ ACCESS_TOKEN_EXPIRE_MINUTES=60
|
||||
REFRESH_TOKEN_EXPIRE_MINUTES = 120 # 7 days
|
||||
|
||||
|
||||
SMTP_SERVER=smtp-mail.outlook.com
|
||||
SMTP_SERVER=smtp.gmail.com
|
||||
SMTP_PORT=587
|
||||
SMTP_SENDER_EMAIL=saurabstha7@outlook.com
|
||||
SMTP_USERNAME=saurabstha7@outlook.com
|
||||
SMTP_PASSWORD=avantador123
|
||||
SMTP_SENDER_EMAIL=shresthasaurab030@outlook.com
|
||||
SMTP_USERNAME=shresthasaurab030
|
||||
SMTP_PASSWORD=huurxwxeorxjorzw
|
@ -0,0 +1,38 @@
|
||||
"""Created a relationships based on company
|
||||
|
||||
Revision ID: 6f93f0d1defb
|
||||
Revises: 6f3cc13e1339
|
||||
Create Date: 2024-01-20 15:46:13.093101
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '6f93f0d1defb'
|
||||
down_revision: Union[str, None] = '6f3cc13e1339'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('user_roles', sa.Column('company_id', sa.Integer(), nullable=True))
|
||||
# op.create_unique_constraint('unique_user_role', 'user_roles', ['user_id', 'role_id', 'company_id'])
|
||||
op.create_foreign_key(None, 'user_roles', 'companies', ['company_id'], ['id'])
|
||||
op.add_column('users', sa.Column('company_id', sa.Integer(), nullable=True))
|
||||
op.create_foreign_key(None, 'users', 'companies', ['company_id'], ['id'])
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_constraint(None, 'users', type_='foreignkey')
|
||||
op.drop_column('users', 'company_id')
|
||||
op.drop_constraint(None, 'user_roles', type_='foreignkey')
|
||||
# op.drop_constraint('unique_user_role', 'user_roles', type_='unique')
|
||||
op.drop_column('user_roles', 'company_id')
|
||||
# ### end Alembic commands ###
|
@ -0,0 +1,32 @@
|
||||
"""Added relation company with user roles
|
||||
|
||||
Revision ID: cccea6c7d70d
|
||||
Revises: f93bebc068de
|
||||
Create Date: 2024-01-20 17:17:36.178991
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'cccea6c7d70d'
|
||||
down_revision: Union[str, None] = 'f93bebc068de'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
# op.create_unique_constraint('unique_user_role', 'user_roles', ['user_id', 'role_id', 'company_id'])
|
||||
pass
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
# op.drop_constraint('unique_user_role', 'user_roles', type_='unique')
|
||||
pass
|
||||
# ### end Alembic commands ###
|
@ -0,0 +1,46 @@
|
||||
"""Created company name as unique
|
||||
|
||||
Revision ID: f93bebc068de
|
||||
Revises: 6f93f0d1defb
|
||||
Create Date: 2024-01-20 17:05:36.133343
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'f93bebc068de'
|
||||
down_revision: Union[str, None] = '6f93f0d1defb'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index('ix_companies_name', table_name='companies')
|
||||
op.create_index(op.f('ix_companies_name'), 'companies', ['name'], unique=True)
|
||||
op.alter_column('user_roles', 'company_id',
|
||||
existing_type=sa.INTEGER(),
|
||||
nullable=True)
|
||||
# op.create_unique_constraint('unique_user_role', 'user_roles', ['user_id', 'role_id', 'company_id'])
|
||||
op.alter_column('users', 'company_id',
|
||||
existing_type=sa.INTEGER(),
|
||||
nullable=True)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.alter_column('users', 'company_id',
|
||||
existing_type=sa.INTEGER(),
|
||||
nullable=True)
|
||||
# op.drop_constraint('unique_user_role', 'user_roles', type_='unique')
|
||||
op.alter_column('user_roles', 'company_id',
|
||||
existing_type=sa.INTEGER(),
|
||||
nullable=True)
|
||||
op.drop_index(op.f('ix_companies_name'), table_name='companies')
|
||||
op.create_index('ix_companies_name', 'companies', ['name'], unique=False)
|
||||
# ### end Alembic commands ###
|
19
bash.exe.stackdump
Normal file
19
bash.exe.stackdump
Normal file
@ -0,0 +1,19 @@
|
||||
Stack trace:
|
||||
Frame Function Args
|
||||
000005FF350 00210062B0E (00210298702, 00210275E3E, 0000000005E, 000005FAEB0)
|
||||
000005FF350 0021004846A (00000000000, 00000000000, 7FF800000000, 00000001000)
|
||||
000005FF350 002100484A2 (00000000000, 0000000005A, 0000000005E, 00000000000)
|
||||
000005FF350 002100DA798 (00000000000, 00200000000, 002102759F2, 000005FBFFC)
|
||||
000005FF350 00210133477 (00000000000, 0021022B110, 0021022B100, 000005FDDE0)
|
||||
000005FF350 002100488B4 (00210317960, 000005FDDE0, 00000000000, 00000000000)
|
||||
000005FF350 0021004A01F (0007FFE0384, 00000000000, 00000000000, 00000000000)
|
||||
000005FF350 002100DB7D8 (00000000000, 00000000000, 00000000000, 00000000000)
|
||||
000005FF5F0 7FF8C2889A1D (00210040000, 00000000001, 00000000000, 000005FF538)
|
||||
000005FF5F0 7FF8C28DC2C7 (7FF8C28ADB00, 000000B3F01, 7FF800000001, 00000000001)
|
||||
000005FF5F0 7FF8C28DC05A (000000B3F70, 000005FF5F0, 000000B48C0, 00000070000)
|
||||
000005FF5F0 7FF8C28DC0E0 (7FF8C2995A10, 00000000000, 00000000010, 000005FF690)
|
||||
00000000000 7FF8C2943C42 (00000000000, 00000000000, 00000000001, 00000000000)
|
||||
00000000000 7FF8C28E4DBB (7FF8C2870000, 00000000000, 000003C5000, 00000000000)
|
||||
00000000000 7FF8C28E4C43 (00000000000, 00000000000, 00000000000, 00000000000)
|
||||
00000000000 7FF8C28E4BEE (00000000000, 00000000000, 00000000000, 00000000000)
|
||||
End of stack trace
|
@ -26,12 +26,12 @@ def create_app(root_injector: Injector) -> FastAPI:
|
||||
|
||||
app = FastAPI(dependencies=[Depends(bind_injector_to_request)])
|
||||
|
||||
# app.include_router(completions_router)
|
||||
# app.include_router(chat_router)
|
||||
# app.include_router(chunks_router)
|
||||
# app.include_router(ingest_router)
|
||||
# app.include_router(embeddings_router)
|
||||
# app.include_router(health_router)
|
||||
app.include_router(completions_router)
|
||||
app.include_router(chat_router)
|
||||
app.include_router(chunks_router)
|
||||
app.include_router(ingest_router)
|
||||
app.include_router(embeddings_router)
|
||||
app.include_router(health_router)
|
||||
|
||||
app.include_router(api_router)
|
||||
|
||||
|
@ -14,6 +14,7 @@ from private_gpt.users.core.security import (
|
||||
from fastapi import Depends, HTTPException, Security, status
|
||||
from jose import jwt
|
||||
from pydantic import ValidationError
|
||||
from private_gpt.users.constants.role import Role
|
||||
from private_gpt.users.schemas.token import TokenPayload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -94,4 +95,32 @@ async def get_current_active_user(
|
||||
) -> models.User:
|
||||
if not crud.user.is_active(current_user):
|
||||
raise HTTPException(status_code=400, detail="Inactive user")
|
||||
return current_user
|
||||
return current_user
|
||||
|
||||
|
||||
|
||||
async def is_company_admin(current_user: models.User = Depends(get_current_user)):
|
||||
if current_user.role == Role.ADMIN["name"]:
|
||||
return current_user
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have the necessary permissions to perform this action",
|
||||
)
|
||||
|
||||
|
||||
async def is_super_admin(current_user: models.User = Depends(get_current_user)):
|
||||
if current_user.role == Role.SUPER_ADMIN["name"]:
|
||||
return current_user
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have the necessary permissions to perform this action",
|
||||
)
|
||||
|
||||
|
||||
async def get_company_name(company_id: int, db: Session = Depends(get_db)) -> str:
|
||||
company = crud.company.get(db=db, id=company_id)
|
||||
if not company:
|
||||
raise HTTPException(status_code=404, detail="Company not found")
|
||||
return company.name
|
||||
|
@ -1,10 +1,12 @@
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
from datetime import timedelta, datetime
|
||||
|
||||
from pydantic.networks import EmailStr
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Security
|
||||
from pydantic.networks import EmailStr
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Security, Path, status
|
||||
|
||||
from private_gpt.users.api import deps
|
||||
from private_gpt.users.core import security
|
||||
@ -16,6 +18,45 @@ from private_gpt.users.utils import send_registration_email
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
def register_user(
|
||||
db: Session,
|
||||
email: str,
|
||||
fullname: str,
|
||||
password: str,
|
||||
company: Optional[models.Company] = None,
|
||||
) -> models.User:
|
||||
"""
|
||||
Register a new user in the database.
|
||||
"""
|
||||
print(f"{email} {fullname} {password} {company.id}")
|
||||
user_in = schemas.UserCreate(email=email, password=password, fullname=fullname, company_id=company.id)
|
||||
send_registration_email(fullname, email, password)
|
||||
return crud.user.create(db, obj_in=user_in)
|
||||
|
||||
|
||||
def create_user_role(
|
||||
db: Session,
|
||||
user: models.User,
|
||||
role_name: str,
|
||||
company: Optional[models.Company] = None,
|
||||
) -> models.UserRole:
|
||||
"""
|
||||
Create a user role in the database.
|
||||
"""
|
||||
role = crud.role.get_by_name(db, name=role_name)
|
||||
user_role_in = schemas.UserRoleCreate(user_id=user.id, role_id=role.id, company_id=company.id if company else None)
|
||||
return crud.user_role.create(db, obj_in=user_role_in)
|
||||
|
||||
|
||||
def create_token_payload(user: models.User, user_role: models.UserRole) -> dict:
|
||||
"""
|
||||
Create a token payload for authentication.
|
||||
"""
|
||||
return {
|
||||
"id": str(user.id),
|
||||
"role": user_role.role.name,
|
||||
"company_id": user_role.company.id if user_role.company else None,
|
||||
}
|
||||
|
||||
@router.post("/login", response_model=schemas.TokenSchema)
|
||||
def login_access_token(
|
||||
@ -32,9 +73,7 @@ def login_access_token(
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Incorrect email or password"
|
||||
)
|
||||
# elif not crud.user.is_active(user):
|
||||
# raise HTTPException(status_code=400, detail="Inactive user")
|
||||
|
||||
|
||||
access_token_expires = timedelta(
|
||||
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
)
|
||||
@ -42,21 +81,26 @@ def login_access_token(
|
||||
minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES
|
||||
)
|
||||
user_in = schemas.UserUpdate(
|
||||
email = user.email,
|
||||
fullname = user.fullname,
|
||||
email=user.email,
|
||||
fullname=user.fullname,
|
||||
company_id=user.company_id,
|
||||
last_login=datetime.now()
|
||||
)
|
||||
|
||||
user = crud.user.update(db, db_obj=user, obj_in=user_in)
|
||||
if not user.user_role:
|
||||
role = "GUEST"
|
||||
else:
|
||||
|
||||
if user.user_role:
|
||||
role = user.user_role.role.name
|
||||
|
||||
if user.user_role.company_id:
|
||||
company_id = user.user_role.company_id
|
||||
else: company_id = None
|
||||
|
||||
token_payload = {
|
||||
"id": str(user.id),
|
||||
"role": role,
|
||||
"company_id": company_id,
|
||||
}
|
||||
|
||||
|
||||
return {
|
||||
"access_token": security.create_access_token(
|
||||
token_payload, expires_delta=access_token_expires
|
||||
@ -68,20 +112,21 @@ def login_access_token(
|
||||
}
|
||||
|
||||
|
||||
@router.post("/register", response_model=schemas.TokenSchema)
|
||||
def register(
|
||||
|
||||
@router.post("/{company_name}/register", response_model=schemas.User)
|
||||
def register_for_company(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
email: EmailStr = Body(...),
|
||||
email: str = Body(...),
|
||||
fullname: str = Body(...),
|
||||
role: str = Body(Default="GUEST"),
|
||||
company_name: Optional[str] = Path(..., title="Company Name", description="Only for company admin"),
|
||||
current_user: models.User = Security(
|
||||
deps.get_current_user,
|
||||
scopes=[Role.ADMIN["name"], Role.SUPER_ADMIN["name"]],
|
||||
scopes=[Role.SUPER_ADMIN["name"], Role.ADMIN['name']],
|
||||
),
|
||||
) -> Any:
|
||||
"""
|
||||
Register new user.
|
||||
Register new user for a specific company.
|
||||
"""
|
||||
user = crud.user.get_by_email(db, email=email)
|
||||
if user:
|
||||
@ -89,42 +134,74 @@ def register(
|
||||
status_code=409,
|
||||
detail="The user with this username already exists in the system",
|
||||
)
|
||||
random_password = security.generate_random_password()
|
||||
user_in = schemas.UserCreate(
|
||||
email=email,
|
||||
password=random_password,
|
||||
fullname=fullname,
|
||||
)
|
||||
user = crud.user.create(db, obj_in=user_in)
|
||||
send_registration_email(fullname, email, random_password)
|
||||
|
||||
role_db = crud.role.get_by_name(db, name=role)
|
||||
if not role_db:
|
||||
if current_user.user_role.role.name not in {Role.ADMIN["name"], Role.SUPER_ADMIN["name"]}:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Role '{role}' not found",
|
||||
status_code=403,
|
||||
detail="You do not have permission to register users for a company.",
|
||||
)
|
||||
user_role_in = schemas.UserRoleCreate(
|
||||
user_id=user.id,
|
||||
role_id=role_db.id
|
||||
|
||||
company = crud.company.get_by_company_name(db, company_name=company_name)
|
||||
print(f"Company is : {company.id}")
|
||||
if not (current_user.user_role.role.name == Role.ADMIN["name"] and current_user.user_role.company_id == company.id):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="You are not the admin of the specified company.",
|
||||
)
|
||||
|
||||
random_password = security.generate_random_password()
|
||||
user = register_user(db, email, fullname, random_password, company)
|
||||
user_role = create_user_role(db, user, Role.GUEST["name"], company)
|
||||
|
||||
token_payload = create_token_payload(user, user_role)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
content={"message": "User registered successfully.\n\n Check respective user email for login credentials", "user": jsonable_encoder(user)},
|
||||
)
|
||||
user_role = crud.user_role.create(db, obj_in=user_role_in)
|
||||
access_token_expires = timedelta(
|
||||
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
)
|
||||
refresh_token_expires = timedelta(
|
||||
minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES
|
||||
)
|
||||
token_payload = {
|
||||
"id": str(user.id),
|
||||
"role": user_role.role.name,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/register", response_model=schemas.TokenSchema)
|
||||
def register_without_company_assignment(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
email: str = Body(...),
|
||||
fullname: str = Body(...),
|
||||
company_id: int = Body(None, title="Company ID", description="Company ID for the user (if applicable)"),
|
||||
current_user: models.User = Security(
|
||||
deps.get_current_user,
|
||||
scopes=[Role.SUPER_ADMIN["name"], Role.ADMIN['name']],
|
||||
),
|
||||
) -> Any:
|
||||
"""
|
||||
Register new user with company assignment.
|
||||
"""
|
||||
user = crud.user.get_by_email(db, email=email)
|
||||
if user:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="The user with this username already exists in the system",
|
||||
)
|
||||
|
||||
if current_user.user_role.role.name != Role.SUPER_ADMIN["name"]:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="You do not have permission to register users without a company.",
|
||||
)
|
||||
|
||||
if company_id is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Company ID is required for registering a user without a specific company.",
|
||||
)
|
||||
|
||||
random_password = security.generate_random_password()
|
||||
company = crud.company.get(db, company_id)
|
||||
user = register_user(db, email, fullname, random_password, company)
|
||||
user_role = create_user_role(db, user, Role.ADMIN["name"], company)
|
||||
|
||||
token_payload = create_token_payload(user, user_role)
|
||||
return {
|
||||
"access_token": security.create_access_token(
|
||||
token_payload, expires_delta=access_token_expires
|
||||
),
|
||||
"refresh_token": security.create_refresh_token(
|
||||
token_payload, expires_delta=refresh_token_expires
|
||||
),
|
||||
"access_token": security.create_access_token(token_payload, expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)),
|
||||
"refresh_token": security.create_refresh_token(token_payload, expires_delta=timedelta(minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES)),
|
||||
"token_type": "bearer",
|
||||
}
|
||||
}
|
@ -11,7 +11,6 @@ from private_gpt.users import crud, models, schemas
|
||||
|
||||
router = APIRouter(prefix="/user-roles", tags=["user-roles"])
|
||||
|
||||
|
||||
@router.post("", response_model=schemas.UserRole)
|
||||
def assign_user_role(
|
||||
*,
|
||||
@ -69,3 +68,10 @@ def update_user_role(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content={"message": "User role updated successfully", "user_role": jsonable_encoder(user_role)},
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
company_router = APIRouter(prefix="/user-roles", tags=["user-roles"])
|
||||
|
||||
|
||||
|
@ -1,10 +1,10 @@
|
||||
from typing import Any, List
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from pydantic.networks import EmailStr
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Security, status
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Security, status, Path
|
||||
|
||||
from private_gpt.users.api import deps
|
||||
from private_gpt.users.constants.role import Role
|
||||
@ -24,13 +24,30 @@ def read_users(
|
||||
scopes=[Role.ADMIN["name"], Role.SUPER_ADMIN["name"]],
|
||||
),
|
||||
) -> Any:
|
||||
"""
|
||||
"""
|
||||
Retrieve all users.
|
||||
"""
|
||||
users = crud.user.get_multi(db, skip=skip, limit=limit)
|
||||
return users
|
||||
|
||||
|
||||
@router.get("/{company_name}")
|
||||
def read_users_by_company(
|
||||
company_name: Optional[str] = Path(..., title="Company Name", description="Only for company admin"),
|
||||
db: Session = Depends(deps.get_db),
|
||||
current_user: models.User = Security(
|
||||
deps.get_current_user,
|
||||
scopes=[Role.ADMIN["name"], Role.SUPER_ADMIN["name"]],
|
||||
),
|
||||
):
|
||||
"""
|
||||
Retrieve all users of that company only
|
||||
"""
|
||||
company = crud.company.get_by_company_name(db, company_name=company_name)
|
||||
users = crud.user.get_multi_by_company_id(db, company_id=company.id)
|
||||
return users
|
||||
|
||||
|
||||
@router.post("", response_model=schemas.User)
|
||||
def create_user(
|
||||
*,
|
||||
@ -185,4 +202,5 @@ def update_user(
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content={"message": "User updated successfully", "user": jsonable_encoder(user_data)},
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -9,5 +9,8 @@ from sqlalchemy.orm import Session
|
||||
class CRUDCompany(CRUDBase[Company, CompanyCreate, CompanyUpdate]):
|
||||
def get_by_id(self, db: Session, *, id: str) -> Optional[Company]:
|
||||
return db.query(self.model).filter(Company.id == id).first()
|
||||
|
||||
|
||||
def get_by_company_name(self, db: Session, *, company_name: str) -> Optional[Company]:
|
||||
return db.query(self.model).filter(Company.name == company_name).first()
|
||||
|
||||
company = CRUDCompany(Company)
|
@ -9,5 +9,5 @@ from sqlalchemy.orm import Session
|
||||
class CRUDRole(CRUDBase[Role, RoleCreate, RoleUpdate]):
|
||||
def get_by_name(self, db: Session, *, name: str) -> Optional[Role]:
|
||||
return db.query(self.model).filter(Role.name == name).first()
|
||||
|
||||
|
||||
role = CRUDRole(Role)
|
@ -73,5 +73,15 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
||||
.all()
|
||||
)
|
||||
|
||||
def get_multi_by_company_id(
|
||||
self, db: Session, *, company_id: str, skip: int = 0, limit: int = 100
|
||||
) -> List[User]:
|
||||
return (
|
||||
db.query(self.model)
|
||||
.filter(User.company_id == company_id)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
user = CRUDUser(User)
|
@ -1 +1,4 @@
|
||||
from .user import User
|
||||
from .user import User
|
||||
from .company import Company
|
||||
from .user_role import UserRole
|
||||
from .role import Role
|
@ -1,6 +1,8 @@
|
||||
from typing import List
|
||||
from sqlalchemy import Column, Integer, String
|
||||
from sqlalchemy.orm import relationship
|
||||
from private_gpt.users.db.base_class import Base
|
||||
from private_gpt.users.schemas.user import User
|
||||
|
||||
class Company(Base):
|
||||
"""Models a Company table."""
|
||||
@ -8,7 +10,8 @@ class Company(Base):
|
||||
__tablename__ = "companies"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, index=True)
|
||||
subscriptions = relationship("Subscription", back_populates="company")
|
||||
name = Column(String, index=True, unique=True)
|
||||
|
||||
|
||||
subscriptions = relationship("Subscription", back_populates="company")
|
||||
users = relationship("User", back_populates="company")
|
||||
user_roles = relationship("UserRole", back_populates="company")
|
@ -1,6 +1,5 @@
|
||||
import datetime
|
||||
from sqlalchemy import (
|
||||
LargeBinary,
|
||||
Column,
|
||||
String,
|
||||
Integer,
|
||||
@ -35,10 +34,10 @@ class User(Base):
|
||||
onupdate=datetime.datetime.utcnow,
|
||||
)
|
||||
|
||||
# account_id = Column(Integer, ForeignKey("accounts.id"), nullable=True)
|
||||
company_id = Column(Integer, ForeignKey("companies.id"), nullable=True)
|
||||
company = relationship("Company", back_populates="users")
|
||||
|
||||
user_role = relationship("UserRole", back_populates="user", uselist=False)
|
||||
# account = relationship("Account", back_populates="users")
|
||||
|
||||
def __repr__(self):
|
||||
"""Returns string representation of model instance"""
|
||||
|
@ -5,7 +5,7 @@ from sqlalchemy.orm import relationship
|
||||
class UserRole(Base):
|
||||
__tablename__ = "user_roles"
|
||||
user_id = Column(
|
||||
Integer,
|
||||
Integer,
|
||||
ForeignKey("users.id"),
|
||||
primary_key=True,
|
||||
nullable=False,
|
||||
@ -16,10 +16,16 @@ class UserRole(Base):
|
||||
primary_key=True,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
company_id = Column(
|
||||
Integer,
|
||||
ForeignKey("companies.id"),
|
||||
primary_key=True,
|
||||
nullable=True,
|
||||
)
|
||||
role = relationship("Role")
|
||||
user = relationship("User", back_populates="user_role", uselist=False)
|
||||
|
||||
company = relationship("Company", back_populates="user_roles")
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("user_id", "role_id", name="unique_user_role"),
|
||||
)
|
||||
UniqueConstraint("user_id", "role_id", "company_id", name="unique_user_role"),
|
||||
)
|
||||
|
@ -14,6 +14,6 @@ class CompanyUpdate(CompanyBase):
|
||||
|
||||
class Company(CompanyBase):
|
||||
id: int
|
||||
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
@ -11,6 +11,7 @@ class TokenSchema(BaseModel):
|
||||
class TokenPayload(BaseModel):
|
||||
id: int
|
||||
role: str = None
|
||||
company: str = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
@ -4,11 +4,12 @@ from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, EmailStr
|
||||
from private_gpt.users.schemas.user_role import UserRole
|
||||
|
||||
from private_gpt.users.schemas.company import Company
|
||||
|
||||
class UserBaseSchema(BaseModel):
|
||||
email: EmailStr
|
||||
fullname: str
|
||||
company_id: Optional[int]
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
@ -7,6 +7,7 @@ from pydantic import BaseModel
|
||||
class UserRoleBase(BaseModel):
|
||||
user_id: Optional[int]
|
||||
role_id: Optional[int]
|
||||
company_id: Optional[int]
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
@ -31,11 +32,10 @@ class UserRoleInDBBase(UserRoleBase):
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
|
||||
# Additional properties to return via API
|
||||
class UserRole(UserRoleInDBBase):
|
||||
pass
|
||||
|
||||
|
||||
class UserRoleInDB(UserRoleInDBBase):
|
||||
pass
|
||||
pass
|
||||
|
Loading…
Reference in New Issue
Block a user