Added subscription module

This commit is contained in:
Saurab-Shrestha 2024-01-17 17:56:24 +05:45
parent c66feb4187
commit 4c5fed7ea4
17 changed files with 430 additions and 20 deletions

View File

@ -12,6 +12,8 @@ from private_gpt.users.core.config import SQLALCHEMY_DATABASE_URI
from private_gpt.users.models.user import User
from private_gpt.users.models.role import Role
from private_gpt.users.models.user_role import UserRole
from private_gpt.users.models.subscription import Subscription
from private_gpt.users.models.company import Company
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config

View File

@ -0,0 +1,52 @@
"""Add Subscription and Company model
Revision ID: 0e0eb0a1a514
Revises: 65688535c5a5
Create Date: 2024-01-17 15:53:13.091801
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '0e0eb0a1a514'
down_revision: Union[str, None] = '65688535c5a5'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('companies',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_companies_id'), 'companies', ['id'], unique=False)
op.create_index(op.f('ix_companies_name'), 'companies', ['name'], unique=False)
op.create_table('subscriptions',
sa.Column('sub_id', sa.Integer(), nullable=False),
sa.Column('company_id', sa.Integer(), nullable=True),
sa.Column('start_date', sa.DateTime(), nullable=True),
sa.Column('end_date', sa.DateTime(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=True),
sa.ForeignKeyConstraint(['company_id'], ['companies.id'], ),
sa.PrimaryKeyConstraint('sub_id')
)
op.create_index(op.f('ix_subscriptions_sub_id'), 'subscriptions', ['sub_id'], unique=False)
# op.create_unique_constraint('unique_user_role', 'user_roles', ['user_id', 'role_id'])
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint('unique_user_role', 'user_roles', type_='unique')
op.drop_index(op.f('ix_subscriptions_sub_id'), table_name='subscriptions')
op.drop_table('subscriptions')
op.drop_index(op.f('ix_companies_name'), table_name='companies')
op.drop_index(op.f('ix_companies_id'), table_name='companies')
op.drop_table('companies')
# ### end Alembic commands ###

View File

@ -0,0 +1,31 @@
"""Add Subscription and Company model:
Revision ID: 65688535c5a5
Revises: cba9e6e394ca
Create Date: 2024-01-17 15:45:28.636265
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '65688535c5a5'
down_revision: Union[str, None] = 'cba9e6e394ca'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
pass
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
pass
# ### end Alembic commands ###

View File

@ -26,12 +26,12 @@ def create_app(root_injector: Injector) -> FastAPI:
app = FastAPI(dependencies=[Depends(bind_injector_to_request)])
app.include_router(completions_router)
app.include_router(chat_router)
app.include_router(chunks_router)
app.include_router(ingest_router)
app.include_router(embeddings_router)
app.include_router(health_router)
# app.include_router(completions_router)
# app.include_router(chat_router)
# app.include_router(chunks_router)
# app.include_router(ingest_router)
# app.include_router(embeddings_router)
# app.include_router(health_router)
app.include_router(api_router)
@ -48,14 +48,14 @@ def create_app(root_injector: Injector) -> FastAPI:
allow_headers=settings.server.cors.allow_headers,
)
if settings.ui.enabled:
logger.debug("Importing the UI module")
from private_gpt.ui.admin_ui import PrivateAdminGptUi
admin_ui = root_injector.get(PrivateAdminGptUi)
admin_ui.mount_in_admin_app(app, '/admin')
# if settings.ui.enabled:
# logger.debug("Importing the UI module")
# from private_gpt.ui.admin_ui import PrivateAdminGptUi
# admin_ui = root_injector.get(PrivateAdminGptUi)
# admin_ui.mount_in_admin_app(app, '/admin')
from private_gpt.ui.ui import PrivateGptUi
ui = root_injector.get(PrivateGptUi)
ui.mount_in_app(app, settings.ui.path)
# from private_gpt.ui.ui import PrivateGptUi
# ui = root_injector.get(PrivateGptUi)
# ui.mount_in_app(app, settings.ui.path)
return app

View File

@ -250,7 +250,6 @@ class PrivateAdminGptUi:
self._set_system_prompt,
inputs=system_prompt_input,
)
with gr.Column(scale=7, elem_id="col"):
_ = gr.ChatInterface(
self._chat,

View File

@ -1,4 +1,4 @@
from private_gpt.users.api.v1.routers import auth, roles, user_roles, users
from private_gpt.users.api.v1.routers import auth, roles, user_roles, users, subscriptions, companies
from fastapi import APIRouter
api_router = APIRouter(prefix="/v1")
@ -6,4 +6,7 @@ api_router = APIRouter(prefix="/v1")
api_router.include_router(auth.router)
api_router.include_router(users.router)
api_router.include_router(roles.router)
api_router.include_router(user_roles.router)
api_router.include_router(user_roles.router)
api_router.include_router(companies.router)
api_router.include_router(subscriptions.router)

View File

@ -0,0 +1,94 @@
from typing import Any, List
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import JSONResponse
from sqlalchemy.orm import Session
from private_gpt.users import crud, models, schemas
from private_gpt.users.constants.role import Role
from private_gpt.users.api import deps
from fastapi.encoders import jsonable_encoder
router = APIRouter(prefix="/companies", tags=["Companies"])
@router.get("", response_model=List[schemas.Company])
def list_companies(
db: Session = Depends(deps.get_db),
skip: int = 0,
limit: int = 100,
) -> List[schemas.Company]:
"""
List companies
"""
companies = crud.company.get_multi(db, skip=skip, limit=limit)
return companies
@router.post("/create", response_model=schemas.Company)
def create_company(
company_in: schemas.CompanyCreate,
db: Session = Depends(deps.get_db),
) -> schemas.Company:
"""
Create a new company
"""
company = crud.company.create(db=db, obj_in=company_in)
return company
@router.get("/{company_id}", response_model=schemas.Company)
def read_company(
company_id: int,
db: Session = Depends(deps.get_db),
) -> schemas.Company:
"""
Read a company by ID
"""
company = crud.company.get_by_id(db, id=company_id)
if company is None:
raise HTTPException(status_code=404, detail="Company not found")
return company
@router.put("/{company_id}", response_model=schemas.Company)
def update_company(
company_id: int,
company_in: schemas.CompanyUpdate,
db: Session = Depends(deps.get_db),
) -> schemas.Company:
"""
Update a company by ID
"""
company = crud.company.get_by_id(db, id=company_id)
if company is None:
raise HTTPException(status_code=404, detail="Company not found")
updated_company = crud.company.update(db=db, db_obj=company, obj_in=company_in)
updated_company = jsonable_encoder(updated_company)
return JSONResponse(
status_code=status.HTTP_200_OK,
content={
"message": f"{company_id} Company updated successfully",
"company": updated_company
},
)
@router.delete("/{company_id}", response_model=schemas.Company)
def delete_company(
company_id: int,
db: Session = Depends(deps.get_db),
) -> schemas.Company:
"""
Delete a company by ID
"""
company = crud.company.remove(db=db, id=company_id)
if company is None:
raise HTTPException(status_code=404, detail="Company not found")
company = jsonable_encoder(company)
return JSONResponse(
status_code=status.HTTP_200_OK,
content={
"message": "Company deleted successfully",
"company": company
},
)

View File

@ -0,0 +1,115 @@
from typing import Any, List
from private_gpt.users import crud, models, schemas
from private_gpt.users.api import deps
from private_gpt.users.constants.role import Role
from fastapi import APIRouter, Body, Depends, HTTPException, Security,status
from fastapi.encoders import jsonable_encoder
from pydantic.networks import EmailStr
from sqlalchemy.orm import Session
from fastapi.responses import JSONResponse
router = APIRouter(prefix="/subscriptions", tags=["Subscriptions"])
@router.post("/create", response_model=schemas.Subscription)
def create_subscription(
subscription_in: schemas.SubscriptionCreate,
db: Session = Depends(deps.get_db),
) -> Any:
"""
Create a new subscription
"""
existing_subscription = crud.subscription.get_by_company_id(db, company_id=subscription_in.company_id)
if existing_subscription:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Company is already subscribed to a plan.",
)
subscription = crud.subscription.create(db=db, obj_in=subscription_in)
subscription_dict = jsonable_encoder(subscription)
return JSONResponse(
status_code=status.HTTP_201_CREATED,
content={
"message": "Subscription created successfully",
"subscription": subscription_dict
},
)
@router.get("/{subscription_id}", response_model=schemas.Subscription)
def read_subscription(
subscription_id: int,
db: Session = Depends(deps.get_db),
):
subscription = crud.subscription.get_by_id(db, subscription_id=subscription_id)
if subscription is None:
raise HTTPException(status_code=404, detail="Subscription not found")
subscription_dict = jsonable_encoder(subscription)
return JSONResponse(
status_code=status.HTTP_200_OK,
content={
"message": "Subscription retrieved successfully",
"subscription": subscription_dict
},
)
@router.get("/company/{company_id}", response_model=List[schemas.Subscription])
def read_subscriptions_by_company(
company_id: int,
db: Session = Depends(deps.get_db),
):
subscriptions = crud.subscription.get_by_company_id(db, company_id=company_id)
subscriptions_list = [jsonable_encoder(subscription) for subscription in subscriptions]
return JSONResponse(
status_code=status.HTTP_200_OK,
content={
"message": "Subscriptions retrieved successfully",
"subscriptions": subscriptions_list
},
)
@router.put("/{subscription_id}", response_model=schemas.Subscription)
def update_subscription(
subscription_id: int,
subscription_in: schemas.SubscriptionUpdate,
db: Session = Depends(deps.get_db),
):
subscription = crud.subscription.get_by_id(db, subscription_id=subscription_id)
if subscription is None:
raise HTTPException(status_code=404, detail="Subscription not found")
if not subscription.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Company subscription is not active. Please contact support for assistance.",
)
updated_subscription = crud.subscription.update(db=db, db_obj=subscription, obj_in=subscription_in)
return JSONResponse(
status_code=status.HTTP_200_OK,
content={
"message": "Subscription updated successfully",
"subscription": jsonable_encoder(updated_subscription)
},
)
@router.delete("/{subscription_id}")
def delete_subscription(
subscription_id: int,
db: Session = Depends(deps.get_db),
):
subscription = crud.subscription.remove(db=db, id=subscription_id)
if subscription is None:
raise HTTPException(status_code=404, detail="Subscription not found")
return JSONResponse(
content={"message": "Subscription deleted successfully"},
status_code=status.HTTP_200_OK
)

View File

@ -1,3 +1,5 @@
from .role_crud import role
from .user_crud import user
from .user_role_crud import user_role
from .user_role_crud import user_role
from .company_crud import company
from .subscription_crud import subscription

View File

@ -0,0 +1,13 @@
from typing import Optional
from private_gpt.users.crud.base import CRUDBase
from private_gpt.users.models.company import Company
from private_gpt.users.schemas.company import CompanyCreate, CompanyUpdate
from sqlalchemy.orm import Session
class CRUDCompany(CRUDBase[Company, CompanyCreate, CompanyUpdate]):
def get_by_id(self, db: Session, *, id: str) -> Optional[Company]:
return db.query(self.model).filter(Company.id == id).first()
company = CRUDCompany(Company)

View File

@ -0,0 +1,16 @@
from typing import Optional, List
from private_gpt.users.crud.base import CRUDBase
from private_gpt.users.models.subscription import Subscription
from private_gpt.users.schemas.subscription import SubscriptionCreate, SubscriptionUpdate
from sqlalchemy.orm import Session
class CRUDSubscription(CRUDBase[Subscription, SubscriptionCreate, SubscriptionUpdate]):
def get_by_id(self, db: Session, *, subscription_id: int) -> Optional[Subscription]:
return db.query(self.model).filter(Subscription.sub_id == subscription_id).first()
def get_by_company_id(self, db: Session, *, company_id: int) -> List[Subscription]:
return db.query(self.model).filter(Subscription.company_id == company_id).all()
subscription = CRUDSubscription(Subscription)

View File

@ -0,0 +1,14 @@
from sqlalchemy import Column, Integer, String
from sqlalchemy.orm import relationship
from private_gpt.users.db.base_class import Base
class Company(Base):
"""Models a Company table."""
__tablename__ = "companies"
id = Column(Integer, primary_key=True, index=True)
name = Column(String, index=True)
subscriptions = relationship("Subscription", back_populates="company")

View File

@ -0,0 +1,19 @@
from sqlalchemy import Column, Integer, String, Boolean, Float, ForeignKey, DateTime
from sqlalchemy.orm import relationship
from datetime import datetime, timedelta
from fastapi import Depends
from private_gpt.users.db.base_class import Base
class Subscription(Base):
"""Models a Subscription table."""
__tablename__ = "subscriptions"
sub_id = Column(Integer, primary_key=True, index=True)
company_id = Column(Integer, ForeignKey("companies.id"))
start_date = Column(DateTime, default=datetime.utcnow)
end_date = Column(DateTime, default=datetime.utcnow() + timedelta(days=30)) # Example: 30 days subscription period
is_active = Column(Boolean, default=True)
company = relationship("Company", back_populates="subscriptions")

View File

@ -2,7 +2,6 @@ from private_gpt.users.db.base_class import Base
from sqlalchemy import Column, ForeignKey, UniqueConstraint, Integer
from sqlalchemy.orm import relationship
class UserRole(Base):
__tablename__ = "user_roles"
user_id = Column(

View File

@ -1,4 +1,6 @@
from .role import Role, RoleCreate, RoleInDB, RoleUpdate
from .token import TokenSchema, TokenPayload
from .user import User, UserCreate, UserInDB, UserUpdate
from .user_role import UserRole, UserRoleCreate, UserRoleInDB, UserRoleUpdate
from .user_role import UserRole, UserRoleCreate, UserRoleInDB, UserRoleUpdate
from .subscription import Subscription, SubscriptionBase, SubscriptionCreate, SubscriptionUpdate
from .company import Company, CompanyBase, CompanyCreate, CompanyUpdate

View File

@ -0,0 +1,19 @@
from typing import List
from datetime import datetime
from pydantic import BaseModel
class CompanyBase(BaseModel):
name: str
class CompanyCreate(CompanyBase):
pass
class CompanyUpdate(CompanyBase):
pass
class Company(CompanyBase):
id: int
class Config:
orm_mode = True

View File

@ -0,0 +1,30 @@
from typing import List
from datetime import datetime
from pydantic import BaseModel
from private_gpt.users.schemas.company import Company
class SubscriptionBase(BaseModel):
start_date: datetime
end_date: datetime
is_active: bool
class SubscriptionCreate(SubscriptionBase):
company_id: int
class SubscriptionSchema(SubscriptionBase):
id: int
company: Company
class Config:
orm_mode = True
class SubscriptionUpdate(SubscriptionBase):
pass
class Subscription(SubscriptionBase):
id: int
company_id: int
class Config:
orm_mode = True