mirror of
https://github.com/imartinez/privateGPT.git
synced 2025-07-12 22:58:09 +00:00
Added subscription module
This commit is contained in:
parent
c66feb4187
commit
4c5fed7ea4
@ -12,6 +12,8 @@ from private_gpt.users.core.config import SQLALCHEMY_DATABASE_URI
|
||||
from private_gpt.users.models.user import User
|
||||
from private_gpt.users.models.role import Role
|
||||
from private_gpt.users.models.user_role import UserRole
|
||||
from private_gpt.users.models.subscription import Subscription
|
||||
from private_gpt.users.models.company import Company
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
|
@ -0,0 +1,52 @@
|
||||
"""Add Subscription and Company model
|
||||
|
||||
Revision ID: 0e0eb0a1a514
|
||||
Revises: 65688535c5a5
|
||||
Create Date: 2024-01-17 15:53:13.091801
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '0e0eb0a1a514'
|
||||
down_revision: Union[str, None] = '65688535c5a5'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('companies',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('name', sa.String(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_companies_id'), 'companies', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_companies_name'), 'companies', ['name'], unique=False)
|
||||
op.create_table('subscriptions',
|
||||
sa.Column('sub_id', sa.Integer(), nullable=False),
|
||||
sa.Column('company_id', sa.Integer(), nullable=True),
|
||||
sa.Column('start_date', sa.DateTime(), nullable=True),
|
||||
sa.Column('end_date', sa.DateTime(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['company_id'], ['companies.id'], ),
|
||||
sa.PrimaryKeyConstraint('sub_id')
|
||||
)
|
||||
op.create_index(op.f('ix_subscriptions_sub_id'), 'subscriptions', ['sub_id'], unique=False)
|
||||
# op.create_unique_constraint('unique_user_role', 'user_roles', ['user_id', 'role_id'])
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_constraint('unique_user_role', 'user_roles', type_='unique')
|
||||
op.drop_index(op.f('ix_subscriptions_sub_id'), table_name='subscriptions')
|
||||
op.drop_table('subscriptions')
|
||||
op.drop_index(op.f('ix_companies_name'), table_name='companies')
|
||||
op.drop_index(op.f('ix_companies_id'), table_name='companies')
|
||||
op.drop_table('companies')
|
||||
# ### end Alembic commands ###
|
@ -0,0 +1,31 @@
|
||||
"""Add Subscription and Company model:
|
||||
|
||||
|
||||
Revision ID: 65688535c5a5
|
||||
Revises: cba9e6e394ca
|
||||
Create Date: 2024-01-17 15:45:28.636265
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '65688535c5a5'
|
||||
down_revision: Union[str, None] = 'cba9e6e394ca'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
pass
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
pass
|
||||
# ### end Alembic commands ###
|
@ -26,12 +26,12 @@ def create_app(root_injector: Injector) -> FastAPI:
|
||||
|
||||
app = FastAPI(dependencies=[Depends(bind_injector_to_request)])
|
||||
|
||||
app.include_router(completions_router)
|
||||
app.include_router(chat_router)
|
||||
app.include_router(chunks_router)
|
||||
app.include_router(ingest_router)
|
||||
app.include_router(embeddings_router)
|
||||
app.include_router(health_router)
|
||||
# app.include_router(completions_router)
|
||||
# app.include_router(chat_router)
|
||||
# app.include_router(chunks_router)
|
||||
# app.include_router(ingest_router)
|
||||
# app.include_router(embeddings_router)
|
||||
# app.include_router(health_router)
|
||||
|
||||
app.include_router(api_router)
|
||||
|
||||
@ -48,14 +48,14 @@ def create_app(root_injector: Injector) -> FastAPI:
|
||||
allow_headers=settings.server.cors.allow_headers,
|
||||
)
|
||||
|
||||
if settings.ui.enabled:
|
||||
logger.debug("Importing the UI module")
|
||||
from private_gpt.ui.admin_ui import PrivateAdminGptUi
|
||||
admin_ui = root_injector.get(PrivateAdminGptUi)
|
||||
admin_ui.mount_in_admin_app(app, '/admin')
|
||||
# if settings.ui.enabled:
|
||||
# logger.debug("Importing the UI module")
|
||||
# from private_gpt.ui.admin_ui import PrivateAdminGptUi
|
||||
# admin_ui = root_injector.get(PrivateAdminGptUi)
|
||||
# admin_ui.mount_in_admin_app(app, '/admin')
|
||||
|
||||
from private_gpt.ui.ui import PrivateGptUi
|
||||
ui = root_injector.get(PrivateGptUi)
|
||||
ui.mount_in_app(app, settings.ui.path)
|
||||
# from private_gpt.ui.ui import PrivateGptUi
|
||||
# ui = root_injector.get(PrivateGptUi)
|
||||
# ui.mount_in_app(app, settings.ui.path)
|
||||
|
||||
return app
|
@ -250,7 +250,6 @@ class PrivateAdminGptUi:
|
||||
self._set_system_prompt,
|
||||
inputs=system_prompt_input,
|
||||
)
|
||||
|
||||
with gr.Column(scale=7, elem_id="col"):
|
||||
_ = gr.ChatInterface(
|
||||
self._chat,
|
||||
|
@ -1,4 +1,4 @@
|
||||
from private_gpt.users.api.v1.routers import auth, roles, user_roles, users
|
||||
from private_gpt.users.api.v1.routers import auth, roles, user_roles, users, subscriptions, companies
|
||||
from fastapi import APIRouter
|
||||
|
||||
api_router = APIRouter(prefix="/v1")
|
||||
@ -6,4 +6,7 @@ api_router = APIRouter(prefix="/v1")
|
||||
api_router.include_router(auth.router)
|
||||
api_router.include_router(users.router)
|
||||
api_router.include_router(roles.router)
|
||||
api_router.include_router(user_roles.router)
|
||||
api_router.include_router(user_roles.router)
|
||||
api_router.include_router(companies.router)
|
||||
api_router.include_router(subscriptions.router)
|
||||
|
||||
|
94
private_gpt/users/api/v1/routers/companies.py
Normal file
94
private_gpt/users/api/v1/routers/companies.py
Normal file
@ -0,0 +1,94 @@
|
||||
from typing import Any, List
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy.orm import Session
|
||||
from private_gpt.users import crud, models, schemas
|
||||
from private_gpt.users.constants.role import Role
|
||||
from private_gpt.users.api import deps
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
|
||||
|
||||
router = APIRouter(prefix="/companies", tags=["Companies"])
|
||||
|
||||
@router.get("", response_model=List[schemas.Company])
|
||||
def list_companies(
|
||||
db: Session = Depends(deps.get_db),
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
) -> List[schemas.Company]:
|
||||
"""
|
||||
List companies
|
||||
"""
|
||||
companies = crud.company.get_multi(db, skip=skip, limit=limit)
|
||||
return companies
|
||||
|
||||
|
||||
@router.post("/create", response_model=schemas.Company)
|
||||
def create_company(
|
||||
company_in: schemas.CompanyCreate,
|
||||
db: Session = Depends(deps.get_db),
|
||||
) -> schemas.Company:
|
||||
"""
|
||||
Create a new company
|
||||
"""
|
||||
company = crud.company.create(db=db, obj_in=company_in)
|
||||
return company
|
||||
|
||||
|
||||
@router.get("/{company_id}", response_model=schemas.Company)
|
||||
def read_company(
|
||||
company_id: int,
|
||||
db: Session = Depends(deps.get_db),
|
||||
) -> schemas.Company:
|
||||
"""
|
||||
Read a company by ID
|
||||
"""
|
||||
company = crud.company.get_by_id(db, id=company_id)
|
||||
if company is None:
|
||||
raise HTTPException(status_code=404, detail="Company not found")
|
||||
return company
|
||||
|
||||
|
||||
@router.put("/{company_id}", response_model=schemas.Company)
|
||||
def update_company(
|
||||
company_id: int,
|
||||
company_in: schemas.CompanyUpdate,
|
||||
db: Session = Depends(deps.get_db),
|
||||
) -> schemas.Company:
|
||||
"""
|
||||
Update a company by ID
|
||||
"""
|
||||
company = crud.company.get_by_id(db, id=company_id)
|
||||
if company is None:
|
||||
raise HTTPException(status_code=404, detail="Company not found")
|
||||
|
||||
updated_company = crud.company.update(db=db, db_obj=company, obj_in=company_in)
|
||||
updated_company = jsonable_encoder(updated_company)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content={
|
||||
"message": f"{company_id} Company updated successfully",
|
||||
"company": updated_company
|
||||
},
|
||||
)
|
||||
|
||||
@router.delete("/{company_id}", response_model=schemas.Company)
|
||||
def delete_company(
|
||||
company_id: int,
|
||||
db: Session = Depends(deps.get_db),
|
||||
) -> schemas.Company:
|
||||
"""
|
||||
Delete a company by ID
|
||||
"""
|
||||
|
||||
company = crud.company.remove(db=db, id=company_id)
|
||||
if company is None:
|
||||
raise HTTPException(status_code=404, detail="Company not found")
|
||||
company = jsonable_encoder(company)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content={
|
||||
"message": "Company deleted successfully",
|
||||
"company": company
|
||||
},
|
||||
)
|
115
private_gpt/users/api/v1/routers/subscriptions.py
Normal file
115
private_gpt/users/api/v1/routers/subscriptions.py
Normal file
@ -0,0 +1,115 @@
|
||||
from typing import Any, List
|
||||
from private_gpt.users import crud, models, schemas
|
||||
from private_gpt.users.api import deps
|
||||
from private_gpt.users.constants.role import Role
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Security,status
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from pydantic.networks import EmailStr
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
router = APIRouter(prefix="/subscriptions", tags=["Subscriptions"])
|
||||
|
||||
|
||||
@router.post("/create", response_model=schemas.Subscription)
|
||||
def create_subscription(
|
||||
subscription_in: schemas.SubscriptionCreate,
|
||||
db: Session = Depends(deps.get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Create a new subscription
|
||||
"""
|
||||
existing_subscription = crud.subscription.get_by_company_id(db, company_id=subscription_in.company_id)
|
||||
if existing_subscription:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Company is already subscribed to a plan.",
|
||||
)
|
||||
|
||||
subscription = crud.subscription.create(db=db, obj_in=subscription_in)
|
||||
subscription_dict = jsonable_encoder(subscription)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
content={
|
||||
"message": "Subscription created successfully",
|
||||
"subscription": subscription_dict
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{subscription_id}", response_model=schemas.Subscription)
|
||||
def read_subscription(
|
||||
subscription_id: int,
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
subscription = crud.subscription.get_by_id(db, subscription_id=subscription_id)
|
||||
if subscription is None:
|
||||
raise HTTPException(status_code=404, detail="Subscription not found")
|
||||
|
||||
subscription_dict = jsonable_encoder(subscription)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content={
|
||||
"message": "Subscription retrieved successfully",
|
||||
"subscription": subscription_dict
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/company/{company_id}", response_model=List[schemas.Subscription])
|
||||
def read_subscriptions_by_company(
|
||||
company_id: int,
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
subscriptions = crud.subscription.get_by_company_id(db, company_id=company_id)
|
||||
subscriptions_list = [jsonable_encoder(subscription) for subscription in subscriptions]
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content={
|
||||
"message": "Subscriptions retrieved successfully",
|
||||
"subscriptions": subscriptions_list
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.put("/{subscription_id}", response_model=schemas.Subscription)
|
||||
def update_subscription(
|
||||
subscription_id: int,
|
||||
subscription_in: schemas.SubscriptionUpdate,
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
subscription = crud.subscription.get_by_id(db, subscription_id=subscription_id)
|
||||
if subscription is None:
|
||||
raise HTTPException(status_code=404, detail="Subscription not found")
|
||||
|
||||
if not subscription.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Company subscription is not active. Please contact support for assistance.",
|
||||
)
|
||||
|
||||
updated_subscription = crud.subscription.update(db=db, db_obj=subscription, obj_in=subscription_in)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content={
|
||||
"message": "Subscription updated successfully",
|
||||
"subscription": jsonable_encoder(updated_subscription)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{subscription_id}")
|
||||
def delete_subscription(
|
||||
subscription_id: int,
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
subscription = crud.subscription.remove(db=db, id=subscription_id)
|
||||
if subscription is None:
|
||||
raise HTTPException(status_code=404, detail="Subscription not found")
|
||||
return JSONResponse(
|
||||
content={"message": "Subscription deleted successfully"},
|
||||
status_code=status.HTTP_200_OK
|
||||
)
|
@ -1,3 +1,5 @@
|
||||
from .role_crud import role
|
||||
from .user_crud import user
|
||||
from .user_role_crud import user_role
|
||||
from .user_role_crud import user_role
|
||||
from .company_crud import company
|
||||
from .subscription_crud import subscription
|
13
private_gpt/users/crud/company_crud.py
Normal file
13
private_gpt/users/crud/company_crud.py
Normal file
@ -0,0 +1,13 @@
|
||||
from typing import Optional
|
||||
|
||||
from private_gpt.users.crud.base import CRUDBase
|
||||
from private_gpt.users.models.company import Company
|
||||
from private_gpt.users.schemas.company import CompanyCreate, CompanyUpdate
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
class CRUDCompany(CRUDBase[Company, CompanyCreate, CompanyUpdate]):
|
||||
def get_by_id(self, db: Session, *, id: str) -> Optional[Company]:
|
||||
return db.query(self.model).filter(Company.id == id).first()
|
||||
|
||||
company = CRUDCompany(Company)
|
16
private_gpt/users/crud/subscription_crud.py
Normal file
16
private_gpt/users/crud/subscription_crud.py
Normal file
@ -0,0 +1,16 @@
|
||||
from typing import Optional, List
|
||||
|
||||
from private_gpt.users.crud.base import CRUDBase
|
||||
from private_gpt.users.models.subscription import Subscription
|
||||
from private_gpt.users.schemas.subscription import SubscriptionCreate, SubscriptionUpdate
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
class CRUDSubscription(CRUDBase[Subscription, SubscriptionCreate, SubscriptionUpdate]):
|
||||
def get_by_id(self, db: Session, *, subscription_id: int) -> Optional[Subscription]:
|
||||
return db.query(self.model).filter(Subscription.sub_id == subscription_id).first()
|
||||
|
||||
def get_by_company_id(self, db: Session, *, company_id: int) -> List[Subscription]:
|
||||
return db.query(self.model).filter(Subscription.company_id == company_id).all()
|
||||
|
||||
subscription = CRUDSubscription(Subscription)
|
14
private_gpt/users/models/company.py
Normal file
14
private_gpt/users/models/company.py
Normal file
@ -0,0 +1,14 @@
|
||||
from sqlalchemy import Column, Integer, String
|
||||
from sqlalchemy.orm import relationship
|
||||
from private_gpt.users.db.base_class import Base
|
||||
|
||||
class Company(Base):
|
||||
"""Models a Company table."""
|
||||
|
||||
__tablename__ = "companies"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, index=True)
|
||||
subscriptions = relationship("Subscription", back_populates="company")
|
||||
|
||||
|
19
private_gpt/users/models/subscription.py
Normal file
19
private_gpt/users/models/subscription.py
Normal file
@ -0,0 +1,19 @@
|
||||
from sqlalchemy import Column, Integer, String, Boolean, Float, ForeignKey, DateTime
|
||||
from sqlalchemy.orm import relationship
|
||||
from datetime import datetime, timedelta
|
||||
from fastapi import Depends
|
||||
from private_gpt.users.db.base_class import Base
|
||||
|
||||
|
||||
class Subscription(Base):
|
||||
"""Models a Subscription table."""
|
||||
__tablename__ = "subscriptions"
|
||||
|
||||
sub_id = Column(Integer, primary_key=True, index=True)
|
||||
company_id = Column(Integer, ForeignKey("companies.id"))
|
||||
start_date = Column(DateTime, default=datetime.utcnow)
|
||||
end_date = Column(DateTime, default=datetime.utcnow() + timedelta(days=30)) # Example: 30 days subscription period
|
||||
is_active = Column(Boolean, default=True)
|
||||
|
||||
company = relationship("Company", back_populates="subscriptions")
|
||||
|
@ -2,7 +2,6 @@ from private_gpt.users.db.base_class import Base
|
||||
from sqlalchemy import Column, ForeignKey, UniqueConstraint, Integer
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
|
||||
class UserRole(Base):
|
||||
__tablename__ = "user_roles"
|
||||
user_id = Column(
|
||||
|
@ -1,4 +1,6 @@
|
||||
from .role import Role, RoleCreate, RoleInDB, RoleUpdate
|
||||
from .token import TokenSchema, TokenPayload
|
||||
from .user import User, UserCreate, UserInDB, UserUpdate
|
||||
from .user_role import UserRole, UserRoleCreate, UserRoleInDB, UserRoleUpdate
|
||||
from .user_role import UserRole, UserRoleCreate, UserRoleInDB, UserRoleUpdate
|
||||
from .subscription import Subscription, SubscriptionBase, SubscriptionCreate, SubscriptionUpdate
|
||||
from .company import Company, CompanyBase, CompanyCreate, CompanyUpdate
|
19
private_gpt/users/schemas/company.py
Normal file
19
private_gpt/users/schemas/company.py
Normal file
@ -0,0 +1,19 @@
|
||||
from typing import List
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel
|
||||
|
||||
class CompanyBase(BaseModel):
|
||||
name: str
|
||||
|
||||
class CompanyCreate(CompanyBase):
|
||||
pass
|
||||
|
||||
class CompanyUpdate(CompanyBase):
|
||||
pass
|
||||
|
||||
|
||||
class Company(CompanyBase):
|
||||
id: int
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
30
private_gpt/users/schemas/subscription.py
Normal file
30
private_gpt/users/schemas/subscription.py
Normal file
@ -0,0 +1,30 @@
|
||||
from typing import List
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel
|
||||
from private_gpt.users.schemas.company import Company
|
||||
|
||||
|
||||
class SubscriptionBase(BaseModel):
|
||||
start_date: datetime
|
||||
end_date: datetime
|
||||
is_active: bool
|
||||
|
||||
class SubscriptionCreate(SubscriptionBase):
|
||||
company_id: int
|
||||
|
||||
class SubscriptionSchema(SubscriptionBase):
|
||||
id: int
|
||||
company: Company
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
class SubscriptionUpdate(SubscriptionBase):
|
||||
pass
|
||||
|
||||
class Subscription(SubscriptionBase):
|
||||
id: int
|
||||
company_id: int
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
Loading…
Reference in New Issue
Block a user