diff --git a/apps/applications/api/application.py b/apps/applications/api/application.py index 8aa9c0550..4d04e894f 100644 --- a/apps/applications/api/application.py +++ b/apps/applications/api/application.py @@ -1,19 +1,16 @@ # coding: utf-8 # from orgs.mixins.api import OrgBulkModelViewSet -from rest_framework import status from rest_framework.decorators import action from rest_framework.response import Response -from rest_framework.viewsets import GenericViewSet from common.tree import TreeNodeSerializer from common.mixins.api import SuggestionMixin -from ..utils import db_port_manager from .. import serializers from ..models import Application -__all__ = ['ApplicationViewSet', 'DBListenPortViewSet'] +__all__ = ['ApplicationViewSet'] class ApplicationViewSet(SuggestionMixin, OrgBulkModelViewSet): @@ -41,30 +38,3 @@ class ApplicationViewSet(SuggestionMixin, OrgBulkModelViewSet): tree_nodes = Application.create_tree_nodes(queryset, show_count=show_count) serializer = self.get_serializer(tree_nodes, many=True) return Response(serializer.data) - - -class DBListenPortViewSet(GenericViewSet): - rbac_perms = { - 'GET': 'applications.view_application', - 'list': 'applications.view_application', - 'db_info': 'applications.view_application', - } - - http_method_names = ['get', 'post'] - - def list(self, request, *args, **kwargs): - ports = db_port_manager.get_already_use_ports() - return Response(data=ports, status=status.HTTP_200_OK) - - @action(methods=['post'], detail=False, url_path='db-info') - def db_info(self, request, *args, **kwargs): - port = request.data.get("port") - db, msg = db_port_manager.get_db_by_port(port) - if db: - serializer = serializers.AppSerializer(instance=db) - data = serializer.data - _status = status.HTTP_200_OK - else: - data = {'error': msg} - _status = status.HTTP_404_NOT_FOUND - return Response(data=data, status=_status) diff --git a/apps/applications/signal_handlers.py b/apps/applications/signal_handlers.py index 0d92a9f96..4aa11c79b 100644 --- a/apps/applications/signal_handlers.py +++ b/apps/applications/signal_handlers.py @@ -1,36 +1,2 @@ # -*- coding: utf-8 -*- -# -from django.db.models.signals import post_save, post_delete - -from common.signals import django_ready -from django.dispatch import receiver -from common.utils import get_logger -from .models import Application -from .utils import db_port_manager, DBPortManager - -db_port_manager: DBPortManager - - -logger = get_logger(__file__) - - -@receiver(django_ready) -def init_db_port_mapper(sender, **kwargs): - logger.info('Init db port mapper') - db_port_manager.init() - - -@receiver(post_save, sender=Application) -def on_db_app_created(sender, instance: Application, created, **kwargs): - if not instance.category_db: - return - if not created: - return - db_port_manager.add(instance) - - -@receiver(post_delete, sender=Application) -def on_db_app_delete(sender, instance, **kwargs): - if not instance.category_db: - return - db_port_manager.pop(instance) +# \ No newline at end of file diff --git a/apps/applications/urls/api_urls.py b/apps/applications/urls/api_urls.py index 813c047a2..4fdf006b0 100644 --- a/apps/applications/urls/api_urls.py +++ b/apps/applications/urls/api_urls.py @@ -13,7 +13,6 @@ router.register(r'applications', api.ApplicationViewSet, 'application') router.register(r'accounts', api.ApplicationAccountViewSet, 'application-account') router.register(r'system-users-apps-relations', api.SystemUserAppRelationViewSet, 'system-users-apps-relation') router.register(r'account-secrets', api.ApplicationAccountSecretViewSet, 'application-account-secret') -router.register(r'db-listen-ports', api.DBListenPortViewSet, 'db-listen-ports') urlpatterns = [ diff --git a/apps/terminal/api/__init__.py b/apps/terminal/api/__init__.py index 16021a5ed..4cdabc9dd 100644 --- a/apps/terminal/api/__init__.py +++ b/apps/terminal/api/__init__.py @@ -8,3 +8,4 @@ from .storage import * from .status import * from .sharing import * from .endpoint import * +from .db_listen_port import * diff --git a/apps/terminal/api/db_listen_port.py b/apps/terminal/api/db_listen_port.py new file mode 100644 index 000000000..4170a33d4 --- /dev/null +++ b/apps/terminal/api/db_listen_port.py @@ -0,0 +1,36 @@ +# coding: utf-8 +# +from rest_framework import status +from rest_framework.decorators import action +from rest_framework.response import Response +from rest_framework.viewsets import GenericViewSet + +from ..utils import db_port_manager, DBPortManager +from applications import serializers + + +db_port_manager: DBPortManager + + +__all__ = ['DBListenPortViewSet'] + + +class DBListenPortViewSet(GenericViewSet): + rbac_perms = { + 'GET': 'applications.view_application', + 'list': 'applications.view_application', + 'db_info': 'applications.view_application', + } + + http_method_names = ['get', 'post'] + + def list(self, request, *args, **kwargs): + ports = db_port_manager.get_already_use_ports() + return Response(data=ports, status=status.HTTP_200_OK) + + @action(methods=['get'], detail=False, url_path='db-info') + def db_info(self, request, *args, **kwargs): + port = request.query_params.get("port") + db = db_port_manager.get_db_by_port(port) + serializer = serializers.AppSerializer(instance=db) + return Response(data=serializer.data, status=status.HTTP_200_OK) diff --git a/apps/terminal/models/endpoint.py b/apps/terminal/models/endpoint.py index 98cc6a328..9d65fa271 100644 --- a/apps/terminal/models/endpoint.py +++ b/apps/terminal/models/endpoint.py @@ -2,12 +2,14 @@ from django.db import models from django.utils.translation import ugettext_lazy as _ from django.core.validators import MinValueValidator, MaxValueValidator from applications.models import Application -from applications.utils import db_port_manager +from ..utils import db_port_manager, DBPortManager from common.db.models import JMSModel from common.db.fields import PortField from common.utils.ip import contains_ip from common.exceptions import JMSException +db_port_manager: DBPortManager + class Endpoint(JMSModel): name = models.CharField(max_length=128, verbose_name=_('Name'), unique=True) @@ -34,10 +36,6 @@ class Endpoint(JMSModel): port = getattr(self, f'{protocol}_port', 0) elif isinstance(target_instance, Application) and target_instance.category_db: port = db_port_manager.get_port_by_db(target_instance) - if port is None: - error = 'No application port is matched, application id: {}' \ - ''.format(target_instance.id) - raise JMSException(error) else: port = 0 return port diff --git a/apps/terminal/serializers/endpoint.py b/apps/terminal/serializers/endpoint.py index 75eb60352..ce45ffa3d 100644 --- a/apps/terminal/serializers/endpoint.py +++ b/apps/terminal/serializers/endpoint.py @@ -2,9 +2,10 @@ from rest_framework import serializers from django.utils.translation import ugettext_lazy as _ from common.drf.serializers import BulkModelSerializer from acls.serializers.rules import ip_group_child_validator, ip_group_help_text -from applications.utils import db_port_manager +from ..utils import db_port_manager from ..models import Endpoint, EndpointRule + __all__ = ['EndpointSerializer', 'EndpointRuleSerializer'] diff --git a/apps/terminal/signal_handlers.py b/apps/terminal/signal_handlers.py index ec51c5a2b..0d92a9f96 100644 --- a/apps/terminal/signal_handlers.py +++ b/apps/terminal/signal_handlers.py @@ -1,2 +1,36 @@ # -*- coding: utf-8 -*- # +from django.db.models.signals import post_save, post_delete + +from common.signals import django_ready +from django.dispatch import receiver +from common.utils import get_logger +from .models import Application +from .utils import db_port_manager, DBPortManager + +db_port_manager: DBPortManager + + +logger = get_logger(__file__) + + +@receiver(django_ready) +def init_db_port_mapper(sender, **kwargs): + logger.info('Init db port mapper') + db_port_manager.init() + + +@receiver(post_save, sender=Application) +def on_db_app_created(sender, instance: Application, created, **kwargs): + if not instance.category_db: + return + if not created: + return + db_port_manager.add(instance) + + +@receiver(post_delete, sender=Application) +def on_db_app_delete(sender, instance, **kwargs): + if not instance.category_db: + return + db_port_manager.pop(instance) diff --git a/apps/terminal/urls/api_urls.py b/apps/terminal/urls/api_urls.py index 3f0445350..8adcb8f52 100644 --- a/apps/terminal/urls/api_urls.py +++ b/apps/terminal/urls/api_urls.py @@ -24,6 +24,7 @@ router.register(r'session-sharings', api.SessionSharingViewSet, 'session-sharing router.register(r'session-join-records', api.SessionJoinRecordsViewSet, 'session-sharing-record') router.register(r'endpoints', api.EndpointViewSet, 'endpoint') router.register(r'endpoint-rules', api.EndpointRuleViewSet, 'endpoint-rule') +router.register(r'db-listen-ports', api.DBListenPortViewSet, 'db-listen-ports') urlpatterns = [ path('my-sessions/', api.MySessionAPIView.as_view(), name='my-session'), diff --git a/apps/terminal/utils/__init__.py b/apps/terminal/utils/__init__.py new file mode 100644 index 000000000..b55827a8d --- /dev/null +++ b/apps/terminal/utils/__init__.py @@ -0,0 +1,4 @@ +from .components import * +from .common import * +from .session_replay import * +from .db_port_mapper import * diff --git a/apps/terminal/utils/common.py b/apps/terminal/utils/common.py new file mode 100644 index 000000000..26fd303b2 --- /dev/null +++ b/apps/terminal/utils/common.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- +# + +from common.utils import get_logger +from .. import const +from tickets.models import TicketSession + + +logger = get_logger(__name__) + + +class ComputeStatUtil: + # system status + @staticmethod + def _common_compute_system_status(value, thresholds): + if thresholds[0] <= value <= thresholds[1]: + return const.ComponentStatusChoices.normal.value + elif thresholds[1] < value <= thresholds[2]: + return const.ComponentStatusChoices.high.value + else: + return const.ComponentStatusChoices.critical.value + + @classmethod + def _compute_system_stat_status(cls, stat): + system_stat_thresholds_mapper = { + 'cpu_load': [0, 5, 20], + 'memory_used': [0, 85, 95], + 'disk_used': [0, 80, 99] + } + system_status = {} + for stat_key, thresholds in system_stat_thresholds_mapper.items(): + stat_value = getattr(stat, stat_key) + if stat_value is None: + msg = 'stat: {}, stat_key: {}, stat_value: {}' + logger.debug(msg.format(stat, stat_key, stat_value)) + stat_value = 0 + status = cls._common_compute_system_status(stat_value, thresholds) + system_status[stat_key] = status + return system_status + + @classmethod + def compute_component_status(cls, stat): + if not stat: + return const.ComponentStatusChoices.offline + system_status_values = cls._compute_system_stat_status(stat).values() + if const.ComponentStatusChoices.critical in system_status_values: + return const.ComponentStatusChoices.critical + elif const.ComponentStatusChoices.high in system_status_values: + return const.ComponentStatusChoices.high + else: + return const.ComponentStatusChoices.normal + + +def is_session_approver(session_id, user_id): + ticket = TicketSession.get_ticket_by_session_id(session_id) + if not ticket: + return False + ok = ticket.has_all_assignee(user_id) + return ok diff --git a/apps/terminal/utils.py b/apps/terminal/utils/components.py similarity index 55% rename from apps/terminal/utils.py rename to apps/terminal/utils/components.py index abdfbd738..0610b0b79 100644 --- a/apps/terminal/utils.py +++ b/apps/terminal/utils/components.py @@ -1,124 +1,13 @@ # -*- coding: utf-8 -*- # -import os -from itertools import groupby, chain - -from django.conf import settings -from django.core.files.storage import default_storage - -import jms_storage +from itertools import groupby from common.utils import get_logger -from . import const -from .models import ReplayStorage -from tickets.models import TicketSession, TicketStep, TicketAssignee -from tickets.const import StepState logger = get_logger(__name__) -def find_session_replay_local(session): - # 存在外部存储上,所有可能的路径名 - session_paths = session.get_all_possible_relative_path() - - # 存在本地存储上,所有可能的路径名 - local_paths = session.get_all_possible_local_path() - - for _local_path in chain(session_paths, local_paths): - if default_storage.exists(_local_path): - url = default_storage.url(_local_path) - return _local_path, url - return None, None - - -def download_session_replay(session): - replay_storages = ReplayStorage.objects.all() - configs = { - storage.name: storage.config - for storage in replay_storages - if not storage.type_null_or_server - } - if settings.SERVER_REPLAY_STORAGE: - configs['SERVER_REPLAY_STORAGE'] = settings.SERVER_REPLAY_STORAGE - if not configs: - msg = "Not found replay file, and not remote storage set" - return None, msg - storage = jms_storage.get_multi_object_storage(configs) - - # 获取外部存储路径名 - session_path = session.find_ok_relative_path_in_storage(storage) - if not session_path: - msg = "Not found session replay file" - return None, msg - - # 通过外部存储路径名后缀,构造真实的本地存储路径 - local_path = session.get_local_path_by_relative_path(session_path) - - # 保存到storage的路径 - target_path = os.path.join(default_storage.base_location, local_path) - target_dir = os.path.dirname(target_path) - if not os.path.isdir(target_dir): - os.makedirs(target_dir, exist_ok=True) - - ok, err = storage.download(session_path, target_path) - if not ok: - msg = "Failed download replay file: {}".format(err) - logger.error(msg) - return None, msg - url = default_storage.url(local_path) - return local_path, url - - -def get_session_replay_url(session): - local_path, url = find_session_replay_local(session) - if local_path is None: - local_path, url = download_session_replay(session) - return local_path, url - - -class ComputeStatUtil: - # system status - @staticmethod - def _common_compute_system_status(value, thresholds): - if thresholds[0] <= value <= thresholds[1]: - return const.ComponentStatusChoices.normal.value - elif thresholds[1] < value <= thresholds[2]: - return const.ComponentStatusChoices.high.value - else: - return const.ComponentStatusChoices.critical.value - - @classmethod - def _compute_system_stat_status(cls, stat): - system_stat_thresholds_mapper = { - 'cpu_load': [0, 5, 20], - 'memory_used': [0, 85, 95], - 'disk_used': [0, 80, 99] - } - system_status = {} - for stat_key, thresholds in system_stat_thresholds_mapper.items(): - stat_value = getattr(stat, stat_key) - if stat_value is None: - msg = 'stat: {}, stat_key: {}, stat_value: {}' - logger.debug(msg.format(stat, stat_key, stat_value)) - stat_value = 0 - status = cls._common_compute_system_status(stat_value, thresholds) - system_status[stat_key] = status - return system_status - - @classmethod - def compute_component_status(cls, stat): - if not stat: - return const.ComponentStatusChoices.offline - system_status_values = cls._compute_system_stat_status(stat).values() - if const.ComponentStatusChoices.critical in system_status_values: - return const.ComponentStatusChoices.critical - elif const.ComponentStatusChoices.high in system_status_values: - return const.ComponentStatusChoices.high - else: - return const.ComponentStatusChoices.normal - - class TypedComponentsStatusMetricsUtil(object): def __init__(self): self.components = [] @@ -126,7 +15,7 @@ class TypedComponentsStatusMetricsUtil(object): self.get_components() def get_components(self): - from .models import Terminal + from ..models import Terminal components = Terminal.objects.filter(is_deleted=False).order_by('type') grouped_components = groupby(components, lambda c: c.type) grouped_components = [(i[0], list(i[1])) for i in grouped_components] @@ -251,10 +140,3 @@ class ComponentsPrometheusMetricsUtil(TypedComponentsStatusMetricsUtil): prometheus_metrics_text = '\n'.join(prometheus_metrics) return prometheus_metrics_text - -def is_session_approver(session_id, user_id): - ticket = TicketSession.get_ticket_by_session_id(session_id) - if not ticket: - return False - ok = ticket.has_all_assignee(user_id) - return ok diff --git a/apps/applications/utils/db_port_mapper.py b/apps/terminal/utils/db_port_mapper.py similarity index 57% rename from apps/applications/utils/db_port_mapper.py rename to apps/terminal/utils/db_port_mapper.py index 13e185b2e..2ca84cc0e 100644 --- a/apps/applications/utils/db_port_mapper.py +++ b/apps/terminal/utils/db_port_mapper.py @@ -6,6 +6,7 @@ from applications.models import Application from common.utils import get_logger from common.utils import get_object_or_none from orgs.utils import tmp_to_root_org +from common.exceptions import JMSException logger = get_logger(__file__) @@ -22,24 +23,23 @@ class DBPortManager(object): self.port_limit = settings.MAGNUS_DB_PORTS_LIMIT_COUNT self.port_end = self.port_start + self.port_limit # 可以使用的端口列表 - self.all_usable_ports = [i for i in range(self.port_start, self.port_end+1)] + self.all_available_ports = list(range(self.port_start, self.port_end + 1)) @property def magnus_listen_port_range(self): return f'{self.port_start}-{self.port_end}' def init(self): - db_ids = Application.objects.filter(category=AppCategory.db).values_list('id', flat=True) + with tmp_to_root_org(): + db_ids = Application.objects.filter(category=AppCategory.db).values_list('id', flat=True) db_ids = [str(i) for i in db_ids] - mapper = dict(zip(self.all_usable_ports, list(db_ids))) + mapper = dict(zip(self.all_available_ports, list(db_ids))) self.set_mapper(mapper) def add(self, db: Application): mapper = self.get_mapper() - usable_port = self.get_next_usable_port() - if not usable_port: - return False - mapper.update({usable_port: str(db.id)}) + available_port = self.get_next_available_port() + mapper.update({available_port: str(db.id)}) self.set_mapper(mapper) return True @@ -54,43 +54,42 @@ class DBPortManager(object): for port, db_id in mapper.items(): if db_id == str(db.id): return port - logger.warning( - 'Not matched db port, db_id: {}, mapper length: {}'.format(db.id, len(mapper)) + raise JMSException( + 'Not matched db port, db id: {}, mapper length: {}'.format(db.id, len(mapper)) ) def get_db_by_port(self, port): mapper = self.get_mapper() db_id = mapper.get(port, None) - if db_id: - with tmp_to_root_org(): - db = get_object_or_none(Application, id=db_id) - if not db: - msg = 'Database not exists, database id: {}'.format(db_id) - else: - msg = '' - else: - db = None - msg = 'Port not in port-db mapper, port: {}'.format(port) - return db, msg + if not db_id: + raise JMSException('Database not in port-db mapper, port: {}'.format(port)) + with tmp_to_root_org(): + db = get_object_or_none(Application, id=db_id) + if not db: + raise JMSException('Database not exists, db id: {}'.format(db_id)) + return db - def get_next_usable_port(self): + def get_next_available_port(self): already_use_ports = self.get_already_use_ports() - usable_ports = sorted(list(set(self.all_usable_ports) - set(already_use_ports))) - if len(usable_ports) > 1: - port = usable_ports[0] - logger.debug('Get next usable port: {}'.format(port)) - return port - - msg = 'No port is usable, All usable port count: {}, Already use port count: {}'.format( - len(self.all_usable_ports), len(already_use_ports) - ) - logger.warning(msg) + available_ports = sorted(list(set(self.all_available_ports) - set(already_use_ports))) + if len(available_ports) <= 0: + raise JMSException( + 'No port is available, All available port count: {}, Already use port count: {}' + ''.format(len(self.all_available_ports), len(already_use_ports)) + ) + port = available_ports[0] + logger.debug('Get next available port: {}'.format(port)) + return port def get_already_use_ports(self): mapper = self.get_mapper() return sorted(list(mapper.keys())) def get_mapper(self): + mapper = cache.get(self.CACHE_KEY, {}) + if not mapper: + # redis 可能被清空,重新初始化一下 + self.init() return cache.get(self.CACHE_KEY, {}) def set_mapper(self, value): diff --git a/apps/terminal/utils/session_replay.py b/apps/terminal/utils/session_replay.py new file mode 100644 index 000000000..f1b061cb0 --- /dev/null +++ b/apps/terminal/utils/session_replay.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +# +import os +from itertools import groupby, chain + +from django.conf import settings +from django.core.files.storage import default_storage + +import jms_storage + +from common.utils import get_logger +from ..models import ReplayStorage + + +logger = get_logger(__name__) + + +def find_session_replay_local(session): + # 存在外部存储上,所有可能的路径名 + session_paths = session.get_all_possible_relative_path() + + # 存在本地存储上,所有可能的路径名 + local_paths = session.get_all_possible_local_path() + + for _local_path in chain(session_paths, local_paths): + if default_storage.exists(_local_path): + url = default_storage.url(_local_path) + return _local_path, url + return None, None + + +def download_session_replay(session): + replay_storages = ReplayStorage.objects.all() + configs = { + storage.name: storage.config + for storage in replay_storages + if not storage.type_null_or_server + } + if settings.SERVER_REPLAY_STORAGE: + configs['SERVER_REPLAY_STORAGE'] = settings.SERVER_REPLAY_STORAGE + if not configs: + msg = "Not found replay file, and not remote storage set" + return None, msg + storage = jms_storage.get_multi_object_storage(configs) + + # 获取外部存储路径名 + session_path = session.find_ok_relative_path_in_storage(storage) + if not session_path: + msg = "Not found session replay file" + return None, msg + + # 通过外部存储路径名后缀,构造真实的本地存储路径 + local_path = session.get_local_path_by_relative_path(session_path) + + # 保存到storage的路径 + target_path = os.path.join(default_storage.base_location, local_path) + target_dir = os.path.dirname(target_path) + if not os.path.isdir(target_dir): + os.makedirs(target_dir, exist_ok=True) + + ok, err = storage.download(session_path, target_path) + if not ok: + msg = "Failed download replay file: {}".format(err) + logger.error(msg) + return None, msg + url = default_storage.url(local_path) + return local_path, url + + +def get_session_replay_url(session): + local_path, url = find_session_replay_local(session) + if local_path is None: + local_path, url = download_session_replay(session) + return local_path, url +