mirror of
https://github.com/imartinez/privateGPT.git
synced 2025-06-29 08:47:19 +00:00
Added api for deleting user, fixed bug in register user and refresh token
This commit is contained in:
parent
71999eb150
commit
17a7ded46b
@ -40,7 +40,7 @@ def get_db() -> Generator:
|
|||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
async def get_current_user(
|
def get_current_user(
|
||||||
security_scopes: SecurityScopes,
|
security_scopes: SecurityScopes,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
token: str = Depends(reusable_oauth2)
|
token: str = Depends(reusable_oauth2)
|
||||||
@ -62,7 +62,6 @@ async def get_current_user(
|
|||||||
)
|
)
|
||||||
if payload.get("id") is None:
|
if payload.get("id") is None:
|
||||||
raise credentials_exception
|
raise credentials_exception
|
||||||
print(payload)
|
|
||||||
token_data = schemas.TokenPayload(**payload)
|
token_data = schemas.TokenPayload(**payload)
|
||||||
except (jwt.JWTError, ValidationError):
|
except (jwt.JWTError, ValidationError):
|
||||||
logger.error("Error Decoding Token", exc_info=True)
|
logger.error("Error Decoding Token", exc_info=True)
|
||||||
|
@ -114,8 +114,10 @@ def login_access_token(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/login/refresh-token", response_model=schemas.TokenSchema)
|
@router.post("/login/refresh-token", response_model=schemas.TokenSchema)
|
||||||
def refresh_access_token(db: Session = Depends(deps.get_db), form_data: OAuth2PasswordRequestForm = Depends()) -> Any:
|
def refresh_access_token(
|
||||||
refresh_token = form_data.refresh_token
|
db: Session = Depends(deps.get_db),
|
||||||
|
refresh_token: str = Body(..., embed=True),
|
||||||
|
) -> Any:
|
||||||
token_payload = security.verify_refresh_token(refresh_token)
|
token_payload = security.verify_refresh_token(refresh_token)
|
||||||
|
|
||||||
if not token_payload:
|
if not token_payload:
|
||||||
@ -131,12 +133,11 @@ def refresh_access_token(db: Session = Depends(deps.get_db), form_data: OAuth2Pa
|
|||||||
"refresh_token": security.create_refresh_token(token_payload, expires_delta=refresh_token_expires),
|
"refresh_token": security.create_refresh_token(token_payload, expires_delta=refresh_token_expires),
|
||||||
"token_type": "bearer",
|
"token_type": "bearer",
|
||||||
}
|
}
|
||||||
|
|
||||||
return JSONResponse(content=response_dict)
|
return JSONResponse(content=response_dict)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/register", response_model=schemas.TokenSchema)
|
@router.post("/register", response_model=schemas.TokenSchema)
|
||||||
def register_user(
|
def register(
|
||||||
*,
|
*,
|
||||||
db: Session = Depends(deps.get_db),
|
db: Session = Depends(deps.get_db),
|
||||||
email: str = Body(...),
|
email: str = Body(...),
|
||||||
@ -153,14 +154,15 @@ def register_user(
|
|||||||
"""
|
"""
|
||||||
Register new user with optional company assignment and role selection.
|
Register new user with optional company assignment and role selection.
|
||||||
"""
|
"""
|
||||||
user = crud.user.get_by_email(db, email=email)
|
existing_user = crud.user.get_by_email(db, email=email)
|
||||||
if user:
|
if existing_user:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=409,
|
status_code=409,
|
||||||
detail="The user with this email already exists in the system",
|
detail="The user with this email already exists in the system",
|
||||||
)
|
)
|
||||||
|
random_password = security.generate_random_password()
|
||||||
|
|
||||||
if company_id is not None:
|
if company_id:
|
||||||
# Registering user with a specific company
|
# Registering user with a specific company
|
||||||
company = crud.company.get(db, company_id)
|
company = crud.company.get(db, company_id)
|
||||||
if not company:
|
if not company:
|
||||||
@ -168,41 +170,28 @@ def register_user(
|
|||||||
status_code=404,
|
status_code=404,
|
||||||
detail="Company not found.",
|
detail="Company not found.",
|
||||||
)
|
)
|
||||||
|
|
||||||
if current_user.user_role.role.name not in {Role.SUPER_ADMIN["name"], Role.ADMIN["name"]}:
|
if current_user.user_role.role.name not in {Role.SUPER_ADMIN["name"], Role.ADMIN["name"]}:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=403,
|
status_code=403,
|
||||||
detail="You do not have permission to register users for a specific company.",
|
detail="You do not have permission to register users for a specific company.",
|
||||||
)
|
)
|
||||||
|
user = register_user(db, email, fullname, random_password, company)
|
||||||
user_role_name = role_name or Role.GUEST["name"]
|
user_role_name = role_name or Role.GUEST["name"]
|
||||||
if user_role_name == Role.SUPER_USER["name"]:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=403,
|
|
||||||
detail="Cannot create a user with SUPER_USER role.",
|
|
||||||
)
|
|
||||||
|
|
||||||
user_role = create_user_role(db, user, user_role_name, company)
|
user_role = create_user_role(db, user, user_role_name, company)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Registering user without a specific company
|
|
||||||
if current_user.user_role.role.name != Role.SUPER_ADMIN["name"]:
|
if current_user.user_role.role.name != Role.SUPER_ADMIN["name"]:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=403,
|
status_code=403,
|
||||||
detail="You do not have permission to register users without a company.",
|
detail="You do not have permission to register users without a company.",
|
||||||
)
|
)
|
||||||
|
user = register_user(db, email, fullname, random_password, None)
|
||||||
user_role_name = role_name or Role.ADMIN["name"]
|
user_role_name = role_name or Role.ADMIN["name"]
|
||||||
if user_role_name == Role.SUPER_USER["name"]:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=403,
|
|
||||||
detail="Cannot create a user with SUPER_USER role.",
|
|
||||||
)
|
|
||||||
|
|
||||||
user_role = create_user_role(db, user, user_role_name, None)
|
user_role = create_user_role(db, user, user_role_name, None)
|
||||||
|
|
||||||
random_password = security.generate_random_password()
|
print("USER REGISTERED: ", user.email, user.fullname, user.company_id)
|
||||||
user = register_user(db, email, fullname, random_password, company)
|
print("USER ROLE REGISTERED: ", user_role.user.email,
|
||||||
|
user_role.role.name, user_role.company_id)
|
||||||
|
|
||||||
token_payload = create_token_payload(user, user_role)
|
token_payload = create_token_payload(user, user_role)
|
||||||
response_dict = {
|
response_dict = {
|
||||||
|
@ -215,6 +215,7 @@ def home_page(
|
|||||||
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": "Welcome to QuickGPT"})
|
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": "Welcome to QuickGPT"})
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@router.patch("/{user_id}/change-password", response_model=schemas.User)
|
@router.patch("/{user_id}/change-password", response_model=schemas.User)
|
||||||
def admin_change_password(
|
def admin_change_password(
|
||||||
*,
|
*,
|
||||||
@ -251,3 +252,26 @@ def admin_change_password(
|
|||||||
content={"message": "User password changed successfully",
|
content={"message": "User password changed successfully",
|
||||||
"user": jsonable_encoder(user_data)},
|
"user": jsonable_encoder(user_data)},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{user_id}")
|
||||||
|
def delete_user(
|
||||||
|
*,
|
||||||
|
db: Session = Depends(deps.get_db),
|
||||||
|
user_id: int,
|
||||||
|
current_user: models.User = Security(
|
||||||
|
deps.get_current_user,
|
||||||
|
scopes=[Role.ADMIN["name"], Role.SUPER_ADMIN["name"]],
|
||||||
|
),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Delete a user by ID.
|
||||||
|
"""
|
||||||
|
user = crud.user.get(db, id=user_id)
|
||||||
|
if user is None:
|
||||||
|
raise HTTPException(status_code=404, detail="User not found")
|
||||||
|
crud.user.remove(db, id=user_id)
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_200_OK,
|
||||||
|
content={"message": "User deleted successfully"},
|
||||||
|
)
|
||||||
|
@ -50,7 +50,7 @@ def generate_random_password(length: int = 12) -> str:
|
|||||||
"""
|
"""
|
||||||
Generate a random password.
|
Generate a random password.
|
||||||
"""
|
"""
|
||||||
characters = string.ascii_letters + string.digits + string.punctuation
|
characters = string.ascii_letters + string.digits
|
||||||
return ''.join(random.choice(characters) for i in range(length))
|
return ''.join(random.choice(characters) for i in range(length))
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,5 +12,9 @@ class CRUDUserRole(CRUDBase[UserRole, UserRoleCreate, UserRoleUpdate]):
|
|||||||
) -> Optional[UserRole]:
|
) -> Optional[UserRole]:
|
||||||
return db.query(UserRole).filter(UserRole.user_id == user_id).first()
|
return db.query(UserRole).filter(UserRole.user_id == user_id).first()
|
||||||
|
|
||||||
|
def remove_user(
|
||||||
|
self, db: Session, *, user_id: int
|
||||||
|
)-> Optional[UserRole]:
|
||||||
|
return db.query(UserRole).filter(UserRole.user_id == user_id).delete()
|
||||||
|
|
||||||
user_role = CRUDUserRole(UserRole)
|
user_role = CRUDUserRole(UserRole)
|
@ -37,7 +37,8 @@ class User(Base):
|
|||||||
company_id = Column(Integer, ForeignKey("companies.id"), nullable=True)
|
company_id = Column(Integer, ForeignKey("companies.id"), nullable=True)
|
||||||
company = relationship("Company", back_populates="users")
|
company = relationship("Company", back_populates="users")
|
||||||
|
|
||||||
user_role = relationship("UserRole", back_populates="user", uselist=False)
|
user_role = relationship(
|
||||||
|
"UserRole", back_populates="user", uselist=False, cascade="all, delete-orphan")
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
"""Returns string representation of model instance"""
|
"""Returns string representation of model instance"""
|
||||||
|
Loading…
Reference in New Issue
Block a user