mirror of
https://github.com/imartinez/privateGPT.git
synced 2025-06-22 05:30:34 +00:00
Added routes for creating users through admin
This commit is contained in:
parent
8b396ac6fb
commit
0b79c23f68
8
.env
8
.env
@ -15,8 +15,8 @@ ACCESS_TOKEN_EXPIRE_MINUTES=60
|
|||||||
REFRESH_TOKEN_EXPIRE_MINUTES = 120 # 7 days
|
REFRESH_TOKEN_EXPIRE_MINUTES = 120 # 7 days
|
||||||
|
|
||||||
|
|
||||||
SMTP_SERVER=smtp-mail.outlook.com
|
SMTP_SERVER=smtp.gmail.com
|
||||||
SMTP_PORT=587
|
SMTP_PORT=587
|
||||||
SMTP_SENDER_EMAIL=saurabstha7@outlook.com
|
SMTP_SENDER_EMAIL=shresthasaurab030@outlook.com
|
||||||
SMTP_USERNAME=saurabstha7@outlook.com
|
SMTP_USERNAME=shresthasaurab030
|
||||||
SMTP_PASSWORD=avantador123
|
SMTP_PASSWORD=huurxwxeorxjorzw
|
@ -0,0 +1,38 @@
|
|||||||
|
"""Created a relationships based on company
|
||||||
|
|
||||||
|
Revision ID: 6f93f0d1defb
|
||||||
|
Revises: 6f3cc13e1339
|
||||||
|
Create Date: 2024-01-20 15:46:13.093101
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '6f93f0d1defb'
|
||||||
|
down_revision: Union[str, None] = '6f3cc13e1339'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.add_column('user_roles', sa.Column('company_id', sa.Integer(), nullable=True))
|
||||||
|
# op.create_unique_constraint('unique_user_role', 'user_roles', ['user_id', 'role_id', 'company_id'])
|
||||||
|
op.create_foreign_key(None, 'user_roles', 'companies', ['company_id'], ['id'])
|
||||||
|
op.add_column('users', sa.Column('company_id', sa.Integer(), nullable=True))
|
||||||
|
op.create_foreign_key(None, 'users', 'companies', ['company_id'], ['id'])
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_constraint(None, 'users', type_='foreignkey')
|
||||||
|
op.drop_column('users', 'company_id')
|
||||||
|
op.drop_constraint(None, 'user_roles', type_='foreignkey')
|
||||||
|
# op.drop_constraint('unique_user_role', 'user_roles', type_='unique')
|
||||||
|
op.drop_column('user_roles', 'company_id')
|
||||||
|
# ### end Alembic commands ###
|
@ -0,0 +1,32 @@
|
|||||||
|
"""Added relation company with user roles
|
||||||
|
|
||||||
|
Revision ID: cccea6c7d70d
|
||||||
|
Revises: f93bebc068de
|
||||||
|
Create Date: 2024-01-20 17:17:36.178991
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = 'cccea6c7d70d'
|
||||||
|
down_revision: Union[str, None] = 'f93bebc068de'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
# op.create_unique_constraint('unique_user_role', 'user_roles', ['user_id', 'role_id', 'company_id'])
|
||||||
|
pass
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
# op.drop_constraint('unique_user_role', 'user_roles', type_='unique')
|
||||||
|
pass
|
||||||
|
# ### end Alembic commands ###
|
@ -0,0 +1,46 @@
|
|||||||
|
"""Created company name as unique
|
||||||
|
|
||||||
|
Revision ID: f93bebc068de
|
||||||
|
Revises: 6f93f0d1defb
|
||||||
|
Create Date: 2024-01-20 17:05:36.133343
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = 'f93bebc068de'
|
||||||
|
down_revision: Union[str, None] = '6f93f0d1defb'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_index('ix_companies_name', table_name='companies')
|
||||||
|
op.create_index(op.f('ix_companies_name'), 'companies', ['name'], unique=True)
|
||||||
|
op.alter_column('user_roles', 'company_id',
|
||||||
|
existing_type=sa.INTEGER(),
|
||||||
|
nullable=True)
|
||||||
|
# op.create_unique_constraint('unique_user_role', 'user_roles', ['user_id', 'role_id', 'company_id'])
|
||||||
|
op.alter_column('users', 'company_id',
|
||||||
|
existing_type=sa.INTEGER(),
|
||||||
|
nullable=True)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.alter_column('users', 'company_id',
|
||||||
|
existing_type=sa.INTEGER(),
|
||||||
|
nullable=True)
|
||||||
|
# op.drop_constraint('unique_user_role', 'user_roles', type_='unique')
|
||||||
|
op.alter_column('user_roles', 'company_id',
|
||||||
|
existing_type=sa.INTEGER(),
|
||||||
|
nullable=True)
|
||||||
|
op.drop_index(op.f('ix_companies_name'), table_name='companies')
|
||||||
|
op.create_index('ix_companies_name', 'companies', ['name'], unique=False)
|
||||||
|
# ### end Alembic commands ###
|
19
bash.exe.stackdump
Normal file
19
bash.exe.stackdump
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
Stack trace:
|
||||||
|
Frame Function Args
|
||||||
|
000005FF350 00210062B0E (00210298702, 00210275E3E, 0000000005E, 000005FAEB0)
|
||||||
|
000005FF350 0021004846A (00000000000, 00000000000, 7FF800000000, 00000001000)
|
||||||
|
000005FF350 002100484A2 (00000000000, 0000000005A, 0000000005E, 00000000000)
|
||||||
|
000005FF350 002100DA798 (00000000000, 00200000000, 002102759F2, 000005FBFFC)
|
||||||
|
000005FF350 00210133477 (00000000000, 0021022B110, 0021022B100, 000005FDDE0)
|
||||||
|
000005FF350 002100488B4 (00210317960, 000005FDDE0, 00000000000, 00000000000)
|
||||||
|
000005FF350 0021004A01F (0007FFE0384, 00000000000, 00000000000, 00000000000)
|
||||||
|
000005FF350 002100DB7D8 (00000000000, 00000000000, 00000000000, 00000000000)
|
||||||
|
000005FF5F0 7FF8C2889A1D (00210040000, 00000000001, 00000000000, 000005FF538)
|
||||||
|
000005FF5F0 7FF8C28DC2C7 (7FF8C28ADB00, 000000B3F01, 7FF800000001, 00000000001)
|
||||||
|
000005FF5F0 7FF8C28DC05A (000000B3F70, 000005FF5F0, 000000B48C0, 00000070000)
|
||||||
|
000005FF5F0 7FF8C28DC0E0 (7FF8C2995A10, 00000000000, 00000000010, 000005FF690)
|
||||||
|
00000000000 7FF8C2943C42 (00000000000, 00000000000, 00000000001, 00000000000)
|
||||||
|
00000000000 7FF8C28E4DBB (7FF8C2870000, 00000000000, 000003C5000, 00000000000)
|
||||||
|
00000000000 7FF8C28E4C43 (00000000000, 00000000000, 00000000000, 00000000000)
|
||||||
|
00000000000 7FF8C28E4BEE (00000000000, 00000000000, 00000000000, 00000000000)
|
||||||
|
End of stack trace
|
@ -26,12 +26,12 @@ def create_app(root_injector: Injector) -> FastAPI:
|
|||||||
|
|
||||||
app = FastAPI(dependencies=[Depends(bind_injector_to_request)])
|
app = FastAPI(dependencies=[Depends(bind_injector_to_request)])
|
||||||
|
|
||||||
# app.include_router(completions_router)
|
app.include_router(completions_router)
|
||||||
# app.include_router(chat_router)
|
app.include_router(chat_router)
|
||||||
# app.include_router(chunks_router)
|
app.include_router(chunks_router)
|
||||||
# app.include_router(ingest_router)
|
app.include_router(ingest_router)
|
||||||
# app.include_router(embeddings_router)
|
app.include_router(embeddings_router)
|
||||||
# app.include_router(health_router)
|
app.include_router(health_router)
|
||||||
|
|
||||||
app.include_router(api_router)
|
app.include_router(api_router)
|
||||||
|
|
||||||
|
@ -14,6 +14,7 @@ from private_gpt.users.core.security import (
|
|||||||
from fastapi import Depends, HTTPException, Security, status
|
from fastapi import Depends, HTTPException, Security, status
|
||||||
from jose import jwt
|
from jose import jwt
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
from private_gpt.users.constants.role import Role
|
||||||
from private_gpt.users.schemas.token import TokenPayload
|
from private_gpt.users.schemas.token import TokenPayload
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
@ -95,3 +96,31 @@ async def get_current_active_user(
|
|||||||
if not crud.user.is_active(current_user):
|
if not crud.user.is_active(current_user):
|
||||||
raise HTTPException(status_code=400, detail="Inactive user")
|
raise HTTPException(status_code=400, detail="Inactive user")
|
||||||
return current_user
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
async def is_company_admin(current_user: models.User = Depends(get_current_user)):
|
||||||
|
if current_user.role == Role.ADMIN["name"]:
|
||||||
|
return current_user
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="You don't have the necessary permissions to perform this action",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def is_super_admin(current_user: models.User = Depends(get_current_user)):
|
||||||
|
if current_user.role == Role.SUPER_ADMIN["name"]:
|
||||||
|
return current_user
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="You don't have the necessary permissions to perform this action",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_company_name(company_id: int, db: Session = Depends(get_db)) -> str:
|
||||||
|
company = crud.company.get(db=db, id=company_id)
|
||||||
|
if not company:
|
||||||
|
raise HTTPException(status_code=404, detail="Company not found")
|
||||||
|
return company.name
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
from typing import Any
|
from typing import Any, Optional
|
||||||
from datetime import timedelta, datetime
|
from datetime import timedelta, datetime
|
||||||
|
|
||||||
from pydantic.networks import EmailStr
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from fastapi import APIRouter, Body, Depends, HTTPException, Security
|
from pydantic.networks import EmailStr
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from fastapi.encoders import jsonable_encoder
|
||||||
from fastapi.security import OAuth2PasswordRequestForm
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
|
from fastapi import APIRouter, Body, Depends, HTTPException, Security, Path, status
|
||||||
|
|
||||||
from private_gpt.users.api import deps
|
from private_gpt.users.api import deps
|
||||||
from private_gpt.users.core import security
|
from private_gpt.users.core import security
|
||||||
@ -16,6 +18,45 @@ from private_gpt.users.utils import send_registration_email
|
|||||||
|
|
||||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||||
|
|
||||||
|
def register_user(
|
||||||
|
db: Session,
|
||||||
|
email: str,
|
||||||
|
fullname: str,
|
||||||
|
password: str,
|
||||||
|
company: Optional[models.Company] = None,
|
||||||
|
) -> models.User:
|
||||||
|
"""
|
||||||
|
Register a new user in the database.
|
||||||
|
"""
|
||||||
|
print(f"{email} {fullname} {password} {company.id}")
|
||||||
|
user_in = schemas.UserCreate(email=email, password=password, fullname=fullname, company_id=company.id)
|
||||||
|
send_registration_email(fullname, email, password)
|
||||||
|
return crud.user.create(db, obj_in=user_in)
|
||||||
|
|
||||||
|
|
||||||
|
def create_user_role(
|
||||||
|
db: Session,
|
||||||
|
user: models.User,
|
||||||
|
role_name: str,
|
||||||
|
company: Optional[models.Company] = None,
|
||||||
|
) -> models.UserRole:
|
||||||
|
"""
|
||||||
|
Create a user role in the database.
|
||||||
|
"""
|
||||||
|
role = crud.role.get_by_name(db, name=role_name)
|
||||||
|
user_role_in = schemas.UserRoleCreate(user_id=user.id, role_id=role.id, company_id=company.id if company else None)
|
||||||
|
return crud.user_role.create(db, obj_in=user_role_in)
|
||||||
|
|
||||||
|
|
||||||
|
def create_token_payload(user: models.User, user_role: models.UserRole) -> dict:
|
||||||
|
"""
|
||||||
|
Create a token payload for authentication.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"id": str(user.id),
|
||||||
|
"role": user_role.role.name,
|
||||||
|
"company_id": user_role.company.id if user_role.company else None,
|
||||||
|
}
|
||||||
|
|
||||||
@router.post("/login", response_model=schemas.TokenSchema)
|
@router.post("/login", response_model=schemas.TokenSchema)
|
||||||
def login_access_token(
|
def login_access_token(
|
||||||
@ -32,8 +73,6 @@ def login_access_token(
|
|||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400, detail="Incorrect email or password"
|
status_code=400, detail="Incorrect email or password"
|
||||||
)
|
)
|
||||||
# elif not crud.user.is_active(user):
|
|
||||||
# raise HTTPException(status_code=400, detail="Inactive user")
|
|
||||||
|
|
||||||
access_token_expires = timedelta(
|
access_token_expires = timedelta(
|
||||||
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||||
@ -44,17 +83,22 @@ def login_access_token(
|
|||||||
user_in = schemas.UserUpdate(
|
user_in = schemas.UserUpdate(
|
||||||
email=user.email,
|
email=user.email,
|
||||||
fullname=user.fullname,
|
fullname=user.fullname,
|
||||||
|
company_id=user.company_id,
|
||||||
last_login=datetime.now()
|
last_login=datetime.now()
|
||||||
)
|
)
|
||||||
|
|
||||||
user = crud.user.update(db, db_obj=user, obj_in=user_in)
|
user = crud.user.update(db, db_obj=user, obj_in=user_in)
|
||||||
if not user.user_role:
|
|
||||||
role = "GUEST"
|
if user.user_role:
|
||||||
else:
|
|
||||||
role = user.user_role.role.name
|
role = user.user_role.role.name
|
||||||
|
if user.user_role.company_id:
|
||||||
|
company_id = user.user_role.company_id
|
||||||
|
else: company_id = None
|
||||||
|
|
||||||
token_payload = {
|
token_payload = {
|
||||||
"id": str(user.id),
|
"id": str(user.id),
|
||||||
"role": role,
|
"role": role,
|
||||||
|
"company_id": company_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@ -68,20 +112,21 @@ def login_access_token(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@router.post("/register", response_model=schemas.TokenSchema)
|
|
||||||
def register(
|
@router.post("/{company_name}/register", response_model=schemas.User)
|
||||||
|
def register_for_company(
|
||||||
*,
|
*,
|
||||||
db: Session = Depends(deps.get_db),
|
db: Session = Depends(deps.get_db),
|
||||||
email: EmailStr = Body(...),
|
email: str = Body(...),
|
||||||
fullname: str = Body(...),
|
fullname: str = Body(...),
|
||||||
role: str = Body(Default="GUEST"),
|
company_name: Optional[str] = Path(..., title="Company Name", description="Only for company admin"),
|
||||||
current_user: models.User = Security(
|
current_user: models.User = Security(
|
||||||
deps.get_current_user,
|
deps.get_current_user,
|
||||||
scopes=[Role.ADMIN["name"], Role.SUPER_ADMIN["name"]],
|
scopes=[Role.SUPER_ADMIN["name"], Role.ADMIN['name']],
|
||||||
),
|
),
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Register new user.
|
Register new user for a specific company.
|
||||||
"""
|
"""
|
||||||
user = crud.user.get_by_email(db, email=email)
|
user = crud.user.get_by_email(db, email=email)
|
||||||
if user:
|
if user:
|
||||||
@ -89,42 +134,74 @@ def register(
|
|||||||
status_code=409,
|
status_code=409,
|
||||||
detail="The user with this username already exists in the system",
|
detail="The user with this username already exists in the system",
|
||||||
)
|
)
|
||||||
random_password = security.generate_random_password()
|
|
||||||
user_in = schemas.UserCreate(
|
|
||||||
email=email,
|
|
||||||
password=random_password,
|
|
||||||
fullname=fullname,
|
|
||||||
)
|
|
||||||
user = crud.user.create(db, obj_in=user_in)
|
|
||||||
send_registration_email(fullname, email, random_password)
|
|
||||||
|
|
||||||
role_db = crud.role.get_by_name(db, name=role)
|
if current_user.user_role.role.name not in {Role.ADMIN["name"], Role.SUPER_ADMIN["name"]}:
|
||||||
if not role_db:
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404,
|
status_code=403,
|
||||||
detail=f"Role '{role}' not found",
|
detail="You do not have permission to register users for a company.",
|
||||||
)
|
)
|
||||||
user_role_in = schemas.UserRoleCreate(
|
|
||||||
user_id=user.id,
|
company = crud.company.get_by_company_name(db, company_name=company_name)
|
||||||
role_id=role_db.id
|
print(f"Company is : {company.id}")
|
||||||
|
if not (current_user.user_role.role.name == Role.ADMIN["name"] and current_user.user_role.company_id == company.id):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail="You are not the admin of the specified company.",
|
||||||
)
|
)
|
||||||
user_role = crud.user_role.create(db, obj_in=user_role_in)
|
|
||||||
access_token_expires = timedelta(
|
random_password = security.generate_random_password()
|
||||||
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
user = register_user(db, email, fullname, random_password, company)
|
||||||
|
user_role = create_user_role(db, user, Role.GUEST["name"], company)
|
||||||
|
|
||||||
|
token_payload = create_token_payload(user, user_role)
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_201_CREATED,
|
||||||
|
content={"message": "User registered successfully.\n\n Check respective user email for login credentials", "user": jsonable_encoder(user)},
|
||||||
)
|
)
|
||||||
refresh_token_expires = timedelta(
|
|
||||||
minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES
|
|
||||||
|
@router.post("/register", response_model=schemas.TokenSchema)
|
||||||
|
def register_without_company_assignment(
|
||||||
|
*,
|
||||||
|
db: Session = Depends(deps.get_db),
|
||||||
|
email: str = Body(...),
|
||||||
|
fullname: str = Body(...),
|
||||||
|
company_id: int = Body(None, title="Company ID", description="Company ID for the user (if applicable)"),
|
||||||
|
current_user: models.User = Security(
|
||||||
|
deps.get_current_user,
|
||||||
|
scopes=[Role.SUPER_ADMIN["name"], Role.ADMIN['name']],
|
||||||
|
),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Register new user with company assignment.
|
||||||
|
"""
|
||||||
|
user = crud.user.get_by_email(db, email=email)
|
||||||
|
if user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=409,
|
||||||
|
detail="The user with this username already exists in the system",
|
||||||
)
|
)
|
||||||
token_payload = {
|
|
||||||
"id": str(user.id),
|
if current_user.user_role.role.name != Role.SUPER_ADMIN["name"]:
|
||||||
"role": user_role.role.name,
|
raise HTTPException(
|
||||||
}
|
status_code=403,
|
||||||
|
detail="You do not have permission to register users without a company.",
|
||||||
|
)
|
||||||
|
|
||||||
|
if company_id is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Company ID is required for registering a user without a specific company.",
|
||||||
|
)
|
||||||
|
|
||||||
|
random_password = security.generate_random_password()
|
||||||
|
company = crud.company.get(db, company_id)
|
||||||
|
user = register_user(db, email, fullname, random_password, company)
|
||||||
|
user_role = create_user_role(db, user, Role.ADMIN["name"], company)
|
||||||
|
|
||||||
|
token_payload = create_token_payload(user, user_role)
|
||||||
return {
|
return {
|
||||||
"access_token": security.create_access_token(
|
"access_token": security.create_access_token(token_payload, expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)),
|
||||||
token_payload, expires_delta=access_token_expires
|
"refresh_token": security.create_refresh_token(token_payload, expires_delta=timedelta(minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES)),
|
||||||
),
|
|
||||||
"refresh_token": security.create_refresh_token(
|
|
||||||
token_payload, expires_delta=refresh_token_expires
|
|
||||||
),
|
|
||||||
"token_type": "bearer",
|
"token_type": "bearer",
|
||||||
}
|
}
|
@ -11,7 +11,6 @@ from private_gpt.users import crud, models, schemas
|
|||||||
|
|
||||||
router = APIRouter(prefix="/user-roles", tags=["user-roles"])
|
router = APIRouter(prefix="/user-roles", tags=["user-roles"])
|
||||||
|
|
||||||
|
|
||||||
@router.post("", response_model=schemas.UserRole)
|
@router.post("", response_model=schemas.UserRole)
|
||||||
def assign_user_role(
|
def assign_user_role(
|
||||||
*,
|
*,
|
||||||
@ -69,3 +68,10 @@ def update_user_role(
|
|||||||
status_code=status.HTTP_200_OK,
|
status_code=status.HTTP_200_OK,
|
||||||
content={"message": "User role updated successfully", "user_role": jsonable_encoder(user_role)},
|
content={"message": "User role updated successfully", "user_role": jsonable_encoder(user_role)},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
company_router = APIRouter(prefix="/user-roles", tags=["user-roles"])
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
from typing import Any, List
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from pydantic.networks import EmailStr
|
from pydantic.networks import EmailStr
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from fastapi.encoders import jsonable_encoder
|
from fastapi.encoders import jsonable_encoder
|
||||||
from fastapi import APIRouter, Body, Depends, HTTPException, Security, status
|
from fastapi import APIRouter, Body, Depends, HTTPException, Security, status, Path
|
||||||
|
|
||||||
from private_gpt.users.api import deps
|
from private_gpt.users.api import deps
|
||||||
from private_gpt.users.constants.role import Role
|
from private_gpt.users.constants.role import Role
|
||||||
@ -31,6 +31,23 @@ def read_users(
|
|||||||
return users
|
return users
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{company_name}")
|
||||||
|
def read_users_by_company(
|
||||||
|
company_name: Optional[str] = Path(..., title="Company Name", description="Only for company admin"),
|
||||||
|
db: Session = Depends(deps.get_db),
|
||||||
|
current_user: models.User = Security(
|
||||||
|
deps.get_current_user,
|
||||||
|
scopes=[Role.ADMIN["name"], Role.SUPER_ADMIN["name"]],
|
||||||
|
),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Retrieve all users of that company only
|
||||||
|
"""
|
||||||
|
company = crud.company.get_by_company_name(db, company_name=company_name)
|
||||||
|
users = crud.user.get_multi_by_company_id(db, company_id=company.id)
|
||||||
|
return users
|
||||||
|
|
||||||
|
|
||||||
@router.post("", response_model=schemas.User)
|
@router.post("", response_model=schemas.User)
|
||||||
def create_user(
|
def create_user(
|
||||||
*,
|
*,
|
||||||
@ -186,3 +203,4 @@ def update_user(
|
|||||||
status_code=status.HTTP_200_OK,
|
status_code=status.HTTP_200_OK,
|
||||||
content={"message": "User updated successfully", "user": jsonable_encoder(user_data)},
|
content={"message": "User updated successfully", "user": jsonable_encoder(user_data)},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -10,4 +10,7 @@ class CRUDCompany(CRUDBase[Company, CompanyCreate, CompanyUpdate]):
|
|||||||
def get_by_id(self, db: Session, *, id: str) -> Optional[Company]:
|
def get_by_id(self, db: Session, *, id: str) -> Optional[Company]:
|
||||||
return db.query(self.model).filter(Company.id == id).first()
|
return db.query(self.model).filter(Company.id == id).first()
|
||||||
|
|
||||||
|
def get_by_company_name(self, db: Session, *, company_name: str) -> Optional[Company]:
|
||||||
|
return db.query(self.model).filter(Company.name == company_name).first()
|
||||||
|
|
||||||
company = CRUDCompany(Company)
|
company = CRUDCompany(Company)
|
@ -73,5 +73,15 @@ class CRUDUser(CRUDBase[User, UserCreate, UserUpdate]):
|
|||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_multi_by_company_id(
|
||||||
|
self, db: Session, *, company_id: str, skip: int = 0, limit: int = 100
|
||||||
|
) -> List[User]:
|
||||||
|
return (
|
||||||
|
db.query(self.model)
|
||||||
|
.filter(User.company_id == company_id)
|
||||||
|
.offset(skip)
|
||||||
|
.limit(limit)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
user = CRUDUser(User)
|
user = CRUDUser(User)
|
@ -1 +1,4 @@
|
|||||||
from .user import User
|
from .user import User
|
||||||
|
from .company import Company
|
||||||
|
from .user_role import UserRole
|
||||||
|
from .role import Role
|
@ -1,6 +1,8 @@
|
|||||||
|
from typing import List
|
||||||
from sqlalchemy import Column, Integer, String
|
from sqlalchemy import Column, Integer, String
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
from private_gpt.users.db.base_class import Base
|
from private_gpt.users.db.base_class import Base
|
||||||
|
from private_gpt.users.schemas.user import User
|
||||||
|
|
||||||
class Company(Base):
|
class Company(Base):
|
||||||
"""Models a Company table."""
|
"""Models a Company table."""
|
||||||
@ -8,7 +10,8 @@ class Company(Base):
|
|||||||
__tablename__ = "companies"
|
__tablename__ = "companies"
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, index=True)
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
name = Column(String, index=True)
|
name = Column(String, index=True, unique=True)
|
||||||
|
|
||||||
subscriptions = relationship("Subscription", back_populates="company")
|
subscriptions = relationship("Subscription", back_populates="company")
|
||||||
|
users = relationship("User", back_populates="company")
|
||||||
|
user_roles = relationship("UserRole", back_populates="company")
|
@ -1,6 +1,5 @@
|
|||||||
import datetime
|
import datetime
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
LargeBinary,
|
|
||||||
Column,
|
Column,
|
||||||
String,
|
String,
|
||||||
Integer,
|
Integer,
|
||||||
@ -35,10 +34,10 @@ class User(Base):
|
|||||||
onupdate=datetime.datetime.utcnow,
|
onupdate=datetime.datetime.utcnow,
|
||||||
)
|
)
|
||||||
|
|
||||||
# account_id = Column(Integer, ForeignKey("accounts.id"), nullable=True)
|
company_id = Column(Integer, ForeignKey("companies.id"), nullable=True)
|
||||||
|
company = relationship("Company", back_populates="users")
|
||||||
|
|
||||||
user_role = relationship("UserRole", back_populates="user", uselist=False)
|
user_role = relationship("UserRole", back_populates="user", uselist=False)
|
||||||
# account = relationship("Account", back_populates="users")
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
"""Returns string representation of model instance"""
|
"""Returns string representation of model instance"""
|
||||||
|
@ -16,10 +16,16 @@ class UserRole(Base):
|
|||||||
primary_key=True,
|
primary_key=True,
|
||||||
nullable=False,
|
nullable=False,
|
||||||
)
|
)
|
||||||
|
company_id = Column(
|
||||||
|
Integer,
|
||||||
|
ForeignKey("companies.id"),
|
||||||
|
primary_key=True,
|
||||||
|
nullable=True,
|
||||||
|
)
|
||||||
role = relationship("Role")
|
role = relationship("Role")
|
||||||
user = relationship("User", back_populates="user_role", uselist=False)
|
user = relationship("User", back_populates="user_role", uselist=False)
|
||||||
|
company = relationship("Company", back_populates="user_roles")
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
UniqueConstraint("user_id", "role_id", name="unique_user_role"),
|
UniqueConstraint("user_id", "role_id", "company_id", name="unique_user_role"),
|
||||||
)
|
)
|
@ -11,6 +11,7 @@ class TokenSchema(BaseModel):
|
|||||||
class TokenPayload(BaseModel):
|
class TokenPayload(BaseModel):
|
||||||
id: int
|
id: int
|
||||||
role: str = None
|
role: str = None
|
||||||
|
company: str = None
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
@ -4,11 +4,12 @@ from typing import Optional
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field, EmailStr
|
from pydantic import BaseModel, Field, EmailStr
|
||||||
from private_gpt.users.schemas.user_role import UserRole
|
from private_gpt.users.schemas.user_role import UserRole
|
||||||
|
from private_gpt.users.schemas.company import Company
|
||||||
|
|
||||||
class UserBaseSchema(BaseModel):
|
class UserBaseSchema(BaseModel):
|
||||||
email: EmailStr
|
email: EmailStr
|
||||||
fullname: str
|
fullname: str
|
||||||
|
company_id: Optional[int]
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
@ -7,6 +7,7 @@ from pydantic import BaseModel
|
|||||||
class UserRoleBase(BaseModel):
|
class UserRoleBase(BaseModel):
|
||||||
user_id: Optional[int]
|
user_id: Optional[int]
|
||||||
role_id: Optional[int]
|
role_id: Optional[int]
|
||||||
|
company_id: Optional[int]
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
@ -31,7 +32,6 @@ class UserRoleInDBBase(UserRoleBase):
|
|||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Additional properties to return via API
|
# Additional properties to return via API
|
||||||
class UserRole(UserRoleInDBBase):
|
class UserRole(UserRoleInDBBase):
|
||||||
pass
|
pass
|
||||||
|
Loading…
Reference in New Issue
Block a user