mirror of
https://github.com/jumpserver/jumpserver.git
synced 2025-12-15 08:32:48 +00:00
Compare commits
20 Commits
revert-162
...
v2.22
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
44c78de941 | ||
|
|
f41b9e8bcf | ||
|
|
087248941b | ||
|
|
ad2199421e | ||
|
|
fadeeaee49 | ||
|
|
052b35ed97 | ||
|
|
bde91886ea | ||
|
|
74504ead98 | ||
|
|
a885f0c448 | ||
|
|
b9d62e02fd | ||
|
|
011262a37b | ||
|
|
0e3bfcc6ea | ||
|
|
e7296df57c | ||
|
|
36d1493f8e | ||
|
|
4b94dc77a9 | ||
|
|
e934c8b903 | ||
|
|
4b9fb4c796 | ||
|
|
c30b024f9c | ||
|
|
9c14eb5165 | ||
|
|
624f32bc6c |
@@ -14,7 +14,6 @@ def create_internal_platform(apps, schema_editor):
|
||||
model.objects.using(db_alias).update_or_create(
|
||||
name=name, defaults=defaults
|
||||
)
|
||||
migrations.RunPython(create_internal_platform)
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
@@ -133,6 +133,15 @@ class AuthMixin:
|
||||
self.password = password
|
||||
|
||||
def load_app_more_auth(self, app_id=None, username=None, user_id=None):
|
||||
# 清除认证信息
|
||||
self._clean_auth_info_if_manual_login_mode()
|
||||
|
||||
# 先加载临时认证信息
|
||||
if self.login_mode == self.LOGIN_MANUAL:
|
||||
self._load_tmp_auth_if_has(app_id, user_id)
|
||||
return
|
||||
|
||||
# Remote app
|
||||
from applications.models import Application
|
||||
app = get_object_or_none(Application, pk=app_id)
|
||||
if app and app.category_remote_app:
|
||||
@@ -141,11 +150,6 @@ class AuthMixin:
|
||||
return
|
||||
|
||||
# Other app
|
||||
self._clean_auth_info_if_manual_login_mode()
|
||||
# 加载临时认证信息
|
||||
if self.login_mode == self.LOGIN_MANUAL:
|
||||
self._load_tmp_auth_if_has(app_id, user_id)
|
||||
return
|
||||
# 更新用户名
|
||||
from users.models import User
|
||||
user = get_object_or_none(User, pk=user_id) if user_id else None
|
||||
|
||||
@@ -58,7 +58,15 @@ class AccountSerializer(AuthSerializerMixin, BulkOrgResourceModelSerializer):
|
||||
return attrs
|
||||
|
||||
def get_protocols(self, v):
|
||||
return v.protocols.replace(' ', ', ')
|
||||
""" protocols 是 queryset 中返回的,Post 创建成功后返回序列化时没有这个字段 """
|
||||
if hasattr(v, 'protocols'):
|
||||
protocols = v.protocols
|
||||
elif hasattr(v, 'asset') and v.asset:
|
||||
protocols = v.asset.protocols
|
||||
else:
|
||||
protocols = ''
|
||||
protocols = protocols.replace(' ', ', ')
|
||||
return protocols
|
||||
|
||||
@classmethod
|
||||
def setup_eager_loading(cls, queryset):
|
||||
|
||||
@@ -1,35 +1,35 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
import urllib.parse
|
||||
import json
|
||||
from typing import Callable
|
||||
import os
|
||||
import base64
|
||||
import ctypes
|
||||
import json
|
||||
import os
|
||||
import urllib.parse
|
||||
from typing import Callable
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.shortcuts import get_object_or_404
|
||||
from django.http import HttpResponse
|
||||
from django.shortcuts import get_object_or_404
|
||||
from django.utils import timezone
|
||||
from django.utils.translation import ugettext as _
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.viewsets import GenericViewSet
|
||||
from rest_framework import serializers
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.exceptions import PermissionDenied
|
||||
from rest_framework import serializers
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.viewsets import GenericViewSet
|
||||
|
||||
from applications.models import Application
|
||||
from authentication.signals import post_auth_failed
|
||||
from common.utils import get_logger, random_string
|
||||
from common.const.http import PATCH
|
||||
from common.http import is_true
|
||||
from common.mixins.api import SerializerMixin
|
||||
from common.utils import get_logger, random_string
|
||||
from common.utils.common import get_file_by_arch
|
||||
from orgs.mixins.api import RootOrgViewMixin
|
||||
from common.http import is_true
|
||||
from perms.models.base import Action
|
||||
from perms.utils.application.permission import get_application_actions
|
||||
from perms.utils.asset.permission import get_asset_actions
|
||||
from common.const.http import PATCH
|
||||
from terminal.models import EndpointRule
|
||||
from ..serializers import (
|
||||
ConnectionTokenSerializer, ConnectionTokenSecretSerializer, SuperConnectionTokenSerializer
|
||||
@@ -151,6 +151,8 @@ class ClientProtocolMixin:
|
||||
|
||||
if asset:
|
||||
name = asset.hostname
|
||||
if asset.platform.meta.get('console', None) == 'true':
|
||||
options['administrative session:i:'] = '1'
|
||||
elif application:
|
||||
name = application.name
|
||||
application.get_rdp_remote_app_setting()
|
||||
|
||||
@@ -157,6 +157,8 @@ class LDAPUser(_LDAPUser):
|
||||
|
||||
def _populate_user_from_attributes(self):
|
||||
for field, attr in self.settings.USER_ATTR_MAP.items():
|
||||
if field in ['groups']:
|
||||
continue
|
||||
try:
|
||||
value = self.attrs[attr][0]
|
||||
value = value.strip()
|
||||
|
||||
@@ -46,6 +46,8 @@ class SessionCookieMiddleware(MiddlewareMixin):
|
||||
|
||||
@staticmethod
|
||||
def set_cookie_public_key(request, response):
|
||||
if request.path.startswith('/api'):
|
||||
return
|
||||
pub_key_name = settings.SESSION_RSA_PUBLIC_KEY_NAME
|
||||
public_key = request.session.get(pub_key_name)
|
||||
cookie_key = request.COOKIES.get(pub_key_name)
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
<!-- Stylesheets -->
|
||||
<link href="{% static 'css/login-style.css' %}" rel="stylesheet">
|
||||
<link href="{% static 'css/jumpserver.css' %}" rel="stylesheet">
|
||||
<script src="{% static "js/jumpserver.js" %}"></script>
|
||||
<script src="{% static "js/jumpserver.js" %}?_=9"></script>
|
||||
|
||||
<style>
|
||||
.login-content {
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
<title>{{ title }}</title>
|
||||
{% include '_head_css_js.html' %}
|
||||
<link href="{% static "css/jumpserver.css" %}" rel="stylesheet">
|
||||
<script src="{% static "js/jumpserver.js" %}"></script>
|
||||
<script src="{% static "js/jumpserver.js" %}?_=9"></script>
|
||||
|
||||
</head>
|
||||
|
||||
|
||||
@@ -259,7 +259,7 @@ def decrypt_password(value):
|
||||
aes = get_aes_crypto(aes_key, 'ECB')
|
||||
try:
|
||||
password = aes.decrypt(password_cipher)
|
||||
except UnicodeDecodeError as e:
|
||||
except Exception as e:
|
||||
logging.error("Decript password error: {}, {}".format(password_cipher, e))
|
||||
return value
|
||||
return password
|
||||
|
||||
@@ -13,10 +13,6 @@ reader = None
|
||||
|
||||
|
||||
def get_ip_city_by_geoip(ip):
|
||||
if not ip or '.' not in ip or not isinstance(ip, str):
|
||||
return _("Invalid ip")
|
||||
if ':' in ip:
|
||||
return 'IPv6'
|
||||
global reader
|
||||
if reader is None:
|
||||
path = os.path.join(os.path.dirname(__file__), 'GeoLite2-City.mmdb')
|
||||
@@ -32,7 +28,7 @@ def get_ip_city_by_geoip(ip):
|
||||
try:
|
||||
response = reader.city(ip)
|
||||
except GeoIP2Error:
|
||||
return {}
|
||||
return _("Unknown")
|
||||
|
||||
city_names = response.city.names or {}
|
||||
lang = settings.LANGUAGE_CODE[:2]
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
import os
|
||||
from django.utils.translation import ugettext as _
|
||||
|
||||
import ipdb
|
||||
|
||||
@@ -11,13 +10,13 @@ ipip_db = None
|
||||
|
||||
def get_ip_city_by_ipip(ip):
|
||||
global ipip_db
|
||||
if not ip or not isinstance(ip, str):
|
||||
return _("Invalid ip")
|
||||
if ':' in ip:
|
||||
return 'IPv6'
|
||||
if ipip_db is None:
|
||||
ipip_db_path = os.path.join(os.path.dirname(__file__), 'ipipfree.ipdb')
|
||||
ipip_db = ipdb.City(ipip_db_path)
|
||||
|
||||
info = ipip_db.find_info(ip, 'CN')
|
||||
try:
|
||||
info = ipip_db.find_info(ip, 'CN')
|
||||
except ValueError:
|
||||
return None
|
||||
if not info:
|
||||
raise None
|
||||
return {'city': info.city_name, 'country': info.country_name}
|
||||
|
||||
@@ -74,13 +74,18 @@ def contains_ip(ip, ip_group):
|
||||
|
||||
|
||||
def get_ip_city(ip):
|
||||
info = get_ip_city_by_ipip(ip)
|
||||
city = info.get('city', _("Unknown"))
|
||||
country = info.get('country')
|
||||
if not ip or not isinstance(ip, str):
|
||||
return _("Invalid ip")
|
||||
if ':' in ip:
|
||||
return 'IPv6'
|
||||
|
||||
# 国内城市 并且 语言是中文就使用国内
|
||||
is_zh = settings.LANGUAGE_CODE.startswith('zh')
|
||||
if country == '中国' and is_zh:
|
||||
return city
|
||||
else:
|
||||
return get_ip_city_by_geoip(ip)
|
||||
info = get_ip_city_by_ipip(ip)
|
||||
if info:
|
||||
city = info.get('city', _("Unknown"))
|
||||
country = info.get('country')
|
||||
|
||||
# 国内城市 并且 语言是中文就使用国内
|
||||
is_zh = settings.LANGUAGE_CODE.startswith('zh')
|
||||
if country == '中国' and is_zh:
|
||||
return city
|
||||
return get_ip_city_by_geoip(ip)
|
||||
|
||||
@@ -153,3 +153,5 @@ ANSIBLE_LOG_DIR = os.path.join(PROJECT_DIR, 'data', 'ansible')
|
||||
REDIS_HOST = CONFIG.REDIS_HOST
|
||||
REDIS_PORT = CONFIG.REDIS_PORT
|
||||
REDIS_PASSWORD = CONFIG.REDIS_PASSWORD
|
||||
|
||||
DJANGO_REDIS_SCAN_ITERSIZE = 1000
|
||||
|
||||
@@ -126,6 +126,8 @@ class BuiltinRole:
|
||||
org_user = PredefineRole(
|
||||
'7', ugettext_noop('OrgUser'), Scope.org, user_perms
|
||||
)
|
||||
system_role_mapper = None
|
||||
org_role_mapper = None
|
||||
|
||||
@classmethod
|
||||
def get_roles(cls):
|
||||
@@ -138,22 +140,24 @@ class BuiltinRole:
|
||||
|
||||
@classmethod
|
||||
def get_system_role_by_old_name(cls, name):
|
||||
mapper = {
|
||||
'App': cls.system_component,
|
||||
'Admin': cls.system_admin,
|
||||
'User': cls.system_user,
|
||||
'Auditor': cls.system_auditor
|
||||
}
|
||||
return mapper[name].get_role()
|
||||
if not cls.system_role_mapper:
|
||||
cls.system_role_mapper = {
|
||||
'App': cls.system_component.get_role(),
|
||||
'Admin': cls.system_admin.get_role(),
|
||||
'User': cls.system_user.get_role(),
|
||||
'Auditor': cls.system_auditor.get_role()
|
||||
}
|
||||
return cls.system_role_mapper[name]
|
||||
|
||||
@classmethod
|
||||
def get_org_role_by_old_name(cls, name):
|
||||
mapper = {
|
||||
'Admin': cls.org_admin,
|
||||
'User': cls.org_user,
|
||||
'Auditor': cls.org_auditor,
|
||||
}
|
||||
return mapper[name].get_role()
|
||||
if not cls.org_role_mapper:
|
||||
cls.org_role_mapper = {
|
||||
'Admin': cls.org_admin.get_role(),
|
||||
'User': cls.org_user.get_role(),
|
||||
'Auditor': cls.org_auditor.get_role(),
|
||||
}
|
||||
return cls.org_role_mapper[name]
|
||||
|
||||
@classmethod
|
||||
def sync_to_db(cls, show_msg=False):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Generated by Django 3.1.13 on 2021-12-01 11:01
|
||||
|
||||
import time
|
||||
from django.db import migrations
|
||||
|
||||
from rbac.builtin import BuiltinRole
|
||||
@@ -9,33 +10,61 @@ def migrate_system_role_binding(apps, schema_editor):
|
||||
db_alias = schema_editor.connection.alias
|
||||
user_model = apps.get_model('users', 'User')
|
||||
role_binding_model = apps.get_model('rbac', 'SystemRoleBinding')
|
||||
users = user_model.objects.using(db_alias).all()
|
||||
|
||||
role_bindings = []
|
||||
for user in users:
|
||||
role = BuiltinRole.get_system_role_by_old_name(user.role)
|
||||
role_binding = role_binding_model(scope='system', user_id=user.id, role_id=role.id)
|
||||
role_bindings.append(role_binding)
|
||||
role_binding_model.objects.bulk_create(role_bindings, ignore_conflicts=True)
|
||||
count = 0
|
||||
bulk_size = 1000
|
||||
while True:
|
||||
users = user_model.objects.using(db_alias) \
|
||||
.only('role', 'id') \
|
||||
.all()[count:count+bulk_size]
|
||||
if not users:
|
||||
break
|
||||
|
||||
role_bindings = []
|
||||
start = time.time()
|
||||
for user in users:
|
||||
role = BuiltinRole.get_system_role_by_old_name(user.role)
|
||||
role_binding = role_binding_model(scope='system', user_id=user.id, role_id=role.id)
|
||||
role_bindings.append(role_binding)
|
||||
|
||||
role_binding_model.objects.bulk_create(role_bindings, ignore_conflicts=True)
|
||||
print("Create role binding: {}-{} using: {:.2f}s".format(
|
||||
count, count + len(users), time.time()-start
|
||||
))
|
||||
count += len(users)
|
||||
|
||||
|
||||
def migrate_org_role_binding(apps, schema_editor):
|
||||
db_alias = schema_editor.connection.alias
|
||||
org_member_model = apps.get_model('orgs', 'OrganizationMember')
|
||||
role_binding_model = apps.get_model('rbac', 'RoleBinding')
|
||||
members = org_member_model.objects.using(db_alias).all()
|
||||
|
||||
role_bindings = []
|
||||
for member in members:
|
||||
role = BuiltinRole.get_org_role_by_old_name(member.role)
|
||||
role_binding = role_binding_model(
|
||||
scope='org',
|
||||
user_id=member.user.id,
|
||||
role_id=role.id,
|
||||
org_id=member.org.id
|
||||
)
|
||||
role_bindings.append(role_binding)
|
||||
role_binding_model.objects.bulk_create(role_bindings)
|
||||
count = 0
|
||||
bulk_size = 1000
|
||||
|
||||
while True:
|
||||
members = org_member_model.objects.using(db_alias)\
|
||||
.only('role', 'user_id', 'org_id')\
|
||||
.all()[count:count+bulk_size]
|
||||
if not members:
|
||||
break
|
||||
role_bindings = []
|
||||
start = time.time()
|
||||
|
||||
for member in members:
|
||||
role = BuiltinRole.get_org_role_by_old_name(member.role)
|
||||
role_binding = role_binding_model(
|
||||
scope='org',
|
||||
user_id=member.user_id,
|
||||
role_id=role.id,
|
||||
org_id=member.org_id
|
||||
)
|
||||
role_bindings.append(role_binding)
|
||||
role_binding_model.objects.bulk_create(role_bindings, ignore_conflicts=True)
|
||||
print("Create role binding: {}-{} using: {:.2f}s".format(
|
||||
count, count + len(members), time.time()-start
|
||||
))
|
||||
count += len(members)
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
@@ -40,6 +40,6 @@ class PrivateSettingSerializer(PublicSettingSerializer):
|
||||
TERMINAL_KOKO_SSH_ENABLED = serializers.BooleanField()
|
||||
|
||||
ANNOUNCEMENT_ENABLED = serializers.BooleanField()
|
||||
ANNOUNCEMENT = serializers.CharField()
|
||||
ANNOUNCEMENT = serializers.DictField()
|
||||
|
||||
TICKETS_ENABLED = serializers.BooleanField()
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
|
||||
{% include '_head_css_js.html' %}
|
||||
<link href="{% static "css/jumpserver.css" %}" rel="stylesheet">
|
||||
<script src="{% static "js/jumpserver.js" %}"></script>
|
||||
<script src="{% static "js/jumpserver.js" %}?_=9"></script>
|
||||
<style>
|
||||
.outerBox {
|
||||
margin: 0 auto;
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
|
||||
{% include '_head_css_js.html' %}
|
||||
<link href="{% static "css/jumpserver.css" %}" rel="stylesheet">
|
||||
<script src="{% static "js/jumpserver.js" %}"></script>
|
||||
<script src="{% static "js/jumpserver.js" %}?_=9"></script>
|
||||
<style>
|
||||
.passwordBox {
|
||||
max-width: 560px;
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
<!-- Custom and plugin javascript -->
|
||||
<script src="{% static "js/plugins/toastr/toastr.min.js" %}"></script>
|
||||
<script src="{% static "js/inspinia.js" %}"></script>
|
||||
<script src="{% static "js/jumpserver.js" %}?v=8"></script>
|
||||
<script src="{% static "js/jumpserver.js" %}?v=9"></script>
|
||||
<script src="{% static 'js/plugins/select2/select2.full.min.js' %}"></script>
|
||||
<script src="{% static 'js/plugins/select2/i18n/zh-CN.js' %}"></script>
|
||||
<script>
|
||||
|
||||
@@ -43,6 +43,9 @@ class SessionJoinRecordsViewSet(OrgModelViewSet):
|
||||
)
|
||||
filterset_fields = search_fields
|
||||
model = models.SessionJoinRecord
|
||||
rbac_perms = {
|
||||
'finished': 'terminal.change_sessionjoinrecord'
|
||||
}
|
||||
|
||||
def create(self, request, *args, **kwargs):
|
||||
try:
|
||||
|
||||
@@ -297,6 +297,9 @@ class QuerySet(DJQuerySet):
|
||||
self._command_store_config = command_store_config
|
||||
self._storage = CommandStore(command_store_config)
|
||||
|
||||
# 命令列表模糊搜索时报错
|
||||
super().__init__()
|
||||
|
||||
@lazyproperty
|
||||
def _grouped_method_calls(self):
|
||||
_method_calls = {k: list(v) for k, v in groupby(self._method_calls, lambda x: x[0])}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import copy
|
||||
import os
|
||||
|
||||
from importlib import import_module
|
||||
@@ -77,7 +78,7 @@ class CommandStorage(CommonStorageModelMixin, CommonModelMixin):
|
||||
def config(self):
|
||||
config = self.meta
|
||||
config.update({'TYPE': self.type})
|
||||
return config
|
||||
return copy.deepcopy(config)
|
||||
|
||||
@property
|
||||
def valid_config(self):
|
||||
|
||||
@@ -27,7 +27,7 @@ class TicketViewSet(CommonApiMixin, viewsets.ModelViewSet):
|
||||
}
|
||||
filterset_class = TicketFilter
|
||||
search_fields = [
|
||||
'title', 'action', 'type', 'status', 'applicant_display'
|
||||
'title', 'type', 'status', 'applicant_display'
|
||||
]
|
||||
ordering_fields = (
|
||||
'title', 'applicant_display', 'status', 'state', 'action_display',
|
||||
|
||||
@@ -143,7 +143,7 @@ class PasswordExpirationReminderMsg(UserMessage):
|
||||
subject = _('Password is about expire')
|
||||
|
||||
date_password_expired_local = timezone.localtime(user.date_password_expired)
|
||||
update_password_url = urljoin(settings.SITE_URL, '/ui/#/users/profile/?activeTab=PasswordUpdate')
|
||||
update_password_url = urljoin(settings.SITE_URL, '/ui/#/profile/setting/?activeTab=PasswordUpdate')
|
||||
date_password_expired = date_password_expired_local.strftime('%Y-%m-%d %H:%M:%S')
|
||||
context = {
|
||||
'name': user.name,
|
||||
|
||||
@@ -12,33 +12,23 @@ BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
APPS_DIR = os.path.join(BASE_DIR, 'apps')
|
||||
|
||||
sys.path.insert(0, BASE_DIR)
|
||||
sys.path.insert(0, APPS_DIR)
|
||||
from apps.jumpserver.const import CONFIG
|
||||
from apps.jumpserver.settings import base as jms_settings
|
||||
|
||||
os.environ.setdefault('PYTHONOPTIMIZE', '1')
|
||||
if os.getuid() == 0:
|
||||
os.environ.setdefault('C_FORCE_ROOT', '1')
|
||||
|
||||
REDIS_SSL_KEYFILE = os.path.join(BASE_DIR, 'data', 'certs', 'redis_client.key')
|
||||
if not os.path.exists(REDIS_SSL_KEYFILE):
|
||||
REDIS_SSL_KEYFILE = None
|
||||
|
||||
REDIS_SSL_CERTFILE = os.path.join(BASE_DIR, 'data', 'certs', 'redis_client.crt')
|
||||
if not os.path.exists(REDIS_SSL_CERTFILE):
|
||||
REDIS_SSL_CERTFILE = None
|
||||
|
||||
REDIS_SSL_CA_CERTS = os.path.join(BASE_DIR, 'data', 'certs', 'redis_ca.crt')
|
||||
if not os.path.exists(REDIS_SSL_CA_CERTS):
|
||||
REDIS_SSL_CA_CERTS = os.path.join(BASE_DIR, 'data', 'certs', 'redis_ca.pem')
|
||||
|
||||
params = {
|
||||
'host': CONFIG.REDIS_HOST,
|
||||
'port': CONFIG.REDIS_PORT,
|
||||
'password': CONFIG.REDIS_PASSWORD,
|
||||
"ssl": CONFIG.REDIS_USE_SSL,
|
||||
'ssl_cert_reqs': CONFIG.REDIS_SSL_REQUIRED,
|
||||
"ssl_keyfile": REDIS_SSL_KEYFILE,
|
||||
"ssl_certfile": REDIS_SSL_CERTFILE,
|
||||
"ssl_ca_certs": REDIS_SSL_CA_CERTS
|
||||
"ssl_keyfile": jms_settings.REDIS_SSL_KEYFILE,
|
||||
"ssl_certfile": jms_settings.REDIS_SSL_CERTFILE,
|
||||
"ssl_ca_certs": jms_settings.REDIS_SSL_CA_CERTS
|
||||
}
|
||||
redis = Redis(**params)
|
||||
scheduler = "django_celery_beat.schedulers:DatabaseScheduler"
|
||||
|
||||
68
utils/test_run_migrations.py
Normal file
68
utils/test_run_migrations.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# Generated by Django 3.1.13 on 2021-12-01 11:01
|
||||
import os
|
||||
import sys
|
||||
import django
|
||||
import time
|
||||
|
||||
app_path = '***** Change me *******'
|
||||
sys.path.insert(0, app_path)
|
||||
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "jumpserver.settings")
|
||||
django.setup()
|
||||
|
||||
from django.apps import apps
|
||||
from django.db import connection
|
||||
|
||||
# ========================== 添加到需要测试的 migrations 上方 ==========================
|
||||
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
from rbac.builtin import BuiltinRole
|
||||
|
||||
|
||||
def migrate_system_role_binding(apps, schema_editor):
|
||||
db_alias = schema_editor.connection.alias
|
||||
user_model = apps.get_model('users', 'User')
|
||||
role_binding_model = apps.get_model('rbac', 'SystemRoleBinding')
|
||||
|
||||
count = 0
|
||||
bulk_size = 1000
|
||||
while True:
|
||||
users = user_model.objects.using(db_alias) \
|
||||
.only('role', 'id') \
|
||||
.all()[count:count+bulk_size]
|
||||
if not users:
|
||||
break
|
||||
|
||||
role_bindings = []
|
||||
start = time.time()
|
||||
for user in users:
|
||||
role = BuiltinRole.get_system_role_by_old_name(user.role)
|
||||
role_binding = role_binding_model(scope='system', user_id=user.id, role_id=role.id)
|
||||
role_bindings.append(role_binding)
|
||||
|
||||
role_binding_model.objects.bulk_create(role_bindings, ignore_conflicts=True)
|
||||
print("Create role binding: {}-{} using: {:.2f}s".format(
|
||||
count, count + len(users), time.time()-start
|
||||
))
|
||||
count += len(users)
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('rbac', '0003_auto_20211130_1037'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.RunPython(migrate_system_role_binding),
|
||||
]
|
||||
|
||||
|
||||
# ================== 添加到下方 ======================
|
||||
def main():
|
||||
schema_editor = connection.schema_editor()
|
||||
migrate_system_role_binding(apps, schema_editor)
|
||||
|
||||
|
||||
# main()
|
||||
Reference in New Issue
Block a user