mirror of
https://github.com/imartinez/privateGPT.git
synced 2025-07-16 16:32:20 +00:00
Added subscriptions module and registering users with email
This commit is contained in:
parent
4c5fed7ea4
commit
8b396ac6fb
7
.env
7
.env
@ -13,3 +13,10 @@ SUPER_ADMIN_ACCOUNT_NAME=superaccount
|
||||
SECRET_KEY=ba9dc3f976cf8fb40519dcd152a8d7d21c0b7861d841711cdb2602be8e85fd7c
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=60
|
||||
REFRESH_TOKEN_EXPIRE_MINUTES = 120 # 7 days
|
||||
|
||||
|
||||
SMTP_SERVER=smtp-mail.outlook.com
|
||||
SMTP_PORT=587
|
||||
SMTP_SENDER_EMAIL=saurabstha7@outlook.com
|
||||
SMTP_USERNAME=saurabstha7@outlook.com
|
||||
SMTP_PASSWORD=avantador123
|
@ -1,52 +0,0 @@
|
||||
"""Add Subscription and Company model
|
||||
|
||||
Revision ID: 0e0eb0a1a514
|
||||
Revises: 65688535c5a5
|
||||
Create Date: 2024-01-17 15:53:13.091801
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '0e0eb0a1a514'
|
||||
down_revision: Union[str, None] = '65688535c5a5'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('companies',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('name', sa.String(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_companies_id'), 'companies', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_companies_name'), 'companies', ['name'], unique=False)
|
||||
op.create_table('subscriptions',
|
||||
sa.Column('sub_id', sa.Integer(), nullable=False),
|
||||
sa.Column('company_id', sa.Integer(), nullable=True),
|
||||
sa.Column('start_date', sa.DateTime(), nullable=True),
|
||||
sa.Column('end_date', sa.DateTime(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['company_id'], ['companies.id'], ),
|
||||
sa.PrimaryKeyConstraint('sub_id')
|
||||
)
|
||||
op.create_index(op.f('ix_subscriptions_sub_id'), 'subscriptions', ['sub_id'], unique=False)
|
||||
# op.create_unique_constraint('unique_user_role', 'user_roles', ['user_id', 'role_id'])
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_constraint('unique_user_role', 'user_roles', type_='unique')
|
||||
op.drop_index(op.f('ix_subscriptions_sub_id'), table_name='subscriptions')
|
||||
op.drop_table('subscriptions')
|
||||
op.drop_index(op.f('ix_companies_name'), table_name='companies')
|
||||
op.drop_index(op.f('ix_companies_id'), table_name='companies')
|
||||
op.drop_table('companies')
|
||||
# ### end Alembic commands ###
|
@ -1,30 +0,0 @@
|
||||
"""Create user model
|
||||
|
||||
Revision ID: 3cd055fe81a3
|
||||
Revises:
|
||||
Create Date: 2024-01-14 10:44:37.040428
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '3cd055fe81a3'
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
pass
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
pass
|
||||
# ### end Alembic commands ###
|
@ -1,31 +0,0 @@
|
||||
"""Add Subscription and Company model:
|
||||
|
||||
|
||||
Revision ID: 65688535c5a5
|
||||
Revises: cba9e6e394ca
|
||||
Create Date: 2024-01-17 15:45:28.636265
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '65688535c5a5'
|
||||
down_revision: Union[str, None] = 'cba9e6e394ca'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
pass
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
pass
|
||||
# ### end Alembic commands ###
|
@ -1,8 +1,8 @@
|
||||
"""Create user model and role
|
||||
"""Create user, roles, user roles, subscription and company model
|
||||
|
||||
Revision ID: cba9e6e394ca
|
||||
Revises: 3cd055fe81a3
|
||||
Create Date: 2024-01-14 10:46:33.847333
|
||||
Revision ID: 6f3cc13e1339
|
||||
Revises:
|
||||
Create Date: 2024-01-18 12:33:39.002575
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
@ -12,14 +12,21 @@ import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'cba9e6e394ca'
|
||||
down_revision: Union[str, None] = '3cd055fe81a3'
|
||||
revision: str = '6f3cc13e1339'
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('companies',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('name', sa.String(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_companies_id'), 'companies', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_companies_name'), 'companies', ['name'], unique=False)
|
||||
op.create_table('roles',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('name', sa.String(length=100), nullable=True),
|
||||
@ -40,6 +47,15 @@ def upgrade() -> None:
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('email')
|
||||
)
|
||||
op.create_table('subscriptions',
|
||||
sa.Column('sub_id', sa.Integer(), nullable=False),
|
||||
sa.Column('company_id', sa.Integer(), nullable=True),
|
||||
sa.Column('start_date', sa.DateTime(), nullable=True),
|
||||
sa.Column('end_date', sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['company_id'], ['companies.id'], ),
|
||||
sa.PrimaryKeyConstraint('sub_id')
|
||||
)
|
||||
op.create_index(op.f('ix_subscriptions_sub_id'), 'subscriptions', ['sub_id'], unique=False)
|
||||
op.create_table('user_roles',
|
||||
sa.Column('user_id', sa.Integer(), nullable=False),
|
||||
sa.Column('role_id', sa.Integer(), nullable=False),
|
||||
@ -54,8 +70,13 @@ def upgrade() -> None:
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table('user_roles')
|
||||
op.drop_index(op.f('ix_subscriptions_sub_id'), table_name='subscriptions')
|
||||
op.drop_table('subscriptions')
|
||||
op.drop_table('users')
|
||||
op.drop_index(op.f('ix_roles_name'), table_name='roles')
|
||||
op.drop_index(op.f('ix_roles_id'), table_name='roles')
|
||||
op.drop_table('roles')
|
||||
op.drop_index(op.f('ix_companies_name'), table_name='companies')
|
||||
op.drop_index(op.f('ix_companies_id'), table_name='companies')
|
||||
op.drop_table('companies')
|
||||
# ### end Alembic commands ###
|
@ -1,15 +1,18 @@
|
||||
from datetime import timedelta, datetime
|
||||
from typing import Any
|
||||
from datetime import timedelta, datetime
|
||||
|
||||
from pydantic.networks import EmailStr
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Security
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
|
||||
from private_gpt.users import crud, models, schemas
|
||||
from private_gpt.users.api import deps
|
||||
from private_gpt.users.core import security
|
||||
from private_gpt.users.constants.role import Role
|
||||
from private_gpt.users.core.config import settings
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from pydantic.networks import EmailStr
|
||||
from sqlalchemy.orm import Session
|
||||
from private_gpt.users import crud, models, schemas
|
||||
from private_gpt.users.utils import send_registration_email
|
||||
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
@ -25,7 +28,6 @@ def login_access_token(
|
||||
user = crud.user.authenticate(
|
||||
db, email=form_data.username, password=form_data.password
|
||||
)
|
||||
print("USER object", user)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Incorrect email or password"
|
||||
@ -39,15 +41,12 @@ def login_access_token(
|
||||
refresh_token_expires = timedelta(
|
||||
minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES
|
||||
)
|
||||
print(f"Access Token expires: {access_token_expires}\n Refresh token expires: {refresh_token_expires}")
|
||||
user_in = schemas.UserUpdate(
|
||||
email = user.email,
|
||||
fullname = user.fullname,
|
||||
last_login=datetime.now()
|
||||
)
|
||||
print("Update last login schema: ", user_in)
|
||||
user = crud.user.update(db, db_obj=user, obj_in=user_in)
|
||||
print("update in database:", user)
|
||||
if not user.user_role:
|
||||
role = "GUEST"
|
||||
else:
|
||||
@ -74,8 +73,12 @@ def register(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
email: EmailStr = Body(...),
|
||||
password: str = Body(...),
|
||||
fullname: str = Body(...),
|
||||
role: str = Body(Default="GUEST"),
|
||||
current_user: models.User = Security(
|
||||
deps.get_current_user,
|
||||
scopes=[Role.ADMIN["name"], Role.SUPER_ADMIN["name"]],
|
||||
),
|
||||
) -> Any:
|
||||
"""
|
||||
Register new user.
|
||||
@ -86,41 +89,35 @@ def register(
|
||||
status_code=409,
|
||||
detail="The user with this username already exists in the system",
|
||||
)
|
||||
random_password = security.generate_random_password()
|
||||
user_in = schemas.UserCreate(
|
||||
email=email,
|
||||
password=password,
|
||||
password=random_password,
|
||||
fullname=fullname,
|
||||
)
|
||||
|
||||
# create user
|
||||
user = crud.user.create(db, obj_in=user_in)
|
||||
send_registration_email(fullname, email, random_password)
|
||||
|
||||
# get role
|
||||
role = crud.role.get_by_name(db, name=Role.SUPER_ADMIN["name"])
|
||||
print("ROLE:", role)
|
||||
# assign user_role
|
||||
role_db = crud.role.get_by_name(db, name=role)
|
||||
if not role_db:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Role '{role}' not found",
|
||||
)
|
||||
user_role_in = schemas.UserRoleCreate(
|
||||
user_id=user.id,
|
||||
role_id=role.id
|
||||
role_id=role_db.id
|
||||
)
|
||||
user_role = crud.user_role.create(db, obj_in=user_role_in)
|
||||
|
||||
print(user)
|
||||
|
||||
access_token_expires = timedelta(
|
||||
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
)
|
||||
refresh_token_expires = timedelta(
|
||||
minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES
|
||||
)
|
||||
if not user.user_role:
|
||||
role = "GUEST"
|
||||
else:
|
||||
role = user.user_role.role.name
|
||||
|
||||
token_payload = {
|
||||
"id": str(user.id),
|
||||
"role": role,
|
||||
"role": user_role.role.name,
|
||||
}
|
||||
return {
|
||||
"access_token": security.create_access_token(
|
||||
|
@ -1,20 +1,27 @@
|
||||
from typing import Any, List
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from private_gpt.users import crud, models, schemas
|
||||
from private_gpt.users.constants.role import Role
|
||||
from private_gpt.users.api import deps
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Security
|
||||
|
||||
from private_gpt.users.api import deps
|
||||
from private_gpt.users.constants.role import Role
|
||||
from private_gpt.users import crud, models, schemas
|
||||
|
||||
|
||||
router = APIRouter(prefix="/companies", tags=["Companies"])
|
||||
|
||||
|
||||
@router.get("", response_model=List[schemas.Company])
|
||||
def list_companies(
|
||||
db: Session = Depends(deps.get_db),
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
current_user: models.User = Security(
|
||||
deps.get_current_user,
|
||||
scopes=[Role.SUPER_ADMIN["name"]],
|
||||
),
|
||||
) -> List[schemas.Company]:
|
||||
"""
|
||||
List companies
|
||||
@ -27,18 +34,34 @@ def list_companies(
|
||||
def create_company(
|
||||
company_in: schemas.CompanyCreate,
|
||||
db: Session = Depends(deps.get_db),
|
||||
current_user: models.User = Security(
|
||||
deps.get_current_user,
|
||||
scopes=[Role.SUPER_ADMIN["name"]],
|
||||
),
|
||||
) -> schemas.Company:
|
||||
"""
|
||||
Create a new company
|
||||
"""
|
||||
company = crud.company.create(db=db, obj_in=company_in)
|
||||
return company
|
||||
company = jsonable_encoder(company)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
content={
|
||||
"message": "Company created successfully",
|
||||
"subscription": company
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{company_id}", response_model=schemas.Company)
|
||||
def read_company(
|
||||
company_id: int,
|
||||
db: Session = Depends(deps.get_db),
|
||||
current_user: models.User = Security(
|
||||
deps.get_current_user,
|
||||
scopes=[Role.SUPER_ADMIN["name"]],
|
||||
),
|
||||
) -> schemas.Company:
|
||||
"""
|
||||
Read a company by ID
|
||||
@ -54,6 +77,10 @@ def update_company(
|
||||
company_id: int,
|
||||
company_in: schemas.CompanyUpdate,
|
||||
db: Session = Depends(deps.get_db),
|
||||
current_user: models.User = Security(
|
||||
deps.get_current_user,
|
||||
scopes=[Role.SUPER_ADMIN["name"]],
|
||||
),
|
||||
) -> schemas.Company:
|
||||
"""
|
||||
Update a company by ID
|
||||
@ -72,10 +99,15 @@ def update_company(
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{company_id}", response_model=schemas.Company)
|
||||
def delete_company(
|
||||
company_id: int,
|
||||
db: Session = Depends(deps.get_db),
|
||||
current_user: models.User = Security(
|
||||
deps.get_current_user,
|
||||
scopes=[Role.SUPER_ADMIN["name"]],
|
||||
),
|
||||
) -> schemas.Company:
|
||||
"""
|
||||
Delete a company by ID
|
||||
|
@ -1,9 +1,12 @@
|
||||
from typing import Any, List
|
||||
|
||||
from private_gpt.users import crud, schemas
|
||||
from private_gpt.users.api import deps
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi import APIRouter, Depends, status, Security
|
||||
|
||||
from private_gpt.users.api import deps
|
||||
from private_gpt.users.constants.role import Role
|
||||
from private_gpt.users import crud, schemas, models
|
||||
|
||||
|
||||
router = APIRouter(prefix='/roles', tags=['roles'])
|
||||
@ -12,9 +15,19 @@ router = APIRouter(prefix='/roles', tags=['roles'])
|
||||
@router.get("/", response_model=List[schemas.Role])
|
||||
def get_roles(
|
||||
db: Session = Depends(deps.get_db), skip: int = 0, limit: int = 100,
|
||||
current_user: models.User = Security(
|
||||
deps.get_current_user,
|
||||
scopes=[Role.SUPER_ADMIN["name"]],
|
||||
),
|
||||
) -> Any:
|
||||
"""
|
||||
Retrieve all available user roles.
|
||||
"""
|
||||
roles = crud.role.get_multi(db, skip=skip, limit=limit)
|
||||
return roles
|
||||
|
||||
roles_data = [{"id": role.id, "name": role.name} for role in roles]
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content={"message": "Roles retrieved successfully", "roles": roles_data},
|
||||
)
|
@ -1,47 +1,63 @@
|
||||
from typing import Any, List
|
||||
from private_gpt.users import crud, models, schemas
|
||||
from private_gpt.users.api import deps
|
||||
from private_gpt.users.constants.role import Role
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Security,status
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from pydantic.networks import EmailStr
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from pydantic.networks import EmailStr
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Security, status
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
router = APIRouter(prefix="/subscriptions", tags=["Subscriptions"])
|
||||
from private_gpt.users.api import deps
|
||||
from private_gpt.users.constants.role import Role
|
||||
from private_gpt.users import crud, models, schemas
|
||||
|
||||
|
||||
router = APIRouter(prefix="/subscriptions", tags=["Subscriptions"])
|
||||
|
||||
@router.post("/create", response_model=schemas.Subscription)
|
||||
def create_subscription(
|
||||
subscription_in: schemas.SubscriptionCreate,
|
||||
db: Session = Depends(deps.get_db),
|
||||
current_user: models.User = Security(
|
||||
deps.get_current_user,
|
||||
scopes=[Role.SUPER_ADMIN["name"]],
|
||||
),
|
||||
) -> Any:
|
||||
"""
|
||||
Create a new subscription
|
||||
"""
|
||||
existing_subscription = crud.subscription.get_by_company_id(db, company_id=subscription_in.company_id)
|
||||
if existing_subscription:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Company is already subscribed to a plan.",
|
||||
)
|
||||
|
||||
subscription = crud.subscription.create(db=db, obj_in=subscription_in)
|
||||
subscription_dict = jsonable_encoder(subscription)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
content={
|
||||
"message": "Subscription created successfully",
|
||||
"subscription": subscription_dict
|
||||
},
|
||||
active_subscription = crud.subscription.get_active_subscription_by_company(
|
||||
db=db, company_id=subscription_in.company_id
|
||||
)
|
||||
|
||||
if active_subscription:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content={
|
||||
"message": "Active subscription found",
|
||||
"subscription": jsonable_encoder(active_subscription),
|
||||
},
|
||||
)
|
||||
else:
|
||||
subscription = crud.subscription.create(db=db, obj_in=subscription_in)
|
||||
subscription_dict = jsonable_encoder(subscription)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
content={
|
||||
"message": "Subscription created successfully",
|
||||
"subscription": subscription_dict,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{subscription_id}", response_model=schemas.Subscription)
|
||||
def read_subscription(
|
||||
subscription_id: int,
|
||||
db: Session = Depends(deps.get_db),
|
||||
current_user: models.User = Security(
|
||||
deps.get_current_user,
|
||||
scopes=[Role.SUPER_ADMIN["name"]],
|
||||
),
|
||||
):
|
||||
subscription = crud.subscription.get_by_id(db, subscription_id=subscription_id)
|
||||
if subscription is None:
|
||||
@ -61,6 +77,10 @@ def read_subscription(
|
||||
def read_subscriptions_by_company(
|
||||
company_id: int,
|
||||
db: Session = Depends(deps.get_db),
|
||||
current_user: models.User = Security(
|
||||
deps.get_current_user,
|
||||
scopes=[Role.SUPER_ADMIN["name"]],
|
||||
),
|
||||
):
|
||||
subscriptions = crud.subscription.get_by_company_id(db, company_id=company_id)
|
||||
subscriptions_list = [jsonable_encoder(subscription) for subscription in subscriptions]
|
||||
@ -79,6 +99,10 @@ def update_subscription(
|
||||
subscription_id: int,
|
||||
subscription_in: schemas.SubscriptionUpdate,
|
||||
db: Session = Depends(deps.get_db),
|
||||
current_user: models.User = Security(
|
||||
deps.get_current_user,
|
||||
scopes=[Role.SUPER_ADMIN["name"]],
|
||||
),
|
||||
):
|
||||
subscription = crud.subscription.get_by_id(db, subscription_id=subscription_id)
|
||||
if subscription is None:
|
||||
@ -105,11 +129,17 @@ def update_subscription(
|
||||
def delete_subscription(
|
||||
subscription_id: int,
|
||||
db: Session = Depends(deps.get_db),
|
||||
current_user: models.User = Security(
|
||||
deps.get_current_user,
|
||||
scopes=[Role.SUPER_ADMIN["name"]],
|
||||
),
|
||||
):
|
||||
subscription = crud.subscription.remove(db=db, id=subscription_id)
|
||||
if subscription is None:
|
||||
raise HTTPException(status_code=404, detail="Subscription not found")
|
||||
return JSONResponse(
|
||||
content={"message": "Subscription deleted successfully"},
|
||||
status_code=status.HTTP_200_OK
|
||||
status_code=status.HTTP_200_OK,
|
||||
content={
|
||||
"message": "Subscription deleted successfully"
|
||||
}
|
||||
)
|
@ -1,10 +1,13 @@
|
||||
from typing import Any
|
||||
|
||||
from private_gpt.users import crud, models, schemas
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi import APIRouter, Depends, HTTPException, Security, status
|
||||
|
||||
from private_gpt.users.api import deps
|
||||
from private_gpt.users.constants.role import Role
|
||||
from fastapi import APIRouter, Depends, HTTPException, Security
|
||||
from sqlalchemy.orm import Session
|
||||
from private_gpt.users import crud, models, schemas
|
||||
|
||||
router = APIRouter(prefix="/user-roles", tags=["user-roles"])
|
||||
|
||||
@ -14,7 +17,12 @@ def assign_user_role(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
user_role_in: schemas.UserRoleCreate,
|
||||
current_user: models.User = Depends(deps.get_current_user),
|
||||
current_user: models.User = Security(
|
||||
deps.get_current_user,
|
||||
scopes=[
|
||||
Role.SUPER_ADMIN["name"],
|
||||
],
|
||||
),
|
||||
) -> Any:
|
||||
"""
|
||||
Assign a role to a user after creation of a user
|
||||
@ -26,7 +34,10 @@ def assign_user_role(
|
||||
detail="This user has already been assigned a role.",
|
||||
)
|
||||
user_role = crud.user_role.create(db, obj_in=user_role_in)
|
||||
return user_role
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
content={"message": "User role assigned successfully", "user_role": jsonable_encoder(user_role)},
|
||||
)
|
||||
|
||||
|
||||
@router.put("/{user_id}", response_model=schemas.UserRole)
|
||||
@ -38,20 +49,23 @@ def update_user_role(
|
||||
current_user: models.User = Security(
|
||||
deps.get_current_user,
|
||||
scopes=[
|
||||
Role.ADMIN["name"],
|
||||
Role.SUPER_ADMIN["name"],
|
||||
],
|
||||
),
|
||||
) -> Any:
|
||||
"""
|
||||
Update a users role.
|
||||
Update a user's role.
|
||||
"""
|
||||
user_role = crud.user_role.get_by_user_id(db, user_id=user_id)
|
||||
if not user_role:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="There is no role assigned to this user",
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="There is no role assigned to this user",
|
||||
)
|
||||
user_role = crud.user_role.update(
|
||||
db, db_obj=user_role, obj_in=user_role_in
|
||||
)
|
||||
return user_role
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content={"message": "User role updated successfully", "user_role": jsonable_encoder(user_role)},
|
||||
)
|
||||
|
@ -1,17 +1,19 @@
|
||||
from typing import Any, List
|
||||
from private_gpt.users import crud, models, schemas
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from pydantic.networks import EmailStr
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Security, status
|
||||
|
||||
from private_gpt.users.api import deps
|
||||
from private_gpt.users.constants.role import Role
|
||||
from private_gpt.users.core.config import settings
|
||||
from private_gpt.users import crud, models, schemas
|
||||
from private_gpt.users.core.security import verify_password, get_password_hash
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Security
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from pydantic.networks import EmailStr
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
router = APIRouter(prefix="/users", tags=["users"])
|
||||
|
||||
|
||||
@router.get("", response_model=List[schemas.User])
|
||||
def read_users(
|
||||
db: Session = Depends(deps.get_db),
|
||||
@ -25,7 +27,7 @@ def read_users(
|
||||
"""
|
||||
Retrieve all users.
|
||||
"""
|
||||
users = crud.user.get_multi(db, skip=skip, limit=limit,)
|
||||
users = crud.user.get_multi(db, skip=skip, limit=limit)
|
||||
return users
|
||||
|
||||
|
||||
@ -45,11 +47,14 @@ def create_user(
|
||||
user = crud.user.get_by_email(db, email=user_in.email)
|
||||
if user:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="The user with this username already exists in the system.",
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="The user with this email already exists in the system.",
|
||||
)
|
||||
user = crud.user.create(db, obj_in=user_in)
|
||||
return user
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
content={"message": "User created successfully", "user": jsonable_encoder(user)},
|
||||
)
|
||||
|
||||
|
||||
@router.put("/me", response_model=schemas.User)
|
||||
@ -64,15 +69,20 @@ def update_user_me(
|
||||
Update own user.
|
||||
"""
|
||||
current_user_data = jsonable_encoder(current_user)
|
||||
print("Current user data: ", current_user_data)
|
||||
user_in = schemas.UserUpdate(**current_user_data)
|
||||
if fullname is not None:
|
||||
user_in.fullname = fullname
|
||||
if email is not None:
|
||||
user_in.email = email
|
||||
print(f"DB obj: {current_user}\n obj IN : {user_in}")
|
||||
user = crud.user.update(db, db_obj=current_user, obj_in=user_in)
|
||||
return user
|
||||
user_data = schemas.UserBaseSchema(
|
||||
email=user.email,
|
||||
fullname=user.fullname,
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content={"message": "User updated successfully", "user": jsonable_encoder(user_data)},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/me", response_model=schemas.User)
|
||||
@ -83,21 +93,15 @@ def read_user_me(
|
||||
"""
|
||||
Get current user.
|
||||
"""
|
||||
if not current_user.user_role:
|
||||
role = None
|
||||
else:
|
||||
role = current_user.user_role.role.name
|
||||
user_data = schemas.User(
|
||||
id=current_user.id,
|
||||
role = current_user.user_role.role.name if current_user.user_role else None
|
||||
user_data = schemas.UserBaseSchema(
|
||||
email=current_user.email,
|
||||
is_active=current_user.is_active,
|
||||
fullname=current_user.fullname,
|
||||
created_at=current_user.created_at,
|
||||
updated_at=current_user.updated_at,
|
||||
last_login = current_user.last_login,
|
||||
role=role
|
||||
)
|
||||
return user_data
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content={"message": "Current user retrieved successfully", "user": jsonable_encoder(user_data)},
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/me/change-password", response_model=schemas.User)
|
||||
@ -111,27 +115,24 @@ def change_password(
|
||||
"""
|
||||
Change current user's password.
|
||||
"""
|
||||
# Verify the old password
|
||||
if not verify_password(old_password, current_user.hashed_password):
|
||||
raise HTTPException(status_code=400, detail="Old password is incorrect")
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Old password is incorrect")
|
||||
|
||||
# Change the password
|
||||
new_password_hashed = get_password_hash(new_password)
|
||||
current_user.hashed_password = new_password_hashed
|
||||
db.commit()
|
||||
|
||||
role = current_user.user_role.role.name if current_user.user_role else None
|
||||
user_data = schemas.User(
|
||||
user_data = schemas.UserBaseSchema(
|
||||
id=current_user.id,
|
||||
email=current_user.email,
|
||||
is_active=current_user.is_active,
|
||||
fullname=current_user.fullname,
|
||||
created_at=current_user.created_at,
|
||||
updated_at=current_user.updated_at,
|
||||
last_login=current_user.last_login,
|
||||
role=role,
|
||||
)
|
||||
return user_data
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content={"message": "Password changed successfully", "user": jsonable_encoder(user_data)},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{user_id}", response_model=schemas.User)
|
||||
@ -147,9 +148,12 @@ def read_user_by_id(
|
||||
Get a specific user by id.
|
||||
"""
|
||||
if user_id is None:
|
||||
return "User id is not given."
|
||||
return JSONResponse(status_code=status.HTTP_400_BAD_REQUEST, content={"message": "User id is not given."})
|
||||
user = crud.user.get(db, id=user_id)
|
||||
return user
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content={"message": "User retrieved successfully", "user": jsonable_encoder(user)},
|
||||
)
|
||||
|
||||
|
||||
@router.put("/{user_id}", response_model=schemas.User)
|
||||
@ -169,8 +173,16 @@ def update_user(
|
||||
user = crud.user.get(db, id=user_id)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="The user with this username does not exist in the system",
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The user with this id does not exist in the system",
|
||||
)
|
||||
user = crud.user.update(db, db_obj=user, obj_in=user_in)
|
||||
return user
|
||||
user_data = schemas.UserBaseSchema(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
fullname=user.fullname,
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content={"message": "User updated successfully", "user": jsonable_encoder(user_data)},
|
||||
)
|
@ -32,6 +32,12 @@ class Settings(BaseSettings):
|
||||
DB_NAME: str
|
||||
PORT: str
|
||||
|
||||
SMTP_SERVER: str
|
||||
SMTP_PORT: str
|
||||
SMTP_SENDER_EMAIL: str
|
||||
SMTP_USERNAME: str
|
||||
SMTP_PASSWORD: str
|
||||
|
||||
|
||||
# SQLALCHEMY_DATABASE_URI: Optional[PostgresDsn] = None
|
||||
|
||||
|
@ -1,8 +1,10 @@
|
||||
from passlib.context import CryptContext
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Union, Any
|
||||
from jose import jwt
|
||||
from passlib.context import CryptContext
|
||||
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 30 # 30 minutes
|
||||
REFRESH_TOKEN_EXPIRE_MINUTES = 60 * 24 * 7 # 7 days
|
||||
@ -41,4 +43,11 @@ def create_refresh_token(subject: Union[str, Any], expires_delta: int = None) ->
|
||||
|
||||
to_encode = {"exp": expires_delta, **subject}
|
||||
encoded_jwt = jwt.encode(to_encode, JWT_REFRESH_SECRET_KEY, ALGORITHM)
|
||||
return encoded_jwt
|
||||
return encoded_jwt
|
||||
|
||||
def generate_random_password(length: int = 12) -> str:
|
||||
"""
|
||||
Generate a random password.
|
||||
"""
|
||||
characters = string.ascii_letters + string.digits + string.punctuation
|
||||
return ''.join(random.choice(characters) for i in range(length))
|
@ -1,9 +1,12 @@
|
||||
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union
|
||||
|
||||
from private_gpt.users.db.base import Base
|
||||
from fastapi import HTTPException, status
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from pydantic.error_wrappers import ValidationError
|
||||
|
||||
# Define custom types for SQLAlchemy model, and Pydantic schemas
|
||||
ModelType = TypeVar("ModelType", bound=Base)
|
||||
@ -30,12 +33,19 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
return db.query(self.model).filter(self.model.id == id).first()
|
||||
|
||||
def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType:
|
||||
obj_in_data = jsonable_encoder(obj_in)
|
||||
db_obj = self.model(**obj_in_data) # type: ignore
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
try:
|
||||
obj_in_data = jsonable_encoder(obj_in)
|
||||
db_obj = self.model(**obj_in_data)
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Integrity Error: {str(e)}",
|
||||
)
|
||||
|
||||
def update(
|
||||
self,
|
||||
@ -44,21 +54,34 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
||||
db_obj: ModelType,
|
||||
obj_in: Union[UpdateSchemaType, Dict[str, Any]]
|
||||
) -> ModelType:
|
||||
obj_data = jsonable_encoder(db_obj)
|
||||
if isinstance(obj_in, dict):
|
||||
update_data = obj_in
|
||||
else:
|
||||
update_data = obj_in.dict(exclude_unset=True)
|
||||
for field in obj_data:
|
||||
if field in update_data:
|
||||
setattr(db_obj, field, update_data[field])
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
try:
|
||||
obj_data = jsonable_encoder(db_obj)
|
||||
if isinstance(obj_in, dict):
|
||||
update_data = obj_in
|
||||
else:
|
||||
update_data = obj_in.dict(exclude_unset=True)
|
||||
for field in obj_data:
|
||||
if field in update_data:
|
||||
setattr(db_obj, field, update_data[field])
|
||||
db.add(db_obj)
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Integrity Error: {str(e)}",
|
||||
)
|
||||
|
||||
def remove(self, db: Session, *, id: int) -> ModelType:
|
||||
obj = db.query(self.model).get(id)
|
||||
db.delete(obj)
|
||||
db.commit()
|
||||
return obj
|
||||
if obj:
|
||||
db.delete(obj)
|
||||
db.commit()
|
||||
return obj
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"{self.model.__name__} not found with id: {id}",
|
||||
)
|
@ -4,13 +4,28 @@ from private_gpt.users.crud.base import CRUDBase
|
||||
from private_gpt.users.models.subscription import Subscription
|
||||
from private_gpt.users.schemas.subscription import SubscriptionCreate, SubscriptionUpdate
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
class CRUDSubscription(CRUDBase[Subscription, SubscriptionCreate, SubscriptionUpdate]):
|
||||
|
||||
def get_by_id(self, db: Session, *, subscription_id: int) -> Optional[Subscription]:
|
||||
return db.query(self.model).filter(Subscription.sub_id == subscription_id).first()
|
||||
|
||||
def get_by_company_id(self, db: Session, *, company_id: int) -> List[Subscription]:
|
||||
return db.query(self.model).filter(Subscription.company_id == company_id).all()
|
||||
|
||||
def get_active_subscription_by_company(self, db: Session, *, company_id: int) -> List[Subscription]:
|
||||
current_datetime = datetime.utcnow()
|
||||
return (
|
||||
db.query(self.model)
|
||||
.filter(
|
||||
Subscription.company_id == company_id,
|
||||
Subscription.is_active == True, # Active subscriptions
|
||||
Subscription.end_date >= current_datetime, # End date is not passed
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
subscription = CRUDSubscription(Subscription)
|
||||
|
||||
|
||||
subscription = CRUDSubscription(Subscription)
|
@ -11,9 +11,13 @@ class Subscription(Base):
|
||||
|
||||
sub_id = Column(Integer, primary_key=True, index=True)
|
||||
company_id = Column(Integer, ForeignKey("companies.id"))
|
||||
start_date = Column(DateTime, default=datetime.utcnow)
|
||||
start_date = Column(DateTime, default=datetime.utcnow())
|
||||
end_date = Column(DateTime, default=datetime.utcnow() + timedelta(days=30)) # Example: 30 days subscription period
|
||||
is_active = Column(Boolean, default=True)
|
||||
is_active = Column(Boolean, default=False)
|
||||
|
||||
company = relationship("Company", back_populates="subscriptions")
|
||||
|
||||
@property
|
||||
def is_active(self) -> bool:
|
||||
"""Check if the subscription is active based on the end_date."""
|
||||
return self.end_date >= datetime.utcnow()
|
||||
|
@ -1,6 +1,6 @@
|
||||
from .role import Role, RoleCreate, RoleInDB, RoleUpdate
|
||||
from .token import TokenSchema, TokenPayload
|
||||
from .user import User, UserCreate, UserInDB, UserUpdate
|
||||
from .user import User, UserCreate, UserInDB, UserUpdate, UserBaseSchema
|
||||
from .user_role import UserRole, UserRoleCreate, UserRoleInDB, UserRoleUpdate
|
||||
from .subscription import Subscription, SubscriptionBase, SubscriptionCreate, SubscriptionUpdate
|
||||
from .company import Company, CompanyBase, CompanyCreate, CompanyUpdate
|
1
private_gpt/users/utils/__init__.py
Normal file
1
private_gpt/users/utils/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .utils import send_registration_email
|
28
private_gpt/users/utils/utils.py
Normal file
28
private_gpt/users/utils/utils.py
Normal file
@ -0,0 +1,28 @@
|
||||
import smtplib
|
||||
from email.mime.text import MIMEText
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from private_gpt.users.core.config import settings
|
||||
|
||||
|
||||
def send_registration_email(fullname: str, email: str, random_password: str) -> None:
|
||||
"""
|
||||
Send a registration email with a random password.
|
||||
"""
|
||||
subject = "Welcome to QuickGPT - Registration Successful"
|
||||
body = f"Hello {fullname},\n\nThank you for registering with Your App!\n\n"\
|
||||
f"Your temporary password is: {random_password}\n\n"\
|
||||
f"Please use this password to log in and consider changing it"\
|
||||
" to a more secure one after logging in.\n\n"\
|
||||
"Best regards,\nQuickGPT Team"
|
||||
|
||||
msg = MIMEMultipart()
|
||||
msg.attach(MIMEText(body, "plain"))
|
||||
msg["Subject"] = subject
|
||||
msg["From"] = settings.SMTP_SENDER_EMAIL
|
||||
msg["To"] = email
|
||||
|
||||
with smtplib.SMTP(settings.SMTP_SERVER, settings.SMTP_PORT) as server:
|
||||
server.starttls()
|
||||
server.login(settings.SMTP_USERNAME, settings.SMTP_PASSWORD)
|
||||
server.sendmail(settings.SMTP_SENDER_EMAIL, email, msg.as_string())
|
||||
|
Loading…
Reference in New Issue
Block a user