diff --git a/.env b/.env index 48179c0b..aabef866 100644 --- a/.env +++ b/.env @@ -4,7 +4,7 @@ ENVIRONMENT=dev DB_HOST=localhost DB_USER=postgres DB_PORT=5432 -DB_PASSWORD=admin +DB_PASSWORD=quick DB_NAME=QuickGpt SUPER_ADMIN_EMAIL=superadmin@email.com diff --git a/private_gpt/server/ingest/ingest_router.py b/private_gpt/server/ingest/ingest_router.py index 3a54203b..5eba3708 100644 --- a/private_gpt/server/ingest/ingest_router.py +++ b/private_gpt/server/ingest/ingest_router.py @@ -175,12 +175,9 @@ def ingest_file( ingested_documents = service.ingest_bin_data(file.filename, f) return IngestResponse(object="list", model="private-gpt", data=ingested_documents) - except Exception as e: logger.error(f"There was an error uploading the file(s): {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal Server Error: Unable to ingest file.", ) - finally: - file.file.close() diff --git a/private_gpt/users/api/deps.py b/private_gpt/users/api/deps.py index 00b0a5e3..83abe934 100644 --- a/private_gpt/users/api/deps.py +++ b/private_gpt/users/api/deps.py @@ -90,36 +90,24 @@ def get_current_user( return user -async def get_current_active_user( - current_user: models.User = Security(get_current_user, scopes=[],), -) -> models.User: - if not crud.user.is_active(current_user): - raise HTTPException(status_code=400, detail="Inactive user") +def check_user_role(current_user: models.User = Depends(get_current_user), role: str = ""): + if current_user.role != role: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You don't have the necessary permissions to perform this action", + ) return current_user - -async def is_company_admin(current_user: models.User = Depends(get_current_user)): - if current_user.role == Role.ADMIN["name"]: - return current_user - else: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="You don't have the necessary permissions to perform this action", - ) +def is_company_admin(current_user: models.User = Depends(get_current_user)): + return check_user_role(current_user, role=Role.ADMIN["name"]) -async def is_super_admin(current_user: models.User = Depends(get_current_user)): - if current_user.role == Role.SUPER_ADMIN["name"]: - return current_user - else: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="You don't have the necessary permissions to perform this action", - ) +def is_super_admin(current_user: models.User = Depends(get_current_user)): + return check_user_role(current_user, role=Role.SUPER_ADMIN["name"]) -async def get_company_name(company_id: int, db: Session = Depends(get_db)) -> str: +def get_company_name(company_id: int, db: Session = Depends(get_db)) -> str: company = crud.company.get(db=db, id=company_id) if not company: raise HTTPException(status_code=404, detail="Company not found") @@ -141,4 +129,4 @@ def get_active_subscription( raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Access Forbidden - No Active Subscription", - ) \ No newline at end of file + ) diff --git a/private_gpt/users/api/v1/routers/auth.py b/private_gpt/users/api/v1/routers/auth.py index 7c6494e2..37d8f2c2 100644 --- a/private_gpt/users/api/v1/routers/auth.py +++ b/private_gpt/users/api/v1/routers/auth.py @@ -189,10 +189,6 @@ def register( user_role_name = role_name or Role.ADMIN["name"] user_role = create_user_role(db, user, user_role_name, None) - print("USER REGISTERED: ", user.email, user.fullname, user.company_id) - print("USER ROLE REGISTERED: ", user_role.user.email, - user_role.role.name, user_role.company_id) - token_payload = create_token_payload(user, user_role) response_dict = { "access_token": security.create_access_token(token_payload, expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)), diff --git a/private_gpt/users/api/v1/routers/users.py b/private_gpt/users/api/v1/routers/users.py index b4cac929..ceba3846 100644 --- a/private_gpt/users/api/v1/routers/users.py +++ b/private_gpt/users/api/v1/routers/users.py @@ -31,19 +31,25 @@ def read_users( return users -@router.get("/{company_name}") +@router.get("/{company_name}", response_model=List[schemas.User]) def read_users_by_company( - company_name: Optional[str] = Path(..., title="Company Name", description="Only for company admin"), + company_name: Optional[str] = Path(..., title="Company Name", + description="Only for company admin"), db: Session = Depends(deps.get_db), current_user: models.User = Security( deps.get_current_user, scopes=[Role.ADMIN["name"], Role.SUPER_ADMIN["name"]], ), ): - """ - Retrieve all users of that company only - """ + """Retrieve all users of that company only""" company = crud.company.get_by_company_name(db, company_name=company_name) + + if company is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Company with name '{company_name}' not found", + ) + users = crud.user.get_multi_by_company_id(db, company_id=company.id) return users diff --git a/private_gpt/users/core/config.py b/private_gpt/users/core/config.py index 58b6e395..f37be46f 100644 --- a/private_gpt/users/core/config.py +++ b/private_gpt/users/core/config.py @@ -1,18 +1,9 @@ from functools import lru_cache from typing import Any, Dict, Optional -from pydantic import PostgresDsn, validator from pydantic_settings import BaseSettings -SQLALCHEMY_DATABASE_URI = "postgresql+psycopg2://{username}:{password}@{host}:{port}/{db_name}".format( - host='localhost', - port='5432', - db_name='QuickGpt', - username='postgres', - password="quick", - ) - class Settings(BaseSettings): PROJECT_NAME: str = "AUTHENTICATION AND AUTHORIZATION" API_V1_STR: str = "/v1" @@ -39,24 +30,9 @@ class Settings(BaseSettings): SMTP_USERNAME: str SMTP_PASSWORD: str - - # SQLALCHEMY_DATABASE_URI: Optional[PostgresDsn] = None - - # @validator("SQLALCHEMY_DATABASE_URI", pre=True) - # def assemble_db_connection( - # cls, v: Optional[str], values: Dict[str, Any] - # ) -> Any: - # if isinstance(v, str): - # return v - # return PostgresDsn.build( - # scheme="postgresql", - # user=values.get("DB_USER"), - # password=values.get("DB_PASS"), - # host=values.get("DB_HOST"), - # path=f"/{values.get('DB_NAME') or ''}", - # ) - # Database url configuration - + @property + def SQLALCHEMY_DATABASE_URI(self) -> str: + return f"postgresql+psycopg2://{self.DB_USER}:{self.DB_PASSWORD}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}" class Config: case_sensitive = True @@ -64,8 +40,8 @@ class Settings(BaseSettings): @lru_cache() -def get_settings(): +def get_settings() -> Settings: return Settings() -settings = get_settings() \ No newline at end of file +settings = get_settings() diff --git a/private_gpt/users/db/session.py b/private_gpt/users/db/session.py index 65e6fab2..ef48bda3 100644 --- a/private_gpt/users/db/session.py +++ b/private_gpt/users/db/session.py @@ -1,13 +1,6 @@ -from private_gpt.users.core.config import SQLALCHEMY_DATABASE_URI +from private_gpt.users.core.config import settings from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker -engine = create_engine(SQLALCHEMY_DATABASE_URI, echo=True, future=True, pool_pre_ping=True) +engine = create_engine(settings.SQLALCHEMY_DATABASE_URI, echo=True, future=True, pool_pre_ping=True) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) - -# test_engine = create_engine( -# f"{settings.SQLALCHEMY_DATABASE_URI}_test", pool_pre_ping=True -# ) -# TestingSessionLocal = sessionmaker( -# autocommit=False, autoflush=False, bind=test_engine -# ) \ No newline at end of file