diff --git a/apps/assets/models/node.py b/apps/assets/models/node.py index ad17a8be9..973df4b4a 100644 --- a/apps/assets/models/node.py +++ b/apps/assets/models/node.py @@ -307,6 +307,15 @@ class NodeAllAssetsMappingMixin: org_id = str(org_id) cls.orgid_nodekey_assetsid_mapping.pop(org_id, None) + @classmethod + def expire_all_orgs_node_all_asset_ids_mapping_from_memory(cls): + orgs = Organization.objects.all() + org_ids = [str(org.id) for org in orgs] + org_ids.append(Organization.ROOT_ID) + + for id in org_ids: + cls.expire_node_all_asset_ids_mapping_from_memory(id) + # get order: from memory -> (from cache -> to generate) @classmethod def get_node_all_asset_ids_mapping_from_cache_or_generate_to_cache(cls, org_id): diff --git a/apps/assets/signals_handler/node_assets_mapping.py b/apps/assets/signals_handler/node_assets_mapping.py index 4e2b0d07b..efa3cb29f 100644 --- a/apps/assets/signals_handler/node_assets_mapping.py +++ b/apps/assets/signals_handler/node_assets_mapping.py @@ -13,6 +13,7 @@ from common.signals import django_ready from common.utils.connection import RedisPubSub from common.utils import get_logger from assets.models import Asset, Node +from orgs.models import Organization logger = get_logger(__file__) @@ -36,13 +37,18 @@ node_assets_mapping_for_memory_pub_sub = NodeAssetsMappingForMemoryPubSub() def expire_node_assets_mapping_for_memory(org_id): # 所有进程清除(自己的 memory 数据) org_id = str(org_id) - node_assets_mapping_for_memory_pub_sub.publish(org_id) + root_org_id = Organization.ROOT_ID + # 当前进程清除(cache 数据) logger.debug( "Expire node assets id mapping from cache of org={}, pid={}" "".format(org_id, os.getpid()) ) Node.expire_node_all_asset_ids_mapping_from_cache(org_id) + Node.expire_node_all_asset_ids_mapping_from_cache(root_org_id) + + node_assets_mapping_for_memory_pub_sub.publish(org_id) + node_assets_mapping_for_memory_pub_sub.publish(root_org_id) @receiver(post_save, sender=Node) @@ -73,16 +79,22 @@ def subscribe_node_assets_mapping_expire(sender, **kwargs): logger.debug("Start subscribe for expire node assets id mapping from memory") def keep_subscribe(): - subscribe = node_assets_mapping_for_memory_pub_sub.subscribe() - for message in subscribe.listen(): - if message["type"] != "message": - continue - org_id = message['data'].decode() - Node.expire_node_all_asset_ids_mapping_from_memory(org_id) - logger.debug( - "Expire node assets id mapping from memory of org={}, pid={}" - "".format(str(org_id), os.getpid()) - ) + while True: + try: + subscribe = node_assets_mapping_for_memory_pub_sub.subscribe() + for message in subscribe.listen(): + if message["type"] != "message": + continue + org_id = message['data'].decode() + Node.expire_node_all_asset_ids_mapping_from_memory(org_id) + logger.debug( + "Expire node assets id mapping from memory of org={}, pid={}" + "".format(str(org_id), os.getpid()) + ) + except Exception as e: + logger.exception(f'subscribe_node_assets_mapping_expire: {e}') + Node.expire_all_orgs_node_all_asset_ids_mapping_from_memory() + t = threading.Thread(target=keep_subscribe) t.daemon = True t.start() diff --git a/apps/audits/signals_handler.py b/apps/audits/signals_handler.py index 80ff68d1d..c604e9b6b 100644 --- a/apps/audits/signals_handler.py +++ b/apps/audits/signals_handler.py @@ -44,6 +44,7 @@ class AuthBackendLabelMapping(LazyObject): backend_label_mapping[backend] = source.label backend_label_mapping[settings.AUTH_BACKEND_PUBKEY] = _('SSH Key') backend_label_mapping[settings.AUTH_BACKEND_MODEL] = _('Password') + backend_label_mapping[settings.AUTH_BACKEND_SSO] = _('SSO') return backend_label_mapping def _setup(self): diff --git a/apps/authentication/errors.py b/apps/authentication/errors.py index 06631742a..8f9ba9307 100644 --- a/apps/authentication/errors.py +++ b/apps/authentication/errors.py @@ -16,6 +16,7 @@ reason_user_not_exist = 'user_not_exist' reason_password_expired = 'password_expired' reason_user_invalid = 'user_invalid' reason_user_inactive = 'user_inactive' +reason_user_expired = 'user_expired' reason_backend_not_match = 'backend_not_match' reason_acl_not_allow = 'acl_not_allow' @@ -28,6 +29,7 @@ reason_choices = { reason_password_expired: _("Password expired"), reason_user_invalid: _('Disabled or expired'), reason_user_inactive: _("This account is inactive."), + reason_user_expired: _("This account is expired"), reason_backend_not_match: _("Auth backend not match"), reason_acl_not_allow: _("ACL is not allowed"), } diff --git a/apps/authentication/mixins.py b/apps/authentication/mixins.py index e13a88c87..89d7b85fd 100644 --- a/apps/authentication/mixins.py +++ b/apps/authentication/mixins.py @@ -171,7 +171,7 @@ class AuthMixin: if not user: self.raise_credential_error(errors.reason_password_failed) elif user.is_expired: - self.raise_credential_error(errors.reason_user_inactive) + self.raise_credential_error(errors.reason_user_expired) elif not user.is_active: self.raise_credential_error(errors.reason_user_inactive) return user diff --git a/apps/authentication/views/login.py b/apps/authentication/views/login.py index 9f628e6e8..e1a3ab4b6 100644 --- a/apps/authentication/views/login.py +++ b/apps/authentication/views/login.py @@ -69,6 +69,7 @@ class UserLoginView(mixins.AuthMixin, FormView): new_form = form_cls(data=form.data) new_form._errors = form.errors context = self.get_context_data(form=new_form) + self.request.session.set_test_cookie() return self.render_to_response(context) except (errors.PasswdTooSimple, errors.PasswordRequireResetError) as e: return redirect(e.url) diff --git a/apps/jumpserver/api.py b/apps/jumpserver/api.py index e31a1d843..bda2537c8 100644 --- a/apps/jumpserver/api.py +++ b/apps/jumpserver/api.py @@ -1,3 +1,5 @@ +import time + from django.core.cache import cache from django.utils import timezone from django.utils.timesince import timesince @@ -6,6 +8,8 @@ from django.http.response import JsonResponse, HttpResponse from rest_framework.views import APIView from rest_framework.permissions import AllowAny from collections import Counter +from django.conf import settings +from rest_framework.response import Response from users.models import User from assets.models import Asset @@ -307,7 +311,68 @@ class IndexApi(TotalCountMixin, DatesLoginMetricMixin, APIView): return JsonResponse(data, status=200) -class PrometheusMetricsApi(APIView): +class HealthApiMixin(APIView): + def is_token_right(self): + token = self.request.query_params.get('token') + ok_token = settings.HEALTH_CHECK_TOKEN + if ok_token and token != ok_token: + return False + return True + + def check_permissions(self, request): + if not self.is_token_right(): + msg = 'Health check token error, ' \ + 'Please set query param in url and same with setting HEALTH_CHECK_TOKEN. ' \ + 'eg: $PATH/?token=$HEALTH_CHECK_TOKEN' + self.permission_denied(request, message={'error': msg}, code=403) + + +class HealthCheckView(HealthApiMixin): + permission_classes = (AllowAny,) + + @staticmethod + def get_db_status(): + t1 = time.time() + try: + User.objects.first() + t2 = time.time() + return True, t2 - t1 + except: + t2 = time.time() + return False, t2 - t1 + + def get_redis_status(self): + key = 'HEALTH_CHECK' + + t1 = time.time() + try: + value = '1' + cache.set(key, '1', 10) + got = cache.get(key) + t2 = time.time() + if value == got: + return True, t2 -t1 + return False, t2 -t1 + except: + t2 = time.time() + return False, t2 - t1 + + def get(self, request): + redis_status, redis_time = self.get_redis_status() + db_status, db_time = self.get_db_status() + status = all([redis_status, db_status]) + data = { + 'status': status, + 'db_status': db_status, + 'db_time': db_time, + 'redis_status': redis_status, + 'redis_time': redis_time, + 'time': int(time.time()) + } + return Response(data) + + +class PrometheusMetricsApi(HealthApiMixin): permission_classes = (AllowAny,) def get(self, request, *args, **kwargs): diff --git a/apps/jumpserver/conf.py b/apps/jumpserver/conf.py index d09a81bd2..179376828 100644 --- a/apps/jumpserver/conf.py +++ b/apps/jumpserver/conf.py @@ -289,6 +289,7 @@ class Config(dict): 'SESSION_SAVE_EVERY_REQUEST': True, 'SESSION_EXPIRE_AT_BROWSER_CLOSE_FORCE': False, 'FORGOT_PASSWORD_URL': '', + 'HEALTH_CHECK_TOKEN': '' } def compatible_auth_openid_of_key(self): diff --git a/apps/jumpserver/settings/custom.py b/apps/jumpserver/settings/custom.py index 89b8d6d53..936b27582 100644 --- a/apps/jumpserver/settings/custom.py +++ b/apps/jumpserver/settings/custom.py @@ -123,3 +123,4 @@ FORGOT_PASSWORD_URL = CONFIG.FORGOT_PASSWORD_URL # 自定义默认组织名 GLOBAL_ORG_DISPLAY_NAME = CONFIG.GLOBAL_ORG_DISPLAY_NAME +HEALTH_CHECK_TOKEN = CONFIG.HEALTH_CHECK_TOKEN diff --git a/apps/jumpserver/urls.py b/apps/jumpserver/urls.py index 044d09310..759c6f271 100644 --- a/apps/jumpserver/urls.py +++ b/apps/jumpserver/urls.py @@ -48,7 +48,8 @@ urlpatterns = [ path('', views.IndexView.as_view(), name='index'), path('api/v1/', include(api_v1)), re_path('api/(?P\w+)/(?Pv\d)/.*', views.redirect_format_api), - path('api/health/', views.HealthCheckView.as_view(), name="health"), + path('api/health/', api.HealthCheckView.as_view(), name="health"), + path('api/v1/health/', api.HealthCheckView.as_view(), name="health_v1"), # External apps url path('core/auth/captcha/', include('captcha.urls')), path('core/', include(app_view_patterns)), diff --git a/apps/jumpserver/views/other.py b/apps/jumpserver/views/other.py index da8046bfc..9ab561c4c 100644 --- a/apps/jumpserver/views/other.py +++ b/apps/jumpserver/views/other.py @@ -17,7 +17,7 @@ from common.http import HttpResponseTemporaryRedirect __all__ = [ - 'LunaView', 'I18NView', 'KokoView', 'WsView', 'HealthCheckView', + 'LunaView', 'I18NView', 'KokoView', 'WsView', 'redirect_format_api', 'redirect_old_apps_view', 'UIView' ] @@ -64,13 +64,6 @@ def redirect_old_apps_view(request, *args, **kwargs): return HttpResponseTemporaryRedirect(new_path) -class HealthCheckView(APIView): - permission_classes = (AllowAny,) - - def get(self, request): - return JsonResponse({"status": 1, "time": int(time.time())}) - - class WsView(APIView): ws_port = settings.HTTP_LISTEN_PORT + 1 diff --git a/apps/orgs/signals_handler/common.py b/apps/orgs/signals_handler/common.py index 59dc2806f..f59c4cb47 100644 --- a/apps/orgs/signals_handler/common.py +++ b/apps/orgs/signals_handler/common.py @@ -46,12 +46,19 @@ def subscribe_orgs_mapping_expire(sender, **kwargs): logger.debug("Start subscribe for expire orgs mapping from memory") def keep_subscribe(): - subscribe = orgs_mapping_for_memory_pub_sub.subscribe() - for message in subscribe.listen(): - if message['type'] != 'message': - continue - Organization.expire_orgs_mapping() - logger.debug('Expire orgs mapping') + while True: + try: + subscribe = orgs_mapping_for_memory_pub_sub.subscribe() + for message in subscribe.listen(): + if message['type'] != 'message': + continue + if message['data'] == b'error': + raise ValueError + Organization.expire_orgs_mapping() + logger.debug('Expire orgs mapping') + except Exception as e: + logger.exception(f'subscribe_orgs_mapping_expire: {e}') + Organization.expire_orgs_mapping() t = threading.Thread(target=keep_subscribe) t.daemon = True diff --git a/apps/settings/signals_handler.py b/apps/settings/signals_handler.py index 9625df3f6..6264fb9e6 100644 --- a/apps/settings/signals_handler.py +++ b/apps/settings/signals_handler.py @@ -6,6 +6,7 @@ import threading from django.dispatch import receiver from django.db.models.signals import post_save, pre_save from django.utils.functional import LazyObject +from django.db import close_old_connections from jumpserver.utils import current_request from common.decorator import on_transaction_commit @@ -71,13 +72,21 @@ def subscribe_settings_change(sender, **kwargs): logger.debug("Start subscribe setting change") def keep_subscribe(): - sub = setting_pub_sub.subscribe() - for msg in sub.listen(): - if msg["type"] != "message": - continue - item = msg['data'].decode() - logger.debug("Found setting change: {}".format(str(item))) - Setting.refresh_item(item) + while True: + try: + sub = setting_pub_sub.subscribe() + for msg in sub.listen(): + close_old_connections() + if msg["type"] != "message": + continue + item = msg['data'].decode() + logger.debug("Found setting change: {}".format(str(item))) + Setting.refresh_item(item) + except Exception as e: + logger.exception(f'subscribe_settings_change: {e}') + close_old_connections() + Setting.refresh_all_settings() + t = threading.Thread(target=keep_subscribe) t.daemon = True t.start() diff --git a/apps/terminal/api/command.py b/apps/terminal/api/command.py index f969e24d7..d7868e4a0 100644 --- a/apps/terminal/api/command.py +++ b/apps/terminal/api/command.py @@ -11,6 +11,7 @@ from rest_framework.response import Response from rest_framework.decorators import action from django.template import loader +from common.http import is_true from terminal.models import CommandStorage, Command from terminal.filters import CommandFilter from orgs.utils import current_org @@ -140,7 +141,21 @@ class CommandViewSet(viewsets.ModelViewSet): if session_id and not command_storage_id: # 会话里的命令列表肯定会提供 session_id,这里防止 merge 的时候取全量的数据 return self.merge_all_storage_list(request, *args, **kwargs) - return super().list(request, *args, **kwargs) + + queryset = self.filter_queryset(self.get_queryset()) + + page = self.paginate_queryset(queryset) + if page is not None: + serializer = self.get_serializer(page, many=True) + return self.get_paginated_response(serializer.data) + + query_all = self.request.query_params.get('all', False) + if is_true(query_all): + # 适配像 ES 这种没有指定分页只返回少量数据的情况 + queryset = queryset[:] + + serializer = self.get_serializer(queryset, many=True) + return Response(serializer.data) def get_queryset(self): command_storage_id = self.request.query_params.get('command_storage_id') diff --git a/apps/terminal/backends/command/es.py b/apps/terminal/backends/command/es.py index fc0f247f4..d8197391d 100644 --- a/apps/terminal/backends/command/es.py +++ b/apps/terminal/backends/command/es.py @@ -10,6 +10,7 @@ import inspect from django.db.models import QuerySet as DJQuerySet from elasticsearch import Elasticsearch from elasticsearch.helpers import bulk +from elasticsearch.exceptions import RequestError from common.utils.common import lazyproperty from common.utils import get_logger @@ -31,6 +32,15 @@ class CommandStore(): kwargs['verify_certs'] = None self.es = Elasticsearch(hosts=hosts, max_retries=0, **kwargs) + def pre_use_check(self): + self._ensure_index_exists() + + def _ensure_index_exists(self): + try: + self.es.indices.create(self.index) + except RequestError: + pass + @staticmethod def make_data(command): data = dict( @@ -234,6 +244,7 @@ class QuerySet(DJQuerySet): uqs = QuerySet(self._command_store_config) uqs._method_calls = self._method_calls.copy() uqs._slice = self._slice + uqs.model = self.model return uqs def count(self, limit_to_max_result_window=True): diff --git a/apps/terminal/models/storage.py b/apps/terminal/models/storage.py index 4826e2eef..883e5f67a 100644 --- a/apps/terminal/models/storage.py +++ b/apps/terminal/models/storage.py @@ -76,6 +76,15 @@ class CommandStorage(CommonModelMixin): qs.model = Command return qs + def save(self, force_insert=False, force_update=False, using=None, + update_fields=None): + super().save() + + if self.type in TYPE_ENGINE_MAPPING: + engine_mod = import_module(TYPE_ENGINE_MAPPING[self.type]) + backend = engine_mod.CommandStore(self.config) + backend.pre_use_check() + class ReplayStorage(CommonModelMixin): name = models.CharField(max_length=128, verbose_name=_("Name"), unique=True) diff --git a/apps/terminal/models/terminal.py b/apps/terminal/models/terminal.py index e13902251..77c9b1ce8 100644 --- a/apps/terminal/models/terminal.py +++ b/apps/terminal/models/terminal.py @@ -7,6 +7,7 @@ from django.conf import settings from common.utils import get_logger from users.models import User +from orgs.utils import tmp_to_root_org from .status import Status from .. import const from ..const import ComponentStatusChoices as StatusChoice @@ -112,7 +113,6 @@ class Terminal(StorageMixin, TerminalStatusMixin, models.Model): date_created = models.DateTimeField(auto_now_add=True) comment = models.TextField(blank=True, verbose_name=_('Comment')) - @property def is_active(self): if self.user and self.user.is_active: @@ -126,7 +126,8 @@ class Terminal(StorageMixin, TerminalStatusMixin, models.Model): self.user.save() def get_online_sessions(self): - return Session.objects.filter(terminal=self, is_finished=False) + with tmp_to_root_org(): + return Session.objects.filter(terminal=self, is_finished=False) def get_online_session_count(self): return self.get_online_sessions().count() diff --git a/apps/users/api/user.py b/apps/users/api/user.py index 6f39f40e5..ebab1ec3c 100644 --- a/apps/users/api/user.py +++ b/apps/users/api/user.py @@ -1,8 +1,8 @@ # ~*~ coding: utf-8 ~*~ -from django.core.cache import cache +from collections import defaultdict + from django.utils.translation import ugettext as _ from rest_framework.decorators import action -from django.conf import settings from rest_framework import generics from rest_framework.response import Response from rest_framework_bulk import BulkModelViewSet @@ -155,10 +155,17 @@ class UserViewSet(CommonApiMixin, UserQuerysetMixin, BulkModelViewSet): serializer = serializer_cls(data=data, many=True) serializer.is_valid(raise_exception=True) validated_data = serializer.validated_data + + users_by_role = defaultdict(list) for i in validated_data: - i['org_id'] = current_org.org_id() - relations = [OrganizationMember(**i) for i in validated_data] - OrganizationMember.objects.bulk_create(relations, ignore_conflicts=True) + users_by_role[i['role']].append(i['user']) + + OrganizationMember.objects.add_users_by_role( + current_org, + users=users_by_role[ORG_ROLE.USER], + admins=users_by_role[ORG_ROLE.ADMIN], + auditors=users_by_role[ORG_ROLE.AUDITOR] + ) return Response(serializer.data, status=201) @action(methods=['post'], detail=True, permission_classes=(IsOrgAdmin,)) diff --git a/apps/users/models/user.py b/apps/users/models/user.py index 52dccbeaa..fab14e252 100644 --- a/apps/users/models/user.py +++ b/apps/users/models/user.py @@ -667,6 +667,11 @@ class User(AuthMixin, TokenMixin, RoleMixin, MFAMixin, AbstractUser): else: return user_default + def unblock_login(self): + from users.utils import LoginBlockUtil, MFABlockUtils + LoginBlockUtil.unblock_user(self.username) + MFABlockUtils.unblock_user(self.username) + @property def login_blocked(self): from users.utils import LoginBlockUtil, MFABlockUtils