diff --git a/alembic/env.py b/alembic/env.py index 2f14acfe..238b121a 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -12,6 +12,8 @@ from private_gpt.users.core.config import SQLALCHEMY_DATABASE_URI from private_gpt.users.models.user import User from private_gpt.users.models.role import Role from private_gpt.users.models.user_role import UserRole +from private_gpt.users.models.subscription import Subscription +from private_gpt.users.models.company import Company # this is the Alembic Config object, which provides # access to the values within the .ini file in use. config = context.config diff --git a/alembic/versions/0e0eb0a1a514_add_subscription_and_company_model.py b/alembic/versions/0e0eb0a1a514_add_subscription_and_company_model.py new file mode 100644 index 00000000..daa5d06c --- /dev/null +++ b/alembic/versions/0e0eb0a1a514_add_subscription_and_company_model.py @@ -0,0 +1,52 @@ +"""Add Subscription and Company model + +Revision ID: 0e0eb0a1a514 +Revises: 65688535c5a5 +Create Date: 2024-01-17 15:53:13.091801 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '0e0eb0a1a514' +down_revision: Union[str, None] = '65688535c5a5' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('companies', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('name', sa.String(), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_companies_id'), 'companies', ['id'], unique=False) + op.create_index(op.f('ix_companies_name'), 'companies', ['name'], unique=False) + op.create_table('subscriptions', + sa.Column('sub_id', sa.Integer(), nullable=False), + sa.Column('company_id', sa.Integer(), nullable=True), + sa.Column('start_date', sa.DateTime(), nullable=True), + sa.Column('end_date', sa.DateTime(), nullable=True), + sa.Column('is_active', sa.Boolean(), nullable=True), + sa.ForeignKeyConstraint(['company_id'], ['companies.id'], ), + sa.PrimaryKeyConstraint('sub_id') + ) + op.create_index(op.f('ix_subscriptions_sub_id'), 'subscriptions', ['sub_id'], unique=False) + # op.create_unique_constraint('unique_user_role', 'user_roles', ['user_id', 'role_id']) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint('unique_user_role', 'user_roles', type_='unique') + op.drop_index(op.f('ix_subscriptions_sub_id'), table_name='subscriptions') + op.drop_table('subscriptions') + op.drop_index(op.f('ix_companies_name'), table_name='companies') + op.drop_index(op.f('ix_companies_id'), table_name='companies') + op.drop_table('companies') + # ### end Alembic commands ### diff --git a/alembic/versions/65688535c5a5_add_subscription_and_company_model.py b/alembic/versions/65688535c5a5_add_subscription_and_company_model.py new file mode 100644 index 00000000..9caf4827 --- /dev/null +++ b/alembic/versions/65688535c5a5_add_subscription_and_company_model.py @@ -0,0 +1,31 @@ +"""Add Subscription and Company model: + + +Revision ID: 65688535c5a5 +Revises: cba9e6e394ca +Create Date: 2024-01-17 15:45:28.636265 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '65688535c5a5' +down_revision: Union[str, None] = 'cba9e6e394ca' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + pass + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + pass + # ### end Alembic commands ### diff --git a/private_gpt/launcher.py b/private_gpt/launcher.py index f6fb8708..5ae70e2c 100644 --- a/private_gpt/launcher.py +++ b/private_gpt/launcher.py @@ -26,12 +26,12 @@ def create_app(root_injector: Injector) -> FastAPI: app = FastAPI(dependencies=[Depends(bind_injector_to_request)]) - app.include_router(completions_router) - app.include_router(chat_router) - app.include_router(chunks_router) - app.include_router(ingest_router) - app.include_router(embeddings_router) - app.include_router(health_router) + # app.include_router(completions_router) + # app.include_router(chat_router) + # app.include_router(chunks_router) + # app.include_router(ingest_router) + # app.include_router(embeddings_router) + # app.include_router(health_router) app.include_router(api_router) @@ -48,14 +48,14 @@ def create_app(root_injector: Injector) -> FastAPI: allow_headers=settings.server.cors.allow_headers, ) - if settings.ui.enabled: - logger.debug("Importing the UI module") - from private_gpt.ui.admin_ui import PrivateAdminGptUi - admin_ui = root_injector.get(PrivateAdminGptUi) - admin_ui.mount_in_admin_app(app, '/admin') + # if settings.ui.enabled: + # logger.debug("Importing the UI module") + # from private_gpt.ui.admin_ui import PrivateAdminGptUi + # admin_ui = root_injector.get(PrivateAdminGptUi) + # admin_ui.mount_in_admin_app(app, '/admin') - from private_gpt.ui.ui import PrivateGptUi - ui = root_injector.get(PrivateGptUi) - ui.mount_in_app(app, settings.ui.path) + # from private_gpt.ui.ui import PrivateGptUi + # ui = root_injector.get(PrivateGptUi) + # ui.mount_in_app(app, settings.ui.path) return app \ No newline at end of file diff --git a/private_gpt/ui/admin_ui.py b/private_gpt/ui/admin_ui.py index e77e2b8a..ce45f64e 100644 --- a/private_gpt/ui/admin_ui.py +++ b/private_gpt/ui/admin_ui.py @@ -250,7 +250,6 @@ class PrivateAdminGptUi: self._set_system_prompt, inputs=system_prompt_input, ) - with gr.Column(scale=7, elem_id="col"): _ = gr.ChatInterface( self._chat, diff --git a/private_gpt/users/api/v1/api.py b/private_gpt/users/api/v1/api.py index bc9f6709..e7f173f0 100644 --- a/private_gpt/users/api/v1/api.py +++ b/private_gpt/users/api/v1/api.py @@ -1,4 +1,4 @@ -from private_gpt.users.api.v1.routers import auth, roles, user_roles, users +from private_gpt.users.api.v1.routers import auth, roles, user_roles, users, subscriptions, companies from fastapi import APIRouter api_router = APIRouter(prefix="/v1") @@ -6,4 +6,7 @@ api_router = APIRouter(prefix="/v1") api_router.include_router(auth.router) api_router.include_router(users.router) api_router.include_router(roles.router) -api_router.include_router(user_roles.router) \ No newline at end of file +api_router.include_router(user_roles.router) +api_router.include_router(companies.router) +api_router.include_router(subscriptions.router) + diff --git a/private_gpt/users/api/v1/routers/companies.py b/private_gpt/users/api/v1/routers/companies.py new file mode 100644 index 00000000..a346e40c --- /dev/null +++ b/private_gpt/users/api/v1/routers/companies.py @@ -0,0 +1,94 @@ +from typing import Any, List +from fastapi import APIRouter, Depends, HTTPException, status +from fastapi.responses import JSONResponse +from sqlalchemy.orm import Session +from private_gpt.users import crud, models, schemas +from private_gpt.users.constants.role import Role +from private_gpt.users.api import deps +from fastapi.encoders import jsonable_encoder + + +router = APIRouter(prefix="/companies", tags=["Companies"]) + +@router.get("", response_model=List[schemas.Company]) +def list_companies( + db: Session = Depends(deps.get_db), + skip: int = 0, + limit: int = 100, +) -> List[schemas.Company]: + """ + List companies + """ + companies = crud.company.get_multi(db, skip=skip, limit=limit) + return companies + + +@router.post("/create", response_model=schemas.Company) +def create_company( + company_in: schemas.CompanyCreate, + db: Session = Depends(deps.get_db), +) -> schemas.Company: + """ + Create a new company + """ + company = crud.company.create(db=db, obj_in=company_in) + return company + + +@router.get("/{company_id}", response_model=schemas.Company) +def read_company( + company_id: int, + db: Session = Depends(deps.get_db), +) -> schemas.Company: + """ + Read a company by ID + """ + company = crud.company.get_by_id(db, id=company_id) + if company is None: + raise HTTPException(status_code=404, detail="Company not found") + return company + + +@router.put("/{company_id}", response_model=schemas.Company) +def update_company( + company_id: int, + company_in: schemas.CompanyUpdate, + db: Session = Depends(deps.get_db), +) -> schemas.Company: + """ + Update a company by ID + """ + company = crud.company.get_by_id(db, id=company_id) + if company is None: + raise HTTPException(status_code=404, detail="Company not found") + + updated_company = crud.company.update(db=db, db_obj=company, obj_in=company_in) + updated_company = jsonable_encoder(updated_company) + return JSONResponse( + status_code=status.HTTP_200_OK, + content={ + "message": f"{company_id} Company updated successfully", + "company": updated_company + }, + ) + +@router.delete("/{company_id}", response_model=schemas.Company) +def delete_company( + company_id: int, + db: Session = Depends(deps.get_db), +) -> schemas.Company: + """ + Delete a company by ID + """ + + company = crud.company.remove(db=db, id=company_id) + if company is None: + raise HTTPException(status_code=404, detail="Company not found") + company = jsonable_encoder(company) + return JSONResponse( + status_code=status.HTTP_200_OK, + content={ + "message": "Company deleted successfully", + "company": company + }, + ) diff --git a/private_gpt/users/api/v1/routers/subscriptions.py b/private_gpt/users/api/v1/routers/subscriptions.py new file mode 100644 index 00000000..b007b90a --- /dev/null +++ b/private_gpt/users/api/v1/routers/subscriptions.py @@ -0,0 +1,115 @@ +from typing import Any, List +from private_gpt.users import crud, models, schemas +from private_gpt.users.api import deps +from private_gpt.users.constants.role import Role +from fastapi import APIRouter, Body, Depends, HTTPException, Security,status +from fastapi.encoders import jsonable_encoder +from pydantic.networks import EmailStr +from sqlalchemy.orm import Session +from fastapi.responses import JSONResponse + +router = APIRouter(prefix="/subscriptions", tags=["Subscriptions"]) + + +@router.post("/create", response_model=schemas.Subscription) +def create_subscription( + subscription_in: schemas.SubscriptionCreate, + db: Session = Depends(deps.get_db), +) -> Any: + """ + Create a new subscription + """ + existing_subscription = crud.subscription.get_by_company_id(db, company_id=subscription_in.company_id) + if existing_subscription: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Company is already subscribed to a plan.", + ) + + subscription = crud.subscription.create(db=db, obj_in=subscription_in) + subscription_dict = jsonable_encoder(subscription) + + return JSONResponse( + status_code=status.HTTP_201_CREATED, + content={ + "message": "Subscription created successfully", + "subscription": subscription_dict + }, + ) + + +@router.get("/{subscription_id}", response_model=schemas.Subscription) +def read_subscription( + subscription_id: int, + db: Session = Depends(deps.get_db), +): + subscription = crud.subscription.get_by_id(db, subscription_id=subscription_id) + if subscription is None: + raise HTTPException(status_code=404, detail="Subscription not found") + + subscription_dict = jsonable_encoder(subscription) + return JSONResponse( + status_code=status.HTTP_200_OK, + content={ + "message": "Subscription retrieved successfully", + "subscription": subscription_dict + }, + ) + + +@router.get("/company/{company_id}", response_model=List[schemas.Subscription]) +def read_subscriptions_by_company( + company_id: int, + db: Session = Depends(deps.get_db), +): + subscriptions = crud.subscription.get_by_company_id(db, company_id=company_id) + subscriptions_list = [jsonable_encoder(subscription) for subscription in subscriptions] + + return JSONResponse( + status_code=status.HTTP_200_OK, + content={ + "message": "Subscriptions retrieved successfully", + "subscriptions": subscriptions_list + }, + ) + + +@router.put("/{subscription_id}", response_model=schemas.Subscription) +def update_subscription( + subscription_id: int, + subscription_in: schemas.SubscriptionUpdate, + db: Session = Depends(deps.get_db), +): + subscription = crud.subscription.get_by_id(db, subscription_id=subscription_id) + if subscription is None: + raise HTTPException(status_code=404, detail="Subscription not found") + + if not subscription.is_active: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Company subscription is not active. Please contact support for assistance.", + ) + + updated_subscription = crud.subscription.update(db=db, db_obj=subscription, obj_in=subscription_in) + + return JSONResponse( + status_code=status.HTTP_200_OK, + content={ + "message": "Subscription updated successfully", + "subscription": jsonable_encoder(updated_subscription) + }, + ) + + +@router.delete("/{subscription_id}") +def delete_subscription( + subscription_id: int, + db: Session = Depends(deps.get_db), +): + subscription = crud.subscription.remove(db=db, id=subscription_id) + if subscription is None: + raise HTTPException(status_code=404, detail="Subscription not found") + return JSONResponse( + content={"message": "Subscription deleted successfully"}, + status_code=status.HTTP_200_OK + ) \ No newline at end of file diff --git a/private_gpt/users/crud/__init__.py b/private_gpt/users/crud/__init__.py index 07bcfb3d..928dfe1d 100644 --- a/private_gpt/users/crud/__init__.py +++ b/private_gpt/users/crud/__init__.py @@ -1,3 +1,5 @@ from .role_crud import role from .user_crud import user -from .user_role_crud import user_role \ No newline at end of file +from .user_role_crud import user_role +from .company_crud import company +from .subscription_crud import subscription \ No newline at end of file diff --git a/private_gpt/users/crud/company_crud.py b/private_gpt/users/crud/company_crud.py new file mode 100644 index 00000000..53f145ce --- /dev/null +++ b/private_gpt/users/crud/company_crud.py @@ -0,0 +1,13 @@ +from typing import Optional + +from private_gpt.users.crud.base import CRUDBase +from private_gpt.users.models.company import Company +from private_gpt.users.schemas.company import CompanyCreate, CompanyUpdate +from sqlalchemy.orm import Session + + +class CRUDCompany(CRUDBase[Company, CompanyCreate, CompanyUpdate]): + def get_by_id(self, db: Session, *, id: str) -> Optional[Company]: + return db.query(self.model).filter(Company.id == id).first() + +company = CRUDCompany(Company) \ No newline at end of file diff --git a/private_gpt/users/crud/subscription_crud.py b/private_gpt/users/crud/subscription_crud.py new file mode 100644 index 00000000..d74957af --- /dev/null +++ b/private_gpt/users/crud/subscription_crud.py @@ -0,0 +1,16 @@ +from typing import Optional, List + +from private_gpt.users.crud.base import CRUDBase +from private_gpt.users.models.subscription import Subscription +from private_gpt.users.schemas.subscription import SubscriptionCreate, SubscriptionUpdate +from sqlalchemy.orm import Session + + +class CRUDSubscription(CRUDBase[Subscription, SubscriptionCreate, SubscriptionUpdate]): + def get_by_id(self, db: Session, *, subscription_id: int) -> Optional[Subscription]: + return db.query(self.model).filter(Subscription.sub_id == subscription_id).first() + + def get_by_company_id(self, db: Session, *, company_id: int) -> List[Subscription]: + return db.query(self.model).filter(Subscription.company_id == company_id).all() + +subscription = CRUDSubscription(Subscription) \ No newline at end of file diff --git a/private_gpt/users/models/company.py b/private_gpt/users/models/company.py new file mode 100644 index 00000000..e8749474 --- /dev/null +++ b/private_gpt/users/models/company.py @@ -0,0 +1,14 @@ +from sqlalchemy import Column, Integer, String +from sqlalchemy.orm import relationship +from private_gpt.users.db.base_class import Base + +class Company(Base): + """Models a Company table.""" + + __tablename__ = "companies" + + id = Column(Integer, primary_key=True, index=True) + name = Column(String, index=True) + subscriptions = relationship("Subscription", back_populates="company") + + \ No newline at end of file diff --git a/private_gpt/users/models/subscription.py b/private_gpt/users/models/subscription.py new file mode 100644 index 00000000..fbd09d3f --- /dev/null +++ b/private_gpt/users/models/subscription.py @@ -0,0 +1,19 @@ +from sqlalchemy import Column, Integer, String, Boolean, Float, ForeignKey, DateTime +from sqlalchemy.orm import relationship +from datetime import datetime, timedelta +from fastapi import Depends +from private_gpt.users.db.base_class import Base + + +class Subscription(Base): + """Models a Subscription table.""" + __tablename__ = "subscriptions" + + sub_id = Column(Integer, primary_key=True, index=True) + company_id = Column(Integer, ForeignKey("companies.id")) + start_date = Column(DateTime, default=datetime.utcnow) + end_date = Column(DateTime, default=datetime.utcnow() + timedelta(days=30)) # Example: 30 days subscription period + is_active = Column(Boolean, default=True) + + company = relationship("Company", back_populates="subscriptions") + diff --git a/private_gpt/users/models/user_role.py b/private_gpt/users/models/user_role.py index fab20261..9aef0424 100644 --- a/private_gpt/users/models/user_role.py +++ b/private_gpt/users/models/user_role.py @@ -2,7 +2,6 @@ from private_gpt.users.db.base_class import Base from sqlalchemy import Column, ForeignKey, UniqueConstraint, Integer from sqlalchemy.orm import relationship - class UserRole(Base): __tablename__ = "user_roles" user_id = Column( diff --git a/private_gpt/users/schemas/__init__.py b/private_gpt/users/schemas/__init__.py index ec2e02f0..05fadef3 100644 --- a/private_gpt/users/schemas/__init__.py +++ b/private_gpt/users/schemas/__init__.py @@ -1,4 +1,6 @@ from .role import Role, RoleCreate, RoleInDB, RoleUpdate from .token import TokenSchema, TokenPayload from .user import User, UserCreate, UserInDB, UserUpdate -from .user_role import UserRole, UserRoleCreate, UserRoleInDB, UserRoleUpdate \ No newline at end of file +from .user_role import UserRole, UserRoleCreate, UserRoleInDB, UserRoleUpdate +from .subscription import Subscription, SubscriptionBase, SubscriptionCreate, SubscriptionUpdate +from .company import Company, CompanyBase, CompanyCreate, CompanyUpdate \ No newline at end of file diff --git a/private_gpt/users/schemas/company.py b/private_gpt/users/schemas/company.py new file mode 100644 index 00000000..69c21063 --- /dev/null +++ b/private_gpt/users/schemas/company.py @@ -0,0 +1,19 @@ +from typing import List +from datetime import datetime +from pydantic import BaseModel + +class CompanyBase(BaseModel): + name: str + +class CompanyCreate(CompanyBase): + pass + +class CompanyUpdate(CompanyBase): + pass + + +class Company(CompanyBase): + id: int + + class Config: + orm_mode = True \ No newline at end of file diff --git a/private_gpt/users/schemas/subscription.py b/private_gpt/users/schemas/subscription.py new file mode 100644 index 00000000..6a403ef6 --- /dev/null +++ b/private_gpt/users/schemas/subscription.py @@ -0,0 +1,30 @@ +from typing import List +from datetime import datetime +from pydantic import BaseModel +from private_gpt.users.schemas.company import Company + + +class SubscriptionBase(BaseModel): + start_date: datetime + end_date: datetime + is_active: bool + +class SubscriptionCreate(SubscriptionBase): + company_id: int + +class SubscriptionSchema(SubscriptionBase): + id: int + company: Company + + class Config: + orm_mode = True + +class SubscriptionUpdate(SubscriptionBase): + pass + +class Subscription(SubscriptionBase): + id: int + company_id: int + + class Config: + orm_mode = True