Added subscription module

This commit is contained in:
Saurab-Shrestha 2024-01-17 17:56:24 +05:45
parent c66feb4187
commit 4c5fed7ea4
17 changed files with 430 additions and 20 deletions

View File

@ -12,6 +12,8 @@ from private_gpt.users.core.config import SQLALCHEMY_DATABASE_URI
from private_gpt.users.models.user import User from private_gpt.users.models.user import User
from private_gpt.users.models.role import Role from private_gpt.users.models.role import Role
from private_gpt.users.models.user_role import UserRole from private_gpt.users.models.user_role import UserRole
from private_gpt.users.models.subscription import Subscription
from private_gpt.users.models.company import Company
# this is the Alembic Config object, which provides # this is the Alembic Config object, which provides
# access to the values within the .ini file in use. # access to the values within the .ini file in use.
config = context.config config = context.config

View File

@ -0,0 +1,52 @@
"""Add Subscription and Company model
Revision ID: 0e0eb0a1a514
Revises: 65688535c5a5
Create Date: 2024-01-17 15:53:13.091801
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '0e0eb0a1a514'
down_revision: Union[str, None] = '65688535c5a5'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('companies',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_companies_id'), 'companies', ['id'], unique=False)
op.create_index(op.f('ix_companies_name'), 'companies', ['name'], unique=False)
op.create_table('subscriptions',
sa.Column('sub_id', sa.Integer(), nullable=False),
sa.Column('company_id', sa.Integer(), nullable=True),
sa.Column('start_date', sa.DateTime(), nullable=True),
sa.Column('end_date', sa.DateTime(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=True),
sa.ForeignKeyConstraint(['company_id'], ['companies.id'], ),
sa.PrimaryKeyConstraint('sub_id')
)
op.create_index(op.f('ix_subscriptions_sub_id'), 'subscriptions', ['sub_id'], unique=False)
# op.create_unique_constraint('unique_user_role', 'user_roles', ['user_id', 'role_id'])
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint('unique_user_role', 'user_roles', type_='unique')
op.drop_index(op.f('ix_subscriptions_sub_id'), table_name='subscriptions')
op.drop_table('subscriptions')
op.drop_index(op.f('ix_companies_name'), table_name='companies')
op.drop_index(op.f('ix_companies_id'), table_name='companies')
op.drop_table('companies')
# ### end Alembic commands ###

View File

@ -0,0 +1,31 @@
"""Add Subscription and Company model:
Revision ID: 65688535c5a5
Revises: cba9e6e394ca
Create Date: 2024-01-17 15:45:28.636265
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '65688535c5a5'
down_revision: Union[str, None] = 'cba9e6e394ca'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
pass
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
pass
# ### end Alembic commands ###

View File

@ -26,12 +26,12 @@ def create_app(root_injector: Injector) -> FastAPI:
app = FastAPI(dependencies=[Depends(bind_injector_to_request)]) app = FastAPI(dependencies=[Depends(bind_injector_to_request)])
app.include_router(completions_router) # app.include_router(completions_router)
app.include_router(chat_router) # app.include_router(chat_router)
app.include_router(chunks_router) # app.include_router(chunks_router)
app.include_router(ingest_router) # app.include_router(ingest_router)
app.include_router(embeddings_router) # app.include_router(embeddings_router)
app.include_router(health_router) # app.include_router(health_router)
app.include_router(api_router) app.include_router(api_router)
@ -48,14 +48,14 @@ def create_app(root_injector: Injector) -> FastAPI:
allow_headers=settings.server.cors.allow_headers, allow_headers=settings.server.cors.allow_headers,
) )
if settings.ui.enabled: # if settings.ui.enabled:
logger.debug("Importing the UI module") # logger.debug("Importing the UI module")
from private_gpt.ui.admin_ui import PrivateAdminGptUi # from private_gpt.ui.admin_ui import PrivateAdminGptUi
admin_ui = root_injector.get(PrivateAdminGptUi) # admin_ui = root_injector.get(PrivateAdminGptUi)
admin_ui.mount_in_admin_app(app, '/admin') # admin_ui.mount_in_admin_app(app, '/admin')
from private_gpt.ui.ui import PrivateGptUi # from private_gpt.ui.ui import PrivateGptUi
ui = root_injector.get(PrivateGptUi) # ui = root_injector.get(PrivateGptUi)
ui.mount_in_app(app, settings.ui.path) # ui.mount_in_app(app, settings.ui.path)
return app return app

View File

@ -250,7 +250,6 @@ class PrivateAdminGptUi:
self._set_system_prompt, self._set_system_prompt,
inputs=system_prompt_input, inputs=system_prompt_input,
) )
with gr.Column(scale=7, elem_id="col"): with gr.Column(scale=7, elem_id="col"):
_ = gr.ChatInterface( _ = gr.ChatInterface(
self._chat, self._chat,

View File

@ -1,4 +1,4 @@
from private_gpt.users.api.v1.routers import auth, roles, user_roles, users from private_gpt.users.api.v1.routers import auth, roles, user_roles, users, subscriptions, companies
from fastapi import APIRouter from fastapi import APIRouter
api_router = APIRouter(prefix="/v1") api_router = APIRouter(prefix="/v1")
@ -7,3 +7,6 @@ api_router.include_router(auth.router)
api_router.include_router(users.router) api_router.include_router(users.router)
api_router.include_router(roles.router) api_router.include_router(roles.router)
api_router.include_router(user_roles.router) api_router.include_router(user_roles.router)
api_router.include_router(companies.router)
api_router.include_router(subscriptions.router)

View File

@ -0,0 +1,94 @@
from typing import Any, List
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import JSONResponse
from sqlalchemy.orm import Session
from private_gpt.users import crud, models, schemas
from private_gpt.users.constants.role import Role
from private_gpt.users.api import deps
from fastapi.encoders import jsonable_encoder
router = APIRouter(prefix="/companies", tags=["Companies"])
@router.get("", response_model=List[schemas.Company])
def list_companies(
db: Session = Depends(deps.get_db),
skip: int = 0,
limit: int = 100,
) -> List[schemas.Company]:
"""
List companies
"""
companies = crud.company.get_multi(db, skip=skip, limit=limit)
return companies
@router.post("/create", response_model=schemas.Company)
def create_company(
company_in: schemas.CompanyCreate,
db: Session = Depends(deps.get_db),
) -> schemas.Company:
"""
Create a new company
"""
company = crud.company.create(db=db, obj_in=company_in)
return company
@router.get("/{company_id}", response_model=schemas.Company)
def read_company(
company_id: int,
db: Session = Depends(deps.get_db),
) -> schemas.Company:
"""
Read a company by ID
"""
company = crud.company.get_by_id(db, id=company_id)
if company is None:
raise HTTPException(status_code=404, detail="Company not found")
return company
@router.put("/{company_id}", response_model=schemas.Company)
def update_company(
company_id: int,
company_in: schemas.CompanyUpdate,
db: Session = Depends(deps.get_db),
) -> schemas.Company:
"""
Update a company by ID
"""
company = crud.company.get_by_id(db, id=company_id)
if company is None:
raise HTTPException(status_code=404, detail="Company not found")
updated_company = crud.company.update(db=db, db_obj=company, obj_in=company_in)
updated_company = jsonable_encoder(updated_company)
return JSONResponse(
status_code=status.HTTP_200_OK,
content={
"message": f"{company_id} Company updated successfully",
"company": updated_company
},
)
@router.delete("/{company_id}", response_model=schemas.Company)
def delete_company(
company_id: int,
db: Session = Depends(deps.get_db),
) -> schemas.Company:
"""
Delete a company by ID
"""
company = crud.company.remove(db=db, id=company_id)
if company is None:
raise HTTPException(status_code=404, detail="Company not found")
company = jsonable_encoder(company)
return JSONResponse(
status_code=status.HTTP_200_OK,
content={
"message": "Company deleted successfully",
"company": company
},
)

View File

@ -0,0 +1,115 @@
from typing import Any, List
from private_gpt.users import crud, models, schemas
from private_gpt.users.api import deps
from private_gpt.users.constants.role import Role
from fastapi import APIRouter, Body, Depends, HTTPException, Security,status
from fastapi.encoders import jsonable_encoder
from pydantic.networks import EmailStr
from sqlalchemy.orm import Session
from fastapi.responses import JSONResponse
router = APIRouter(prefix="/subscriptions", tags=["Subscriptions"])
@router.post("/create", response_model=schemas.Subscription)
def create_subscription(
subscription_in: schemas.SubscriptionCreate,
db: Session = Depends(deps.get_db),
) -> Any:
"""
Create a new subscription
"""
existing_subscription = crud.subscription.get_by_company_id(db, company_id=subscription_in.company_id)
if existing_subscription:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Company is already subscribed to a plan.",
)
subscription = crud.subscription.create(db=db, obj_in=subscription_in)
subscription_dict = jsonable_encoder(subscription)
return JSONResponse(
status_code=status.HTTP_201_CREATED,
content={
"message": "Subscription created successfully",
"subscription": subscription_dict
},
)
@router.get("/{subscription_id}", response_model=schemas.Subscription)
def read_subscription(
subscription_id: int,
db: Session = Depends(deps.get_db),
):
subscription = crud.subscription.get_by_id(db, subscription_id=subscription_id)
if subscription is None:
raise HTTPException(status_code=404, detail="Subscription not found")
subscription_dict = jsonable_encoder(subscription)
return JSONResponse(
status_code=status.HTTP_200_OK,
content={
"message": "Subscription retrieved successfully",
"subscription": subscription_dict
},
)
@router.get("/company/{company_id}", response_model=List[schemas.Subscription])
def read_subscriptions_by_company(
company_id: int,
db: Session = Depends(deps.get_db),
):
subscriptions = crud.subscription.get_by_company_id(db, company_id=company_id)
subscriptions_list = [jsonable_encoder(subscription) for subscription in subscriptions]
return JSONResponse(
status_code=status.HTTP_200_OK,
content={
"message": "Subscriptions retrieved successfully",
"subscriptions": subscriptions_list
},
)
@router.put("/{subscription_id}", response_model=schemas.Subscription)
def update_subscription(
subscription_id: int,
subscription_in: schemas.SubscriptionUpdate,
db: Session = Depends(deps.get_db),
):
subscription = crud.subscription.get_by_id(db, subscription_id=subscription_id)
if subscription is None:
raise HTTPException(status_code=404, detail="Subscription not found")
if not subscription.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Company subscription is not active. Please contact support for assistance.",
)
updated_subscription = crud.subscription.update(db=db, db_obj=subscription, obj_in=subscription_in)
return JSONResponse(
status_code=status.HTTP_200_OK,
content={
"message": "Subscription updated successfully",
"subscription": jsonable_encoder(updated_subscription)
},
)
@router.delete("/{subscription_id}")
def delete_subscription(
subscription_id: int,
db: Session = Depends(deps.get_db),
):
subscription = crud.subscription.remove(db=db, id=subscription_id)
if subscription is None:
raise HTTPException(status_code=404, detail="Subscription not found")
return JSONResponse(
content={"message": "Subscription deleted successfully"},
status_code=status.HTTP_200_OK
)

View File

@ -1,3 +1,5 @@
from .role_crud import role from .role_crud import role
from .user_crud import user from .user_crud import user
from .user_role_crud import user_role from .user_role_crud import user_role
from .company_crud import company
from .subscription_crud import subscription

View File

@ -0,0 +1,13 @@
from typing import Optional
from private_gpt.users.crud.base import CRUDBase
from private_gpt.users.models.company import Company
from private_gpt.users.schemas.company import CompanyCreate, CompanyUpdate
from sqlalchemy.orm import Session
class CRUDCompany(CRUDBase[Company, CompanyCreate, CompanyUpdate]):
def get_by_id(self, db: Session, *, id: str) -> Optional[Company]:
return db.query(self.model).filter(Company.id == id).first()
company = CRUDCompany(Company)

View File

@ -0,0 +1,16 @@
from typing import Optional, List
from private_gpt.users.crud.base import CRUDBase
from private_gpt.users.models.subscription import Subscription
from private_gpt.users.schemas.subscription import SubscriptionCreate, SubscriptionUpdate
from sqlalchemy.orm import Session
class CRUDSubscription(CRUDBase[Subscription, SubscriptionCreate, SubscriptionUpdate]):
def get_by_id(self, db: Session, *, subscription_id: int) -> Optional[Subscription]:
return db.query(self.model).filter(Subscription.sub_id == subscription_id).first()
def get_by_company_id(self, db: Session, *, company_id: int) -> List[Subscription]:
return db.query(self.model).filter(Subscription.company_id == company_id).all()
subscription = CRUDSubscription(Subscription)

View File

@ -0,0 +1,14 @@
from sqlalchemy import Column, Integer, String
from sqlalchemy.orm import relationship
from private_gpt.users.db.base_class import Base
class Company(Base):
"""Models a Company table."""
__tablename__ = "companies"
id = Column(Integer, primary_key=True, index=True)
name = Column(String, index=True)
subscriptions = relationship("Subscription", back_populates="company")

View File

@ -0,0 +1,19 @@
from sqlalchemy import Column, Integer, String, Boolean, Float, ForeignKey, DateTime
from sqlalchemy.orm import relationship
from datetime import datetime, timedelta
from fastapi import Depends
from private_gpt.users.db.base_class import Base
class Subscription(Base):
"""Models a Subscription table."""
__tablename__ = "subscriptions"
sub_id = Column(Integer, primary_key=True, index=True)
company_id = Column(Integer, ForeignKey("companies.id"))
start_date = Column(DateTime, default=datetime.utcnow)
end_date = Column(DateTime, default=datetime.utcnow() + timedelta(days=30)) # Example: 30 days subscription period
is_active = Column(Boolean, default=True)
company = relationship("Company", back_populates="subscriptions")

View File

@ -2,7 +2,6 @@ from private_gpt.users.db.base_class import Base
from sqlalchemy import Column, ForeignKey, UniqueConstraint, Integer from sqlalchemy import Column, ForeignKey, UniqueConstraint, Integer
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
class UserRole(Base): class UserRole(Base):
__tablename__ = "user_roles" __tablename__ = "user_roles"
user_id = Column( user_id = Column(

View File

@ -2,3 +2,5 @@ from .role import Role, RoleCreate, RoleInDB, RoleUpdate
from .token import TokenSchema, TokenPayload from .token import TokenSchema, TokenPayload
from .user import User, UserCreate, UserInDB, UserUpdate from .user import User, UserCreate, UserInDB, UserUpdate
from .user_role import UserRole, UserRoleCreate, UserRoleInDB, UserRoleUpdate from .user_role import UserRole, UserRoleCreate, UserRoleInDB, UserRoleUpdate
from .subscription import Subscription, SubscriptionBase, SubscriptionCreate, SubscriptionUpdate
from .company import Company, CompanyBase, CompanyCreate, CompanyUpdate

View File

@ -0,0 +1,19 @@
from typing import List
from datetime import datetime
from pydantic import BaseModel
class CompanyBase(BaseModel):
name: str
class CompanyCreate(CompanyBase):
pass
class CompanyUpdate(CompanyBase):
pass
class Company(CompanyBase):
id: int
class Config:
orm_mode = True

View File

@ -0,0 +1,30 @@
from typing import List
from datetime import datetime
from pydantic import BaseModel
from private_gpt.users.schemas.company import Company
class SubscriptionBase(BaseModel):
start_date: datetime
end_date: datetime
is_active: bool
class SubscriptionCreate(SubscriptionBase):
company_id: int
class SubscriptionSchema(SubscriptionBase):
id: int
company: Company
class Config:
orm_mode = True
class SubscriptionUpdate(SubscriptionBase):
pass
class Subscription(SubscriptionBase):
id: int
company_id: int
class Config:
orm_mode = True