From 0981cd1ed1d80aa79c8afa4f36b94a1a53143303 Mon Sep 17 00:00:00 2001 From: ibuler Date: Tue, 29 Nov 2022 14:42:04 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20=E4=BF=AE=E6=94=B9=20Connect=20token=20?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E7=BB=93=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/authentication/api/connection_token.py | 64 ++++++++----------- .../migrations/0016_auto_20221125_2240.py | 15 ++++- .../migrations/0017_auto_20221128_1839.py | 13 ---- .../0018_connectiontoken_endpoint_protocol.py | 19 ++++++ .../authentication/models/connection_token.py | 3 + .../serializers/connection_token.py | 10 ++- apps/common/utils/django.py | 19 +++++- apps/common/utils/http.py | 7 +- apps/terminal/api/component/terminal.py | 9 +-- apps/terminal/const.py | 55 ++++++++++------ apps/terminal/serializers/terminal.py | 4 +- 11 files changed, 129 insertions(+), 89 deletions(-) create mode 100644 apps/authentication/migrations/0018_connectiontoken_endpoint_protocol.py diff --git a/apps/authentication/api/connection_token.py b/apps/authentication/api/connection_token.py index 777c02e6a..b8f461d52 100644 --- a/apps/authentication/api/connection_token.py +++ b/apps/authentication/api/connection_token.py @@ -16,10 +16,11 @@ from rest_framework.serializers import ValidationError from common.drf.api import JMSModelViewSet from common.http import is_true from common.utils import random_string +from common.utils.django import get_request_os from orgs.mixins.api import RootOrgViewMixin from perms.models import ActionChoices -from terminal.models import EndpointRule from terminal.const import NativeClient +from terminal.models import EndpointRule from ..models import ConnectionToken from ..serializers import ( ConnectionTokenSerializer, ConnectionTokenSecretSerializer, @@ -130,42 +131,32 @@ class RDPFileClientProtocolURLMixin: return true_value if is_true(os.getenv(env_key, env_default)) else false_value def get_client_protocol_data(self, token: ConnectionToken): - username = token.user.username - rdp_config = ssh_token = '' - connect_method = token.connect_method + _os = get_request_os(self.request) - if connect_method == NativeClient.ssh: - filename, ssh_token = self.get_ssh_token(token) - elif connect_method == NativeClient.mstsc: - filename, rdp_config = self.get_rdp_file_info(token) - else: - raise ValueError('Protocol not support: {}'.format(connect_method)) + connect_method = getattr(NativeClient, token.connect_method, None) + if connect_method is None: + raise ValueError('Connect method not support: {}'.format(token.connect_method)) - return { - "filename": filename, - "protocol": token.protocol, - "username": username, - "token": ssh_token, - "config": rdp_config - } - - def get_ssh_token(self, token: ConnectionToken): - if token.asset: - name = token.asset.name - else: - name = '*' - prefix_name = f'{token.user.username}-{name}' - filename = self.get_connect_filename(prefix_name) - - endpoint = self.get_smart_endpoint(protocol='ssh', asset=token.asset) data = { - 'ip': endpoint.host, - 'port': str(endpoint.ssh_port), - 'username': 'JMS-{}'.format(str(token.id)), - 'password': token.value + 'id': str(token.id), + 'value': token.value, + 'cmd': '', + 'file': {} } - token = json.dumps(data) - return filename, token + + if connect_method == NativeClient.mstsc: + filename, content = self.get_rdp_file_info(token) + data.update({ + 'file': { + 'filename': filename, + 'content': content, + } + }) + else: + endpoint = self.get_smart_endpoint(protocol=token.endpoint_protocol, asset=token.asset) + cmd = NativeClient.get_launch_command(connect_method, token, endpoint) + data.update({'cmd': cmd}) + return data def get_smart_endpoint(self, protocol, asset=None): target_ip = asset.get_target_ip() if asset else '' @@ -223,6 +214,7 @@ class ConnectionTokenViewSet(ExtraActionApiMixin, RootOrgViewMixin, JMSModelView 'get_secret_detail': ConnectionTokenSecretSerializer, } rbac_perms = { + 'list': 'authentication.view_connectiontoken', 'retrieve': 'authentication.view_connectiontoken', 'create': 'authentication.add_connectiontoken', 'expire': 'authentication.add_connectiontoken', @@ -252,9 +244,9 @@ class ConnectionTokenViewSet(ExtraActionApiMixin, RootOrgViewMixin, JMSModelView return Response(serializer.data, status=status.HTTP_200_OK) def get_queryset(self): - queryset = ConnectionToken.objects\ - .filter(user=self.request.user)\ - .filter(date_expired__lt=timezone.now()) + queryset = ConnectionToken.objects \ + .filter(user=self.request.user) \ + .filter(date_expired__gt=timezone.now()) return queryset def get_user(self, serializer): diff --git a/apps/authentication/migrations/0016_auto_20221125_2240.py b/apps/authentication/migrations/0016_auto_20221125_2240.py index 745fcef2b..041a29fc6 100644 --- a/apps/authentication/migrations/0016_auto_20221125_2240.py +++ b/apps/authentication/migrations/0016_auto_20221125_2240.py @@ -1,11 +1,11 @@ # Generated by Django 3.2.14 on 2022-11-25 14:40 -import common.db.fields from django.db import migrations, models +import common.db.fields + class Migration(migrations.Migration): - dependencies = [ ('authentication', '0015_alter_connectiontoken_login'), ] @@ -36,4 +36,15 @@ class Migration(migrations.Migration): name='value', field=models.CharField(default='', max_length=64, verbose_name='Value'), ), + migrations.AddField( + model_name='connectiontoken', + name='input_secret', + field=common.db.fields.EncryptCharField(blank=True, default='', max_length=128, + verbose_name='Input Secret'), + ), + migrations.AlterField( + model_name='connectiontoken', + name='input_username', + field=models.CharField(blank=True, default='', max_length=128, verbose_name='Input Username'), + ), ] diff --git a/apps/authentication/migrations/0017_auto_20221128_1839.py b/apps/authentication/migrations/0017_auto_20221128_1839.py index 8c392cb92..bcdb71020 100644 --- a/apps/authentication/migrations/0017_auto_20221128_1839.py +++ b/apps/authentication/migrations/0017_auto_20221128_1839.py @@ -1,11 +1,9 @@ # Generated by Django 3.2.14 on 2022-11-28 10:39 -import common.db.fields from django.db import migrations, models class Migration(migrations.Migration): - dependencies = [ ('authentication', '0016_auto_20221125_2240'), ] @@ -17,15 +15,4 @@ class Migration(migrations.Migration): field=models.CharField(default='web_ui', max_length=32, verbose_name='Connect method'), preserve_default=False, ), - migrations.AddField( - model_name='connectiontoken', - name='input_secret', - field=common.db.fields.EncryptCharField(blank=True, default='', max_length=128, - verbose_name='Input Secret'), - ), - migrations.AlterField( - model_name='connectiontoken', - name='input_username', - field=models.CharField(blank=True, default='', max_length=128, verbose_name='Input Username'), - ), ] diff --git a/apps/authentication/migrations/0018_connectiontoken_endpoint_protocol.py b/apps/authentication/migrations/0018_connectiontoken_endpoint_protocol.py new file mode 100644 index 000000000..f267a62fd --- /dev/null +++ b/apps/authentication/migrations/0018_connectiontoken_endpoint_protocol.py @@ -0,0 +1,19 @@ +# Generated by Django 3.2.14 on 2022-11-29 04:49 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('authentication', '0017_auto_20221128_1839'), + ] + + operations = [ + migrations.AddField( + model_name='connectiontoken', + name='endpoint_protocol', + field=models.CharField(choices=[('ssh', 'SSH'), ('rdp', 'RDP'), ('telnet', 'Telnet'), ('vnc', 'VNC'), ('mysql', 'MySQL'), ('mariadb', 'MariaDB'), ('oracle', 'Oracle'), ('postgresql', 'PostgreSQL'), ('sqlserver', 'SQLServer'), ('redis', 'Redis'), ('mongodb', 'MongoDB'), ('k8s', 'K8S'), ('http', 'HTTP'), ('None', ' Settings')], default='', max_length=16, verbose_name='Endpoint protocol'), + preserve_default=False, + ), + ] diff --git a/apps/authentication/models/connection_token.py b/apps/authentication/models/connection_token.py index 5505f81a3..7f4e7f42b 100644 --- a/apps/authentication/models/connection_token.py +++ b/apps/authentication/models/connection_token.py @@ -35,6 +35,9 @@ class ConnectionToken(OrgModelMixin, JMSBaseModel): choices=Protocol.choices, max_length=16, default=Protocol.ssh, verbose_name=_("Protocol") ) connect_method = models.CharField(max_length=32, verbose_name=_("Connect method")) + endpoint_protocol = models.CharField( + choices=Protocol.choices, max_length=16, verbose_name=_("Endpoint protocol") + ) user_display = models.CharField(max_length=128, default='', verbose_name=_("User display")) asset_display = models.CharField(max_length=128, default='', verbose_name=_("Asset display")) date_expired = models.DateTimeField( diff --git a/apps/authentication/serializers/connection_token.py b/apps/authentication/serializers/connection_token.py index 0a223306a..5915f542b 100644 --- a/apps/authentication/serializers/connection_token.py +++ b/apps/authentication/serializers/connection_token.py @@ -1,7 +1,7 @@ from django.utils.translation import ugettext_lazy as _ from rest_framework import serializers -from assets.models import Asset, Domain, CommandFilterRule, Account, Platform +from assets.models import Asset, CommandFilterRule, Account, Platform from assets.serializers import PlatformSerializer, AssetProtocolsSerializer from authentication.models import ConnectionToken from orgs.mixins.serializers import OrgResourceModelSerializerMixin @@ -21,21 +21,19 @@ class ConnectionTokenSerializer(OrgResourceModelSerializerMixin): model = ConnectionToken fields_mini = ['id', 'value'] fields_small = fields_mini + [ - 'protocol', 'account_name', + 'user', 'asset', 'account_name', 'input_username', 'input_secret', + 'connect_method', 'endpoint_protocol', 'protocol', 'actions', 'date_expired', 'date_created', 'date_updated', 'created_by', 'updated_by', 'org_id', 'org_name', ] - fields_fk = [ - 'user', 'asset', - ] read_only_fields = [ # 普通 Token 不支持指定 user 'user', 'expire_time', 'user_display', 'asset_display', ] - fields = fields_small + fields_fk + read_only_fields + fields = fields_small + read_only_fields extra_kwargs = { 'value': {'read_only': True}, } diff --git a/apps/common/utils/django.py b/apps/common/utils/django.py index 1f3f83282..2c4808d16 100644 --- a/apps/common/utils/django.py +++ b/apps/common/utils/django.py @@ -2,11 +2,11 @@ # import re -from django.shortcuts import reverse as dj_reverse from django.conf import settings -from django.utils import timezone from django.db import models from django.db.models.signals import post_save, pre_save +from django.shortcuts import reverse as dj_reverse +from django.utils import timezone UUID_PATTERN = re.compile(r'[0-9a-zA-Z\-]{36}') @@ -80,3 +80,18 @@ def bulk_create_with_signal(cls: models.Model, items, **kwargs): for i in items: post_save.send(sender=cls, instance=i, created=True) return result + + +def get_request_os(request): + """获取请求的操作系统""" + agent = request.META.get('HTTP_USER_AGENT', '').lower() + + if agent is None: + return 'unknown' + if 'windows' in agent.lower(): + return 'windows' + if 'mac' in agent.lower(): + return 'mac' + if 'linux' in agent.lower(): + return 'linux' + return 'unknown' diff --git a/apps/common/utils/http.py b/apps/common/utils/http.py index 185397881..ab39c4fa2 100644 --- a/apps/common/utils/http.py +++ b/apps/common/utils/http.py @@ -1,9 +1,9 @@ # -*- coding: utf-8 -*- # -import time -from email.utils import formatdate import calendar import threading +import time +from email.utils import formatdate _STRPTIME_LOCK = threading.Lock() @@ -35,3 +35,6 @@ def http_to_unixtime(time_string): def iso8601_to_unixtime(time_string): """把ISO8601时间字符串(形如,2012-02-24T06:07:48.000Z)转换为UNIX时间,精确到秒。""" return to_unixtime(time_string, _ISO8601_FORMAT) + + + diff --git a/apps/terminal/api/component/terminal.py b/apps/terminal/api/component/terminal.py index e3aa4afb3..df14296f5 100644 --- a/apps/terminal/api/component/terminal.py +++ b/apps/terminal/api/component/terminal.py @@ -12,6 +12,7 @@ from common.drf.api import JMSBulkModelViewSet from common.exceptions import JMSException from common.permissions import IsValidUser from common.permissions import WithBootstrapToken +from common.utils import get_request_os from terminal import serializers from terminal.const import TerminalType from terminal.models import Terminal @@ -77,13 +78,7 @@ class ConnectMethodListApi(generics.ListAPIView): permission_classes = [IsValidUser] def get_queryset(self): - user_agent = self.request.META['HTTP_USER_AGENT'].lower() - if 'macintosh' in user_agent: - os = 'macos' - elif 'windows' in user_agent: - os = 'windows' - else: - os = 'linux' + os = get_request_os(self.request) return TerminalType.get_protocols_connect_methods(os) def list(self, request, *args, **kwargs): diff --git a/apps/terminal/const.py b/apps/terminal/const.py index 177d50e20..40c89ff24 100644 --- a/apps/terminal/const.py +++ b/apps/terminal/const.py @@ -56,7 +56,11 @@ class NativeClient(TextChoices): xshell = 'xshell', 'Xshell' # Magnus - db_client = 'db_client', _('DB Client') + mysql = 'db_client_mysql', _('DB Client') + psql = 'db_client_psql', _('DB Client') + sqlplus = 'db_client_sqlplus', _('DB Client') + redis = 'db_client_redis', _('DB Client') + mongodb = 'db_client_mongodb', _('DB Client') # Razor mstsc = 'mstsc', 'Remote Desktop' @@ -69,14 +73,23 @@ class NativeClient(TextChoices): 'windows': [cls.putty], }, Protocol.rdp: [cls.mstsc], - Protocol.mysql: [cls.db_client], - Protocol.oracle: [cls.db_client], - Protocol.postgresql: [cls.db_client], - Protocol.redis: [cls.db_client], - Protocol.mongodb: [cls.db_client], + Protocol.mysql: [cls.mysql], + Protocol.oracle: [cls.sqlplus], + Protocol.postgresql: [cls.psql], + Protocol.redis: [cls.redis], + Protocol.mongodb: [cls.mongodb], } return clients + @classmethod + def get_target_protocol(cls, name, os): + for protocol, clients in cls.get_native_clients().items(): + if isinstance(clients, dict): + clients = clients.get(os) or clients.get('default') + if name in clients: + return protocol + return None + @classmethod def get_methods(cls, os='windows'): clients_map = cls.get_native_clients() @@ -94,23 +107,18 @@ class NativeClient(TextChoices): return methods @classmethod - def get_launch_command(cls, name, os='windows'): + def get_launch_command(cls, name, token, endpoint, os='windows'): commands = { - cls.ssh: 'ssh {token.id}@{endpoint.ip} -p {endpoint.port}', - cls.putty: 'putty-ssh {token.id}@{endpoint.ip} -P {endpoint.port}', - cls.xshell: 'xshell -url ssh://{token.id}:{token.value}@{endpoint.ip}:{endpoint.port}', - # 'mysql': 'mysql -h {hostname} -P {port} -u {username} -p', - # 'psql': { + cls.ssh: f'ssh {token.id}@{endpoint.host} -p {endpoint.ssh_port}', + cls.putty: f'putty -ssh {token.id}@{endpoint.host} -P {endpoint.ssh_port}', + cls.xshell: f'xshell -url ssh://{token.id}:{token.value}@{endpoint.host}:{endpoint.ssh_port}', + # cls.mysql: 'mysql -h {hostname} -P {port} -u {username} -p', + # cls.psql: { # 'default': 'psql -h {hostname} -p {port} -U {username} -W', # 'windows': 'psql /h {hostname} /p {port} /U {username} -W', # }, - # 'sqlplus': 'sqlplus {username}/{password}@{hostname}:{port}', - # 'redis': 'redis-cli -h {hostname} -p {port} -a {password}', - cls.mstsc: { - 'command': "$open_file$", - 'file': { - } - }, + # cls.sqlplus: 'sqlplus {username}/{password}@{hostname}:{port}', + # cls.redis: 'redis-cli -h {hostname} -p {port} -a {password}', } command = commands.get(name) if isinstance(command, dict): @@ -217,19 +225,26 @@ class TerminalType(TextChoices): methods[protocol.value].append({ 'value': web_protocol.value, 'label': web_protocol.label, + 'endpoint_protocol': 'http', 'type': 'web', 'component': component.value, }) # Native method methods[protocol.value].extend([ - {'component': component.value, 'type': 'native', **method} + { + 'component': component.value, + 'type': 'native', + 'endpoint_protocol': listen_protocol, + **method + } for method in native_methods[listen_protocol] ]) for protocol, applet_methods in applet_methods.items(): for method in applet_methods: method['type'] = 'applet' + method['listen'] = 'rdp' method['component'] = cls.tinker.value methods[protocol].extend(applet_methods) return methods diff --git a/apps/terminal/serializers/terminal.py b/apps/terminal/serializers/terminal.py index df32d89c2..f7f935f50 100644 --- a/apps/terminal/serializers/terminal.py +++ b/apps/terminal/serializers/terminal.py @@ -138,4 +138,6 @@ class TerminalRegistrationSerializer(serializers.ModelSerializer): class ConnectMethodSerializer(serializers.Serializer): value = serializers.CharField(max_length=128) label = serializers.CharField(max_length=128) - group = serializers.CharField(max_length=128) + type = serializers.CharField(max_length=128) + listen = serializers.CharField(max_length=128) + component = serializers.CharField(max_length=128)