Merge pull request #5952 from jumpserver/dev

v2.9.0 rc2
This commit is contained in:
Jiangjie.Bai 2021-04-13 19:19:43 +08:00 committed by GitHub
commit 4bf2371cf0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 193 additions and 43 deletions

View File

@ -307,6 +307,15 @@ class NodeAllAssetsMappingMixin:
org_id = str(org_id) org_id = str(org_id)
cls.orgid_nodekey_assetsid_mapping.pop(org_id, None) cls.orgid_nodekey_assetsid_mapping.pop(org_id, None)
@classmethod
def expire_all_orgs_node_all_asset_ids_mapping_from_memory(cls):
orgs = Organization.objects.all()
org_ids = [str(org.id) for org in orgs]
org_ids.append(Organization.ROOT_ID)
for id in org_ids:
cls.expire_node_all_asset_ids_mapping_from_memory(id)
# get order: from memory -> (from cache -> to generate) # get order: from memory -> (from cache -> to generate)
@classmethod @classmethod
def get_node_all_asset_ids_mapping_from_cache_or_generate_to_cache(cls, org_id): def get_node_all_asset_ids_mapping_from_cache_or_generate_to_cache(cls, org_id):

View File

@ -13,6 +13,7 @@ from common.signals import django_ready
from common.utils.connection import RedisPubSub from common.utils.connection import RedisPubSub
from common.utils import get_logger from common.utils import get_logger
from assets.models import Asset, Node from assets.models import Asset, Node
from orgs.models import Organization
logger = get_logger(__file__) logger = get_logger(__file__)
@ -36,13 +37,18 @@ node_assets_mapping_for_memory_pub_sub = NodeAssetsMappingForMemoryPubSub()
def expire_node_assets_mapping_for_memory(org_id): def expire_node_assets_mapping_for_memory(org_id):
# 所有进程清除(自己的 memory 数据) # 所有进程清除(自己的 memory 数据)
org_id = str(org_id) org_id = str(org_id)
node_assets_mapping_for_memory_pub_sub.publish(org_id) root_org_id = Organization.ROOT_ID
# 当前进程清除(cache 数据) # 当前进程清除(cache 数据)
logger.debug( logger.debug(
"Expire node assets id mapping from cache of org={}, pid={}" "Expire node assets id mapping from cache of org={}, pid={}"
"".format(org_id, os.getpid()) "".format(org_id, os.getpid())
) )
Node.expire_node_all_asset_ids_mapping_from_cache(org_id) Node.expire_node_all_asset_ids_mapping_from_cache(org_id)
Node.expire_node_all_asset_ids_mapping_from_cache(root_org_id)
node_assets_mapping_for_memory_pub_sub.publish(org_id)
node_assets_mapping_for_memory_pub_sub.publish(root_org_id)
@receiver(post_save, sender=Node) @receiver(post_save, sender=Node)
@ -73,6 +79,8 @@ def subscribe_node_assets_mapping_expire(sender, **kwargs):
logger.debug("Start subscribe for expire node assets id mapping from memory") logger.debug("Start subscribe for expire node assets id mapping from memory")
def keep_subscribe(): def keep_subscribe():
while True:
try:
subscribe = node_assets_mapping_for_memory_pub_sub.subscribe() subscribe = node_assets_mapping_for_memory_pub_sub.subscribe()
for message in subscribe.listen(): for message in subscribe.listen():
if message["type"] != "message": if message["type"] != "message":
@ -83,6 +91,10 @@ def subscribe_node_assets_mapping_expire(sender, **kwargs):
"Expire node assets id mapping from memory of org={}, pid={}" "Expire node assets id mapping from memory of org={}, pid={}"
"".format(str(org_id), os.getpid()) "".format(str(org_id), os.getpid())
) )
except Exception as e:
logger.exception(f'subscribe_node_assets_mapping_expire: {e}')
Node.expire_all_orgs_node_all_asset_ids_mapping_from_memory()
t = threading.Thread(target=keep_subscribe) t = threading.Thread(target=keep_subscribe)
t.daemon = True t.daemon = True
t.start() t.start()

View File

@ -44,6 +44,7 @@ class AuthBackendLabelMapping(LazyObject):
backend_label_mapping[backend] = source.label backend_label_mapping[backend] = source.label
backend_label_mapping[settings.AUTH_BACKEND_PUBKEY] = _('SSH Key') backend_label_mapping[settings.AUTH_BACKEND_PUBKEY] = _('SSH Key')
backend_label_mapping[settings.AUTH_BACKEND_MODEL] = _('Password') backend_label_mapping[settings.AUTH_BACKEND_MODEL] = _('Password')
backend_label_mapping[settings.AUTH_BACKEND_SSO] = _('SSO')
return backend_label_mapping return backend_label_mapping
def _setup(self): def _setup(self):

View File

@ -16,6 +16,7 @@ reason_user_not_exist = 'user_not_exist'
reason_password_expired = 'password_expired' reason_password_expired = 'password_expired'
reason_user_invalid = 'user_invalid' reason_user_invalid = 'user_invalid'
reason_user_inactive = 'user_inactive' reason_user_inactive = 'user_inactive'
reason_user_expired = 'user_expired'
reason_backend_not_match = 'backend_not_match' reason_backend_not_match = 'backend_not_match'
reason_acl_not_allow = 'acl_not_allow' reason_acl_not_allow = 'acl_not_allow'
@ -28,6 +29,7 @@ reason_choices = {
reason_password_expired: _("Password expired"), reason_password_expired: _("Password expired"),
reason_user_invalid: _('Disabled or expired'), reason_user_invalid: _('Disabled or expired'),
reason_user_inactive: _("This account is inactive."), reason_user_inactive: _("This account is inactive."),
reason_user_expired: _("This account is expired"),
reason_backend_not_match: _("Auth backend not match"), reason_backend_not_match: _("Auth backend not match"),
reason_acl_not_allow: _("ACL is not allowed"), reason_acl_not_allow: _("ACL is not allowed"),
} }

View File

@ -171,7 +171,7 @@ class AuthMixin:
if not user: if not user:
self.raise_credential_error(errors.reason_password_failed) self.raise_credential_error(errors.reason_password_failed)
elif user.is_expired: elif user.is_expired:
self.raise_credential_error(errors.reason_user_inactive) self.raise_credential_error(errors.reason_user_expired)
elif not user.is_active: elif not user.is_active:
self.raise_credential_error(errors.reason_user_inactive) self.raise_credential_error(errors.reason_user_inactive)
return user return user

View File

@ -69,6 +69,7 @@ class UserLoginView(mixins.AuthMixin, FormView):
new_form = form_cls(data=form.data) new_form = form_cls(data=form.data)
new_form._errors = form.errors new_form._errors = form.errors
context = self.get_context_data(form=new_form) context = self.get_context_data(form=new_form)
self.request.session.set_test_cookie()
return self.render_to_response(context) return self.render_to_response(context)
except (errors.PasswdTooSimple, errors.PasswordRequireResetError) as e: except (errors.PasswdTooSimple, errors.PasswordRequireResetError) as e:
return redirect(e.url) return redirect(e.url)

View File

@ -1,3 +1,5 @@
import time
from django.core.cache import cache from django.core.cache import cache
from django.utils import timezone from django.utils import timezone
from django.utils.timesince import timesince from django.utils.timesince import timesince
@ -6,6 +8,8 @@ from django.http.response import JsonResponse, HttpResponse
from rest_framework.views import APIView from rest_framework.views import APIView
from rest_framework.permissions import AllowAny from rest_framework.permissions import AllowAny
from collections import Counter from collections import Counter
from django.conf import settings
from rest_framework.response import Response
from users.models import User from users.models import User
from assets.models import Asset from assets.models import Asset
@ -307,7 +311,68 @@ class IndexApi(TotalCountMixin, DatesLoginMetricMixin, APIView):
return JsonResponse(data, status=200) return JsonResponse(data, status=200)
class PrometheusMetricsApi(APIView): class HealthApiMixin(APIView):
def is_token_right(self):
token = self.request.query_params.get('token')
ok_token = settings.HEALTH_CHECK_TOKEN
if ok_token and token != ok_token:
return False
return True
def check_permissions(self, request):
if not self.is_token_right():
msg = 'Health check token error, ' \
'Please set query param in url and same with setting HEALTH_CHECK_TOKEN. ' \
'eg: $PATH/?token=$HEALTH_CHECK_TOKEN'
self.permission_denied(request, message={'error': msg}, code=403)
class HealthCheckView(HealthApiMixin):
permission_classes = (AllowAny,)
@staticmethod
def get_db_status():
t1 = time.time()
try:
User.objects.first()
t2 = time.time()
return True, t2 - t1
except:
t2 = time.time()
return False, t2 - t1
def get_redis_status(self):
key = 'HEALTH_CHECK'
t1 = time.time()
try:
value = '1'
cache.set(key, '1', 10)
got = cache.get(key)
t2 = time.time()
if value == got:
return True, t2 -t1
return False, t2 -t1
except:
t2 = time.time()
return False, t2 - t1
def get(self, request):
redis_status, redis_time = self.get_redis_status()
db_status, db_time = self.get_db_status()
status = all([redis_status, db_status])
data = {
'status': status,
'db_status': db_status,
'db_time': db_time,
'redis_status': redis_status,
'redis_time': redis_time,
'time': int(time.time())
}
return Response(data)
class PrometheusMetricsApi(HealthApiMixin):
permission_classes = (AllowAny,) permission_classes = (AllowAny,)
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):

View File

@ -289,6 +289,7 @@ class Config(dict):
'SESSION_SAVE_EVERY_REQUEST': True, 'SESSION_SAVE_EVERY_REQUEST': True,
'SESSION_EXPIRE_AT_BROWSER_CLOSE_FORCE': False, 'SESSION_EXPIRE_AT_BROWSER_CLOSE_FORCE': False,
'FORGOT_PASSWORD_URL': '', 'FORGOT_PASSWORD_URL': '',
'HEALTH_CHECK_TOKEN': ''
} }
def compatible_auth_openid_of_key(self): def compatible_auth_openid_of_key(self):

View File

@ -123,3 +123,4 @@ FORGOT_PASSWORD_URL = CONFIG.FORGOT_PASSWORD_URL
# 自定义默认组织名 # 自定义默认组织名
GLOBAL_ORG_DISPLAY_NAME = CONFIG.GLOBAL_ORG_DISPLAY_NAME GLOBAL_ORG_DISPLAY_NAME = CONFIG.GLOBAL_ORG_DISPLAY_NAME
HEALTH_CHECK_TOKEN = CONFIG.HEALTH_CHECK_TOKEN

View File

@ -48,7 +48,8 @@ urlpatterns = [
path('', views.IndexView.as_view(), name='index'), path('', views.IndexView.as_view(), name='index'),
path('api/v1/', include(api_v1)), path('api/v1/', include(api_v1)),
re_path('api/(?P<app>\w+)/(?P<version>v\d)/.*', views.redirect_format_api), re_path('api/(?P<app>\w+)/(?P<version>v\d)/.*', views.redirect_format_api),
path('api/health/', views.HealthCheckView.as_view(), name="health"), path('api/health/', api.HealthCheckView.as_view(), name="health"),
path('api/v1/health/', api.HealthCheckView.as_view(), name="health_v1"),
# External apps url # External apps url
path('core/auth/captcha/', include('captcha.urls')), path('core/auth/captcha/', include('captcha.urls')),
path('core/', include(app_view_patterns)), path('core/', include(app_view_patterns)),

View File

@ -17,7 +17,7 @@ from common.http import HttpResponseTemporaryRedirect
__all__ = [ __all__ = [
'LunaView', 'I18NView', 'KokoView', 'WsView', 'HealthCheckView', 'LunaView', 'I18NView', 'KokoView', 'WsView',
'redirect_format_api', 'redirect_old_apps_view', 'UIView' 'redirect_format_api', 'redirect_old_apps_view', 'UIView'
] ]
@ -64,13 +64,6 @@ def redirect_old_apps_view(request, *args, **kwargs):
return HttpResponseTemporaryRedirect(new_path) return HttpResponseTemporaryRedirect(new_path)
class HealthCheckView(APIView):
permission_classes = (AllowAny,)
def get(self, request):
return JsonResponse({"status": 1, "time": int(time.time())})
class WsView(APIView): class WsView(APIView):
ws_port = settings.HTTP_LISTEN_PORT + 1 ws_port = settings.HTTP_LISTEN_PORT + 1

View File

@ -46,12 +46,19 @@ def subscribe_orgs_mapping_expire(sender, **kwargs):
logger.debug("Start subscribe for expire orgs mapping from memory") logger.debug("Start subscribe for expire orgs mapping from memory")
def keep_subscribe(): def keep_subscribe():
while True:
try:
subscribe = orgs_mapping_for_memory_pub_sub.subscribe() subscribe = orgs_mapping_for_memory_pub_sub.subscribe()
for message in subscribe.listen(): for message in subscribe.listen():
if message['type'] != 'message': if message['type'] != 'message':
continue continue
if message['data'] == b'error':
raise ValueError
Organization.expire_orgs_mapping() Organization.expire_orgs_mapping()
logger.debug('Expire orgs mapping') logger.debug('Expire orgs mapping')
except Exception as e:
logger.exception(f'subscribe_orgs_mapping_expire: {e}')
Organization.expire_orgs_mapping()
t = threading.Thread(target=keep_subscribe) t = threading.Thread(target=keep_subscribe)
t.daemon = True t.daemon = True

View File

@ -6,6 +6,7 @@ import threading
from django.dispatch import receiver from django.dispatch import receiver
from django.db.models.signals import post_save, pre_save from django.db.models.signals import post_save, pre_save
from django.utils.functional import LazyObject from django.utils.functional import LazyObject
from django.db import close_old_connections
from jumpserver.utils import current_request from jumpserver.utils import current_request
from common.decorator import on_transaction_commit from common.decorator import on_transaction_commit
@ -71,13 +72,21 @@ def subscribe_settings_change(sender, **kwargs):
logger.debug("Start subscribe setting change") logger.debug("Start subscribe setting change")
def keep_subscribe(): def keep_subscribe():
while True:
try:
sub = setting_pub_sub.subscribe() sub = setting_pub_sub.subscribe()
for msg in sub.listen(): for msg in sub.listen():
close_old_connections()
if msg["type"] != "message": if msg["type"] != "message":
continue continue
item = msg['data'].decode() item = msg['data'].decode()
logger.debug("Found setting change: {}".format(str(item))) logger.debug("Found setting change: {}".format(str(item)))
Setting.refresh_item(item) Setting.refresh_item(item)
except Exception as e:
logger.exception(f'subscribe_settings_change: {e}')
close_old_connections()
Setting.refresh_all_settings()
t = threading.Thread(target=keep_subscribe) t = threading.Thread(target=keep_subscribe)
t.daemon = True t.daemon = True
t.start() t.start()

View File

@ -11,6 +11,7 @@ from rest_framework.response import Response
from rest_framework.decorators import action from rest_framework.decorators import action
from django.template import loader from django.template import loader
from common.http import is_true
from terminal.models import CommandStorage, Command from terminal.models import CommandStorage, Command
from terminal.filters import CommandFilter from terminal.filters import CommandFilter
from orgs.utils import current_org from orgs.utils import current_org
@ -140,7 +141,21 @@ class CommandViewSet(viewsets.ModelViewSet):
if session_id and not command_storage_id: if session_id and not command_storage_id:
# 会话里的命令列表肯定会提供 session_id这里防止 merge 的时候取全量的数据 # 会话里的命令列表肯定会提供 session_id这里防止 merge 的时候取全量的数据
return self.merge_all_storage_list(request, *args, **kwargs) return self.merge_all_storage_list(request, *args, **kwargs)
return super().list(request, *args, **kwargs)
queryset = self.filter_queryset(self.get_queryset())
page = self.paginate_queryset(queryset)
if page is not None:
serializer = self.get_serializer(page, many=True)
return self.get_paginated_response(serializer.data)
query_all = self.request.query_params.get('all', False)
if is_true(query_all):
# 适配像 ES 这种没有指定分页只返回少量数据的情况
queryset = queryset[:]
serializer = self.get_serializer(queryset, many=True)
return Response(serializer.data)
def get_queryset(self): def get_queryset(self):
command_storage_id = self.request.query_params.get('command_storage_id') command_storage_id = self.request.query_params.get('command_storage_id')

View File

@ -10,6 +10,7 @@ import inspect
from django.db.models import QuerySet as DJQuerySet from django.db.models import QuerySet as DJQuerySet
from elasticsearch import Elasticsearch from elasticsearch import Elasticsearch
from elasticsearch.helpers import bulk from elasticsearch.helpers import bulk
from elasticsearch.exceptions import RequestError
from common.utils.common import lazyproperty from common.utils.common import lazyproperty
from common.utils import get_logger from common.utils import get_logger
@ -31,6 +32,15 @@ class CommandStore():
kwargs['verify_certs'] = None kwargs['verify_certs'] = None
self.es = Elasticsearch(hosts=hosts, max_retries=0, **kwargs) self.es = Elasticsearch(hosts=hosts, max_retries=0, **kwargs)
def pre_use_check(self):
self._ensure_index_exists()
def _ensure_index_exists(self):
try:
self.es.indices.create(self.index)
except RequestError:
pass
@staticmethod @staticmethod
def make_data(command): def make_data(command):
data = dict( data = dict(
@ -234,6 +244,7 @@ class QuerySet(DJQuerySet):
uqs = QuerySet(self._command_store_config) uqs = QuerySet(self._command_store_config)
uqs._method_calls = self._method_calls.copy() uqs._method_calls = self._method_calls.copy()
uqs._slice = self._slice uqs._slice = self._slice
uqs.model = self.model
return uqs return uqs
def count(self, limit_to_max_result_window=True): def count(self, limit_to_max_result_window=True):

View File

@ -76,6 +76,15 @@ class CommandStorage(CommonModelMixin):
qs.model = Command qs.model = Command
return qs return qs
def save(self, force_insert=False, force_update=False, using=None,
update_fields=None):
super().save()
if self.type in TYPE_ENGINE_MAPPING:
engine_mod = import_module(TYPE_ENGINE_MAPPING[self.type])
backend = engine_mod.CommandStore(self.config)
backend.pre_use_check()
class ReplayStorage(CommonModelMixin): class ReplayStorage(CommonModelMixin):
name = models.CharField(max_length=128, verbose_name=_("Name"), unique=True) name = models.CharField(max_length=128, verbose_name=_("Name"), unique=True)

View File

@ -7,6 +7,7 @@ from django.conf import settings
from common.utils import get_logger from common.utils import get_logger
from users.models import User from users.models import User
from orgs.utils import tmp_to_root_org
from .status import Status from .status import Status
from .. import const from .. import const
from ..const import ComponentStatusChoices as StatusChoice from ..const import ComponentStatusChoices as StatusChoice
@ -112,7 +113,6 @@ class Terminal(StorageMixin, TerminalStatusMixin, models.Model):
date_created = models.DateTimeField(auto_now_add=True) date_created = models.DateTimeField(auto_now_add=True)
comment = models.TextField(blank=True, verbose_name=_('Comment')) comment = models.TextField(blank=True, verbose_name=_('Comment'))
@property @property
def is_active(self): def is_active(self):
if self.user and self.user.is_active: if self.user and self.user.is_active:
@ -126,6 +126,7 @@ class Terminal(StorageMixin, TerminalStatusMixin, models.Model):
self.user.save() self.user.save()
def get_online_sessions(self): def get_online_sessions(self):
with tmp_to_root_org():
return Session.objects.filter(terminal=self, is_finished=False) return Session.objects.filter(terminal=self, is_finished=False)
def get_online_session_count(self): def get_online_session_count(self):

View File

@ -1,8 +1,8 @@
# ~*~ coding: utf-8 ~*~ # ~*~ coding: utf-8 ~*~
from django.core.cache import cache from collections import defaultdict
from django.utils.translation import ugettext as _ from django.utils.translation import ugettext as _
from rest_framework.decorators import action from rest_framework.decorators import action
from django.conf import settings
from rest_framework import generics from rest_framework import generics
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework_bulk import BulkModelViewSet from rest_framework_bulk import BulkModelViewSet
@ -155,10 +155,17 @@ class UserViewSet(CommonApiMixin, UserQuerysetMixin, BulkModelViewSet):
serializer = serializer_cls(data=data, many=True) serializer = serializer_cls(data=data, many=True)
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
validated_data = serializer.validated_data validated_data = serializer.validated_data
users_by_role = defaultdict(list)
for i in validated_data: for i in validated_data:
i['org_id'] = current_org.org_id() users_by_role[i['role']].append(i['user'])
relations = [OrganizationMember(**i) for i in validated_data]
OrganizationMember.objects.bulk_create(relations, ignore_conflicts=True) OrganizationMember.objects.add_users_by_role(
current_org,
users=users_by_role[ORG_ROLE.USER],
admins=users_by_role[ORG_ROLE.ADMIN],
auditors=users_by_role[ORG_ROLE.AUDITOR]
)
return Response(serializer.data, status=201) return Response(serializer.data, status=201)
@action(methods=['post'], detail=True, permission_classes=(IsOrgAdmin,)) @action(methods=['post'], detail=True, permission_classes=(IsOrgAdmin,))

View File

@ -667,6 +667,11 @@ class User(AuthMixin, TokenMixin, RoleMixin, MFAMixin, AbstractUser):
else: else:
return user_default return user_default
def unblock_login(self):
from users.utils import LoginBlockUtil, MFABlockUtils
LoginBlockUtil.unblock_user(self.username)
MFABlockUtils.unblock_user(self.username)
@property @property
def login_blocked(self): def login_blocked(self):
from users.utils import LoginBlockUtil, MFABlockUtils from users.utils import LoginBlockUtil, MFABlockUtils