mirror of
https://github.com/imartinez/privateGPT.git
synced 2025-07-17 08:53:39 +00:00
Added subscriptions module and registering users with email
This commit is contained in:
parent
4c5fed7ea4
commit
8b396ac6fb
7
.env
7
.env
@ -13,3 +13,10 @@ SUPER_ADMIN_ACCOUNT_NAME=superaccount
|
|||||||
SECRET_KEY=ba9dc3f976cf8fb40519dcd152a8d7d21c0b7861d841711cdb2602be8e85fd7c
|
SECRET_KEY=ba9dc3f976cf8fb40519dcd152a8d7d21c0b7861d841711cdb2602be8e85fd7c
|
||||||
ACCESS_TOKEN_EXPIRE_MINUTES=60
|
ACCESS_TOKEN_EXPIRE_MINUTES=60
|
||||||
REFRESH_TOKEN_EXPIRE_MINUTES = 120 # 7 days
|
REFRESH_TOKEN_EXPIRE_MINUTES = 120 # 7 days
|
||||||
|
|
||||||
|
|
||||||
|
SMTP_SERVER=smtp-mail.outlook.com
|
||||||
|
SMTP_PORT=587
|
||||||
|
SMTP_SENDER_EMAIL=saurabstha7@outlook.com
|
||||||
|
SMTP_USERNAME=saurabstha7@outlook.com
|
||||||
|
SMTP_PASSWORD=avantador123
|
@ -1,52 +0,0 @@
|
|||||||
"""Add Subscription and Company model
|
|
||||||
|
|
||||||
Revision ID: 0e0eb0a1a514
|
|
||||||
Revises: 65688535c5a5
|
|
||||||
Create Date: 2024-01-17 15:53:13.091801
|
|
||||||
|
|
||||||
"""
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = '0e0eb0a1a514'
|
|
||||||
down_revision: Union[str, None] = '65688535c5a5'
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
op.create_table('companies',
|
|
||||||
sa.Column('id', sa.Integer(), nullable=False),
|
|
||||||
sa.Column('name', sa.String(), nullable=True),
|
|
||||||
sa.PrimaryKeyConstraint('id')
|
|
||||||
)
|
|
||||||
op.create_index(op.f('ix_companies_id'), 'companies', ['id'], unique=False)
|
|
||||||
op.create_index(op.f('ix_companies_name'), 'companies', ['name'], unique=False)
|
|
||||||
op.create_table('subscriptions',
|
|
||||||
sa.Column('sub_id', sa.Integer(), nullable=False),
|
|
||||||
sa.Column('company_id', sa.Integer(), nullable=True),
|
|
||||||
sa.Column('start_date', sa.DateTime(), nullable=True),
|
|
||||||
sa.Column('end_date', sa.DateTime(), nullable=True),
|
|
||||||
sa.Column('is_active', sa.Boolean(), nullable=True),
|
|
||||||
sa.ForeignKeyConstraint(['company_id'], ['companies.id'], ),
|
|
||||||
sa.PrimaryKeyConstraint('sub_id')
|
|
||||||
)
|
|
||||||
op.create_index(op.f('ix_subscriptions_sub_id'), 'subscriptions', ['sub_id'], unique=False)
|
|
||||||
# op.create_unique_constraint('unique_user_role', 'user_roles', ['user_id', 'role_id'])
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
op.drop_constraint('unique_user_role', 'user_roles', type_='unique')
|
|
||||||
op.drop_index(op.f('ix_subscriptions_sub_id'), table_name='subscriptions')
|
|
||||||
op.drop_table('subscriptions')
|
|
||||||
op.drop_index(op.f('ix_companies_name'), table_name='companies')
|
|
||||||
op.drop_index(op.f('ix_companies_id'), table_name='companies')
|
|
||||||
op.drop_table('companies')
|
|
||||||
# ### end Alembic commands ###
|
|
@ -1,30 +0,0 @@
|
|||||||
"""Create user model
|
|
||||||
|
|
||||||
Revision ID: 3cd055fe81a3
|
|
||||||
Revises:
|
|
||||||
Create Date: 2024-01-14 10:44:37.040428
|
|
||||||
|
|
||||||
"""
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = '3cd055fe81a3'
|
|
||||||
down_revision: Union[str, None] = None
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
pass
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
pass
|
|
||||||
# ### end Alembic commands ###
|
|
@ -1,31 +0,0 @@
|
|||||||
"""Add Subscription and Company model:
|
|
||||||
|
|
||||||
|
|
||||||
Revision ID: 65688535c5a5
|
|
||||||
Revises: cba9e6e394ca
|
|
||||||
Create Date: 2024-01-17 15:45:28.636265
|
|
||||||
|
|
||||||
"""
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = '65688535c5a5'
|
|
||||||
down_revision: Union[str, None] = 'cba9e6e394ca'
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
pass
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
pass
|
|
||||||
# ### end Alembic commands ###
|
|
@ -1,8 +1,8 @@
|
|||||||
"""Create user model and role
|
"""Create user, roles, user roles, subscription and company model
|
||||||
|
|
||||||
Revision ID: cba9e6e394ca
|
Revision ID: 6f3cc13e1339
|
||||||
Revises: 3cd055fe81a3
|
Revises:
|
||||||
Create Date: 2024-01-14 10:46:33.847333
|
Create Date: 2024-01-18 12:33:39.002575
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from typing import Sequence, Union
|
from typing import Sequence, Union
|
||||||
@ -12,14 +12,21 @@ import sqlalchemy as sa
|
|||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision: str = 'cba9e6e394ca'
|
revision: str = '6f3cc13e1339'
|
||||||
down_revision: Union[str, None] = '3cd055fe81a3'
|
down_revision: Union[str, None] = None
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table('companies',
|
||||||
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('name', sa.String(), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_companies_id'), 'companies', ['id'], unique=False)
|
||||||
|
op.create_index(op.f('ix_companies_name'), 'companies', ['name'], unique=False)
|
||||||
op.create_table('roles',
|
op.create_table('roles',
|
||||||
sa.Column('id', sa.Integer(), nullable=False),
|
sa.Column('id', sa.Integer(), nullable=False),
|
||||||
sa.Column('name', sa.String(length=100), nullable=True),
|
sa.Column('name', sa.String(length=100), nullable=True),
|
||||||
@ -40,6 +47,15 @@ def upgrade() -> None:
|
|||||||
sa.PrimaryKeyConstraint('id'),
|
sa.PrimaryKeyConstraint('id'),
|
||||||
sa.UniqueConstraint('email')
|
sa.UniqueConstraint('email')
|
||||||
)
|
)
|
||||||
|
op.create_table('subscriptions',
|
||||||
|
sa.Column('sub_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('company_id', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('start_date', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('end_date', sa.DateTime(), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(['company_id'], ['companies.id'], ),
|
||||||
|
sa.PrimaryKeyConstraint('sub_id')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_subscriptions_sub_id'), 'subscriptions', ['sub_id'], unique=False)
|
||||||
op.create_table('user_roles',
|
op.create_table('user_roles',
|
||||||
sa.Column('user_id', sa.Integer(), nullable=False),
|
sa.Column('user_id', sa.Integer(), nullable=False),
|
||||||
sa.Column('role_id', sa.Integer(), nullable=False),
|
sa.Column('role_id', sa.Integer(), nullable=False),
|
||||||
@ -54,8 +70,13 @@ def upgrade() -> None:
|
|||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
op.drop_table('user_roles')
|
op.drop_table('user_roles')
|
||||||
|
op.drop_index(op.f('ix_subscriptions_sub_id'), table_name='subscriptions')
|
||||||
|
op.drop_table('subscriptions')
|
||||||
op.drop_table('users')
|
op.drop_table('users')
|
||||||
op.drop_index(op.f('ix_roles_name'), table_name='roles')
|
op.drop_index(op.f('ix_roles_name'), table_name='roles')
|
||||||
op.drop_index(op.f('ix_roles_id'), table_name='roles')
|
op.drop_index(op.f('ix_roles_id'), table_name='roles')
|
||||||
op.drop_table('roles')
|
op.drop_table('roles')
|
||||||
|
op.drop_index(op.f('ix_companies_name'), table_name='companies')
|
||||||
|
op.drop_index(op.f('ix_companies_id'), table_name='companies')
|
||||||
|
op.drop_table('companies')
|
||||||
# ### end Alembic commands ###
|
# ### end Alembic commands ###
|
@ -1,15 +1,18 @@
|
|||||||
from datetime import timedelta, datetime
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from datetime import timedelta, datetime
|
||||||
|
|
||||||
|
from pydantic.networks import EmailStr
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from fastapi import APIRouter, Body, Depends, HTTPException, Security
|
||||||
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
|
|
||||||
from private_gpt.users import crud, models, schemas
|
|
||||||
from private_gpt.users.api import deps
|
from private_gpt.users.api import deps
|
||||||
from private_gpt.users.core import security
|
from private_gpt.users.core import security
|
||||||
from private_gpt.users.constants.role import Role
|
from private_gpt.users.constants.role import Role
|
||||||
from private_gpt.users.core.config import settings
|
from private_gpt.users.core.config import settings
|
||||||
from fastapi import APIRouter, Body, Depends, HTTPException
|
from private_gpt.users import crud, models, schemas
|
||||||
from fastapi.security import OAuth2PasswordRequestForm
|
from private_gpt.users.utils import send_registration_email
|
||||||
from pydantic.networks import EmailStr
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||||
|
|
||||||
@ -25,7 +28,6 @@ def login_access_token(
|
|||||||
user = crud.user.authenticate(
|
user = crud.user.authenticate(
|
||||||
db, email=form_data.username, password=form_data.password
|
db, email=form_data.username, password=form_data.password
|
||||||
)
|
)
|
||||||
print("USER object", user)
|
|
||||||
if not user:
|
if not user:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400, detail="Incorrect email or password"
|
status_code=400, detail="Incorrect email or password"
|
||||||
@ -39,15 +41,12 @@ def login_access_token(
|
|||||||
refresh_token_expires = timedelta(
|
refresh_token_expires = timedelta(
|
||||||
minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES
|
minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES
|
||||||
)
|
)
|
||||||
print(f"Access Token expires: {access_token_expires}\n Refresh token expires: {refresh_token_expires}")
|
|
||||||
user_in = schemas.UserUpdate(
|
user_in = schemas.UserUpdate(
|
||||||
email = user.email,
|
email = user.email,
|
||||||
fullname = user.fullname,
|
fullname = user.fullname,
|
||||||
last_login=datetime.now()
|
last_login=datetime.now()
|
||||||
)
|
)
|
||||||
print("Update last login schema: ", user_in)
|
|
||||||
user = crud.user.update(db, db_obj=user, obj_in=user_in)
|
user = crud.user.update(db, db_obj=user, obj_in=user_in)
|
||||||
print("update in database:", user)
|
|
||||||
if not user.user_role:
|
if not user.user_role:
|
||||||
role = "GUEST"
|
role = "GUEST"
|
||||||
else:
|
else:
|
||||||
@ -74,8 +73,12 @@ def register(
|
|||||||
*,
|
*,
|
||||||
db: Session = Depends(deps.get_db),
|
db: Session = Depends(deps.get_db),
|
||||||
email: EmailStr = Body(...),
|
email: EmailStr = Body(...),
|
||||||
password: str = Body(...),
|
|
||||||
fullname: str = Body(...),
|
fullname: str = Body(...),
|
||||||
|
role: str = Body(Default="GUEST"),
|
||||||
|
current_user: models.User = Security(
|
||||||
|
deps.get_current_user,
|
||||||
|
scopes=[Role.ADMIN["name"], Role.SUPER_ADMIN["name"]],
|
||||||
|
),
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Register new user.
|
Register new user.
|
||||||
@ -86,41 +89,35 @@ def register(
|
|||||||
status_code=409,
|
status_code=409,
|
||||||
detail="The user with this username already exists in the system",
|
detail="The user with this username already exists in the system",
|
||||||
)
|
)
|
||||||
|
random_password = security.generate_random_password()
|
||||||
user_in = schemas.UserCreate(
|
user_in = schemas.UserCreate(
|
||||||
email=email,
|
email=email,
|
||||||
password=password,
|
password=random_password,
|
||||||
fullname=fullname,
|
fullname=fullname,
|
||||||
)
|
)
|
||||||
|
|
||||||
# create user
|
|
||||||
user = crud.user.create(db, obj_in=user_in)
|
user = crud.user.create(db, obj_in=user_in)
|
||||||
|
send_registration_email(fullname, email, random_password)
|
||||||
|
|
||||||
# get role
|
role_db = crud.role.get_by_name(db, name=role)
|
||||||
role = crud.role.get_by_name(db, name=Role.SUPER_ADMIN["name"])
|
if not role_db:
|
||||||
print("ROLE:", role)
|
raise HTTPException(
|
||||||
# assign user_role
|
status_code=404,
|
||||||
|
detail=f"Role '{role}' not found",
|
||||||
|
)
|
||||||
user_role_in = schemas.UserRoleCreate(
|
user_role_in = schemas.UserRoleCreate(
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
role_id=role.id
|
role_id=role_db.id
|
||||||
)
|
)
|
||||||
user_role = crud.user_role.create(db, obj_in=user_role_in)
|
user_role = crud.user_role.create(db, obj_in=user_role_in)
|
||||||
|
|
||||||
print(user)
|
|
||||||
|
|
||||||
access_token_expires = timedelta(
|
access_token_expires = timedelta(
|
||||||
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||||
)
|
)
|
||||||
refresh_token_expires = timedelta(
|
refresh_token_expires = timedelta(
|
||||||
minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES
|
minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES
|
||||||
)
|
)
|
||||||
if not user.user_role:
|
|
||||||
role = "GUEST"
|
|
||||||
else:
|
|
||||||
role = user.user_role.role.name
|
|
||||||
|
|
||||||
token_payload = {
|
token_payload = {
|
||||||
"id": str(user.id),
|
"id": str(user.id),
|
||||||
"role": role,
|
"role": user_role.role.name,
|
||||||
}
|
}
|
||||||
return {
|
return {
|
||||||
"access_token": security.create_access_token(
|
"access_token": security.create_access_token(
|
||||||
|
@ -1,20 +1,27 @@
|
|||||||
from typing import Any, List
|
from typing import Any, List
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
|
||||||
from fastapi.responses import JSONResponse
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from private_gpt.users import crud, models, schemas
|
from fastapi.responses import JSONResponse
|
||||||
from private_gpt.users.constants.role import Role
|
|
||||||
from private_gpt.users.api import deps
|
|
||||||
from fastapi.encoders import jsonable_encoder
|
from fastapi.encoders import jsonable_encoder
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status, Security
|
||||||
|
|
||||||
|
from private_gpt.users.api import deps
|
||||||
|
from private_gpt.users.constants.role import Role
|
||||||
|
from private_gpt.users import crud, models, schemas
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/companies", tags=["Companies"])
|
router = APIRouter(prefix="/companies", tags=["Companies"])
|
||||||
|
|
||||||
|
|
||||||
@router.get("", response_model=List[schemas.Company])
|
@router.get("", response_model=List[schemas.Company])
|
||||||
def list_companies(
|
def list_companies(
|
||||||
db: Session = Depends(deps.get_db),
|
db: Session = Depends(deps.get_db),
|
||||||
skip: int = 0,
|
skip: int = 0,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
|
current_user: models.User = Security(
|
||||||
|
deps.get_current_user,
|
||||||
|
scopes=[Role.SUPER_ADMIN["name"]],
|
||||||
|
),
|
||||||
) -> List[schemas.Company]:
|
) -> List[schemas.Company]:
|
||||||
"""
|
"""
|
||||||
List companies
|
List companies
|
||||||
@ -27,18 +34,34 @@ def list_companies(
|
|||||||
def create_company(
|
def create_company(
|
||||||
company_in: schemas.CompanyCreate,
|
company_in: schemas.CompanyCreate,
|
||||||
db: Session = Depends(deps.get_db),
|
db: Session = Depends(deps.get_db),
|
||||||
|
current_user: models.User = Security(
|
||||||
|
deps.get_current_user,
|
||||||
|
scopes=[Role.SUPER_ADMIN["name"]],
|
||||||
|
),
|
||||||
) -> schemas.Company:
|
) -> schemas.Company:
|
||||||
"""
|
"""
|
||||||
Create a new company
|
Create a new company
|
||||||
"""
|
"""
|
||||||
company = crud.company.create(db=db, obj_in=company_in)
|
company = crud.company.create(db=db, obj_in=company_in)
|
||||||
return company
|
company = jsonable_encoder(company)
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_201_CREATED,
|
||||||
|
content={
|
||||||
|
"message": "Company created successfully",
|
||||||
|
"subscription": company
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{company_id}", response_model=schemas.Company)
|
@router.get("/{company_id}", response_model=schemas.Company)
|
||||||
def read_company(
|
def read_company(
|
||||||
company_id: int,
|
company_id: int,
|
||||||
db: Session = Depends(deps.get_db),
|
db: Session = Depends(deps.get_db),
|
||||||
|
current_user: models.User = Security(
|
||||||
|
deps.get_current_user,
|
||||||
|
scopes=[Role.SUPER_ADMIN["name"]],
|
||||||
|
),
|
||||||
) -> schemas.Company:
|
) -> schemas.Company:
|
||||||
"""
|
"""
|
||||||
Read a company by ID
|
Read a company by ID
|
||||||
@ -54,6 +77,10 @@ def update_company(
|
|||||||
company_id: int,
|
company_id: int,
|
||||||
company_in: schemas.CompanyUpdate,
|
company_in: schemas.CompanyUpdate,
|
||||||
db: Session = Depends(deps.get_db),
|
db: Session = Depends(deps.get_db),
|
||||||
|
current_user: models.User = Security(
|
||||||
|
deps.get_current_user,
|
||||||
|
scopes=[Role.SUPER_ADMIN["name"]],
|
||||||
|
),
|
||||||
) -> schemas.Company:
|
) -> schemas.Company:
|
||||||
"""
|
"""
|
||||||
Update a company by ID
|
Update a company by ID
|
||||||
@ -72,10 +99,15 @@ def update_company(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{company_id}", response_model=schemas.Company)
|
@router.delete("/{company_id}", response_model=schemas.Company)
|
||||||
def delete_company(
|
def delete_company(
|
||||||
company_id: int,
|
company_id: int,
|
||||||
db: Session = Depends(deps.get_db),
|
db: Session = Depends(deps.get_db),
|
||||||
|
current_user: models.User = Security(
|
||||||
|
deps.get_current_user,
|
||||||
|
scopes=[Role.SUPER_ADMIN["name"]],
|
||||||
|
),
|
||||||
) -> schemas.Company:
|
) -> schemas.Company:
|
||||||
"""
|
"""
|
||||||
Delete a company by ID
|
Delete a company by ID
|
||||||
|
@ -1,9 +1,12 @@
|
|||||||
from typing import Any, List
|
from typing import Any, List
|
||||||
|
|
||||||
from private_gpt.users import crud, schemas
|
|
||||||
from private_gpt.users.api import deps
|
|
||||||
from fastapi import APIRouter, Depends
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from fastapi import APIRouter, Depends, status, Security
|
||||||
|
|
||||||
|
from private_gpt.users.api import deps
|
||||||
|
from private_gpt.users.constants.role import Role
|
||||||
|
from private_gpt.users import crud, schemas, models
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(prefix='/roles', tags=['roles'])
|
router = APIRouter(prefix='/roles', tags=['roles'])
|
||||||
@ -12,9 +15,19 @@ router = APIRouter(prefix='/roles', tags=['roles'])
|
|||||||
@router.get("/", response_model=List[schemas.Role])
|
@router.get("/", response_model=List[schemas.Role])
|
||||||
def get_roles(
|
def get_roles(
|
||||||
db: Session = Depends(deps.get_db), skip: int = 0, limit: int = 100,
|
db: Session = Depends(deps.get_db), skip: int = 0, limit: int = 100,
|
||||||
|
current_user: models.User = Security(
|
||||||
|
deps.get_current_user,
|
||||||
|
scopes=[Role.SUPER_ADMIN["name"]],
|
||||||
|
),
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Retrieve all available user roles.
|
Retrieve all available user roles.
|
||||||
"""
|
"""
|
||||||
roles = crud.role.get_multi(db, skip=skip, limit=limit)
|
roles = crud.role.get_multi(db, skip=skip, limit=limit)
|
||||||
return roles
|
|
||||||
|
roles_data = [{"id": role.id, "name": role.name} for role in roles]
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_200_OK,
|
||||||
|
content={"message": "Roles retrieved successfully", "roles": roles_data},
|
||||||
|
)
|
@ -1,47 +1,63 @@
|
|||||||
from typing import Any, List
|
from typing import Any, List
|
||||||
from private_gpt.users import crud, models, schemas
|
|
||||||
from private_gpt.users.api import deps
|
|
||||||
from private_gpt.users.constants.role import Role
|
|
||||||
from fastapi import APIRouter, Body, Depends, HTTPException, Security,status
|
|
||||||
from fastapi.encoders import jsonable_encoder
|
|
||||||
from pydantic.networks import EmailStr
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
from pydantic.networks import EmailStr
|
||||||
|
from fastapi import APIRouter, Body, Depends, HTTPException, Security, status
|
||||||
|
from fastapi.encoders import jsonable_encoder
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
router = APIRouter(prefix="/subscriptions", tags=["Subscriptions"])
|
from private_gpt.users.api import deps
|
||||||
|
from private_gpt.users.constants.role import Role
|
||||||
|
from private_gpt.users import crud, models, schemas
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/subscriptions", tags=["Subscriptions"])
|
||||||
|
|
||||||
@router.post("/create", response_model=schemas.Subscription)
|
@router.post("/create", response_model=schemas.Subscription)
|
||||||
def create_subscription(
|
def create_subscription(
|
||||||
subscription_in: schemas.SubscriptionCreate,
|
subscription_in: schemas.SubscriptionCreate,
|
||||||
db: Session = Depends(deps.get_db),
|
db: Session = Depends(deps.get_db),
|
||||||
|
current_user: models.User = Security(
|
||||||
|
deps.get_current_user,
|
||||||
|
scopes=[Role.SUPER_ADMIN["name"]],
|
||||||
|
),
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Create a new subscription
|
Create a new subscription
|
||||||
"""
|
"""
|
||||||
existing_subscription = crud.subscription.get_by_company_id(db, company_id=subscription_in.company_id)
|
active_subscription = crud.subscription.get_active_subscription_by_company(
|
||||||
if existing_subscription:
|
db=db, company_id=subscription_in.company_id
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail="Company is already subscribed to a plan.",
|
|
||||||
)
|
|
||||||
|
|
||||||
subscription = crud.subscription.create(db=db, obj_in=subscription_in)
|
|
||||||
subscription_dict = jsonable_encoder(subscription)
|
|
||||||
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=status.HTTP_201_CREATED,
|
|
||||||
content={
|
|
||||||
"message": "Subscription created successfully",
|
|
||||||
"subscription": subscription_dict
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if active_subscription:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_200_OK,
|
||||||
|
content={
|
||||||
|
"message": "Active subscription found",
|
||||||
|
"subscription": jsonable_encoder(active_subscription),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
subscription = crud.subscription.create(db=db, obj_in=subscription_in)
|
||||||
|
subscription_dict = jsonable_encoder(subscription)
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_201_CREATED,
|
||||||
|
content={
|
||||||
|
"message": "Subscription created successfully",
|
||||||
|
"subscription": subscription_dict,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{subscription_id}", response_model=schemas.Subscription)
|
@router.get("/{subscription_id}", response_model=schemas.Subscription)
|
||||||
def read_subscription(
|
def read_subscription(
|
||||||
subscription_id: int,
|
subscription_id: int,
|
||||||
db: Session = Depends(deps.get_db),
|
db: Session = Depends(deps.get_db),
|
||||||
|
current_user: models.User = Security(
|
||||||
|
deps.get_current_user,
|
||||||
|
scopes=[Role.SUPER_ADMIN["name"]],
|
||||||
|
),
|
||||||
):
|
):
|
||||||
subscription = crud.subscription.get_by_id(db, subscription_id=subscription_id)
|
subscription = crud.subscription.get_by_id(db, subscription_id=subscription_id)
|
||||||
if subscription is None:
|
if subscription is None:
|
||||||
@ -61,6 +77,10 @@ def read_subscription(
|
|||||||
def read_subscriptions_by_company(
|
def read_subscriptions_by_company(
|
||||||
company_id: int,
|
company_id: int,
|
||||||
db: Session = Depends(deps.get_db),
|
db: Session = Depends(deps.get_db),
|
||||||
|
current_user: models.User = Security(
|
||||||
|
deps.get_current_user,
|
||||||
|
scopes=[Role.SUPER_ADMIN["name"]],
|
||||||
|
),
|
||||||
):
|
):
|
||||||
subscriptions = crud.subscription.get_by_company_id(db, company_id=company_id)
|
subscriptions = crud.subscription.get_by_company_id(db, company_id=company_id)
|
||||||
subscriptions_list = [jsonable_encoder(subscription) for subscription in subscriptions]
|
subscriptions_list = [jsonable_encoder(subscription) for subscription in subscriptions]
|
||||||
@ -79,6 +99,10 @@ def update_subscription(
|
|||||||
subscription_id: int,
|
subscription_id: int,
|
||||||
subscription_in: schemas.SubscriptionUpdate,
|
subscription_in: schemas.SubscriptionUpdate,
|
||||||
db: Session = Depends(deps.get_db),
|
db: Session = Depends(deps.get_db),
|
||||||
|
current_user: models.User = Security(
|
||||||
|
deps.get_current_user,
|
||||||
|
scopes=[Role.SUPER_ADMIN["name"]],
|
||||||
|
),
|
||||||
):
|
):
|
||||||
subscription = crud.subscription.get_by_id(db, subscription_id=subscription_id)
|
subscription = crud.subscription.get_by_id(db, subscription_id=subscription_id)
|
||||||
if subscription is None:
|
if subscription is None:
|
||||||
@ -105,11 +129,17 @@ def update_subscription(
|
|||||||
def delete_subscription(
|
def delete_subscription(
|
||||||
subscription_id: int,
|
subscription_id: int,
|
||||||
db: Session = Depends(deps.get_db),
|
db: Session = Depends(deps.get_db),
|
||||||
|
current_user: models.User = Security(
|
||||||
|
deps.get_current_user,
|
||||||
|
scopes=[Role.SUPER_ADMIN["name"]],
|
||||||
|
),
|
||||||
):
|
):
|
||||||
subscription = crud.subscription.remove(db=db, id=subscription_id)
|
subscription = crud.subscription.remove(db=db, id=subscription_id)
|
||||||
if subscription is None:
|
if subscription is None:
|
||||||
raise HTTPException(status_code=404, detail="Subscription not found")
|
raise HTTPException(status_code=404, detail="Subscription not found")
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
content={"message": "Subscription deleted successfully"},
|
status_code=status.HTTP_200_OK,
|
||||||
status_code=status.HTTP_200_OK
|
content={
|
||||||
|
"message": "Subscription deleted successfully"
|
||||||
|
}
|
||||||
)
|
)
|
@ -1,10 +1,13 @@
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from private_gpt.users import crud, models, schemas
|
from sqlalchemy.orm import Session
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from fastapi.encoders import jsonable_encoder
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Security, status
|
||||||
|
|
||||||
from private_gpt.users.api import deps
|
from private_gpt.users.api import deps
|
||||||
from private_gpt.users.constants.role import Role
|
from private_gpt.users.constants.role import Role
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Security
|
from private_gpt.users import crud, models, schemas
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/user-roles", tags=["user-roles"])
|
router = APIRouter(prefix="/user-roles", tags=["user-roles"])
|
||||||
|
|
||||||
@ -14,7 +17,12 @@ def assign_user_role(
|
|||||||
*,
|
*,
|
||||||
db: Session = Depends(deps.get_db),
|
db: Session = Depends(deps.get_db),
|
||||||
user_role_in: schemas.UserRoleCreate,
|
user_role_in: schemas.UserRoleCreate,
|
||||||
current_user: models.User = Depends(deps.get_current_user),
|
current_user: models.User = Security(
|
||||||
|
deps.get_current_user,
|
||||||
|
scopes=[
|
||||||
|
Role.SUPER_ADMIN["name"],
|
||||||
|
],
|
||||||
|
),
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Assign a role to a user after creation of a user
|
Assign a role to a user after creation of a user
|
||||||
@ -26,7 +34,10 @@ def assign_user_role(
|
|||||||
detail="This user has already been assigned a role.",
|
detail="This user has already been assigned a role.",
|
||||||
)
|
)
|
||||||
user_role = crud.user_role.create(db, obj_in=user_role_in)
|
user_role = crud.user_role.create(db, obj_in=user_role_in)
|
||||||
return user_role
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_201_CREATED,
|
||||||
|
content={"message": "User role assigned successfully", "user_role": jsonable_encoder(user_role)},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.put("/{user_id}", response_model=schemas.UserRole)
|
@router.put("/{user_id}", response_model=schemas.UserRole)
|
||||||
@ -38,20 +49,23 @@ def update_user_role(
|
|||||||
current_user: models.User = Security(
|
current_user: models.User = Security(
|
||||||
deps.get_current_user,
|
deps.get_current_user,
|
||||||
scopes=[
|
scopes=[
|
||||||
Role.ADMIN["name"],
|
|
||||||
Role.SUPER_ADMIN["name"],
|
Role.SUPER_ADMIN["name"],
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Update a users role.
|
Update a user's role.
|
||||||
"""
|
"""
|
||||||
user_role = crud.user_role.get_by_user_id(db, user_id=user_id)
|
user_role = crud.user_role.get_by_user_id(db, user_id=user_id)
|
||||||
if not user_role:
|
if not user_role:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404, detail="There is no role assigned to this user",
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="There is no role assigned to this user",
|
||||||
)
|
)
|
||||||
user_role = crud.user_role.update(
|
user_role = crud.user_role.update(
|
||||||
db, db_obj=user_role, obj_in=user_role_in
|
db, db_obj=user_role, obj_in=user_role_in
|
||||||
)
|
)
|
||||||
return user_role
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_200_OK,
|
||||||
|
content={"message": "User role updated successfully", "user_role": jsonable_encoder(user_role)},
|
||||||
|
)
|
||||||
|
@ -1,17 +1,19 @@
|
|||||||
from typing import Any, List
|
from typing import Any, List
|
||||||
from private_gpt.users import crud, models, schemas
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from pydantic.networks import EmailStr
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from fastapi.encoders import jsonable_encoder
|
||||||
|
from fastapi import APIRouter, Body, Depends, HTTPException, Security, status
|
||||||
|
|
||||||
from private_gpt.users.api import deps
|
from private_gpt.users.api import deps
|
||||||
from private_gpt.users.constants.role import Role
|
from private_gpt.users.constants.role import Role
|
||||||
from private_gpt.users.core.config import settings
|
from private_gpt.users.core.config import settings
|
||||||
|
from private_gpt.users import crud, models, schemas
|
||||||
from private_gpt.users.core.security import verify_password, get_password_hash
|
from private_gpt.users.core.security import verify_password, get_password_hash
|
||||||
from fastapi import APIRouter, Body, Depends, HTTPException, Security
|
|
||||||
from fastapi.encoders import jsonable_encoder
|
|
||||||
from pydantic.networks import EmailStr
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/users", tags=["users"])
|
router = APIRouter(prefix="/users", tags=["users"])
|
||||||
|
|
||||||
|
|
||||||
@router.get("", response_model=List[schemas.User])
|
@router.get("", response_model=List[schemas.User])
|
||||||
def read_users(
|
def read_users(
|
||||||
db: Session = Depends(deps.get_db),
|
db: Session = Depends(deps.get_db),
|
||||||
@ -25,7 +27,7 @@ def read_users(
|
|||||||
"""
|
"""
|
||||||
Retrieve all users.
|
Retrieve all users.
|
||||||
"""
|
"""
|
||||||
users = crud.user.get_multi(db, skip=skip, limit=limit,)
|
users = crud.user.get_multi(db, skip=skip, limit=limit)
|
||||||
return users
|
return users
|
||||||
|
|
||||||
|
|
||||||
@ -45,11 +47,14 @@ def create_user(
|
|||||||
user = crud.user.get_by_email(db, email=user_in.email)
|
user = crud.user.get_by_email(db, email=user_in.email)
|
||||||
if user:
|
if user:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=409,
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
detail="The user with this username already exists in the system.",
|
detail="The user with this email already exists in the system.",
|
||||||
)
|
)
|
||||||
user = crud.user.create(db, obj_in=user_in)
|
user = crud.user.create(db, obj_in=user_in)
|
||||||
return user
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_201_CREATED,
|
||||||
|
content={"message": "User created successfully", "user": jsonable_encoder(user)},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.put("/me", response_model=schemas.User)
|
@router.put("/me", response_model=schemas.User)
|
||||||
@ -64,15 +69,20 @@ def update_user_me(
|
|||||||
Update own user.
|
Update own user.
|
||||||
"""
|
"""
|
||||||
current_user_data = jsonable_encoder(current_user)
|
current_user_data = jsonable_encoder(current_user)
|
||||||
print("Current user data: ", current_user_data)
|
|
||||||
user_in = schemas.UserUpdate(**current_user_data)
|
user_in = schemas.UserUpdate(**current_user_data)
|
||||||
if fullname is not None:
|
if fullname is not None:
|
||||||
user_in.fullname = fullname
|
user_in.fullname = fullname
|
||||||
if email is not None:
|
if email is not None:
|
||||||
user_in.email = email
|
user_in.email = email
|
||||||
print(f"DB obj: {current_user}\n obj IN : {user_in}")
|
|
||||||
user = crud.user.update(db, db_obj=current_user, obj_in=user_in)
|
user = crud.user.update(db, db_obj=current_user, obj_in=user_in)
|
||||||
return user
|
user_data = schemas.UserBaseSchema(
|
||||||
|
email=user.email,
|
||||||
|
fullname=user.fullname,
|
||||||
|
)
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_200_OK,
|
||||||
|
content={"message": "User updated successfully", "user": jsonable_encoder(user_data)},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/me", response_model=schemas.User)
|
@router.get("/me", response_model=schemas.User)
|
||||||
@ -83,21 +93,15 @@ def read_user_me(
|
|||||||
"""
|
"""
|
||||||
Get current user.
|
Get current user.
|
||||||
"""
|
"""
|
||||||
if not current_user.user_role:
|
role = current_user.user_role.role.name if current_user.user_role else None
|
||||||
role = None
|
user_data = schemas.UserBaseSchema(
|
||||||
else:
|
|
||||||
role = current_user.user_role.role.name
|
|
||||||
user_data = schemas.User(
|
|
||||||
id=current_user.id,
|
|
||||||
email=current_user.email,
|
email=current_user.email,
|
||||||
is_active=current_user.is_active,
|
|
||||||
fullname=current_user.fullname,
|
fullname=current_user.fullname,
|
||||||
created_at=current_user.created_at,
|
|
||||||
updated_at=current_user.updated_at,
|
|
||||||
last_login = current_user.last_login,
|
|
||||||
role=role
|
|
||||||
)
|
)
|
||||||
return user_data
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_200_OK,
|
||||||
|
content={"message": "Current user retrieved successfully", "user": jsonable_encoder(user_data)},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.patch("/me/change-password", response_model=schemas.User)
|
@router.patch("/me/change-password", response_model=schemas.User)
|
||||||
@ -111,27 +115,24 @@ def change_password(
|
|||||||
"""
|
"""
|
||||||
Change current user's password.
|
Change current user's password.
|
||||||
"""
|
"""
|
||||||
# Verify the old password
|
|
||||||
if not verify_password(old_password, current_user.hashed_password):
|
if not verify_password(old_password, current_user.hashed_password):
|
||||||
raise HTTPException(status_code=400, detail="Old password is incorrect")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Old password is incorrect")
|
||||||
|
|
||||||
# Change the password
|
|
||||||
new_password_hashed = get_password_hash(new_password)
|
new_password_hashed = get_password_hash(new_password)
|
||||||
current_user.hashed_password = new_password_hashed
|
current_user.hashed_password = new_password_hashed
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
role = current_user.user_role.role.name if current_user.user_role else None
|
role = current_user.user_role.role.name if current_user.user_role else None
|
||||||
user_data = schemas.User(
|
user_data = schemas.UserBaseSchema(
|
||||||
id=current_user.id,
|
id=current_user.id,
|
||||||
email=current_user.email,
|
email=current_user.email,
|
||||||
is_active=current_user.is_active,
|
|
||||||
fullname=current_user.fullname,
|
fullname=current_user.fullname,
|
||||||
created_at=current_user.created_at,
|
|
||||||
updated_at=current_user.updated_at,
|
|
||||||
last_login=current_user.last_login,
|
|
||||||
role=role,
|
|
||||||
)
|
)
|
||||||
return user_data
|
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_200_OK,
|
||||||
|
content={"message": "Password changed successfully", "user": jsonable_encoder(user_data)},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{user_id}", response_model=schemas.User)
|
@router.get("/{user_id}", response_model=schemas.User)
|
||||||
@ -147,9 +148,12 @@ def read_user_by_id(
|
|||||||
Get a specific user by id.
|
Get a specific user by id.
|
||||||
"""
|
"""
|
||||||
if user_id is None:
|
if user_id is None:
|
||||||
return "User id is not given."
|
return JSONResponse(status_code=status.HTTP_400_BAD_REQUEST, content={"message": "User id is not given."})
|
||||||
user = crud.user.get(db, id=user_id)
|
user = crud.user.get(db, id=user_id)
|
||||||
return user
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_200_OK,
|
||||||
|
content={"message": "User retrieved successfully", "user": jsonable_encoder(user)},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.put("/{user_id}", response_model=schemas.User)
|
@router.put("/{user_id}", response_model=schemas.User)
|
||||||
@ -169,8 +173,16 @@ def update_user(
|
|||||||
user = crud.user.get(db, id=user_id)
|
user = crud.user.get(db, id=user_id)
|
||||||
if not user:
|
if not user:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
detail="The user with this username does not exist in the system",
|
detail="The user with this id does not exist in the system",
|
||||||
)
|
)
|
||||||
user = crud.user.update(db, db_obj=user, obj_in=user_in)
|
user = crud.user.update(db, db_obj=user, obj_in=user_in)
|
||||||
return user
|
user_data = schemas.UserBaseSchema(
|
||||||
|
id=user.id,
|
||||||
|
email=user.email,
|
||||||
|
fullname=user.fullname,
|
||||||
|
)
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_200_OK,
|
||||||
|
content={"message": "User updated successfully", "user": jsonable_encoder(user_data)},
|
||||||
|
)
|
@ -32,6 +32,12 @@ class Settings(BaseSettings):
|
|||||||
DB_NAME: str
|
DB_NAME: str
|
||||||
PORT: str
|
PORT: str
|
||||||
|
|
||||||
|
SMTP_SERVER: str
|
||||||
|
SMTP_PORT: str
|
||||||
|
SMTP_SENDER_EMAIL: str
|
||||||
|
SMTP_USERNAME: str
|
||||||
|
SMTP_PASSWORD: str
|
||||||
|
|
||||||
|
|
||||||
# SQLALCHEMY_DATABASE_URI: Optional[PostgresDsn] = None
|
# SQLALCHEMY_DATABASE_URI: Optional[PostgresDsn] = None
|
||||||
|
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
from passlib.context import CryptContext
|
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
|
import string
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Union, Any
|
from typing import Union, Any
|
||||||
from jose import jwt
|
from jose import jwt
|
||||||
|
from passlib.context import CryptContext
|
||||||
|
|
||||||
ACCESS_TOKEN_EXPIRE_MINUTES = 30 # 30 minutes
|
ACCESS_TOKEN_EXPIRE_MINUTES = 30 # 30 minutes
|
||||||
REFRESH_TOKEN_EXPIRE_MINUTES = 60 * 24 * 7 # 7 days
|
REFRESH_TOKEN_EXPIRE_MINUTES = 60 * 24 * 7 # 7 days
|
||||||
@ -41,4 +43,11 @@ def create_refresh_token(subject: Union[str, Any], expires_delta: int = None) ->
|
|||||||
|
|
||||||
to_encode = {"exp": expires_delta, **subject}
|
to_encode = {"exp": expires_delta, **subject}
|
||||||
encoded_jwt = jwt.encode(to_encode, JWT_REFRESH_SECRET_KEY, ALGORITHM)
|
encoded_jwt = jwt.encode(to_encode, JWT_REFRESH_SECRET_KEY, ALGORITHM)
|
||||||
return encoded_jwt
|
return encoded_jwt
|
||||||
|
|
||||||
|
def generate_random_password(length: int = 12) -> str:
|
||||||
|
"""
|
||||||
|
Generate a random password.
|
||||||
|
"""
|
||||||
|
characters = string.ascii_letters + string.digits + string.punctuation
|
||||||
|
return ''.join(random.choice(characters) for i in range(length))
|
@ -1,9 +1,12 @@
|
|||||||
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union
|
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union
|
||||||
|
|
||||||
from private_gpt.users.db.base import Base
|
from private_gpt.users.db.base import Base
|
||||||
|
from fastapi import HTTPException, status
|
||||||
from fastapi.encoders import jsonable_encoder
|
from fastapi.encoders import jsonable_encoder
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
from pydantic.error_wrappers import ValidationError
|
||||||
|
|
||||||
# Define custom types for SQLAlchemy model, and Pydantic schemas
|
# Define custom types for SQLAlchemy model, and Pydantic schemas
|
||||||
ModelType = TypeVar("ModelType", bound=Base)
|
ModelType = TypeVar("ModelType", bound=Base)
|
||||||
@ -30,12 +33,19 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
|||||||
return db.query(self.model).filter(self.model.id == id).first()
|
return db.query(self.model).filter(self.model.id == id).first()
|
||||||
|
|
||||||
def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType:
|
def create(self, db: Session, *, obj_in: CreateSchemaType) -> ModelType:
|
||||||
obj_in_data = jsonable_encoder(obj_in)
|
try:
|
||||||
db_obj = self.model(**obj_in_data) # type: ignore
|
obj_in_data = jsonable_encoder(obj_in)
|
||||||
db.add(db_obj)
|
db_obj = self.model(**obj_in_data)
|
||||||
db.commit()
|
db.add(db_obj)
|
||||||
db.refresh(db_obj)
|
db.commit()
|
||||||
return db_obj
|
db.refresh(db_obj)
|
||||||
|
return db_obj
|
||||||
|
except IntegrityError as e:
|
||||||
|
db.rollback()
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Integrity Error: {str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
def update(
|
def update(
|
||||||
self,
|
self,
|
||||||
@ -44,21 +54,34 @@ class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]):
|
|||||||
db_obj: ModelType,
|
db_obj: ModelType,
|
||||||
obj_in: Union[UpdateSchemaType, Dict[str, Any]]
|
obj_in: Union[UpdateSchemaType, Dict[str, Any]]
|
||||||
) -> ModelType:
|
) -> ModelType:
|
||||||
obj_data = jsonable_encoder(db_obj)
|
try:
|
||||||
if isinstance(obj_in, dict):
|
obj_data = jsonable_encoder(db_obj)
|
||||||
update_data = obj_in
|
if isinstance(obj_in, dict):
|
||||||
else:
|
update_data = obj_in
|
||||||
update_data = obj_in.dict(exclude_unset=True)
|
else:
|
||||||
for field in obj_data:
|
update_data = obj_in.dict(exclude_unset=True)
|
||||||
if field in update_data:
|
for field in obj_data:
|
||||||
setattr(db_obj, field, update_data[field])
|
if field in update_data:
|
||||||
db.add(db_obj)
|
setattr(db_obj, field, update_data[field])
|
||||||
db.commit()
|
db.add(db_obj)
|
||||||
db.refresh(db_obj)
|
db.commit()
|
||||||
return db_obj
|
db.refresh(db_obj)
|
||||||
|
return db_obj
|
||||||
|
except IntegrityError as e:
|
||||||
|
db.rollback()
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Integrity Error: {str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
def remove(self, db: Session, *, id: int) -> ModelType:
|
def remove(self, db: Session, *, id: int) -> ModelType:
|
||||||
obj = db.query(self.model).get(id)
|
obj = db.query(self.model).get(id)
|
||||||
db.delete(obj)
|
if obj:
|
||||||
db.commit()
|
db.delete(obj)
|
||||||
return obj
|
db.commit()
|
||||||
|
return obj
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"{self.model.__name__} not found with id: {id}",
|
||||||
|
)
|
@ -4,13 +4,28 @@ from private_gpt.users.crud.base import CRUDBase
|
|||||||
from private_gpt.users.models.subscription import Subscription
|
from private_gpt.users.models.subscription import Subscription
|
||||||
from private_gpt.users.schemas.subscription import SubscriptionCreate, SubscriptionUpdate
|
from private_gpt.users.schemas.subscription import SubscriptionCreate, SubscriptionUpdate
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
class CRUDSubscription(CRUDBase[Subscription, SubscriptionCreate, SubscriptionUpdate]):
|
class CRUDSubscription(CRUDBase[Subscription, SubscriptionCreate, SubscriptionUpdate]):
|
||||||
|
|
||||||
def get_by_id(self, db: Session, *, subscription_id: int) -> Optional[Subscription]:
|
def get_by_id(self, db: Session, *, subscription_id: int) -> Optional[Subscription]:
|
||||||
return db.query(self.model).filter(Subscription.sub_id == subscription_id).first()
|
return db.query(self.model).filter(Subscription.sub_id == subscription_id).first()
|
||||||
|
|
||||||
def get_by_company_id(self, db: Session, *, company_id: int) -> List[Subscription]:
|
def get_by_company_id(self, db: Session, *, company_id: int) -> List[Subscription]:
|
||||||
return db.query(self.model).filter(Subscription.company_id == company_id).all()
|
return db.query(self.model).filter(Subscription.company_id == company_id).all()
|
||||||
|
|
||||||
|
def get_active_subscription_by_company(self, db: Session, *, company_id: int) -> List[Subscription]:
|
||||||
|
current_datetime = datetime.utcnow()
|
||||||
|
return (
|
||||||
|
db.query(self.model)
|
||||||
|
.filter(
|
||||||
|
Subscription.company_id == company_id,
|
||||||
|
Subscription.is_active == True, # Active subscriptions
|
||||||
|
Subscription.end_date >= current_datetime, # End date is not passed
|
||||||
|
)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
subscription = CRUDSubscription(Subscription)
|
||||||
|
|
||||||
|
|
||||||
subscription = CRUDSubscription(Subscription)
|
|
@ -11,9 +11,13 @@ class Subscription(Base):
|
|||||||
|
|
||||||
sub_id = Column(Integer, primary_key=True, index=True)
|
sub_id = Column(Integer, primary_key=True, index=True)
|
||||||
company_id = Column(Integer, ForeignKey("companies.id"))
|
company_id = Column(Integer, ForeignKey("companies.id"))
|
||||||
start_date = Column(DateTime, default=datetime.utcnow)
|
start_date = Column(DateTime, default=datetime.utcnow())
|
||||||
end_date = Column(DateTime, default=datetime.utcnow() + timedelta(days=30)) # Example: 30 days subscription period
|
end_date = Column(DateTime, default=datetime.utcnow() + timedelta(days=30)) # Example: 30 days subscription period
|
||||||
is_active = Column(Boolean, default=True)
|
is_active = Column(Boolean, default=False)
|
||||||
|
|
||||||
company = relationship("Company", back_populates="subscriptions")
|
company = relationship("Company", back_populates="subscriptions")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_active(self) -> bool:
|
||||||
|
"""Check if the subscription is active based on the end_date."""
|
||||||
|
return self.end_date >= datetime.utcnow()
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from .role import Role, RoleCreate, RoleInDB, RoleUpdate
|
from .role import Role, RoleCreate, RoleInDB, RoleUpdate
|
||||||
from .token import TokenSchema, TokenPayload
|
from .token import TokenSchema, TokenPayload
|
||||||
from .user import User, UserCreate, UserInDB, UserUpdate
|
from .user import User, UserCreate, UserInDB, UserUpdate, UserBaseSchema
|
||||||
from .user_role import UserRole, UserRoleCreate, UserRoleInDB, UserRoleUpdate
|
from .user_role import UserRole, UserRoleCreate, UserRoleInDB, UserRoleUpdate
|
||||||
from .subscription import Subscription, SubscriptionBase, SubscriptionCreate, SubscriptionUpdate
|
from .subscription import Subscription, SubscriptionBase, SubscriptionCreate, SubscriptionUpdate
|
||||||
from .company import Company, CompanyBase, CompanyCreate, CompanyUpdate
|
from .company import Company, CompanyBase, CompanyCreate, CompanyUpdate
|
1
private_gpt/users/utils/__init__.py
Normal file
1
private_gpt/users/utils/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from .utils import send_registration_email
|
28
private_gpt/users/utils/utils.py
Normal file
28
private_gpt/users/utils/utils.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
import smtplib
|
||||||
|
from email.mime.text import MIMEText
|
||||||
|
from email.mime.multipart import MIMEMultipart
|
||||||
|
from private_gpt.users.core.config import settings
|
||||||
|
|
||||||
|
|
||||||
|
def send_registration_email(fullname: str, email: str, random_password: str) -> None:
|
||||||
|
"""
|
||||||
|
Send a registration email with a random password.
|
||||||
|
"""
|
||||||
|
subject = "Welcome to QuickGPT - Registration Successful"
|
||||||
|
body = f"Hello {fullname},\n\nThank you for registering with Your App!\n\n"\
|
||||||
|
f"Your temporary password is: {random_password}\n\n"\
|
||||||
|
f"Please use this password to log in and consider changing it"\
|
||||||
|
" to a more secure one after logging in.\n\n"\
|
||||||
|
"Best regards,\nQuickGPT Team"
|
||||||
|
|
||||||
|
msg = MIMEMultipart()
|
||||||
|
msg.attach(MIMEText(body, "plain"))
|
||||||
|
msg["Subject"] = subject
|
||||||
|
msg["From"] = settings.SMTP_SENDER_EMAIL
|
||||||
|
msg["To"] = email
|
||||||
|
|
||||||
|
with smtplib.SMTP(settings.SMTP_SERVER, settings.SMTP_PORT) as server:
|
||||||
|
server.starttls()
|
||||||
|
server.login(settings.SMTP_USERNAME, settings.SMTP_PASSWORD)
|
||||||
|
server.sendmail(settings.SMTP_SENDER_EMAIL, email, msg.as_string())
|
||||||
|
|
Loading…
Reference in New Issue
Block a user