Made changes to get database args from the env

This commit is contained in:
Saurab-Shrestha 2024-02-04 12:51:01 +05:45
parent 78b5e6c6b3
commit 3d58a3d568
7 changed files with 31 additions and 75 deletions

2
.env
View File

@ -4,7 +4,7 @@ ENVIRONMENT=dev
DB_HOST=localhost
DB_USER=postgres
DB_PORT=5432
DB_PASSWORD=admin
DB_PASSWORD=quick
DB_NAME=QuickGpt
SUPER_ADMIN_EMAIL=superadmin@email.com

View File

@ -175,12 +175,9 @@ def ingest_file(
ingested_documents = service.ingest_bin_data(file.filename, f)
return IngestResponse(object="list", model="private-gpt", data=ingested_documents)
except Exception as e:
logger.error(f"There was an error uploading the file(s): {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Internal Server Error: Unable to ingest file.",
)
finally:
file.file.close()

View File

@ -90,36 +90,24 @@ def get_current_user(
return user
async def get_current_active_user(
current_user: models.User = Security(get_current_user, scopes=[],),
) -> models.User:
if not crud.user.is_active(current_user):
raise HTTPException(status_code=400, detail="Inactive user")
def check_user_role(current_user: models.User = Depends(get_current_user), role: str = ""):
if current_user.role != role:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have the necessary permissions to perform this action",
)
return current_user
async def is_company_admin(current_user: models.User = Depends(get_current_user)):
if current_user.role == Role.ADMIN["name"]:
return current_user
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have the necessary permissions to perform this action",
)
def is_company_admin(current_user: models.User = Depends(get_current_user)):
return check_user_role(current_user, role=Role.ADMIN["name"])
async def is_super_admin(current_user: models.User = Depends(get_current_user)):
if current_user.role == Role.SUPER_ADMIN["name"]:
return current_user
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have the necessary permissions to perform this action",
)
def is_super_admin(current_user: models.User = Depends(get_current_user)):
return check_user_role(current_user, role=Role.SUPER_ADMIN["name"])
async def get_company_name(company_id: int, db: Session = Depends(get_db)) -> str:
def get_company_name(company_id: int, db: Session = Depends(get_db)) -> str:
company = crud.company.get(db=db, id=company_id)
if not company:
raise HTTPException(status_code=404, detail="Company not found")
@ -141,4 +129,4 @@ def get_active_subscription(
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access Forbidden - No Active Subscription",
)
)

View File

@ -189,10 +189,6 @@ def register(
user_role_name = role_name or Role.ADMIN["name"]
user_role = create_user_role(db, user, user_role_name, None)
print("USER REGISTERED: ", user.email, user.fullname, user.company_id)
print("USER ROLE REGISTERED: ", user_role.user.email,
user_role.role.name, user_role.company_id)
token_payload = create_token_payload(user, user_role)
response_dict = {
"access_token": security.create_access_token(token_payload, expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)),

View File

@ -31,19 +31,25 @@ def read_users(
return users
@router.get("/{company_name}")
@router.get("/{company_name}", response_model=List[schemas.User])
def read_users_by_company(
company_name: Optional[str] = Path(..., title="Company Name", description="Only for company admin"),
company_name: Optional[str] = Path(..., title="Company Name",
description="Only for company admin"),
db: Session = Depends(deps.get_db),
current_user: models.User = Security(
deps.get_current_user,
scopes=[Role.ADMIN["name"], Role.SUPER_ADMIN["name"]],
),
):
"""
Retrieve all users of that company only
"""
"""Retrieve all users of that company only"""
company = crud.company.get_by_company_name(db, company_name=company_name)
if company is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Company with name '{company_name}' not found",
)
users = crud.user.get_multi_by_company_id(db, company_id=company.id)
return users

View File

@ -1,18 +1,9 @@
from functools import lru_cache
from typing import Any, Dict, Optional
from pydantic import PostgresDsn, validator
from pydantic_settings import BaseSettings
SQLALCHEMY_DATABASE_URI = "postgresql+psycopg2://{username}:{password}@{host}:{port}/{db_name}".format(
host='localhost',
port='5432',
db_name='QuickGpt',
username='postgres',
password="quick",
)
class Settings(BaseSettings):
PROJECT_NAME: str = "AUTHENTICATION AND AUTHORIZATION"
API_V1_STR: str = "/v1"
@ -39,24 +30,9 @@ class Settings(BaseSettings):
SMTP_USERNAME: str
SMTP_PASSWORD: str
# SQLALCHEMY_DATABASE_URI: Optional[PostgresDsn] = None
# @validator("SQLALCHEMY_DATABASE_URI", pre=True)
# def assemble_db_connection(
# cls, v: Optional[str], values: Dict[str, Any]
# ) -> Any:
# if isinstance(v, str):
# return v
# return PostgresDsn.build(
# scheme="postgresql",
# user=values.get("DB_USER"),
# password=values.get("DB_PASS"),
# host=values.get("DB_HOST"),
# path=f"/{values.get('DB_NAME') or ''}",
# )
# Database url configuration
@property
def SQLALCHEMY_DATABASE_URI(self) -> str:
return f"postgresql+psycopg2://{self.DB_USER}:{self.DB_PASSWORD}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}"
class Config:
case_sensitive = True
@ -64,8 +40,8 @@ class Settings(BaseSettings):
@lru_cache()
def get_settings():
def get_settings() -> Settings:
return Settings()
settings = get_settings()
settings = get_settings()

View File

@ -1,13 +1,6 @@
from private_gpt.users.core.config import SQLALCHEMY_DATABASE_URI
from private_gpt.users.core.config import settings
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
engine = create_engine(SQLALCHEMY_DATABASE_URI, echo=True, future=True, pool_pre_ping=True)
engine = create_engine(settings.SQLALCHEMY_DATABASE_URI, echo=True, future=True, pool_pre_ping=True)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# test_engine = create_engine(
# f"{settings.SQLALCHEMY_DATABASE_URI}_test", pool_pre_ping=True
# )
# TestingSessionLocal = sessionmaker(
# autocommit=False, autoflush=False, bind=test_engine
# )