perf: 修改 Connect token 数据结构

This commit is contained in:
ibuler 2022-11-29 14:42:04 +08:00
parent e4edf3be02
commit 0981cd1ed1
11 changed files with 129 additions and 89 deletions

View File

@ -16,10 +16,11 @@ from rest_framework.serializers import ValidationError
from common.drf.api import JMSModelViewSet from common.drf.api import JMSModelViewSet
from common.http import is_true from common.http import is_true
from common.utils import random_string from common.utils import random_string
from common.utils.django import get_request_os
from orgs.mixins.api import RootOrgViewMixin from orgs.mixins.api import RootOrgViewMixin
from perms.models import ActionChoices from perms.models import ActionChoices
from terminal.models import EndpointRule
from terminal.const import NativeClient from terminal.const import NativeClient
from terminal.models import EndpointRule
from ..models import ConnectionToken from ..models import ConnectionToken
from ..serializers import ( from ..serializers import (
ConnectionTokenSerializer, ConnectionTokenSecretSerializer, ConnectionTokenSerializer, ConnectionTokenSecretSerializer,
@ -130,42 +131,32 @@ class RDPFileClientProtocolURLMixin:
return true_value if is_true(os.getenv(env_key, env_default)) else false_value return true_value if is_true(os.getenv(env_key, env_default)) else false_value
def get_client_protocol_data(self, token: ConnectionToken): def get_client_protocol_data(self, token: ConnectionToken):
username = token.user.username _os = get_request_os(self.request)
rdp_config = ssh_token = ''
connect_method = token.connect_method
if connect_method == NativeClient.ssh: connect_method = getattr(NativeClient, token.connect_method, None)
filename, ssh_token = self.get_ssh_token(token) if connect_method is None:
elif connect_method == NativeClient.mstsc: raise ValueError('Connect method not support: {}'.format(token.connect_method))
filename, rdp_config = self.get_rdp_file_info(token)
else:
raise ValueError('Protocol not support: {}'.format(connect_method))
return {
"filename": filename,
"protocol": token.protocol,
"username": username,
"token": ssh_token,
"config": rdp_config
}
def get_ssh_token(self, token: ConnectionToken):
if token.asset:
name = token.asset.name
else:
name = '*'
prefix_name = f'{token.user.username}-{name}'
filename = self.get_connect_filename(prefix_name)
endpoint = self.get_smart_endpoint(protocol='ssh', asset=token.asset)
data = { data = {
'ip': endpoint.host, 'id': str(token.id),
'port': str(endpoint.ssh_port), 'value': token.value,
'username': 'JMS-{}'.format(str(token.id)), 'cmd': '',
'password': token.value 'file': {}
} }
token = json.dumps(data)
return filename, token if connect_method == NativeClient.mstsc:
filename, content = self.get_rdp_file_info(token)
data.update({
'file': {
'filename': filename,
'content': content,
}
})
else:
endpoint = self.get_smart_endpoint(protocol=token.endpoint_protocol, asset=token.asset)
cmd = NativeClient.get_launch_command(connect_method, token, endpoint)
data.update({'cmd': cmd})
return data
def get_smart_endpoint(self, protocol, asset=None): def get_smart_endpoint(self, protocol, asset=None):
target_ip = asset.get_target_ip() if asset else '' target_ip = asset.get_target_ip() if asset else ''
@ -223,6 +214,7 @@ class ConnectionTokenViewSet(ExtraActionApiMixin, RootOrgViewMixin, JMSModelView
'get_secret_detail': ConnectionTokenSecretSerializer, 'get_secret_detail': ConnectionTokenSecretSerializer,
} }
rbac_perms = { rbac_perms = {
'list': 'authentication.view_connectiontoken',
'retrieve': 'authentication.view_connectiontoken', 'retrieve': 'authentication.view_connectiontoken',
'create': 'authentication.add_connectiontoken', 'create': 'authentication.add_connectiontoken',
'expire': 'authentication.add_connectiontoken', 'expire': 'authentication.add_connectiontoken',
@ -252,9 +244,9 @@ class ConnectionTokenViewSet(ExtraActionApiMixin, RootOrgViewMixin, JMSModelView
return Response(serializer.data, status=status.HTTP_200_OK) return Response(serializer.data, status=status.HTTP_200_OK)
def get_queryset(self): def get_queryset(self):
queryset = ConnectionToken.objects\ queryset = ConnectionToken.objects \
.filter(user=self.request.user)\ .filter(user=self.request.user) \
.filter(date_expired__lt=timezone.now()) .filter(date_expired__gt=timezone.now())
return queryset return queryset
def get_user(self, serializer): def get_user(self, serializer):

View File

@ -1,11 +1,11 @@
# Generated by Django 3.2.14 on 2022-11-25 14:40 # Generated by Django 3.2.14 on 2022-11-25 14:40
import common.db.fields
from django.db import migrations, models from django.db import migrations, models
import common.db.fields
class Migration(migrations.Migration): class Migration(migrations.Migration):
dependencies = [ dependencies = [
('authentication', '0015_alter_connectiontoken_login'), ('authentication', '0015_alter_connectiontoken_login'),
] ]
@ -36,4 +36,15 @@ class Migration(migrations.Migration):
name='value', name='value',
field=models.CharField(default='', max_length=64, verbose_name='Value'), field=models.CharField(default='', max_length=64, verbose_name='Value'),
), ),
migrations.AddField(
model_name='connectiontoken',
name='input_secret',
field=common.db.fields.EncryptCharField(blank=True, default='', max_length=128,
verbose_name='Input Secret'),
),
migrations.AlterField(
model_name='connectiontoken',
name='input_username',
field=models.CharField(blank=True, default='', max_length=128, verbose_name='Input Username'),
),
] ]

View File

@ -1,11 +1,9 @@
# Generated by Django 3.2.14 on 2022-11-28 10:39 # Generated by Django 3.2.14 on 2022-11-28 10:39
import common.db.fields
from django.db import migrations, models from django.db import migrations, models
class Migration(migrations.Migration): class Migration(migrations.Migration):
dependencies = [ dependencies = [
('authentication', '0016_auto_20221125_2240'), ('authentication', '0016_auto_20221125_2240'),
] ]
@ -17,15 +15,4 @@ class Migration(migrations.Migration):
field=models.CharField(default='web_ui', max_length=32, verbose_name='Connect method'), field=models.CharField(default='web_ui', max_length=32, verbose_name='Connect method'),
preserve_default=False, preserve_default=False,
), ),
migrations.AddField(
model_name='connectiontoken',
name='input_secret',
field=common.db.fields.EncryptCharField(blank=True, default='', max_length=128,
verbose_name='Input Secret'),
),
migrations.AlterField(
model_name='connectiontoken',
name='input_username',
field=models.CharField(blank=True, default='', max_length=128, verbose_name='Input Username'),
),
] ]

View File

@ -0,0 +1,19 @@
# Generated by Django 3.2.14 on 2022-11-29 04:49
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('authentication', '0017_auto_20221128_1839'),
]
operations = [
migrations.AddField(
model_name='connectiontoken',
name='endpoint_protocol',
field=models.CharField(choices=[('ssh', 'SSH'), ('rdp', 'RDP'), ('telnet', 'Telnet'), ('vnc', 'VNC'), ('mysql', 'MySQL'), ('mariadb', 'MariaDB'), ('oracle', 'Oracle'), ('postgresql', 'PostgreSQL'), ('sqlserver', 'SQLServer'), ('redis', 'Redis'), ('mongodb', 'MongoDB'), ('k8s', 'K8S'), ('http', 'HTTP'), ('None', ' Settings')], default='', max_length=16, verbose_name='Endpoint protocol'),
preserve_default=False,
),
]

View File

@ -35,6 +35,9 @@ class ConnectionToken(OrgModelMixin, JMSBaseModel):
choices=Protocol.choices, max_length=16, default=Protocol.ssh, verbose_name=_("Protocol") choices=Protocol.choices, max_length=16, default=Protocol.ssh, verbose_name=_("Protocol")
) )
connect_method = models.CharField(max_length=32, verbose_name=_("Connect method")) connect_method = models.CharField(max_length=32, verbose_name=_("Connect method"))
endpoint_protocol = models.CharField(
choices=Protocol.choices, max_length=16, verbose_name=_("Endpoint protocol")
)
user_display = models.CharField(max_length=128, default='', verbose_name=_("User display")) user_display = models.CharField(max_length=128, default='', verbose_name=_("User display"))
asset_display = models.CharField(max_length=128, default='', verbose_name=_("Asset display")) asset_display = models.CharField(max_length=128, default='', verbose_name=_("Asset display"))
date_expired = models.DateTimeField( date_expired = models.DateTimeField(

View File

@ -1,7 +1,7 @@
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework import serializers from rest_framework import serializers
from assets.models import Asset, Domain, CommandFilterRule, Account, Platform from assets.models import Asset, CommandFilterRule, Account, Platform
from assets.serializers import PlatformSerializer, AssetProtocolsSerializer from assets.serializers import PlatformSerializer, AssetProtocolsSerializer
from authentication.models import ConnectionToken from authentication.models import ConnectionToken
from orgs.mixins.serializers import OrgResourceModelSerializerMixin from orgs.mixins.serializers import OrgResourceModelSerializerMixin
@ -21,21 +21,19 @@ class ConnectionTokenSerializer(OrgResourceModelSerializerMixin):
model = ConnectionToken model = ConnectionToken
fields_mini = ['id', 'value'] fields_mini = ['id', 'value']
fields_small = fields_mini + [ fields_small = fields_mini + [
'protocol', 'account_name', 'user', 'asset', 'account_name',
'input_username', 'input_secret', 'input_username', 'input_secret',
'connect_method', 'endpoint_protocol', 'protocol',
'actions', 'date_expired', 'date_created', 'actions', 'date_expired', 'date_created',
'date_updated', 'created_by', 'date_updated', 'created_by',
'updated_by', 'org_id', 'org_name', 'updated_by', 'org_id', 'org_name',
] ]
fields_fk = [
'user', 'asset',
]
read_only_fields = [ read_only_fields = [
# 普通 Token 不支持指定 user # 普通 Token 不支持指定 user
'user', 'expire_time', 'user', 'expire_time',
'user_display', 'asset_display', 'user_display', 'asset_display',
] ]
fields = fields_small + fields_fk + read_only_fields fields = fields_small + read_only_fields
extra_kwargs = { extra_kwargs = {
'value': {'read_only': True}, 'value': {'read_only': True},
} }

View File

@ -2,11 +2,11 @@
# #
import re import re
from django.shortcuts import reverse as dj_reverse
from django.conf import settings from django.conf import settings
from django.utils import timezone
from django.db import models from django.db import models
from django.db.models.signals import post_save, pre_save from django.db.models.signals import post_save, pre_save
from django.shortcuts import reverse as dj_reverse
from django.utils import timezone
UUID_PATTERN = re.compile(r'[0-9a-zA-Z\-]{36}') UUID_PATTERN = re.compile(r'[0-9a-zA-Z\-]{36}')
@ -80,3 +80,18 @@ def bulk_create_with_signal(cls: models.Model, items, **kwargs):
for i in items: for i in items:
post_save.send(sender=cls, instance=i, created=True) post_save.send(sender=cls, instance=i, created=True)
return result return result
def get_request_os(request):
"""获取请求的操作系统"""
agent = request.META.get('HTTP_USER_AGENT', '').lower()
if agent is None:
return 'unknown'
if 'windows' in agent.lower():
return 'windows'
if 'mac' in agent.lower():
return 'mac'
if 'linux' in agent.lower():
return 'linux'
return 'unknown'

View File

@ -1,9 +1,9 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
import time
from email.utils import formatdate
import calendar import calendar
import threading import threading
import time
from email.utils import formatdate
_STRPTIME_LOCK = threading.Lock() _STRPTIME_LOCK = threading.Lock()
@ -35,3 +35,6 @@ def http_to_unixtime(time_string):
def iso8601_to_unixtime(time_string): def iso8601_to_unixtime(time_string):
"""把ISO8601时间字符串形如2012-02-24T06:07:48.000Z转换为UNIX时间精确到秒。""" """把ISO8601时间字符串形如2012-02-24T06:07:48.000Z转换为UNIX时间精确到秒。"""
return to_unixtime(time_string, _ISO8601_FORMAT) return to_unixtime(time_string, _ISO8601_FORMAT)

View File

@ -12,6 +12,7 @@ from common.drf.api import JMSBulkModelViewSet
from common.exceptions import JMSException from common.exceptions import JMSException
from common.permissions import IsValidUser from common.permissions import IsValidUser
from common.permissions import WithBootstrapToken from common.permissions import WithBootstrapToken
from common.utils import get_request_os
from terminal import serializers from terminal import serializers
from terminal.const import TerminalType from terminal.const import TerminalType
from terminal.models import Terminal from terminal.models import Terminal
@ -77,13 +78,7 @@ class ConnectMethodListApi(generics.ListAPIView):
permission_classes = [IsValidUser] permission_classes = [IsValidUser]
def get_queryset(self): def get_queryset(self):
user_agent = self.request.META['HTTP_USER_AGENT'].lower() os = get_request_os(self.request)
if 'macintosh' in user_agent:
os = 'macos'
elif 'windows' in user_agent:
os = 'windows'
else:
os = 'linux'
return TerminalType.get_protocols_connect_methods(os) return TerminalType.get_protocols_connect_methods(os)
def list(self, request, *args, **kwargs): def list(self, request, *args, **kwargs):

View File

@ -56,7 +56,11 @@ class NativeClient(TextChoices):
xshell = 'xshell', 'Xshell' xshell = 'xshell', 'Xshell'
# Magnus # Magnus
db_client = 'db_client', _('DB Client') mysql = 'db_client_mysql', _('DB Client')
psql = 'db_client_psql', _('DB Client')
sqlplus = 'db_client_sqlplus', _('DB Client')
redis = 'db_client_redis', _('DB Client')
mongodb = 'db_client_mongodb', _('DB Client')
# Razor # Razor
mstsc = 'mstsc', 'Remote Desktop' mstsc = 'mstsc', 'Remote Desktop'
@ -69,14 +73,23 @@ class NativeClient(TextChoices):
'windows': [cls.putty], 'windows': [cls.putty],
}, },
Protocol.rdp: [cls.mstsc], Protocol.rdp: [cls.mstsc],
Protocol.mysql: [cls.db_client], Protocol.mysql: [cls.mysql],
Protocol.oracle: [cls.db_client], Protocol.oracle: [cls.sqlplus],
Protocol.postgresql: [cls.db_client], Protocol.postgresql: [cls.psql],
Protocol.redis: [cls.db_client], Protocol.redis: [cls.redis],
Protocol.mongodb: [cls.db_client], Protocol.mongodb: [cls.mongodb],
} }
return clients return clients
@classmethod
def get_target_protocol(cls, name, os):
for protocol, clients in cls.get_native_clients().items():
if isinstance(clients, dict):
clients = clients.get(os) or clients.get('default')
if name in clients:
return protocol
return None
@classmethod @classmethod
def get_methods(cls, os='windows'): def get_methods(cls, os='windows'):
clients_map = cls.get_native_clients() clients_map = cls.get_native_clients()
@ -94,23 +107,18 @@ class NativeClient(TextChoices):
return methods return methods
@classmethod @classmethod
def get_launch_command(cls, name, os='windows'): def get_launch_command(cls, name, token, endpoint, os='windows'):
commands = { commands = {
cls.ssh: 'ssh {token.id}@{endpoint.ip} -p {endpoint.port}', cls.ssh: f'ssh {token.id}@{endpoint.host} -p {endpoint.ssh_port}',
cls.putty: 'putty-ssh {token.id}@{endpoint.ip} -P {endpoint.port}', cls.putty: f'putty -ssh {token.id}@{endpoint.host} -P {endpoint.ssh_port}',
cls.xshell: 'xshell -url ssh://{token.id}:{token.value}@{endpoint.ip}:{endpoint.port}', cls.xshell: f'xshell -url ssh://{token.id}:{token.value}@{endpoint.host}:{endpoint.ssh_port}',
# 'mysql': 'mysql -h {hostname} -P {port} -u {username} -p', # cls.mysql: 'mysql -h {hostname} -P {port} -u {username} -p',
# 'psql': { # cls.psql: {
# 'default': 'psql -h {hostname} -p {port} -U {username} -W', # 'default': 'psql -h {hostname} -p {port} -U {username} -W',
# 'windows': 'psql /h {hostname} /p {port} /U {username} -W', # 'windows': 'psql /h {hostname} /p {port} /U {username} -W',
# }, # },
# 'sqlplus': 'sqlplus {username}/{password}@{hostname}:{port}', # cls.sqlplus: 'sqlplus {username}/{password}@{hostname}:{port}',
# 'redis': 'redis-cli -h {hostname} -p {port} -a {password}', # cls.redis: 'redis-cli -h {hostname} -p {port} -a {password}',
cls.mstsc: {
'command': "$open_file$",
'file': {
}
},
} }
command = commands.get(name) command = commands.get(name)
if isinstance(command, dict): if isinstance(command, dict):
@ -217,19 +225,26 @@ class TerminalType(TextChoices):
methods[protocol.value].append({ methods[protocol.value].append({
'value': web_protocol.value, 'value': web_protocol.value,
'label': web_protocol.label, 'label': web_protocol.label,
'endpoint_protocol': 'http',
'type': 'web', 'type': 'web',
'component': component.value, 'component': component.value,
}) })
# Native method # Native method
methods[protocol.value].extend([ methods[protocol.value].extend([
{'component': component.value, 'type': 'native', **method} {
'component': component.value,
'type': 'native',
'endpoint_protocol': listen_protocol,
**method
}
for method in native_methods[listen_protocol] for method in native_methods[listen_protocol]
]) ])
for protocol, applet_methods in applet_methods.items(): for protocol, applet_methods in applet_methods.items():
for method in applet_methods: for method in applet_methods:
method['type'] = 'applet' method['type'] = 'applet'
method['listen'] = 'rdp'
method['component'] = cls.tinker.value method['component'] = cls.tinker.value
methods[protocol].extend(applet_methods) methods[protocol].extend(applet_methods)
return methods return methods

View File

@ -138,4 +138,6 @@ class TerminalRegistrationSerializer(serializers.ModelSerializer):
class ConnectMethodSerializer(serializers.Serializer): class ConnectMethodSerializer(serializers.Serializer):
value = serializers.CharField(max_length=128) value = serializers.CharField(max_length=128)
label = serializers.CharField(max_length=128) label = serializers.CharField(max_length=128)
group = serializers.CharField(max_length=128) type = serializers.CharField(max_length=128)
listen = serializers.CharField(max_length=128)
component = serializers.CharField(max_length=128)