diff --git a/private_gpt/users/api/v1/routers/auth.py b/private_gpt/users/api/v1/routers/auth.py index 895e5c9d..05a34e3f 100644 --- a/private_gpt/users/api/v1/routers/auth.py +++ b/private_gpt/users/api/v1/routers/auth.py @@ -53,12 +53,13 @@ def register_user( def ldap_login(db, username, password): ldap = Ldap(LDAP_SERVER, username, password) - print("LDAP LOGIN:: ", ldap.who_am_i()) + username = ldap.who_am_i() + department = ldap.get_department(username) if not ldap: raise HTTPException( status_code=400, detail="Incorrect email or password" ) - return ldap.who_am_i() + return username, department[0] def create_user_role( db: Session, @@ -92,11 +93,12 @@ def ad_user_register( email: str, fullname: str, password: str, + department_id: int, ) -> models.User: """ Register a new user in the database. Company id is directly given here. """ - user_in = schemas.UserCreate(email=email, password=password, fullname=fullname, company_id=1) + user_in = schemas.UserCreate(email=email, password=password, fullname=fullname, company_id=1, department_id=department_id) print("user is: ", user_in) user = crud.user.create(db, obj_in=user_in) user_role_name = Role.GUEST["name"] @@ -124,10 +126,15 @@ def login_access_token( if existing_user.user_role.role.name == "SUPER_ADMIN": return True else: - ldap = ldap_login(db=db, username=form_data.username, password=form_data.password) + username, department = ldap_login(db=db, username=form_data.username, password=form_data.password) + else: - ldap = ldap_login(db=db, username=form_data.username, password=form_data.password) - ad_user_register(db=db, email=form_data.username, fullname=ldap, password=form_data.password) + username, department = ldap_login(db=db, username=form_data.username, password=form_data.password) + depart = crud.department.get_by_department_name(db, name=department) + if depart: + ad_user_register(db=db, email=form_data.username, fullname=username, password=form_data.password, department_id=depart.id) + else: + ad_user_register(db=db, email=form_data.username, fullname=username, password=form_data.password, department_id=1) return True return False if not (ad_auth(LDAP_ENABLE)): diff --git a/private_gpt/users/utils/ad_auth.py b/private_gpt/users/utils/ad_auth.py index c71f73de..1fef7f7f 100644 --- a/private_gpt/users/utils/ad_auth.py +++ b/private_gpt/users/utils/ad_auth.py @@ -1,14 +1,27 @@ import ldap3 +from ldap3 import SUBTREE class Ldap: """Class for LDAP related connections/operations.""" def __init__(self, server_uri, ldap_user, ldap_pass): self.server = ldap3.Server(server_uri, get_info=ldap3.ALL) - print(f"Connected to ldap server: {self.server}") self.conn = ldap3.Connection(self.server, user=ldap_user, password=ldap_pass, auto_bind=True) def who_am_i(self): - return self.conn.extend.standard.who_am_i() + account = self.conn.extend.standard.who_am_i() + account = account.split('\\')[1] + return account + + def get_department(self, user): + attributes = ['cn', 'givenName','sAMAccountName', 'department'] + filter = f"(&(objectclass=person)(objectclass=user)(sAMAccountName={user}))" + result = self.conn.search('ou=GLOBAL IME BANK LIMITED,dc=gibl,dc=org', filter, search_scope=SUBTREE, attributes=attributes) + if result: + department = [entry.department.value for entry in self.conn.entries ] + return department + else: + return + \ No newline at end of file