mirror of
https://github.com/imartinez/privateGPT.git
synced 2025-07-31 23:16:58 +00:00
Made changes to get database args from the env
This commit is contained in:
parent
78b5e6c6b3
commit
3d58a3d568
2
.env
2
.env
@ -4,7 +4,7 @@ ENVIRONMENT=dev
|
||||
DB_HOST=localhost
|
||||
DB_USER=postgres
|
||||
DB_PORT=5432
|
||||
DB_PASSWORD=admin
|
||||
DB_PASSWORD=quick
|
||||
DB_NAME=QuickGpt
|
||||
|
||||
SUPER_ADMIN_EMAIL=superadmin@email.com
|
||||
|
@ -175,12 +175,9 @@ def ingest_file(
|
||||
ingested_documents = service.ingest_bin_data(file.filename, f)
|
||||
|
||||
return IngestResponse(object="list", model="private-gpt", data=ingested_documents)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"There was an error uploading the file(s): {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Internal Server Error: Unable to ingest file.",
|
||||
)
|
||||
finally:
|
||||
file.file.close()
|
||||
|
@ -90,36 +90,24 @@ def get_current_user(
|
||||
return user
|
||||
|
||||
|
||||
async def get_current_active_user(
|
||||
current_user: models.User = Security(get_current_user, scopes=[],),
|
||||
) -> models.User:
|
||||
if not crud.user.is_active(current_user):
|
||||
raise HTTPException(status_code=400, detail="Inactive user")
|
||||
def check_user_role(current_user: models.User = Depends(get_current_user), role: str = ""):
|
||||
if current_user.role != role:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have the necessary permissions to perform this action",
|
||||
)
|
||||
return current_user
|
||||
|
||||
|
||||
|
||||
async def is_company_admin(current_user: models.User = Depends(get_current_user)):
|
||||
if current_user.role == Role.ADMIN["name"]:
|
||||
return current_user
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have the necessary permissions to perform this action",
|
||||
)
|
||||
def is_company_admin(current_user: models.User = Depends(get_current_user)):
|
||||
return check_user_role(current_user, role=Role.ADMIN["name"])
|
||||
|
||||
|
||||
async def is_super_admin(current_user: models.User = Depends(get_current_user)):
|
||||
if current_user.role == Role.SUPER_ADMIN["name"]:
|
||||
return current_user
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have the necessary permissions to perform this action",
|
||||
)
|
||||
def is_super_admin(current_user: models.User = Depends(get_current_user)):
|
||||
return check_user_role(current_user, role=Role.SUPER_ADMIN["name"])
|
||||
|
||||
|
||||
async def get_company_name(company_id: int, db: Session = Depends(get_db)) -> str:
|
||||
def get_company_name(company_id: int, db: Session = Depends(get_db)) -> str:
|
||||
company = crud.company.get(db=db, id=company_id)
|
||||
if not company:
|
||||
raise HTTPException(status_code=404, detail="Company not found")
|
||||
@ -141,4 +129,4 @@ def get_active_subscription(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access Forbidden - No Active Subscription",
|
||||
)
|
||||
)
|
||||
|
@ -189,10 +189,6 @@ def register(
|
||||
user_role_name = role_name or Role.ADMIN["name"]
|
||||
user_role = create_user_role(db, user, user_role_name, None)
|
||||
|
||||
print("USER REGISTERED: ", user.email, user.fullname, user.company_id)
|
||||
print("USER ROLE REGISTERED: ", user_role.user.email,
|
||||
user_role.role.name, user_role.company_id)
|
||||
|
||||
token_payload = create_token_payload(user, user_role)
|
||||
response_dict = {
|
||||
"access_token": security.create_access_token(token_payload, expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)),
|
||||
|
@ -31,19 +31,25 @@ def read_users(
|
||||
return users
|
||||
|
||||
|
||||
@router.get("/{company_name}")
|
||||
@router.get("/{company_name}", response_model=List[schemas.User])
|
||||
def read_users_by_company(
|
||||
company_name: Optional[str] = Path(..., title="Company Name", description="Only for company admin"),
|
||||
company_name: Optional[str] = Path(..., title="Company Name",
|
||||
description="Only for company admin"),
|
||||
db: Session = Depends(deps.get_db),
|
||||
current_user: models.User = Security(
|
||||
deps.get_current_user,
|
||||
scopes=[Role.ADMIN["name"], Role.SUPER_ADMIN["name"]],
|
||||
),
|
||||
):
|
||||
"""
|
||||
Retrieve all users of that company only
|
||||
"""
|
||||
"""Retrieve all users of that company only"""
|
||||
company = crud.company.get_by_company_name(db, company_name=company_name)
|
||||
|
||||
if company is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Company with name '{company_name}' not found",
|
||||
)
|
||||
|
||||
users = crud.user.get_multi_by_company_id(db, company_id=company.id)
|
||||
return users
|
||||
|
||||
|
@ -1,18 +1,9 @@
|
||||
from functools import lru_cache
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import PostgresDsn, validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
SQLALCHEMY_DATABASE_URI = "postgresql+psycopg2://{username}:{password}@{host}:{port}/{db_name}".format(
|
||||
host='localhost',
|
||||
port='5432',
|
||||
db_name='QuickGpt',
|
||||
username='postgres',
|
||||
password="quick",
|
||||
)
|
||||
|
||||
class Settings(BaseSettings):
|
||||
PROJECT_NAME: str = "AUTHENTICATION AND AUTHORIZATION"
|
||||
API_V1_STR: str = "/v1"
|
||||
@ -39,24 +30,9 @@ class Settings(BaseSettings):
|
||||
SMTP_USERNAME: str
|
||||
SMTP_PASSWORD: str
|
||||
|
||||
|
||||
# SQLALCHEMY_DATABASE_URI: Optional[PostgresDsn] = None
|
||||
|
||||
# @validator("SQLALCHEMY_DATABASE_URI", pre=True)
|
||||
# def assemble_db_connection(
|
||||
# cls, v: Optional[str], values: Dict[str, Any]
|
||||
# ) -> Any:
|
||||
# if isinstance(v, str):
|
||||
# return v
|
||||
# return PostgresDsn.build(
|
||||
# scheme="postgresql",
|
||||
# user=values.get("DB_USER"),
|
||||
# password=values.get("DB_PASS"),
|
||||
# host=values.get("DB_HOST"),
|
||||
# path=f"/{values.get('DB_NAME') or ''}",
|
||||
# )
|
||||
# Database url configuration
|
||||
|
||||
@property
|
||||
def SQLALCHEMY_DATABASE_URI(self) -> str:
|
||||
return f"postgresql+psycopg2://{self.DB_USER}:{self.DB_PASSWORD}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}"
|
||||
|
||||
class Config:
|
||||
case_sensitive = True
|
||||
@ -64,8 +40,8 @@ class Settings(BaseSettings):
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_settings():
|
||||
def get_settings() -> Settings:
|
||||
return Settings()
|
||||
|
||||
|
||||
settings = get_settings()
|
||||
settings = get_settings()
|
||||
|
@ -1,13 +1,6 @@
|
||||
from private_gpt.users.core.config import SQLALCHEMY_DATABASE_URI
|
||||
from private_gpt.users.core.config import settings
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
engine = create_engine(SQLALCHEMY_DATABASE_URI, echo=True, future=True, pool_pre_ping=True)
|
||||
engine = create_engine(settings.SQLALCHEMY_DATABASE_URI, echo=True, future=True, pool_pre_ping=True)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
# test_engine = create_engine(
|
||||
# f"{settings.SQLALCHEMY_DATABASE_URI}_test", pool_pre_ping=True
|
||||
# )
|
||||
# TestingSessionLocal = sessionmaker(
|
||||
# autocommit=False, autoflush=False, bind=test_engine
|
||||
# )
|
Loading…
Reference in New Issue
Block a user