mirror of
https://github.com/imartinez/privateGPT.git
synced 2025-08-09 03:17:48 +00:00
Made changes to get database args from the env
This commit is contained in:
parent
78b5e6c6b3
commit
3d58a3d568
2
.env
2
.env
@ -4,7 +4,7 @@ ENVIRONMENT=dev
|
|||||||
DB_HOST=localhost
|
DB_HOST=localhost
|
||||||
DB_USER=postgres
|
DB_USER=postgres
|
||||||
DB_PORT=5432
|
DB_PORT=5432
|
||||||
DB_PASSWORD=admin
|
DB_PASSWORD=quick
|
||||||
DB_NAME=QuickGpt
|
DB_NAME=QuickGpt
|
||||||
|
|
||||||
SUPER_ADMIN_EMAIL=superadmin@email.com
|
SUPER_ADMIN_EMAIL=superadmin@email.com
|
||||||
|
@ -175,12 +175,9 @@ def ingest_file(
|
|||||||
ingested_documents = service.ingest_bin_data(file.filename, f)
|
ingested_documents = service.ingest_bin_data(file.filename, f)
|
||||||
|
|
||||||
return IngestResponse(object="list", model="private-gpt", data=ingested_documents)
|
return IngestResponse(object="list", model="private-gpt", data=ingested_documents)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"There was an error uploading the file(s): {str(e)}")
|
logger.error(f"There was an error uploading the file(s): {str(e)}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail="Internal Server Error: Unable to ingest file.",
|
detail="Internal Server Error: Unable to ingest file.",
|
||||||
)
|
)
|
||||||
finally:
|
|
||||||
file.file.close()
|
|
||||||
|
@ -90,36 +90,24 @@ def get_current_user(
|
|||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
async def get_current_active_user(
|
def check_user_role(current_user: models.User = Depends(get_current_user), role: str = ""):
|
||||||
current_user: models.User = Security(get_current_user, scopes=[],),
|
if current_user.role != role:
|
||||||
) -> models.User:
|
raise HTTPException(
|
||||||
if not crud.user.is_active(current_user):
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
raise HTTPException(status_code=400, detail="Inactive user")
|
detail="You don't have the necessary permissions to perform this action",
|
||||||
|
)
|
||||||
return current_user
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
|
def is_company_admin(current_user: models.User = Depends(get_current_user)):
|
||||||
async def is_company_admin(current_user: models.User = Depends(get_current_user)):
|
return check_user_role(current_user, role=Role.ADMIN["name"])
|
||||||
if current_user.role == Role.ADMIN["name"]:
|
|
||||||
return current_user
|
|
||||||
else:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail="You don't have the necessary permissions to perform this action",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def is_super_admin(current_user: models.User = Depends(get_current_user)):
|
def is_super_admin(current_user: models.User = Depends(get_current_user)):
|
||||||
if current_user.role == Role.SUPER_ADMIN["name"]:
|
return check_user_role(current_user, role=Role.SUPER_ADMIN["name"])
|
||||||
return current_user
|
|
||||||
else:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail="You don't have the necessary permissions to perform this action",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_company_name(company_id: int, db: Session = Depends(get_db)) -> str:
|
def get_company_name(company_id: int, db: Session = Depends(get_db)) -> str:
|
||||||
company = crud.company.get(db=db, id=company_id)
|
company = crud.company.get(db=db, id=company_id)
|
||||||
if not company:
|
if not company:
|
||||||
raise HTTPException(status_code=404, detail="Company not found")
|
raise HTTPException(status_code=404, detail="Company not found")
|
||||||
|
@ -189,10 +189,6 @@ def register(
|
|||||||
user_role_name = role_name or Role.ADMIN["name"]
|
user_role_name = role_name or Role.ADMIN["name"]
|
||||||
user_role = create_user_role(db, user, user_role_name, None)
|
user_role = create_user_role(db, user, user_role_name, None)
|
||||||
|
|
||||||
print("USER REGISTERED: ", user.email, user.fullname, user.company_id)
|
|
||||||
print("USER ROLE REGISTERED: ", user_role.user.email,
|
|
||||||
user_role.role.name, user_role.company_id)
|
|
||||||
|
|
||||||
token_payload = create_token_payload(user, user_role)
|
token_payload = create_token_payload(user, user_role)
|
||||||
response_dict = {
|
response_dict = {
|
||||||
"access_token": security.create_access_token(token_payload, expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)),
|
"access_token": security.create_access_token(token_payload, expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)),
|
||||||
|
@ -31,19 +31,25 @@ def read_users(
|
|||||||
return users
|
return users
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{company_name}")
|
@router.get("/{company_name}", response_model=List[schemas.User])
|
||||||
def read_users_by_company(
|
def read_users_by_company(
|
||||||
company_name: Optional[str] = Path(..., title="Company Name", description="Only for company admin"),
|
company_name: Optional[str] = Path(..., title="Company Name",
|
||||||
|
description="Only for company admin"),
|
||||||
db: Session = Depends(deps.get_db),
|
db: Session = Depends(deps.get_db),
|
||||||
current_user: models.User = Security(
|
current_user: models.User = Security(
|
||||||
deps.get_current_user,
|
deps.get_current_user,
|
||||||
scopes=[Role.ADMIN["name"], Role.SUPER_ADMIN["name"]],
|
scopes=[Role.ADMIN["name"], Role.SUPER_ADMIN["name"]],
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
"""
|
"""Retrieve all users of that company only"""
|
||||||
Retrieve all users of that company only
|
|
||||||
"""
|
|
||||||
company = crud.company.get_by_company_name(db, company_name=company_name)
|
company = crud.company.get_by_company_name(db, company_name=company_name)
|
||||||
|
|
||||||
|
if company is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Company with name '{company_name}' not found",
|
||||||
|
)
|
||||||
|
|
||||||
users = crud.user.get_multi_by_company_id(db, company_id=company.id)
|
users = crud.user.get_multi_by_company_id(db, company_id=company.id)
|
||||||
return users
|
return users
|
||||||
|
|
||||||
|
@ -1,18 +1,9 @@
|
|||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from pydantic import PostgresDsn, validator
|
|
||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
SQLALCHEMY_DATABASE_URI = "postgresql+psycopg2://{username}:{password}@{host}:{port}/{db_name}".format(
|
|
||||||
host='localhost',
|
|
||||||
port='5432',
|
|
||||||
db_name='QuickGpt',
|
|
||||||
username='postgres',
|
|
||||||
password="quick",
|
|
||||||
)
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
PROJECT_NAME: str = "AUTHENTICATION AND AUTHORIZATION"
|
PROJECT_NAME: str = "AUTHENTICATION AND AUTHORIZATION"
|
||||||
API_V1_STR: str = "/v1"
|
API_V1_STR: str = "/v1"
|
||||||
@ -39,24 +30,9 @@ class Settings(BaseSettings):
|
|||||||
SMTP_USERNAME: str
|
SMTP_USERNAME: str
|
||||||
SMTP_PASSWORD: str
|
SMTP_PASSWORD: str
|
||||||
|
|
||||||
|
@property
|
||||||
# SQLALCHEMY_DATABASE_URI: Optional[PostgresDsn] = None
|
def SQLALCHEMY_DATABASE_URI(self) -> str:
|
||||||
|
return f"postgresql+psycopg2://{self.DB_USER}:{self.DB_PASSWORD}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}"
|
||||||
# @validator("SQLALCHEMY_DATABASE_URI", pre=True)
|
|
||||||
# def assemble_db_connection(
|
|
||||||
# cls, v: Optional[str], values: Dict[str, Any]
|
|
||||||
# ) -> Any:
|
|
||||||
# if isinstance(v, str):
|
|
||||||
# return v
|
|
||||||
# return PostgresDsn.build(
|
|
||||||
# scheme="postgresql",
|
|
||||||
# user=values.get("DB_USER"),
|
|
||||||
# password=values.get("DB_PASS"),
|
|
||||||
# host=values.get("DB_HOST"),
|
|
||||||
# path=f"/{values.get('DB_NAME') or ''}",
|
|
||||||
# )
|
|
||||||
# Database url configuration
|
|
||||||
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
case_sensitive = True
|
case_sensitive = True
|
||||||
@ -64,7 +40,7 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def get_settings():
|
def get_settings() -> Settings:
|
||||||
return Settings()
|
return Settings()
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,13 +1,6 @@
|
|||||||
from private_gpt.users.core.config import SQLALCHEMY_DATABASE_URI
|
from private_gpt.users.core.config import settings
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
engine = create_engine(SQLALCHEMY_DATABASE_URI, echo=True, future=True, pool_pre_ping=True)
|
engine = create_engine(settings.SQLALCHEMY_DATABASE_URI, echo=True, future=True, pool_pre_ping=True)
|
||||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||||
|
|
||||||
# test_engine = create_engine(
|
|
||||||
# f"{settings.SQLALCHEMY_DATABASE_URI}_test", pool_pre_ping=True
|
|
||||||
# )
|
|
||||||
# TestingSessionLocal = sessionmaker(
|
|
||||||
# autocommit=False, autoflush=False, bind=test_engine
|
|
||||||
# )
|
|
Loading…
Reference in New Issue
Block a user