mirror of
https://github.com/jumpserver/jumpserver.git
synced 2025-07-17 00:11:42 +00:00
commit
ce24c1c3fd
@ -36,9 +36,11 @@ ARG TOOLS=" \
|
||||
curl \
|
||||
default-libmysqlclient-dev \
|
||||
default-mysql-client \
|
||||
iputils-ping \
|
||||
locales \
|
||||
nmap \
|
||||
openssh-client \
|
||||
patch \
|
||||
sshpass \
|
||||
telnet \
|
||||
vim \
|
||||
|
@ -12,8 +12,6 @@
|
||||
|
||||
|
||||
<p align="center">
|
||||
JumpServer <a href="https://github.com/jumpserver/jumpserver/releases/tag/v3.0.0">v3.0</a> 正式发布。
|
||||
<br>
|
||||
9 年时间,倾情投入,用心做好一款开源堡垒机。
|
||||
</p>
|
||||
|
||||
|
@ -6,11 +6,12 @@ from rest_framework.status import HTTP_200_OK
|
||||
|
||||
from accounts import serializers
|
||||
from accounts.filters import AccountFilterSet
|
||||
from accounts.models import Account
|
||||
from accounts.mixins import AccountRecordViewLogMixin
|
||||
from accounts.models import Account
|
||||
from assets.models import Asset, Node
|
||||
from authentication.permissions import UserConfirmation, ConfirmType
|
||||
from common.api.mixin import ExtraFilterFieldsMixin
|
||||
from common.permissions import UserConfirmation, ConfirmType, IsValidUser
|
||||
from common.permissions import IsValidUser
|
||||
from orgs.mixins.api import OrgBulkModelViewSet
|
||||
from rbac.permissions import RBACPermission
|
||||
|
||||
@ -57,19 +58,19 @@ class AccountViewSet(OrgBulkModelViewSet):
|
||||
permission_classes=[IsValidUser]
|
||||
)
|
||||
def username_suggestions(self, request, *args, **kwargs):
|
||||
asset_ids = request.data.get('assets')
|
||||
node_ids = request.data.get('nodes')
|
||||
username = request.data.get('username')
|
||||
asset_ids = request.data.get('assets', [])
|
||||
node_ids = request.data.get('nodes', [])
|
||||
username = request.data.get('username', '')
|
||||
|
||||
assets = Asset.objects.all()
|
||||
if asset_ids:
|
||||
assets = assets.filter(id__in=asset_ids)
|
||||
accounts = Account.objects.all()
|
||||
if node_ids:
|
||||
nodes = Node.objects.filter(id__in=node_ids)
|
||||
node_asset_ids = Node.get_nodes_all_assets(*nodes).values_list('id', flat=True)
|
||||
assets = assets.filter(id__in=set(list(asset_ids) + list(node_asset_ids)))
|
||||
asset_ids.extend(node_asset_ids)
|
||||
|
||||
if asset_ids:
|
||||
accounts = accounts.filter(asset_id__in=list(set(asset_ids)))
|
||||
|
||||
accounts = Account.objects.filter(asset__in=assets)
|
||||
if username:
|
||||
accounts = accounts.filter(username__icontains=username)
|
||||
usernames = list(accounts.values_list('username', flat=True).distinct()[:10])
|
||||
|
@ -1,13 +1,15 @@
|
||||
from django_filters import rest_framework as drf_filters
|
||||
from rest_framework import status
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.response import Response
|
||||
|
||||
from accounts import serializers
|
||||
from accounts.models import AccountTemplate
|
||||
from accounts.mixins import AccountRecordViewLogMixin
|
||||
from accounts.models import AccountTemplate
|
||||
from accounts.tasks import template_sync_related_accounts
|
||||
from assets.const import Protocol
|
||||
from authentication.permissions import UserConfirmation, ConfirmType
|
||||
from common.drf.filters import BaseFilterSet
|
||||
from common.permissions import UserConfirmation, ConfirmType
|
||||
from orgs.mixins.api import OrgBulkModelViewSet
|
||||
from rbac.permissions import RBACPermission
|
||||
|
||||
@ -44,6 +46,7 @@ class AccountTemplateViewSet(OrgBulkModelViewSet):
|
||||
}
|
||||
rbac_perms = {
|
||||
'su_from_account_templates': 'accounts.view_accounttemplate',
|
||||
'sync_related_accounts': 'accounts.change_account',
|
||||
}
|
||||
|
||||
@action(methods=['get'], detail=False, url_path='su-from-account-templates')
|
||||
@ -54,6 +57,13 @@ class AccountTemplateViewSet(OrgBulkModelViewSet):
|
||||
serializer = self.get_serializer(templates, many=True)
|
||||
return Response(data=serializer.data)
|
||||
|
||||
@action(methods=['patch'], detail=True, url_path='sync-related-accounts')
|
||||
def sync_related_accounts(self, request, *args, **kwargs):
|
||||
instance = self.get_object()
|
||||
user_id = str(request.user.id)
|
||||
task = template_sync_related_accounts.delay(str(instance.id), user_id)
|
||||
return Response({'task': task.id}, status=status.HTTP_200_OK)
|
||||
|
||||
|
||||
class AccountTemplateSecretsViewSet(AccountRecordViewLogMixin, AccountTemplateViewSet):
|
||||
serializer_classes = {
|
||||
|
@ -5,8 +5,7 @@ from rest_framework import mixins
|
||||
|
||||
from accounts import serializers
|
||||
from accounts.const import AutomationTypes
|
||||
from accounts.models import ChangeSecretAutomation, ChangeSecretRecord, AutomationExecution
|
||||
from common.utils import get_object_or_none
|
||||
from accounts.models import ChangeSecretAutomation, ChangeSecretRecord
|
||||
from orgs.mixins.api import OrgBulkModelViewSet, OrgGenericViewSet
|
||||
from .base import (
|
||||
AutomationAssetsListApi, AutomationRemoveAssetApi, AutomationAddAssetApi,
|
||||
@ -30,8 +29,8 @@ class ChangeSecretAutomationViewSet(OrgBulkModelViewSet):
|
||||
|
||||
class ChangeSecretRecordViewSet(mixins.ListModelMixin, OrgGenericViewSet):
|
||||
serializer_class = serializers.ChangeSecretRecordSerializer
|
||||
filter_fields = ['asset', 'execution_id']
|
||||
search_fields = ['asset__hostname']
|
||||
filter_fields = ('asset', 'execution_id')
|
||||
search_fields = ('asset__address',)
|
||||
|
||||
def get_queryset(self):
|
||||
return ChangeSecretRecord.objects.filter(
|
||||
@ -41,10 +40,7 @@ class ChangeSecretRecordViewSet(mixins.ListModelMixin, OrgGenericViewSet):
|
||||
def filter_queryset(self, queryset):
|
||||
queryset = super().filter_queryset(queryset)
|
||||
eid = self.request.query_params.get('execution_id')
|
||||
execution = get_object_or_none(AutomationExecution, pk=eid)
|
||||
if execution:
|
||||
queryset = queryset.filter(execution=execution)
|
||||
return queryset
|
||||
return queryset.filter(execution_id=eid)
|
||||
|
||||
|
||||
class ChangSecretExecutionViewSet(AutomationExecutionViewSet):
|
||||
|
@ -47,4 +47,8 @@
|
||||
login_password: "{{ account.secret }}"
|
||||
login_host: "{{ jms_asset.address }}"
|
||||
login_port: "{{ jms_asset.port }}"
|
||||
become: false
|
||||
become: "{{ account.become.ansible_become | default(False) }}"
|
||||
become_method: su
|
||||
become_user: "{{ account.become.ansible_user | default('') }}"
|
||||
become_password: "{{ account.become.ansible_password | default('') }}"
|
||||
become_private_key_path: "{{ account.become.ansible_ssh_private_key_file | default(None) }}"
|
||||
|
@ -80,7 +80,11 @@
|
||||
login_host: "{{ jms_asset.address }}"
|
||||
login_port: "{{ jms_asset.port }}"
|
||||
gateway_args: "{{ jms_asset.ansible_ssh_common_args | default('') }}"
|
||||
become: false
|
||||
become: "{{ account.become.ansible_become | default(False) }}"
|
||||
become_method: su
|
||||
become_user: "{{ account.become.ansible_user | default('') }}"
|
||||
become_password: "{{ account.become.ansible_password | default('') }}"
|
||||
become_private_key_path: "{{ account.become.ansible_ssh_private_key_file | default(None) }}"
|
||||
when: account.secret_type == "password"
|
||||
delegate_to: localhost
|
||||
|
||||
@ -91,6 +95,5 @@
|
||||
login_user: "{{ account.username }}"
|
||||
login_private_key_path: "{{ account.private_key_path }}"
|
||||
gateway_args: "{{ jms_asset.ansible_ssh_common_args | default('') }}"
|
||||
become: false
|
||||
when: account.secret_type == "ssh_key"
|
||||
delegate_to: localhost
|
||||
|
@ -80,7 +80,11 @@
|
||||
login_host: "{{ jms_asset.address }}"
|
||||
login_port: "{{ jms_asset.port }}"
|
||||
gateway_args: "{{ jms_asset.ansible_ssh_common_args | default('') }}"
|
||||
become: false
|
||||
become: "{{ account.become.ansible_become | default(False) }}"
|
||||
become_method: su
|
||||
become_user: "{{ account.become.ansible_user | default('') }}"
|
||||
become_password: "{{ account.become.ansible_password | default('') }}"
|
||||
become_private_key_path: "{{ account.become.ansible_ssh_private_key_file | default(None) }}"
|
||||
when: account.secret_type == "password"
|
||||
delegate_to: localhost
|
||||
|
||||
@ -91,6 +95,5 @@
|
||||
login_user: "{{ account.username }}"
|
||||
login_private_key_path: "{{ account.private_key_path }}"
|
||||
gateway_args: "{{ jms_asset.ansible_ssh_common_args | default('') }}"
|
||||
become: false
|
||||
when: account.secret_type == "ssh_key"
|
||||
delegate_to: localhost
|
||||
|
@ -80,7 +80,11 @@
|
||||
login_host: "{{ jms_asset.address }}"
|
||||
login_port: "{{ jms_asset.port }}"
|
||||
gateway_args: "{{ jms_asset.ansible_ssh_common_args | default('') }}"
|
||||
become: false
|
||||
become: "{{ account.become.ansible_become | default(False) }}"
|
||||
become_method: su
|
||||
become_user: "{{ account.become.ansible_user | default('') }}"
|
||||
become_password: "{{ account.become.ansible_password | default('') }}"
|
||||
become_private_key_path: "{{ account.become.ansible_ssh_private_key_file | default(None) }}"
|
||||
when: account.secret_type == "password"
|
||||
delegate_to: localhost
|
||||
|
||||
@ -91,7 +95,6 @@
|
||||
login_user: "{{ account.username }}"
|
||||
login_private_key_path: "{{ account.private_key_path }}"
|
||||
gateway_args: "{{ jms_asset.ansible_ssh_common_args | default('') }}"
|
||||
become: false
|
||||
when: account.secret_type == "ssh_key"
|
||||
delegate_to: localhost
|
||||
|
||||
|
@ -80,7 +80,11 @@
|
||||
login_host: "{{ jms_asset.address }}"
|
||||
login_port: "{{ jms_asset.port }}"
|
||||
gateway_args: "{{ jms_asset.ansible_ssh_common_args | default('') }}"
|
||||
become: false
|
||||
become: "{{ account.become.ansible_become | default(False) }}"
|
||||
become_method: su
|
||||
become_user: "{{ account.become.ansible_user | default('') }}"
|
||||
become_password: "{{ account.become.ansible_password | default('') }}"
|
||||
become_private_key_path: "{{ account.become.ansible_ssh_private_key_file | default(None) }}"
|
||||
when: account.secret_type == "password"
|
||||
delegate_to: localhost
|
||||
|
||||
@ -91,7 +95,6 @@
|
||||
login_user: "{{ account.username }}"
|
||||
login_private_key_path: "{{ account.private_key_path }}"
|
||||
gateway_args: "{{ jms_asset.ansible_ssh_common_args | default('') }}"
|
||||
become: false
|
||||
when: account.secret_type == "ssh_key"
|
||||
delegate_to: localhost
|
||||
|
||||
|
@ -8,7 +8,7 @@
|
||||
- name: Verify account (pyfreerdp)
|
||||
rdp_ping:
|
||||
login_host: "{{ jms_asset.address }}"
|
||||
login_port: "{{ jms_asset.port }}"
|
||||
login_port: "{{ jms_asset.protocols | selectattr('name', 'equalto', 'rdp') | map(attribute='port') | first }}"
|
||||
login_user: "{{ account.username }}"
|
||||
login_password: "{{ account.secret }}"
|
||||
login_secret_type: "{{ account.secret_type }}"
|
||||
|
@ -13,8 +13,8 @@
|
||||
login_password: "{{ account.secret }}"
|
||||
login_secret_type: "{{ account.secret_type }}"
|
||||
login_private_key_path: "{{ account.private_key_path }}"
|
||||
become: "{{ custom_become | default(False) }}"
|
||||
become_method: "{{ custom_become_method | default('su') }}"
|
||||
become_user: "{{ custom_become_user | default('') }}"
|
||||
become_password: "{{ custom_become_password | default('') }}"
|
||||
become_private_key_path: "{{ custom_become_private_key_path | default(None) }}"
|
||||
become: "{{ account.become.ansible_become | default(False) }}"
|
||||
become_method: "{{ account.become.ansible_become_method | default('su') }}"
|
||||
become_user: "{{ account.become.ansible_user | default('') }}"
|
||||
become_password: "{{ account.become.ansible_password | default('') }}"
|
||||
become_private_key_path: "{{ account.become.ansible_ssh_private_key_file | default(None) }}"
|
||||
|
@ -1,11 +1,23 @@
|
||||
- hosts: demo
|
||||
gather_facts: no
|
||||
tasks:
|
||||
- name: Verify account connectivity
|
||||
become: no
|
||||
- name: Verify account connectivity(Do not switch)
|
||||
ansible.builtin.ping:
|
||||
vars:
|
||||
ansible_become: no
|
||||
ansible_user: "{{ account.username }}"
|
||||
ansible_password: "{{ account.secret }}"
|
||||
ansible_ssh_private_key_file: "{{ account.private_key_path }}"
|
||||
when: not account.become.ansible_become
|
||||
|
||||
- name: Verify account connectivity(Switch)
|
||||
ansible.builtin.ping:
|
||||
vars:
|
||||
ansible_become: yes
|
||||
ansible_user: "{{ account.become.ansible_user }}"
|
||||
ansible_password: "{{ account.become.ansible_password }}"
|
||||
ansible_ssh_private_key_file: "{{ account.become.ansible_ssh_private_key_file }}"
|
||||
ansible_become_method: "{{ account.become.ansible_become_method }}"
|
||||
ansible_become_user: "{{ account.become.ansible_become_user }}"
|
||||
ansible_become_password: "{{ account.become.ansible_become_password }}"
|
||||
when: account.become.ansible_become
|
||||
|
@ -42,7 +42,6 @@ class VerifyAccountManager(AccountBasePlaybookManager):
|
||||
if host.get('error'):
|
||||
return host
|
||||
|
||||
# host['ssh_args'] = '-o ControlMaster=no -o ControlPersist=no'
|
||||
accounts = asset.accounts.all()
|
||||
accounts = self.get_accounts(account, accounts)
|
||||
inventory_hosts = []
|
||||
@ -64,7 +63,8 @@ class VerifyAccountManager(AccountBasePlaybookManager):
|
||||
'username': account.username,
|
||||
'secret_type': account.secret_type,
|
||||
'secret': secret,
|
||||
'private_key_path': private_key_path
|
||||
'private_key_path': private_key_path,
|
||||
'become': account.get_ansible_become_auth(),
|
||||
}
|
||||
if account.platform.type == 'oracle':
|
||||
h['account']['mode'] = 'sysdba' if account.privileged else None
|
||||
|
@ -13,11 +13,11 @@ class Migration(migrations.Migration):
|
||||
migrations.AlterField(
|
||||
model_name='changesecretautomation',
|
||||
name='secret_strategy',
|
||||
field=models.CharField(choices=[('specific', 'Specific password'), ('random', 'Random')], default='specific', max_length=16, verbose_name='Secret strategy'),
|
||||
field=models.CharField(choices=[('specific', 'Specific secret'), ('random', 'Random generate')], default='specific', max_length=16, verbose_name='Secret strategy'),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='pushaccountautomation',
|
||||
name='secret_strategy',
|
||||
field=models.CharField(choices=[('specific', 'Specific password'), ('random', 'Random')], default='specific', max_length=16, verbose_name='Secret strategy'),
|
||||
field=models.CharField(choices=[('specific', 'Specific secret'), ('random', 'Random generate')], default='specific', max_length=16, verbose_name='Secret strategy'),
|
||||
),
|
||||
]
|
||||
|
@ -29,6 +29,6 @@ class Migration(migrations.Migration):
|
||||
migrations.AddField(
|
||||
model_name='accounttemplate',
|
||||
name='secret_strategy',
|
||||
field=models.CharField(choices=[('specific', 'Specific password'), ('random', 'Random')], default='specific', max_length=16, verbose_name='Secret strategy'),
|
||||
field=models.CharField(choices=[('specific', 'Specific secret'), ('random', 'Random generate')], default='specific', max_length=16, verbose_name='Secret strategy'),
|
||||
),
|
||||
]
|
||||
|
@ -95,6 +95,33 @@ class Account(AbsConnectivity, BaseAccount):
|
||||
""" 排除自己和以自己为 su-from 的账号 """
|
||||
return self.asset.accounts.exclude(id=self.id).exclude(su_from=self)
|
||||
|
||||
@staticmethod
|
||||
def make_account_ansible_vars(su_from):
|
||||
var = {
|
||||
'ansible_user': su_from.username,
|
||||
}
|
||||
if not su_from.secret:
|
||||
return var
|
||||
var['ansible_password'] = su_from.secret
|
||||
var['ansible_ssh_private_key_file'] = su_from.private_key_path
|
||||
return var
|
||||
|
||||
def get_ansible_become_auth(self):
|
||||
su_from = self.su_from
|
||||
platform = self.platform
|
||||
auth = {'ansible_become': False}
|
||||
if not (platform.su_enabled and su_from):
|
||||
return auth
|
||||
|
||||
auth.update(self.make_account_ansible_vars(su_from))
|
||||
become_method = 'sudo' if platform.su_method != 'su' else 'su'
|
||||
password = su_from.secret if become_method == 'sudo' else self.secret
|
||||
auth['ansible_become'] = True
|
||||
auth['ansible_become_method'] = become_method
|
||||
auth['ansible_become_user'] = self.username
|
||||
auth['ansible_become_password'] = password
|
||||
return auth
|
||||
|
||||
|
||||
def replace_history_model_with_mixin():
|
||||
"""
|
||||
|
@ -1,9 +1,9 @@
|
||||
from django.conf import settings
|
||||
from django.db import models
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from accounts.const import AutomationTypes
|
||||
from accounts.models import Account
|
||||
from jumpserver.utils import has_valid_xpack_license
|
||||
from .base import AccountBaseAutomation
|
||||
from .change_secret import ChangeSecretMixin
|
||||
|
||||
@ -41,7 +41,7 @@ class PushAccountAutomation(ChangeSecretMixin, AccountBaseAutomation):
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
self.type = AutomationTypes.push_account
|
||||
if not has_valid_xpack_license():
|
||||
if not settings.XPACK_LICENSE_IS_VALID:
|
||||
self.is_periodic = False
|
||||
super().save(*args, **kwargs)
|
||||
|
||||
|
@ -37,8 +37,9 @@ class VaultManagerMixin(models.Manager):
|
||||
post_save.send(obj.__class__, instance=obj, created=True)
|
||||
return objs
|
||||
|
||||
def bulk_update(self, objs, batch_size=None, ignore_conflicts=False):
|
||||
objs = super().bulk_update(objs, batch_size=batch_size, ignore_conflicts=ignore_conflicts)
|
||||
def bulk_update(self, objs, fields, batch_size=None):
|
||||
fields = ["_secret" if field == "secret" else field for field in fields]
|
||||
super().bulk_update(objs, fields, batch_size=batch_size)
|
||||
for obj in objs:
|
||||
post_save.send(obj.__class__, instance=obj, created=False)
|
||||
return objs
|
||||
|
@ -49,8 +49,7 @@ class AccountTemplate(BaseAccount, SecretWithRandomMixin):
|
||||
).first()
|
||||
return account
|
||||
|
||||
@staticmethod
|
||||
def bulk_update_accounts(accounts, data):
|
||||
def bulk_update_accounts(self, accounts):
|
||||
history_model = Account.history.model
|
||||
account_ids = accounts.values_list('id', flat=True)
|
||||
history_accounts = history_model.objects.filter(id__in=account_ids)
|
||||
@ -63,8 +62,7 @@ class AccountTemplate(BaseAccount, SecretWithRandomMixin):
|
||||
for account in accounts:
|
||||
account_id = str(account.id)
|
||||
account.version = account_id_count_map.get(account_id) + 1
|
||||
for k, v in data.items():
|
||||
setattr(account, k, v)
|
||||
account.secret = self.get_secret()
|
||||
Account.objects.bulk_update(accounts, ['version', 'secret'])
|
||||
|
||||
@staticmethod
|
||||
@ -86,7 +84,5 @@ class AccountTemplate(BaseAccount, SecretWithRandomMixin):
|
||||
|
||||
def bulk_sync_account_secret(self, accounts, user_id):
|
||||
""" 批量同步账号密码 """
|
||||
if not accounts:
|
||||
return
|
||||
self.bulk_update_accounts(accounts, {'secret': self.secret})
|
||||
self.bulk_update_accounts(accounts)
|
||||
self.bulk_create_history_accounts(accounts, user_id)
|
||||
|
@ -78,7 +78,8 @@ class AccountCreateUpdateSerializerMixin(serializers.Serializer):
|
||||
def get_template_attr_for_account(template):
|
||||
# Set initial data from template
|
||||
field_names = [
|
||||
'username', 'secret', 'secret_type', 'privileged', 'is_active'
|
||||
'name', 'username', 'secret',
|
||||
'secret_type', 'privileged', 'is_active'
|
||||
]
|
||||
|
||||
attrs = {}
|
||||
|
@ -1,7 +1,9 @@
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework import serializers
|
||||
|
||||
from accounts.models import AccountTemplate, Account
|
||||
from accounts.const import SecretStrategy, SecretType
|
||||
from accounts.models import AccountTemplate
|
||||
from accounts.utils import SecretGenerator
|
||||
from common.serializers import SecretReadableMixin
|
||||
from common.serializers.fields import ObjectRelatedField
|
||||
from .base import BaseAccountSerializer
|
||||
@ -16,9 +18,6 @@ class PasswordRulesSerializer(serializers.Serializer):
|
||||
|
||||
|
||||
class AccountTemplateSerializer(BaseAccountSerializer):
|
||||
is_sync_account = serializers.BooleanField(default=False, write_only=True)
|
||||
_is_sync_account = False
|
||||
|
||||
password_rules = PasswordRulesSerializer(required=False, label=_('Password rules'))
|
||||
su_from = ObjectRelatedField(
|
||||
required=False, queryset=AccountTemplate.objects, allow_null=True,
|
||||
@ -30,7 +29,7 @@ class AccountTemplateSerializer(BaseAccountSerializer):
|
||||
fields = BaseAccountSerializer.Meta.fields + [
|
||||
'secret_strategy', 'password_rules',
|
||||
'auto_push', 'push_params', 'platforms',
|
||||
'is_sync_account', 'su_from'
|
||||
'su_from'
|
||||
]
|
||||
extra_kwargs = {
|
||||
'secret_strategy': {'help_text': _('Secret generation strategy for account creation')},
|
||||
@ -44,34 +43,21 @@ class AccountTemplateSerializer(BaseAccountSerializer):
|
||||
},
|
||||
}
|
||||
|
||||
def sync_accounts_secret(self, instance, diff):
|
||||
if not self._is_sync_account or 'secret' not in diff:
|
||||
@staticmethod
|
||||
def generate_secret(attrs):
|
||||
secret_type = attrs.get('secret_type', SecretType.PASSWORD)
|
||||
secret_strategy = attrs.get('secret_strategy', SecretStrategy.custom)
|
||||
password_rules = attrs.get('password_rules')
|
||||
if secret_strategy != SecretStrategy.random:
|
||||
return
|
||||
query_data = {
|
||||
'source_id': instance.id,
|
||||
'username': instance.username,
|
||||
'secret_type': instance.secret_type
|
||||
}
|
||||
accounts = Account.objects.filter(**query_data)
|
||||
instance.bulk_sync_account_secret(accounts, self.context['request'].user.id)
|
||||
generator = SecretGenerator(secret_strategy, secret_type, password_rules)
|
||||
attrs['secret'] = generator.get_secret()
|
||||
|
||||
def validate(self, attrs):
|
||||
self._is_sync_account = attrs.pop('is_sync_account', None)
|
||||
attrs = super().validate(attrs)
|
||||
self.generate_secret(attrs)
|
||||
return attrs
|
||||
|
||||
def update(self, instance, validated_data):
|
||||
diff = {
|
||||
k: v for k, v in validated_data.items()
|
||||
if getattr(instance, k, None) != v
|
||||
}
|
||||
instance = super().update(instance, validated_data)
|
||||
if {'username', 'secret_type'} & set(diff.keys()):
|
||||
Account.objects.filter(source_id=instance.id).update(source_id=None)
|
||||
else:
|
||||
self.sync_accounts_secret(instance, diff)
|
||||
return instance
|
||||
|
||||
|
||||
class AccountTemplateSecretSerializer(SecretReadableMixin, AccountTemplateSerializer):
|
||||
class Meta(AccountTemplateSerializer.Meta):
|
||||
|
@ -1,5 +1,6 @@
|
||||
from .backup_account import *
|
||||
from .automation import *
|
||||
from .push_account import *
|
||||
from .verify_account import *
|
||||
from .backup_account import *
|
||||
from .gather_accounts import *
|
||||
from .push_account import *
|
||||
from .template import *
|
||||
from .verify_account import *
|
||||
|
60
apps/accounts/tasks/template.py
Normal file
60
apps/accounts/tasks/template.py
Normal file
@ -0,0 +1,60 @@
|
||||
from datetime import datetime
|
||||
|
||||
from celery import shared_task
|
||||
from django.shortcuts import get_object_or_404
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from orgs.utils import tmp_to_root_org, tmp_to_org
|
||||
|
||||
|
||||
@shared_task(
|
||||
verbose_name=_('Template sync info to related accounts'),
|
||||
activity_callback=lambda self, template_id, *args, **kwargs: (template_id, None)
|
||||
)
|
||||
def template_sync_related_accounts(template_id, user_id=None):
|
||||
from accounts.models import Account, AccountTemplate
|
||||
with tmp_to_root_org():
|
||||
template = get_object_or_404(AccountTemplate, id=template_id)
|
||||
org_id = template.org_id
|
||||
|
||||
with tmp_to_org(org_id):
|
||||
accounts = Account.objects.filter(source_id=template_id)
|
||||
if not accounts:
|
||||
print('\033[35m>>> 没有需要同步的账号, 结束任务')
|
||||
print('\033[0m')
|
||||
return
|
||||
|
||||
failed, succeeded = 0, 0
|
||||
succeeded_account_ids = []
|
||||
name = template.name
|
||||
username = template.username
|
||||
secret_type = template.secret_type
|
||||
print(f'\033[32m>>> 开始同步模版名称、用户名、密钥类型到相关联的账号 ({datetime.now().strftime("%Y-%m-%d %H:%M:%S")})')
|
||||
with tmp_to_org(org_id):
|
||||
for account in accounts:
|
||||
account.name = name
|
||||
account.username = username
|
||||
account.secret_type = secret_type
|
||||
try:
|
||||
account.save(update_fields=['name', 'username', 'secret_type'])
|
||||
succeeded += 1
|
||||
succeeded_account_ids.append(account.id)
|
||||
except Exception as e:
|
||||
account.source_id = None
|
||||
account.save(update_fields=['source_id'])
|
||||
print(f'\033[31m- 同步失败: [{account}] 原因: [{e}]')
|
||||
failed += 1
|
||||
accounts = Account.objects.filter(id__in=succeeded_account_ids)
|
||||
if accounts:
|
||||
print(f'\033[33m>>> 批量更新账号密文 ({datetime.now().strftime("%Y-%m-%d %H:%M:%S")})')
|
||||
template.bulk_sync_account_secret(accounts, user_id)
|
||||
|
||||
total = succeeded + failed
|
||||
print(
|
||||
f'\033[33m>>> 同步完成:, '
|
||||
f'共计: {total}, '
|
||||
f'成功: {succeeded}, '
|
||||
f'失败: {failed}, '
|
||||
f'({datetime.now().strftime("%Y-%m-%d %H:%M:%S")}) '
|
||||
)
|
||||
print('\033[0m')
|
@ -16,7 +16,7 @@ class SecretGenerator:
|
||||
|
||||
@staticmethod
|
||||
def generate_ssh_key():
|
||||
private_key, public_key = ssh_key_gen()
|
||||
private_key, __ = ssh_key_gen()
|
||||
return private_key
|
||||
|
||||
def generate_password(self):
|
||||
|
@ -10,7 +10,7 @@ __all__ = ['CommandFilterACLViewSet', 'CommandGroupViewSet']
|
||||
|
||||
class CommandGroupViewSet(OrgBulkModelViewSet):
|
||||
model = models.CommandGroup
|
||||
filterset_fields = ('name',)
|
||||
filterset_fields = ('name', 'command_filters')
|
||||
search_fields = filterset_fields
|
||||
serializer_class = serializers.CommandGroupSerializer
|
||||
|
||||
|
@ -7,3 +7,4 @@ class ActionChoices(models.TextChoices):
|
||||
accept = 'accept', _('Accept')
|
||||
review = 'review', _('Review')
|
||||
warning = 'warning', _('Warning')
|
||||
notice = 'notice', _('Notifications')
|
||||
|
@ -0,0 +1,18 @@
|
||||
# Generated by Django 4.1.10 on 2023-10-18 10:44
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('acls', '0017_alter_connectmethodacl_options'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name='commandfilteracl',
|
||||
name='command_groups',
|
||||
field=models.ManyToManyField(related_name='command_filters', to='acls.commandgroup', verbose_name='Command group'),
|
||||
),
|
||||
]
|
@ -93,7 +93,10 @@ class CommandGroup(JMSOrgBaseModel):
|
||||
|
||||
|
||||
class CommandFilterACL(UserAssetAccountBaseACL):
|
||||
command_groups = models.ManyToManyField(CommandGroup, verbose_name=_('Command group'))
|
||||
command_groups = models.ManyToManyField(
|
||||
CommandGroup, verbose_name=_('Command group'),
|
||||
related_name='command_filters'
|
||||
)
|
||||
|
||||
class Meta(UserAssetAccountBaseACL.Meta):
|
||||
abstract = False
|
||||
|
68
apps/acls/notifications.py
Normal file
68
apps/acls/notifications.py
Normal file
@ -0,0 +1,68 @@
|
||||
from django.template.loader import render_to_string
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from assets.models import Asset
|
||||
from audits.models import UserLoginLog
|
||||
from notifications.notifications import UserMessage
|
||||
from users.models import User
|
||||
|
||||
|
||||
class UserLoginReminderMsg(UserMessage):
|
||||
subject = _('User login reminder')
|
||||
|
||||
def __init__(self, user, user_log: UserLoginLog):
|
||||
self.user_log = user_log
|
||||
super().__init__(user)
|
||||
|
||||
def get_html_msg(self) -> dict:
|
||||
user_log = self.user_log
|
||||
|
||||
context = {
|
||||
'ip': user_log.ip,
|
||||
'city': user_log.city,
|
||||
'username': user_log.username,
|
||||
'recipient': self.user.username,
|
||||
'user_agent': user_log.user_agent,
|
||||
}
|
||||
message = render_to_string('acls/user_login_reminder.html', context)
|
||||
|
||||
return {
|
||||
'subject': str(self.subject),
|
||||
'message': message
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def gen_test_msg(cls):
|
||||
user = User.objects.first()
|
||||
user_log = UserLoginLog.objects.first()
|
||||
return cls(user, user_log)
|
||||
|
||||
|
||||
class AssetLoginReminderMsg(UserMessage):
|
||||
subject = _('Asset login reminder')
|
||||
|
||||
def __init__(self, user, asset: Asset, login_user: User, account_username):
|
||||
self.asset = asset
|
||||
self.login_user = login_user
|
||||
self.account_username = account_username
|
||||
super().__init__(user)
|
||||
|
||||
def get_html_msg(self) -> dict:
|
||||
context = {
|
||||
'recipient': self.user.username,
|
||||
'username': self.login_user.username,
|
||||
'asset': str(self.asset),
|
||||
'account': self.account_username,
|
||||
}
|
||||
message = render_to_string('acls/asset_login_reminder.html', context)
|
||||
|
||||
return {
|
||||
'subject': str(self.subject),
|
||||
'message': message
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def gen_test_msg(cls):
|
||||
user = User.objects.first()
|
||||
asset = Asset.objects.first()
|
||||
return cls(user, asset, user)
|
@ -1,9 +1,9 @@
|
||||
from django.conf import settings
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework import serializers
|
||||
|
||||
from acls.models.base import BaseACL
|
||||
from common.serializers.fields import JSONManyToManyField, LabeledChoiceField
|
||||
from jumpserver.utils import has_valid_xpack_license
|
||||
from orgs.models import Organization
|
||||
from ..const import ActionChoices
|
||||
|
||||
@ -68,7 +68,7 @@ class ActionAclSerializer(serializers.Serializer):
|
||||
field_action = self.fields.get("action")
|
||||
if not field_action:
|
||||
return
|
||||
if not has_valid_xpack_license():
|
||||
if not settings.XPACK_LICENSE_IS_VALID:
|
||||
field_action._choices.pop(ActionChoices.review, None)
|
||||
for choice in self.Meta.action_choices_exclude:
|
||||
field_action._choices.pop(choice, None)
|
||||
|
@ -8,6 +8,7 @@ from orgs.mixins.serializers import BulkOrgResourceModelSerializer
|
||||
from orgs.utils import tmp_to_root_org
|
||||
from terminal.models import Session
|
||||
from .base import BaseUserAssetAccountACLSerializer as BaseSerializer
|
||||
from ..const import ActionChoices
|
||||
|
||||
__all__ = ["CommandFilterACLSerializer", "CommandGroupSerializer", "CommandReviewSerializer"]
|
||||
|
||||
@ -31,8 +32,7 @@ class CommandFilterACLSerializer(BaseSerializer, BulkOrgResourceModelSerializer)
|
||||
class Meta(BaseSerializer.Meta):
|
||||
model = CommandFilterACL
|
||||
fields = BaseSerializer.Meta.fields + ['command_groups']
|
||||
# 默认都支持所有的 actions
|
||||
action_choices_exclude = []
|
||||
action_choices_exclude = [ActionChoices.notice]
|
||||
|
||||
|
||||
class CommandReviewSerializer(serializers.Serializer):
|
||||
|
@ -1,7 +1,7 @@
|
||||
from orgs.mixins.serializers import BulkOrgResourceModelSerializer
|
||||
from .base import BaseUserAssetAccountACLSerializer as BaseSerializer
|
||||
from ..models import ConnectMethodACL
|
||||
from ..const import ActionChoices
|
||||
from ..models import ConnectMethodACL
|
||||
|
||||
__all__ = ["ConnectMethodACLSerializer"]
|
||||
|
||||
@ -14,5 +14,5 @@ class ConnectMethodACLSerializer(BaseSerializer, BulkOrgResourceModelSerializer)
|
||||
if i not in ['assets', 'accounts']
|
||||
]
|
||||
action_choices_exclude = BaseSerializer.Meta.action_choices_exclude + [
|
||||
ActionChoices.review, ActionChoices.accept
|
||||
ActionChoices.review, ActionChoices.accept, ActionChoices.notice
|
||||
]
|
||||
|
13
apps/acls/templates/acls/asset_login_reminder.html
Normal file
13
apps/acls/templates/acls/asset_login_reminder.html
Normal file
@ -0,0 +1,13 @@
|
||||
{% load i18n %}
|
||||
|
||||
<h3>{% trans 'Respectful' %}{{ recipient }},</h3>
|
||||
<hr>
|
||||
<p><strong>{% trans 'Username' %}:</strong> [{{ username }}]</p>
|
||||
<p><strong>{% trans 'Assets' %}:</strong> [{{ asset }}]</p>
|
||||
<p><strong>{% trans 'Account' %}:</strong> [{{ account }}]</p>
|
||||
<hr>
|
||||
|
||||
<p>{% trans 'The user has just logged in to the asset. Please ensure that this is an authorized operation. If you suspect that this is an unauthorized access, please take appropriate measures immediately.' %}</p>
|
||||
|
||||
<p>{% trans 'Thank you' %}!</p>
|
||||
|
14
apps/acls/templates/acls/user_login_reminder.html
Normal file
14
apps/acls/templates/acls/user_login_reminder.html
Normal file
@ -0,0 +1,14 @@
|
||||
{% load i18n %}
|
||||
|
||||
<h3>{% trans 'Respectful' %}{{ recipient }},</h3>
|
||||
<hr>
|
||||
<p><strong>{% trans 'Username' %}:</strong> [{{ username }}]</p>
|
||||
<p><strong>IP:</strong> [{{ ip }}]</p>
|
||||
<p><strong>{% trans 'Login city' %}:</strong> [{{ city }}]</p>
|
||||
<p><strong>{% trans 'User agent' %}:</strong> [{{ user_agent }}]</p>
|
||||
<hr>
|
||||
|
||||
<p>{% trans 'The user has just successfully logged into the system. Please ensure that this is an authorized operation. If you suspect that this is an unauthorized access, please take appropriate measures immediately.' %}</p>
|
||||
|
||||
<p>{% trans 'Thank you' %}!</p>
|
||||
|
@ -1,9 +1,12 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
from collections import defaultdict
|
||||
|
||||
import django_filters
|
||||
from django.db.models import Q
|
||||
from django.shortcuts import get_object_or_404
|
||||
from django.utils.translation import gettext as _
|
||||
from rest_framework import status
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.status import HTTP_200_OK
|
||||
@ -12,7 +15,7 @@ from accounts.tasks import push_accounts_to_assets_task, verify_accounts_connect
|
||||
from assets import serializers
|
||||
from assets.exceptions import NotSupportedTemporarilyError
|
||||
from assets.filters import IpInFilterBackend, LabelFilterBackend, NodeFilterBackend
|
||||
from assets.models import Asset, Gateway, Platform
|
||||
from assets.models import Asset, Gateway, Platform, Protocol
|
||||
from assets.tasks import test_assets_connectivity_manual, update_assets_hardware_info_manual
|
||||
from common.api import SuggestionMixin
|
||||
from common.drf.filters import BaseFilterSet, AttrRulesFilterBackend
|
||||
@ -115,6 +118,7 @@ class AssetViewSet(SuggestionMixin, NodeFilterMixin, OrgBulkModelViewSet):
|
||||
("gateways", "assets.view_gateway"),
|
||||
("spec_info", "assets.view_asset"),
|
||||
("gathered_info", "assets.view_asset"),
|
||||
("sync_platform_protocols", "assets.change_asset"),
|
||||
)
|
||||
extra_filter_backends = [
|
||||
LabelFilterBackend, IpInFilterBackend,
|
||||
@ -152,6 +156,39 @@ class AssetViewSet(SuggestionMixin, NodeFilterMixin, OrgBulkModelViewSet):
|
||||
gateways = asset.domain.gateways
|
||||
return self.get_paginated_response_from_queryset(gateways)
|
||||
|
||||
@action(methods=['post'], detail=False, url_path='sync-platform-protocols')
|
||||
def sync_platform_protocols(self, request, *args, **kwargs):
|
||||
platform_id = request.data.get('platform_id')
|
||||
platform = get_object_or_404(Platform, pk=platform_id)
|
||||
assets = platform.assets.all()
|
||||
|
||||
platform_protocols = {
|
||||
p['name']: p['port']
|
||||
for p in platform.protocols.values('name', 'port')
|
||||
}
|
||||
asset_protocols_map = defaultdict(set)
|
||||
protocols = assets.prefetch_related('protocols').values_list(
|
||||
'id', 'protocols__name'
|
||||
)
|
||||
for asset_id, protocol in protocols:
|
||||
asset_id = str(asset_id)
|
||||
asset_protocols_map[asset_id].add(protocol)
|
||||
objs = []
|
||||
for asset_id, protocols in asset_protocols_map.items():
|
||||
protocol_names = set(platform_protocols) - protocols
|
||||
if not protocol_names:
|
||||
continue
|
||||
for name in protocol_names:
|
||||
objs.append(
|
||||
Protocol(
|
||||
name=name,
|
||||
port=platform_protocols[name],
|
||||
asset_id=asset_id,
|
||||
)
|
||||
)
|
||||
Protocol.objects.bulk_create(objs)
|
||||
return Response(status=status.HTTP_200_OK)
|
||||
|
||||
def create(self, request, *args, **kwargs):
|
||||
if request.path.find('/api/v1/assets/assets/') > -1:
|
||||
error = _('Cannot create asset directly, you should create a host or other')
|
||||
|
@ -10,6 +10,6 @@
|
||||
login_user: "{{ jms_account.username }}"
|
||||
login_password: "{{ jms_account.secret }}"
|
||||
login_host: "{{ jms_asset.address }}"
|
||||
login_port: "{{ jms_asset.port }}"
|
||||
login_port: "{{ jms_asset.protocols | selectattr('name', 'equalto', 'rdp') | map(attribute='port') | first }}"
|
||||
login_secret_type: "{{ jms_account.secret_type }}"
|
||||
login_private_key_path: "{{ jms_account.private_key_path }}"
|
||||
|
@ -1,9 +1,8 @@
|
||||
from django.conf import settings
|
||||
from django.db import models
|
||||
from django.db.models import TextChoices
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from jumpserver.utils import has_valid_xpack_license
|
||||
|
||||
|
||||
class Type:
|
||||
def __init__(self, label, value):
|
||||
@ -113,7 +112,7 @@ class BaseType(TextChoices):
|
||||
|
||||
@classmethod
|
||||
def get_choices(cls):
|
||||
if not has_valid_xpack_license():
|
||||
if not settings.XPACK_LICENSE_IS_VALID:
|
||||
return [
|
||||
(tp.value, tp.label)
|
||||
for tp in cls.get_community_types()
|
||||
|
@ -7,6 +7,7 @@ class DatabaseTypes(BaseType):
|
||||
POSTGRESQL = 'postgresql', 'PostgreSQL'
|
||||
ORACLE = 'oracle', 'Oracle'
|
||||
SQLSERVER = 'sqlserver', 'SQLServer'
|
||||
DB2 = 'db2', 'DB2'
|
||||
CLICKHOUSE = 'clickhouse', 'ClickHouse'
|
||||
MONGODB = 'mongodb', 'MongoDB'
|
||||
REDIS = 'redis', 'Redis'
|
||||
@ -45,6 +46,15 @@ class DatabaseTypes(BaseType):
|
||||
'change_secret_enabled': False,
|
||||
'push_account_enabled': False,
|
||||
},
|
||||
cls.DB2: {
|
||||
'ansible_enabled': False,
|
||||
'ping_enabled': False,
|
||||
'gather_facts_enabled': False,
|
||||
'gather_accounts_enabled': False,
|
||||
'verify_account_enabled': False,
|
||||
'change_secret_enabled': False,
|
||||
'push_account_enabled': False,
|
||||
},
|
||||
cls.CLICKHOUSE: {
|
||||
'ansible_enabled': False,
|
||||
'ping_enabled': False,
|
||||
@ -73,6 +83,7 @@ class DatabaseTypes(BaseType):
|
||||
cls.POSTGRESQL: [{'name': 'PostgreSQL'}],
|
||||
cls.ORACLE: [{'name': 'Oracle'}],
|
||||
cls.SQLSERVER: [{'name': 'SQLServer'}],
|
||||
cls.DB2: [{'name': 'DB2'}],
|
||||
cls.CLICKHOUSE: [{'name': 'ClickHouse'}],
|
||||
cls.MONGODB: [{'name': 'MongoDB'}],
|
||||
cls.REDIS: [
|
||||
|
@ -22,6 +22,7 @@ class Protocol(ChoicesMixin, models.TextChoices):
|
||||
oracle = 'oracle', 'Oracle'
|
||||
postgresql = 'postgresql', 'PostgreSQL'
|
||||
sqlserver = 'sqlserver', 'SQLServer'
|
||||
db2 = 'db2', 'DB2'
|
||||
clickhouse = 'clickhouse', 'ClickHouse'
|
||||
redis = 'redis', 'Redis'
|
||||
mongodb = 'mongodb', 'MongoDB'
|
||||
@ -170,6 +171,12 @@ class Protocol(ChoicesMixin, models.TextChoices):
|
||||
}
|
||||
}
|
||||
},
|
||||
cls.db2: {
|
||||
'port': 5000,
|
||||
'required': True,
|
||||
'secret_types': ['password'],
|
||||
'xpack': True,
|
||||
},
|
||||
cls.clickhouse: {
|
||||
'port': 9000,
|
||||
'required': True,
|
||||
@ -269,7 +276,7 @@ class Protocol(ChoicesMixin, models.TextChoices):
|
||||
}
|
||||
}
|
||||
}
|
||||
if settings.XPACK_ENABLED:
|
||||
if settings.XPACK_LICENSE_IS_VALID:
|
||||
choices = protocols[cls.chatgpt]['setting']['api_mode']['choices']
|
||||
choices.extend([
|
||||
('gpt-4', 'GPT-4'),
|
||||
|
@ -25,7 +25,7 @@ def migrate_asset_accounts(apps, schema_editor):
|
||||
count += len(auth_books)
|
||||
# auth book 和 account 相同的属性
|
||||
same_attrs = [
|
||||
'id', 'username', 'comment', 'date_created', 'date_updated',
|
||||
'username', 'comment', 'date_created', 'date_updated',
|
||||
'created_by', 'asset_id', 'org_id',
|
||||
]
|
||||
# 认证的属性,可能是 auth_book 的,可能是 system_user 的
|
||||
|
26
apps/assets/migrations/0124_auto_20231007_1437.py
Normal file
26
apps/assets/migrations/0124_auto_20231007_1437.py
Normal file
@ -0,0 +1,26 @@
|
||||
# Generated by Django 4.1.10 on 2023-10-07 06:37
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
def add_db2_platform(apps, schema_editor):
|
||||
platform_cls = apps.get_model('assets', 'Platform')
|
||||
automation_cls = apps.get_model('assets', 'PlatformAutomation')
|
||||
platform = platform_cls.objects.create(
|
||||
name='DB2', internal=True, category='database', type='db2',
|
||||
domain_enabled=True, su_enabled=False, comment='DB2',
|
||||
created_by='System', updated_by='System',
|
||||
)
|
||||
platform.protocols.create(name='db2', port=5000, primary=True, setting={})
|
||||
automation_cls.objects.create(ansible_enabled=False, platform=platform)
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('assets', '0123_device_automation_ansible_enabled'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.RunPython(add_db2_platform)
|
||||
]
|
21
apps/assets/migrations/0125_auto_20231011_1053.py
Normal file
21
apps/assets/migrations/0125_auto_20231011_1053.py
Normal file
@ -0,0 +1,21 @@
|
||||
# Generated by Django 4.1.10 on 2023-10-11 02:53
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
def change_windows_ping_method(apps, schema_editor):
|
||||
platform_automation_cls = apps.get_model('assets', 'PlatformAutomation')
|
||||
automations = platform_automation_cls.objects.filter(platform__name__in=['Windows', 'Windows2016'])
|
||||
automations.update(ping_method='ping_by_rdp')
|
||||
automations.update(verify_account_method='verify_account_by_rdp')
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('assets', '0124_auto_20231007_1437'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.RunPython(change_windows_ping_method)
|
||||
]
|
@ -402,12 +402,7 @@ class NodeAssetsMixin(NodeAllAssetsMappingMixin):
|
||||
return Asset.objects.filter(q).distinct()
|
||||
|
||||
def get_assets_amount(self):
|
||||
q = Q(node__key__startswith=f'{self.key}:') | Q(node__key=self.key)
|
||||
return self.assets.through.objects.filter(q).count()
|
||||
|
||||
def get_assets_account_by_children(self):
|
||||
children = self.get_all_children().values_list()
|
||||
return self.assets.through.objects.filter(node_id__in=children).count()
|
||||
return self.get_all_assets().count()
|
||||
|
||||
@classmethod
|
||||
def get_node_all_assets_by_key_v2(cls, key):
|
||||
|
@ -175,6 +175,8 @@ class AssetSerializer(BulkOrgResourceModelSerializer, WritableNestedModelSeriali
|
||||
protocols = self.initial_data.get('protocols')
|
||||
if protocols is not None:
|
||||
return
|
||||
if getattr(self, 'instance', None):
|
||||
return
|
||||
|
||||
protocols_required, protocols_default = self._get_protocols_required_default()
|
||||
protocol_map = {str(protocol.id): protocol for protocol in protocols_required + protocols_default}
|
||||
@ -281,14 +283,52 @@ class AssetSerializer(BulkOrgResourceModelSerializer, WritableNestedModelSeriali
|
||||
return protocols_data_map.values()
|
||||
|
||||
@staticmethod
|
||||
def accounts_create(accounts_data, asset):
|
||||
def update_account_su_from(accounts, include_su_from_accounts):
|
||||
if not include_su_from_accounts:
|
||||
return
|
||||
name_map = {account.name: account for account in accounts}
|
||||
username_secret_type_map = {
|
||||
(account.username, account.secret_type): account for account in accounts
|
||||
}
|
||||
|
||||
for name, username_secret_type in include_su_from_accounts.items():
|
||||
account = name_map.get(name)
|
||||
if not account:
|
||||
continue
|
||||
su_from_account = username_secret_type_map.get(username_secret_type)
|
||||
if su_from_account:
|
||||
account.su_from = su_from_account
|
||||
account.save()
|
||||
|
||||
def accounts_create(self, accounts_data, asset):
|
||||
from accounts.models import AccountTemplate
|
||||
if not accounts_data:
|
||||
return
|
||||
|
||||
if not isinstance(accounts_data[0], dict):
|
||||
raise serializers.ValidationError({'accounts': _("Invalid data")})
|
||||
|
||||
su_from_name_username_secret_type_map = {}
|
||||
for data in accounts_data:
|
||||
data['asset'] = asset.id
|
||||
name = data.get('name')
|
||||
su_from = data.pop('su_from', None)
|
||||
template_id = data.get('template', None)
|
||||
if template_id:
|
||||
template = AccountTemplate.objects.get(id=template_id)
|
||||
if template and template.su_from:
|
||||
su_from_name_username_secret_type_map[template.name] = (
|
||||
template.su_from.username, template.su_from.secret_type
|
||||
)
|
||||
elif isinstance(su_from, dict):
|
||||
su_from = Account.objects.get(id=su_from.get('id'))
|
||||
su_from_name_username_secret_type_map[name] = (
|
||||
su_from.username, su_from.secret_type
|
||||
)
|
||||
s = AssetAccountSerializer(data=accounts_data, many=True)
|
||||
s.is_valid(raise_exception=True)
|
||||
s.save()
|
||||
accounts = s.save()
|
||||
self.update_account_su_from(accounts, su_from_name_username_secret_type_map)
|
||||
|
||||
@atomic
|
||||
def create(self, validated_data):
|
||||
@ -298,10 +338,37 @@ class AssetSerializer(BulkOrgResourceModelSerializer, WritableNestedModelSeriali
|
||||
self.perform_nodes_display_create(instance, nodes_display)
|
||||
return instance
|
||||
|
||||
@staticmethod
|
||||
def sync_platform_protocols(instance, old_platform):
|
||||
platform = instance.platform
|
||||
|
||||
if str(old_platform.id) == str(instance.platform_id):
|
||||
return
|
||||
|
||||
platform_protocols = {
|
||||
p['name']: p['port']
|
||||
for p in platform.protocols.values('name', 'port')
|
||||
}
|
||||
|
||||
protocols = set(instance.protocols.values_list('name', flat=True))
|
||||
protocol_names = set(platform_protocols) - protocols
|
||||
objs = []
|
||||
for name in protocol_names:
|
||||
objs.append(
|
||||
Protocol(
|
||||
name=name,
|
||||
port=platform_protocols[name],
|
||||
asset_id=instance.id,
|
||||
)
|
||||
)
|
||||
Protocol.objects.bulk_create(objs)
|
||||
|
||||
@atomic
|
||||
def update(self, instance, validated_data):
|
||||
old_platform = instance.platform
|
||||
nodes_display = validated_data.pop('nodes_display', '')
|
||||
instance = super().update(instance, validated_data)
|
||||
self.sync_platform_protocols(instance, old_platform)
|
||||
self.perform_nodes_display_create(instance, nodes_display)
|
||||
return instance
|
||||
|
||||
|
@ -1,8 +1,7 @@
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework import serializers
|
||||
from rest_framework.serializers import ValidationError
|
||||
|
||||
from assets.models import Database
|
||||
from assets.models import Database, Platform
|
||||
from assets.serializers.gateway import GatewayWithAccountSecretSerializer
|
||||
from .common import AssetSerializer
|
||||
|
||||
@ -20,13 +19,44 @@ class DatabaseSerializer(AssetSerializer):
|
||||
]
|
||||
fields = AssetSerializer.Meta.fields + extra_fields
|
||||
|
||||
def validate(self, attrs):
|
||||
platform = attrs.get('platform')
|
||||
db_type_required = ('mongodb', 'postgresql')
|
||||
if platform and getattr(platform, 'type') in db_type_required \
|
||||
and not attrs.get('db_name'):
|
||||
raise ValidationError({'db_name': _('This field is required.')})
|
||||
return attrs
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.set_db_name_required()
|
||||
|
||||
def get_platform(self):
|
||||
platform = None
|
||||
platform_id = None
|
||||
|
||||
if getattr(self, 'initial_data', None):
|
||||
platform_id = self.initial_data.get('platform')
|
||||
if isinstance(platform_id, dict):
|
||||
platform_id = platform_id.get('id') or platform_id.get('pk')
|
||||
if not platform_id and self.instance:
|
||||
platform = self.instance.platform
|
||||
elif getattr(self, 'instance', None):
|
||||
if isinstance(self.instance, list):
|
||||
return
|
||||
platform = self.instance.platform
|
||||
elif self.context.get('request'):
|
||||
platform_id = self.context['request'].query_params.get('platform')
|
||||
|
||||
if not platform and platform_id:
|
||||
platform = Platform.objects.filter(id=platform_id).first()
|
||||
return platform
|
||||
|
||||
def set_db_name_required(self):
|
||||
db_field = self.fields.get('db_name')
|
||||
if not db_field:
|
||||
return
|
||||
|
||||
platform = self.get_platform()
|
||||
if not platform:
|
||||
return
|
||||
|
||||
if platform.type in ['mysql', 'mariadb']:
|
||||
db_field.required = False
|
||||
db_field.allow_blank = True
|
||||
db_field.allow_null = True
|
||||
|
||||
|
||||
class DatabaseWithGatewaySerializer(DatabaseSerializer):
|
||||
|
@ -30,8 +30,9 @@ class NodeSerializer(BulkOrgResourceModelSerializer):
|
||||
if '/' in data:
|
||||
error = _("Can't contains: " + "/")
|
||||
raise serializers.ValidationError(error)
|
||||
if self.instance:
|
||||
instance = self.instance
|
||||
view = self.context['view']
|
||||
instance = self.instance or getattr(view, 'instance', None)
|
||||
if instance:
|
||||
siblings = instance.get_siblings()
|
||||
else:
|
||||
instance = Node.org_root()
|
||||
|
@ -6,7 +6,6 @@ from importlib import import_module
|
||||
from django.conf import settings
|
||||
from django.db.models import F, Value, CharField, Q
|
||||
from django.http import HttpResponse, FileResponse
|
||||
from django.utils import timezone
|
||||
from django.utils.encoding import escape_uri_path
|
||||
from rest_framework import generics
|
||||
from rest_framework import status
|
||||
@ -185,6 +184,8 @@ class ResourceActivityAPIView(generics.ListAPIView):
|
||||
'r_user', 'r_action', 'r_type'
|
||||
)
|
||||
org_q = Q(org_id=Organization.SYSTEM_ID) | Q(org_id=current_org.id)
|
||||
if resource_id:
|
||||
org_q |= Q(org_id='') | Q(org_id=Organization.ROOT_ID)
|
||||
with tmp_to_root_org():
|
||||
qs1 = self.get_operate_log_qs(fields, limit, org_q, resource_id=resource_id)
|
||||
qs2 = self.get_activity_log_qs(fields, limit, org_q, resource_id=resource_id)
|
||||
@ -216,11 +217,10 @@ class OperateLogViewSet(OrgReadonlyModelViewSet):
|
||||
return super().get_serializer_class()
|
||||
|
||||
def get_queryset(self):
|
||||
org_q = Q(org_id=current_org.id)
|
||||
qs = OperateLog.objects.all()
|
||||
if self.is_action_detail:
|
||||
org_q |= Q(org_id=Organization.SYSTEM_ID)
|
||||
with tmp_to_root_org():
|
||||
qs = OperateLog.objects.filter(org_q)
|
||||
with tmp_to_root_org():
|
||||
qs |= OperateLog.objects.filter(org_id=Organization.SYSTEM_ID)
|
||||
es_config = settings.OPERATE_LOG_ELASTICSEARCH_CONFIG
|
||||
if es_config:
|
||||
engine_mod = import_module(TYPE_ENGINE_MAPPING['es'])
|
||||
@ -257,9 +257,8 @@ class UserSessionViewSet(CommonApiMixin, viewsets.ModelViewSet):
|
||||
serializer_class = UserSessionSerializer
|
||||
filterset_fields = ['id', 'ip', 'city', 'type']
|
||||
search_fields = ['id', 'ip', 'city']
|
||||
|
||||
rbac_perms = {
|
||||
'offline': ['users.offline_usersession']
|
||||
'offline': ['audits.offline_usersession']
|
||||
}
|
||||
|
||||
@property
|
||||
@ -269,9 +268,7 @@ class UserSessionViewSet(CommonApiMixin, viewsets.ModelViewSet):
|
||||
|
||||
def get_queryset(self):
|
||||
keys = UserSession.get_keys()
|
||||
queryset = UserSession.objects.filter(
|
||||
date_expired__gt=timezone.now(), key__in=keys
|
||||
)
|
||||
queryset = UserSession.objects.filter(key__in=keys)
|
||||
if current_org.is_root():
|
||||
return queryset
|
||||
user_ids = self.org_user_ids
|
||||
@ -281,7 +278,9 @@ class UserSessionViewSet(CommonApiMixin, viewsets.ModelViewSet):
|
||||
@action(['POST'], detail=False, url_path='offline')
|
||||
def offline(self, request, *args, **kwargs):
|
||||
ids = request.data.get('ids', [])
|
||||
queryset = self.get_queryset().exclude(key=request.session.session_key).filter(id__in=ids)
|
||||
queryset = self.get_queryset()
|
||||
session_key = request.session.session_key
|
||||
queryset = queryset.exclude(key=session_key).filter(id__in=ids)
|
||||
if not queryset.exists():
|
||||
return Response(status=status.HTTP_200_OK)
|
||||
|
||||
|
@ -58,7 +58,7 @@ class OperateLogStore(object):
|
||||
return diff_list
|
||||
|
||||
def save(self, **kwargs):
|
||||
log_id = kwargs.get('id', '')
|
||||
log_id = kwargs.pop('id', None)
|
||||
before = kwargs.pop('before') or {}
|
||||
after = kwargs.pop('after') or {}
|
||||
|
||||
|
@ -30,6 +30,13 @@ class ActionChoices(TextChoices):
|
||||
login = "login", _("Login")
|
||||
change_auth = "change_password", _("Change password")
|
||||
|
||||
accept = 'accept', _('Accept')
|
||||
review = 'review', _('Review')
|
||||
notice = 'notice', _('Notifications')
|
||||
reject = 'reject', _('Reject')
|
||||
approve = 'approve', _('Approve')
|
||||
close = 'close', _('Close')
|
||||
|
||||
|
||||
class LoginTypeChoices(TextChoices):
|
||||
web = "W", _("Web")
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Generated by Django 4.1.10 on 2023-09-06 05:31
|
||||
|
||||
from django.db import migrations, models
|
||||
import django.utils.timezone
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
@ -19,7 +19,7 @@ class Migration(migrations.Migration):
|
||||
migrations.AlterField(
|
||||
model_name='operatelog',
|
||||
name='action',
|
||||
field=models.CharField(choices=[('view', 'View'), ('update', 'Update'), ('delete', 'Delete'), ('create', 'Create'), ('download', 'Download'), ('connect', 'Connect'), ('login', 'Login'), ('change_password', 'Change password')], max_length=16, verbose_name='Action'),
|
||||
field=models.CharField(choices=[('view', 'View'), ('update', 'Update'), ('delete', 'Delete'), ('create', 'Create'), ('download', 'Download'), ('connect', 'Connect'), ('login', 'Login'), ('change_password', 'Change password'), ('accept', 'Accept'), ('review', 'Review'), ('notice', 'Notifications'), ('reject', 'Reject'), ('approve', 'Approve'), ('close', 'Close')], max_length=16, verbose_name='Action'),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='userloginlog',
|
||||
|
@ -0,0 +1,17 @@
|
||||
# Generated by Django 4.1.10 on 2023-10-18 08:01
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('audits', '0024_usersession'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.RemoveField(
|
||||
model_name='usersession',
|
||||
name='date_expired',
|
||||
),
|
||||
]
|
@ -1,5 +1,6 @@
|
||||
import os
|
||||
import uuid
|
||||
from datetime import timedelta
|
||||
from importlib import import_module
|
||||
|
||||
from django.conf import settings
|
||||
@ -10,7 +11,7 @@ from django.utils import timezone
|
||||
from django.utils.translation import gettext, gettext_lazy as _
|
||||
|
||||
from common.db.encoder import ModelJSONFieldEncoder
|
||||
from common.utils import lazyproperty
|
||||
from common.utils import lazyproperty, i18n_trans
|
||||
from ops.models import JobExecution
|
||||
from orgs.mixins.models import OrgModelMixin, Organization
|
||||
from orgs.utils import current_org
|
||||
@ -155,6 +156,10 @@ class ActivityLog(OrgModelMixin):
|
||||
verbose_name = _("Activity log")
|
||||
ordering = ('-datetime',)
|
||||
|
||||
def __str__(self):
|
||||
detail = i18n_trans(self.detail)
|
||||
return "{} {}".format(detail, self.resource_id)
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
if current_org.is_root() and not self.org_id:
|
||||
self.org_id = Organization.ROOT_ID
|
||||
@ -259,7 +264,6 @@ class UserSession(models.Model):
|
||||
type = models.CharField(choices=LoginTypeChoices.choices, max_length=2, verbose_name=_("Login type"))
|
||||
backend = models.CharField(max_length=32, default="", verbose_name=_("Authentication backend"))
|
||||
date_created = models.DateTimeField(null=True, blank=True, verbose_name=_('Date created'))
|
||||
date_expired = models.DateTimeField(null=True, blank=True, verbose_name=_("Date expired"), db_index=True)
|
||||
user = models.ForeignKey(
|
||||
'users.User', verbose_name=_('User'), related_name='sessions', on_delete=models.CASCADE
|
||||
)
|
||||
@ -271,6 +275,14 @@ class UserSession(models.Model):
|
||||
def backend_display(self):
|
||||
return gettext(self.backend)
|
||||
|
||||
@property
|
||||
def date_expired(self):
|
||||
session_store_cls = import_module(settings.SESSION_ENGINE).SessionStore
|
||||
session_store = session_store_cls(session_key=self.key)
|
||||
cache_key = session_store.cache_key
|
||||
ttl = caches[settings.SESSION_CACHE_ALIAS].ttl(cache_key)
|
||||
return timezone.now() + timedelta(seconds=ttl)
|
||||
|
||||
@staticmethod
|
||||
def get_keys():
|
||||
session_store_cls = import_module(settings.SESSION_ENGINE).SessionStore
|
||||
@ -280,8 +292,8 @@ class UserSession(models.Model):
|
||||
|
||||
@classmethod
|
||||
def clear_expired_sessions(cls):
|
||||
cls.objects.filter(date_expired__lt=timezone.now()).delete()
|
||||
cls.objects.exclude(key__in=cls.get_keys()).delete()
|
||||
keys = cls.get_keys()
|
||||
cls.objects.exclude(key__in=keys).delete()
|
||||
|
||||
class Meta:
|
||||
ordering = ['-date_created']
|
||||
|
@ -169,6 +169,7 @@ class FileSerializer(serializers.Serializer):
|
||||
class UserSessionSerializer(serializers.ModelSerializer):
|
||||
type = LabeledChoiceField(choices=LoginTypeChoices.choices, label=_("Type"))
|
||||
user = ObjectRelatedField(required=False, queryset=User.objects, label=_('User'))
|
||||
date_expired = serializers.DateTimeField(format="%Y/%m/%d %H:%M:%S", label=_('Date expired'))
|
||||
is_current_user_session = serializers.SerializerMethodField()
|
||||
|
||||
class Meta:
|
||||
|
@ -70,6 +70,8 @@ class ActivityLogHandler:
|
||||
def create_activities(resource_ids, detail, detail_id, action, org_id):
|
||||
if not resource_ids:
|
||||
return
|
||||
if not org_id:
|
||||
org_id = Organization.ROOT_ID
|
||||
activities = [
|
||||
ActivityLog(
|
||||
resource_id=getattr(resource_id, 'pk', resource_id),
|
||||
@ -92,6 +94,8 @@ def after_task_publish_for_activity_log(headers=None, body=None, **kwargs):
|
||||
logger.error(f'Get celery task info error: {e}', exc_info=True)
|
||||
else:
|
||||
logger.debug(f'Create activity log for celery task: {task_id}')
|
||||
if not resource_ids:
|
||||
return
|
||||
create_activities(resource_ids, detail, task_id, action=ActivityChoices.task, org_id=org_id)
|
||||
|
||||
|
||||
@ -110,6 +114,8 @@ def on_session_or_login_log_created(sender, instance=None, created=False, **kwar
|
||||
logger.error('Activity log handler not found: {}'.format(sender))
|
||||
|
||||
resource_ids, detail, act_type, org_id = func(instance)
|
||||
if not resource_ids:
|
||||
return
|
||||
return create_activities(resource_ids, detail, instance.id, act_type, org_id)
|
||||
|
||||
|
||||
|
@ -1,7 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
from datetime import timedelta
|
||||
from importlib import import_module
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.auth import BACKEND_SESSION_KEY
|
||||
@ -11,6 +9,8 @@ from django.utils.functional import LazyObject
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework.request import Request
|
||||
|
||||
from acls.models import LoginACL
|
||||
from acls.notifications import UserLoginReminderMsg
|
||||
from audits.models import UserLoginLog
|
||||
from authentication.signals import post_auth_failed, post_auth_success
|
||||
from authentication.utils import check_different_city_login_if_need
|
||||
@ -82,10 +82,10 @@ def generate_data(username, request, login_type=None):
|
||||
|
||||
|
||||
def create_user_session(request, user_id, instance: UserLoginLog):
|
||||
# TODO 目前只记录 web 登录的 session
|
||||
if instance.type != LoginTypeChoices.web:
|
||||
return
|
||||
session_key = request.session.session_key or '-'
|
||||
session_store_cls = import_module(settings.SESSION_ENGINE).SessionStore
|
||||
session_store = session_store_cls(session_key=session_key)
|
||||
ttl = session_store.get_expiry_age()
|
||||
|
||||
online_session_data = {
|
||||
'user_id': user_id,
|
||||
@ -96,26 +96,45 @@ def create_user_session(request, user_id, instance: UserLoginLog):
|
||||
'backend': instance.backend,
|
||||
'user_agent': instance.user_agent,
|
||||
'date_created': instance.datetime,
|
||||
'date_expired': instance.datetime + timedelta(seconds=ttl),
|
||||
}
|
||||
user_session = UserSession.objects.create(**online_session_data)
|
||||
request.session['user_session_id'] = user_session.id
|
||||
request.session['user_session_id'] = str(user_session.id)
|
||||
|
||||
|
||||
def send_login_info_to_reviewers(instance: UserLoginLog | str, auth_acl_id):
|
||||
if isinstance(instance, str):
|
||||
instance = UserLoginLog.objects.filter(id=instance).first()
|
||||
|
||||
if not instance:
|
||||
return
|
||||
|
||||
acl = LoginACL.objects.filter(id=auth_acl_id).first()
|
||||
if not acl or not acl.reviewers.exists():
|
||||
return
|
||||
|
||||
reviewers = acl.reviewers.all()
|
||||
for reviewer in reviewers:
|
||||
UserLoginReminderMsg(reviewer, instance).publish_async()
|
||||
|
||||
|
||||
@receiver(post_auth_success)
|
||||
def on_user_auth_success(sender, user, request, login_type=None, **kwargs):
|
||||
logger.debug('User login success: {}'.format(user.username))
|
||||
check_different_city_login_if_need(user, request)
|
||||
data = generate_data(
|
||||
user.username, request, login_type=login_type
|
||||
)
|
||||
request.session['login_time'] = data['datetime'].strftime("%Y-%m-%d %H:%M:%S")
|
||||
data = generate_data(user.username, request, login_type=login_type)
|
||||
request.session['login_time'] = data['datetime'].strftime('%Y-%m-%d %H:%M:%S')
|
||||
data.update({'mfa': int(user.mfa_enabled), 'status': True})
|
||||
instance = write_login_log(**data)
|
||||
# TODO 目前只记录 web 登录的 session
|
||||
if instance.type != LoginTypeChoices.web:
|
||||
return
|
||||
|
||||
create_user_session(request, user.id, instance)
|
||||
request.session['user_log_id'] = str(instance.id)
|
||||
request.session['can_send_notifications'] = True
|
||||
auth_notice_required = request.session.get('auth_notice_required')
|
||||
if not auth_notice_required:
|
||||
return
|
||||
|
||||
auth_acl_id = request.session.get('auth_acl_id')
|
||||
send_login_info_to_reviewers(instance, auth_acl_id)
|
||||
|
||||
|
||||
@receiver(post_auth_failed)
|
||||
|
@ -1,20 +1,48 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
from django.utils.translation import gettext as _
|
||||
from rest_framework import serializers
|
||||
from rest_framework.response import Response
|
||||
|
||||
from common.api import JMSModelViewSet
|
||||
from rbac.permissions import RBACPermission
|
||||
from ..serializers import AccessKeySerializer
|
||||
from ..const import ConfirmType
|
||||
from ..permissions import UserConfirmation
|
||||
from ..serializers import AccessKeySerializer, AccessKeyCreateSerializer
|
||||
|
||||
|
||||
class AccessKeyViewSet(ModelViewSet):
|
||||
serializer_class = AccessKeySerializer
|
||||
search_fields = ['^id', '^secret']
|
||||
class AccessKeyViewSet(JMSModelViewSet):
|
||||
serializer_classes = {
|
||||
'default': AccessKeySerializer,
|
||||
'create': AccessKeyCreateSerializer
|
||||
}
|
||||
search_fields = ['^id']
|
||||
permission_classes = [RBACPermission]
|
||||
|
||||
def get_queryset(self):
|
||||
return self.request.user.access_keys.all()
|
||||
|
||||
def get_permissions(self):
|
||||
if self.is_swagger_request():
|
||||
return super().get_permissions()
|
||||
|
||||
if self.action == 'create':
|
||||
self.permission_classes = [
|
||||
RBACPermission, UserConfirmation.require(ConfirmType.PASSWORD)
|
||||
]
|
||||
return super().get_permissions()
|
||||
|
||||
def perform_create(self, serializer):
|
||||
user = self.request.user
|
||||
user.create_access_key()
|
||||
if user.access_keys.count() >= 10:
|
||||
raise serializers.ValidationError(_('Access keys can be created at most 10'))
|
||||
key = user.create_access_key()
|
||||
return key
|
||||
|
||||
def create(self, request, *args, **kwargs):
|
||||
serializer = self.get_serializer(data=request.data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
key = self.perform_create(serializer)
|
||||
serializer = self.get_serializer(instance=key)
|
||||
return Response(serializer.data, status=201)
|
||||
|
@ -4,27 +4,37 @@ import time
|
||||
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework import status
|
||||
from rest_framework.generics import RetrieveAPIView, CreateAPIView
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.generics import RetrieveAPIView
|
||||
from rest_framework.response import Response
|
||||
|
||||
from common.permissions import IsValidUser, UserConfirmation
|
||||
from authentication.permissions import UserConfirmation
|
||||
from common.api import JMSGenericViewSet
|
||||
from common.permissions import IsValidUser
|
||||
from ..const import ConfirmType
|
||||
from ..serializers import ConfirmSerializer
|
||||
|
||||
|
||||
class ConfirmBindORUNBindOAuth(RetrieveAPIView):
|
||||
permission_classes = (IsValidUser, UserConfirmation.require(ConfirmType.ReLogin),)
|
||||
permission_classes = (IsValidUser, UserConfirmation.require(ConfirmType.RELOGIN),)
|
||||
|
||||
def retrieve(self, request, *args, **kwargs):
|
||||
return Response('ok')
|
||||
|
||||
|
||||
class ConfirmApi(RetrieveAPIView, CreateAPIView):
|
||||
class UserConfirmationViewSet(JMSGenericViewSet):
|
||||
permission_classes = (IsValidUser,)
|
||||
serializer_class = ConfirmSerializer
|
||||
|
||||
@action(methods=['get'], detail=False)
|
||||
def check(self, request):
|
||||
confirm_type = request.query_params.get('confirm_type', 'password')
|
||||
permission = UserConfirmation.require(confirm_type)()
|
||||
permission.has_permission(request, self)
|
||||
return Response('ok')
|
||||
|
||||
def get_confirm_backend(self, confirm_type):
|
||||
backend_classes = ConfirmType.get_can_confirm_backend_classes(confirm_type)
|
||||
backend_classes = ConfirmType.get_prop_backends(confirm_type)
|
||||
if not backend_classes:
|
||||
return
|
||||
for backend_cls in backend_classes:
|
||||
@ -33,12 +43,12 @@ class ConfirmApi(RetrieveAPIView, CreateAPIView):
|
||||
continue
|
||||
return backend
|
||||
|
||||
def retrieve(self, request, *args, **kwargs):
|
||||
confirm_type = request.query_params.get('confirm_type')
|
||||
def list(self, request, *args, **kwargs):
|
||||
confirm_type = request.query_params.get('confirm_type', 'password')
|
||||
backend = self.get_confirm_backend(confirm_type)
|
||||
if backend is None:
|
||||
msg = _('This action require verify your MFA')
|
||||
return Response(data={'error': msg}, status=status.HTTP_404_NOT_FOUND)
|
||||
return Response(data={'error': msg}, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
data = {
|
||||
'confirm_type': backend.name,
|
||||
@ -51,7 +61,7 @@ class ConfirmApi(RetrieveAPIView, CreateAPIView):
|
||||
serializer.is_valid(raise_exception=True)
|
||||
validated_data = serializer.validated_data
|
||||
|
||||
confirm_type = validated_data.get('confirm_type')
|
||||
confirm_type = validated_data.get('confirm_type', 'password')
|
||||
mfa_type = validated_data.get('mfa_type')
|
||||
secret_key = validated_data.get('secret_key')
|
||||
|
||||
|
@ -15,12 +15,14 @@ from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
|
||||
from accounts.const import AliasAccount
|
||||
from acls.notifications import AssetLoginReminderMsg
|
||||
from common.api import JMSModelViewSet
|
||||
from common.exceptions import JMSException
|
||||
from common.utils import random_string, get_logger, get_request_ip
|
||||
from common.utils import random_string, get_logger, get_request_ip_or_data
|
||||
from common.utils.django import get_request_os
|
||||
from common.utils.http import is_true, is_false
|
||||
from orgs.mixins.api import RootOrgViewMixin
|
||||
from orgs.utils import tmp_to_org
|
||||
from perms.models import ActionChoices
|
||||
from terminal.connect_methods import NativeClient, ConnectMethodUtil
|
||||
from terminal.models import EndpointRule, Endpoint
|
||||
@ -298,6 +300,7 @@ class ConnectionTokenViewSet(ExtraActionApiMixin, RootOrgViewMixin, JMSModelView
|
||||
'get_rdp_file': 'authentication.add_connectiontoken',
|
||||
'get_client_protocol_url': 'authentication.add_connectiontoken',
|
||||
}
|
||||
input_username = ''
|
||||
|
||||
def get_queryset(self):
|
||||
queryset = ConnectionToken.objects \
|
||||
@ -313,21 +316,42 @@ class ConnectionTokenViewSet(ExtraActionApiMixin, RootOrgViewMixin, JMSModelView
|
||||
return super().perform_create(serializer)
|
||||
|
||||
def _insert_connect_options(self, data, user):
|
||||
name = 'file_name_conflict_resolution'
|
||||
connect_options = data.pop('connect_options', {})
|
||||
preference = Preference.objects.filter(
|
||||
name=name, user=user, category='koko'
|
||||
).first()
|
||||
value = preference.value if preference else FileNameConflictResolution.REPLACE
|
||||
connect_options[name] = value
|
||||
default_name_opts = {
|
||||
'file_name_conflict_resolution': FileNameConflictResolution.REPLACE,
|
||||
'terminal_theme_name': 'Default',
|
||||
}
|
||||
preferences_query = Preference.objects.filter(
|
||||
user=user, category='koko', name__in=default_name_opts.keys()
|
||||
).values_list('name', 'value')
|
||||
preferences = dict(preferences_query)
|
||||
for name in default_name_opts.keys():
|
||||
value = preferences.get(name, default_name_opts[name])
|
||||
connect_options[name] = value
|
||||
data['connect_options'] = connect_options
|
||||
|
||||
@staticmethod
|
||||
def get_input_username(data):
|
||||
input_username = data.get('input_username', '')
|
||||
if input_username:
|
||||
return input_username
|
||||
|
||||
account = data.get('account', '')
|
||||
if account == '@USER':
|
||||
input_username = str(data.get('user', ''))
|
||||
elif account == '@INPUT':
|
||||
input_username = '@INPUT'
|
||||
else:
|
||||
input_username = account
|
||||
return input_username
|
||||
|
||||
def validate_serializer(self, serializer):
|
||||
data = serializer.validated_data
|
||||
user = self.get_user(serializer)
|
||||
self._insert_connect_options(data, user)
|
||||
asset = data.get('asset')
|
||||
account_name = data.get('account')
|
||||
self.input_username = self.get_input_username(data)
|
||||
_data = self._validate(user, asset, account_name)
|
||||
data.update(_data)
|
||||
return serializer
|
||||
@ -374,28 +398,62 @@ class ConnectionTokenViewSet(ExtraActionApiMixin, RootOrgViewMixin, JMSModelView
|
||||
raise JMSException(code='perm_expired', detail=msg)
|
||||
return account
|
||||
|
||||
def _record_operate_log(self, acl, asset):
|
||||
from audits.handler import create_or_update_operate_log
|
||||
with tmp_to_org(asset.org_id):
|
||||
after = {
|
||||
str(_('Assets')): str(asset),
|
||||
str(_('Account')): self.input_username
|
||||
}
|
||||
object_name = acl._meta.object_name
|
||||
resource_type = acl._meta.verbose_name
|
||||
create_or_update_operate_log(
|
||||
acl.action, resource_type, resource=acl,
|
||||
after=after, object_name=object_name
|
||||
)
|
||||
|
||||
def _validate_acl(self, user, asset, account):
|
||||
from acls.models import LoginAssetACL
|
||||
acls = LoginAssetACL.filter_queryset(user=user, asset=asset, account=account)
|
||||
ip = get_request_ip(self.request)
|
||||
ip = get_request_ip_or_data(self.request)
|
||||
acl = LoginAssetACL.get_match_rule_acls(user, ip, acls)
|
||||
if not acl:
|
||||
return
|
||||
if acl.is_action(acl.ActionChoices.accept):
|
||||
self._record_operate_log(acl, asset)
|
||||
return
|
||||
if acl.is_action(acl.ActionChoices.reject):
|
||||
self._record_operate_log(acl, asset)
|
||||
msg = _('ACL action is reject: {}({})'.format(acl.name, acl.id))
|
||||
raise JMSException(code='acl_reject', detail=msg)
|
||||
if acl.is_action(acl.ActionChoices.review):
|
||||
if not self.request.query_params.get('create_ticket'):
|
||||
msg = _('ACL action is review')
|
||||
raise JMSException(code='acl_review', detail=msg)
|
||||
|
||||
self._record_operate_log(acl, asset)
|
||||
ticket = LoginAssetACL.create_login_asset_review_ticket(
|
||||
user=user, asset=asset, account_username=account.username,
|
||||
user=user, asset=asset, account_username=self.input_username,
|
||||
assignees=acl.reviewers.all(), org_id=asset.org_id
|
||||
)
|
||||
return ticket
|
||||
if acl.is_action(acl.ActionChoices.notice):
|
||||
reviewers = acl.reviewers.all()
|
||||
if not reviewers:
|
||||
return
|
||||
|
||||
self._record_operate_log(acl, asset)
|
||||
for reviewer in reviewers:
|
||||
AssetLoginReminderMsg(
|
||||
reviewer, asset, user, self.input_username
|
||||
).publish_async()
|
||||
|
||||
def create(self, request, *args, **kwargs):
|
||||
try:
|
||||
response = super().create(request, *args, **kwargs)
|
||||
except JMSException as e:
|
||||
data = {'code': e.detail.code, 'detail': e.detail}
|
||||
return Response(data, status=e.status_code)
|
||||
return response
|
||||
|
||||
|
||||
class SuperConnectionTokenViewSet(ConnectionTokenViewSet):
|
||||
|
@ -4,8 +4,9 @@ from rest_framework.views import APIView
|
||||
|
||||
from authentication import errors
|
||||
from authentication.const import ConfirmType
|
||||
from authentication.permissions import UserConfirmation
|
||||
from common.api import RoleUserMixin, RoleAdminMixin
|
||||
from common.permissions import UserConfirmation, IsValidUser
|
||||
from common.permissions import IsValidUser
|
||||
from common.utils import get_logger
|
||||
from users.models import User
|
||||
|
||||
@ -27,7 +28,7 @@ class DingTalkQRUnBindBase(APIView):
|
||||
|
||||
|
||||
class DingTalkQRUnBindForUserApi(RoleUserMixin, DingTalkQRUnBindBase):
|
||||
permission_classes = (IsValidUser, UserConfirmation.require(ConfirmType.ReLogin),)
|
||||
permission_classes = (IsValidUser, UserConfirmation.require(ConfirmType.RELOGIN),)
|
||||
|
||||
|
||||
class DingTalkQRUnBindForAdminApi(RoleAdminMixin, DingTalkQRUnBindBase):
|
||||
|
@ -4,12 +4,13 @@ from rest_framework.views import APIView
|
||||
|
||||
from authentication import errors
|
||||
from authentication.const import ConfirmType
|
||||
from authentication.permissions import UserConfirmation
|
||||
from common.api import RoleUserMixin, RoleAdminMixin
|
||||
from common.permissions import UserConfirmation, IsValidUser
|
||||
from common.permissions import IsValidUser
|
||||
from common.utils import get_logger
|
||||
from users.models import User
|
||||
|
||||
logger = get_logger(__file__)
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class FeiShuQRUnBindBase(APIView):
|
||||
@ -27,7 +28,7 @@ class FeiShuQRUnBindBase(APIView):
|
||||
|
||||
|
||||
class FeiShuQRUnBindForUserApi(RoleUserMixin, FeiShuQRUnBindBase):
|
||||
permission_classes = (IsValidUser, UserConfirmation.require(ConfirmType.ReLogin),)
|
||||
permission_classes = (IsValidUser, UserConfirmation.require(ConfirmType.RELOGIN),)
|
||||
|
||||
|
||||
class FeiShuQRUnBindForAdminApi(RoleAdminMixin, FeiShuQRUnBindBase):
|
||||
|
@ -1,3 +1,5 @@
|
||||
import time
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.http import HttpResponseRedirect
|
||||
from django.shortcuts import reverse
|
||||
@ -7,7 +9,7 @@ from rest_framework.generics import CreateAPIView
|
||||
from rest_framework.permissions import AllowAny
|
||||
from rest_framework.response import Response
|
||||
|
||||
from authentication.errors import PasswordInvalid
|
||||
from authentication.errors import PasswordInvalid, IntervalTooShort
|
||||
from authentication.mixins import AuthMixin
|
||||
from authentication.mixins import authenticate
|
||||
from authentication.serializers import (
|
||||
@ -38,18 +40,18 @@ class UserResetPasswordSendCodeApi(CreateAPIView):
|
||||
return None, err_msg
|
||||
return user, None
|
||||
|
||||
def create(self, request, *args, **kwargs):
|
||||
token = request.GET.get('token')
|
||||
userinfo = cache.get(token)
|
||||
if not userinfo:
|
||||
return HttpResponseRedirect(reverse('authentication:forgot-previewing'))
|
||||
@staticmethod
|
||||
def safe_send_code(token, code, target, form_type, content):
|
||||
token_sent_key = '{}_send_at'.format(token)
|
||||
token_send_at = cache.get(token_sent_key, 0)
|
||||
if token_send_at:
|
||||
raise IntervalTooShort(60)
|
||||
SendAndVerifyCodeUtil(target, code, backend=form_type, **content).gen_and_send_async()
|
||||
cache.set(token_sent_key, int(time.time()), 60)
|
||||
|
||||
serializer = self.get_serializer(data=request.data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
username = userinfo.get('username')
|
||||
def prepare_code_data(self, user_info, serializer):
|
||||
username = user_info.get('username')
|
||||
form_type = serializer.validated_data['form_type']
|
||||
code = random_string(6, lower=False, upper=False)
|
||||
other_args = {}
|
||||
|
||||
target = serializer.validated_data[form_type]
|
||||
if form_type == 'sms':
|
||||
@ -59,15 +61,30 @@ class UserResetPasswordSendCodeApi(CreateAPIView):
|
||||
query_key = form_type
|
||||
user, err = self.is_valid_user(username=username, **{query_key: target})
|
||||
if not user:
|
||||
return Response({'error': err}, status=400)
|
||||
raise ValueError(err)
|
||||
|
||||
code = random_string(6, lower=False, upper=False)
|
||||
subject = '%s: %s' % (get_login_title(), _('Forgot password'))
|
||||
context = {
|
||||
'user': user, 'title': subject, 'code': code,
|
||||
}
|
||||
message = render_to_string('authentication/_msg_reset_password_code.html', context)
|
||||
other_args['subject'], other_args['message'] = subject, message
|
||||
SendAndVerifyCodeUtil(target, code, backend=form_type, **other_args).gen_and_send_async()
|
||||
content = {'subject': subject, 'message': message}
|
||||
return code, target, form_type, content
|
||||
|
||||
def create(self, request, *args, **kwargs):
|
||||
token = request.GET.get('token')
|
||||
user_info = cache.get(token)
|
||||
if not user_info:
|
||||
return HttpResponseRedirect(reverse('authentication:forgot-previewing'))
|
||||
|
||||
serializer = self.get_serializer(data=request.data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
try:
|
||||
code, target, form_type, content = self.prepare_code_data(user_info, serializer)
|
||||
except ValueError as e:
|
||||
return Response({'error': str(e)}, status=400)
|
||||
self.safe_send_code(token, code, target, form_type, content)
|
||||
return Response({'data': 'ok'}, status=200)
|
||||
|
||||
|
||||
|
@ -4,8 +4,9 @@ from rest_framework.views import APIView
|
||||
|
||||
from authentication import errors
|
||||
from authentication.const import ConfirmType
|
||||
from authentication.permissions import UserConfirmation
|
||||
from common.api import RoleUserMixin, RoleAdminMixin
|
||||
from common.permissions import UserConfirmation, IsValidUser
|
||||
from common.permissions import IsValidUser
|
||||
from common.utils import get_logger
|
||||
from users.models import User
|
||||
|
||||
@ -27,7 +28,7 @@ class WeComQRUnBindBase(APIView):
|
||||
|
||||
|
||||
class WeComQRUnBindForUserApi(RoleUserMixin, WeComQRUnBindBase):
|
||||
permission_classes = (IsValidUser, UserConfirmation.require(ConfirmType.ReLogin),)
|
||||
permission_classes = (IsValidUser, UserConfirmation.require(ConfirmType.RELOGIN),)
|
||||
|
||||
|
||||
class WeComQRUnBindForAdminApi(RoleAdminMixin, WeComQRUnBindBase):
|
||||
|
@ -1,119 +1,33 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.core.cache import cache
|
||||
from django.utils import timezone
|
||||
from django.utils.translation import gettext as _
|
||||
from rest_framework import HTTP_HEADER_ENCODING
|
||||
from rest_framework import authentication, exceptions
|
||||
from six import text_type
|
||||
|
||||
from common.auth import signature
|
||||
from common.utils import get_object_or_none, make_signature, http_to_unixtime
|
||||
from .base import JMSBaseAuthBackend
|
||||
from common.utils import get_object_or_none
|
||||
from ..models import AccessKey, PrivateToken
|
||||
|
||||
UserModel = get_user_model()
|
||||
|
||||
def date_more_than(d, seconds):
|
||||
return d is None or (timezone.now() - d).seconds > seconds
|
||||
|
||||
|
||||
def get_request_date_header(request):
|
||||
date = request.META.get('HTTP_DATE', b'')
|
||||
if isinstance(date, text_type):
|
||||
# Work around django test client oddness
|
||||
date = date.encode(HTTP_HEADER_ENCODING)
|
||||
return date
|
||||
def after_authenticate_update_date(user, token=None):
|
||||
if date_more_than(user.date_api_key_last_used, 60):
|
||||
user.date_api_key_last_used = timezone.now()
|
||||
user.save(update_fields=['date_api_key_last_used'])
|
||||
|
||||
|
||||
class AccessKeyAuthentication(authentication.BaseAuthentication):
|
||||
"""App使用Access key进行签名认证, 目前签名算法比较简单,
|
||||
app注册或者手动建立后,会生成 access_key_id 和 access_key_secret,
|
||||
然后使用 如下算法生成签名:
|
||||
Signature = md5(access_key_secret + '\n' + Date)
|
||||
example: Signature = md5('d32d2b8b-9a10-4b8d-85bb-1a66976f6fdc' + '\n' +
|
||||
'Thu, 12 Jan 2017 08:19:41 GMT')
|
||||
请求时设置请求header
|
||||
header['Authorization'] = 'Sign access_key_id:Signature' 如:
|
||||
header['Authorization'] =
|
||||
'Sign d32d2b8b-9a10-4b8d-85bb-1a66976f6fdc:OKOlmdxgYPZ9+SddnUUDbQ=='
|
||||
|
||||
验证时根据相同算法进行验证, 取到access_key_id对应的access_key_id, 从request
|
||||
headers取到Date, 然后进行md5, 判断得到的结果是否相同, 如果是认证通过, 否则 认证
|
||||
失败
|
||||
"""
|
||||
keyword = 'Sign'
|
||||
|
||||
def authenticate(self, request):
|
||||
auth = authentication.get_authorization_header(request).split()
|
||||
if not auth or auth[0].lower() != self.keyword.lower().encode():
|
||||
return None
|
||||
|
||||
if len(auth) == 1:
|
||||
msg = _('Invalid signature header. No credentials provided.')
|
||||
raise exceptions.AuthenticationFailed(msg)
|
||||
elif len(auth) > 2:
|
||||
msg = _('Invalid signature header. Signature '
|
||||
'string should not contain spaces.')
|
||||
raise exceptions.AuthenticationFailed(msg)
|
||||
|
||||
try:
|
||||
sign = auth[1].decode().split(':')
|
||||
if len(sign) != 2:
|
||||
msg = _('Invalid signature header. '
|
||||
'Format like AccessKeyId:Signature')
|
||||
raise exceptions.AuthenticationFailed(msg)
|
||||
except UnicodeError:
|
||||
msg = _('Invalid signature header. '
|
||||
'Signature string should not contain invalid characters.')
|
||||
raise exceptions.AuthenticationFailed(msg)
|
||||
|
||||
access_key_id = sign[0]
|
||||
try:
|
||||
uuid.UUID(access_key_id)
|
||||
except ValueError:
|
||||
raise exceptions.AuthenticationFailed('Access key id invalid')
|
||||
request_signature = sign[1]
|
||||
|
||||
return self.authenticate_credentials(
|
||||
request, access_key_id, request_signature
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def authenticate_credentials(request, access_key_id, request_signature):
|
||||
access_key = get_object_or_none(AccessKey, id=access_key_id)
|
||||
request_date = get_request_date_header(request)
|
||||
if access_key is None or not access_key.user:
|
||||
raise exceptions.AuthenticationFailed(_('Invalid signature.'))
|
||||
access_key_secret = access_key.secret
|
||||
|
||||
try:
|
||||
request_unix_time = http_to_unixtime(request_date)
|
||||
except ValueError:
|
||||
raise exceptions.AuthenticationFailed(
|
||||
_('HTTP header: Date not provide '
|
||||
'or not %a, %d %b %Y %H:%M:%S GMT'))
|
||||
|
||||
if int(time.time()) - request_unix_time > 15 * 60:
|
||||
raise exceptions.AuthenticationFailed(
|
||||
_('Expired, more than 15 minutes'))
|
||||
|
||||
signature = make_signature(access_key_secret, request_date)
|
||||
if not signature == request_signature:
|
||||
raise exceptions.AuthenticationFailed(_('Invalid signature.'))
|
||||
|
||||
if not access_key.user.is_active:
|
||||
raise exceptions.AuthenticationFailed(_('User disabled.'))
|
||||
return access_key.user, None
|
||||
|
||||
def authenticate_header(self, request):
|
||||
return 'Sign access_key_id:Signature'
|
||||
if token and hasattr(token, 'date_last_used') and date_more_than(token.date_last_used, 60):
|
||||
token.date_last_used = timezone.now()
|
||||
token.save(update_fields=['date_last_used'])
|
||||
|
||||
|
||||
class AccessTokenAuthentication(authentication.BaseAuthentication):
|
||||
keyword = 'Bearer'
|
||||
# expiration = settings.TOKEN_EXPIRATION or 3600
|
||||
model = get_user_model()
|
||||
|
||||
def authenticate(self, request):
|
||||
@ -125,19 +39,20 @@ class AccessTokenAuthentication(authentication.BaseAuthentication):
|
||||
msg = _('Invalid token header. No credentials provided.')
|
||||
raise exceptions.AuthenticationFailed(msg)
|
||||
elif len(auth) > 2:
|
||||
msg = _('Invalid token header. Sign string '
|
||||
'should not contain spaces.')
|
||||
msg = _('Invalid token header. Sign string should not contain spaces.')
|
||||
raise exceptions.AuthenticationFailed(msg)
|
||||
|
||||
try:
|
||||
token = auth[1].decode()
|
||||
except UnicodeError:
|
||||
msg = _('Invalid token header. Sign string '
|
||||
'should not contain invalid characters.')
|
||||
msg = _('Invalid token header. Sign string should not contain invalid characters.')
|
||||
raise exceptions.AuthenticationFailed(msg)
|
||||
return self.authenticate_credentials(token)
|
||||
user, header = self.authenticate_credentials(token)
|
||||
after_authenticate_update_date(user)
|
||||
return user, header
|
||||
|
||||
def authenticate_credentials(self, token):
|
||||
@staticmethod
|
||||
def authenticate_credentials(token):
|
||||
model = get_user_model()
|
||||
user_id = cache.get(token)
|
||||
user = get_object_or_none(model, id=user_id)
|
||||
@ -151,15 +66,23 @@ class AccessTokenAuthentication(authentication.BaseAuthentication):
|
||||
return self.keyword
|
||||
|
||||
|
||||
class PrivateTokenAuthentication(JMSBaseAuthBackend, authentication.TokenAuthentication):
|
||||
class PrivateTokenAuthentication(authentication.TokenAuthentication):
|
||||
model = PrivateToken
|
||||
|
||||
def authenticate(self, request):
|
||||
user_token = super().authenticate(request)
|
||||
if not user_token:
|
||||
return
|
||||
user, token = user_token
|
||||
after_authenticate_update_date(user, token)
|
||||
return user, token
|
||||
|
||||
|
||||
class SessionAuthentication(authentication.SessionAuthentication):
|
||||
def authenticate(self, request):
|
||||
"""
|
||||
Returns a `User` if the request session currently has a logged in user.
|
||||
Otherwise returns `None`.
|
||||
Otherwise, returns `None`.
|
||||
"""
|
||||
|
||||
# Get the session-based user from the underlying HttpRequest object
|
||||
@ -195,6 +118,7 @@ class SignatureAuthentication(signature.SignatureAuthentication):
|
||||
if not key.is_active:
|
||||
return None, None
|
||||
user, secret = key.user, str(key.secret)
|
||||
after_authenticate_update_date(user, key)
|
||||
return user, secret
|
||||
except (AccessKey.DoesNotExist, exceptions.ValidationError):
|
||||
return None, None
|
||||
|
@ -166,7 +166,7 @@ class OIDCAuthCallbackView(View):
|
||||
code_verifier = request.session.get('oidc_auth_code_verifier', None)
|
||||
logger.debug(log_prompt.format('Process authenticate'))
|
||||
user = auth.authenticate(nonce=nonce, request=request, code_verifier=code_verifier)
|
||||
if user and user.is_valid:
|
||||
if user:
|
||||
logger.debug(log_prompt.format('Login: {}'.format(user)))
|
||||
auth.login(self.request, user)
|
||||
# Stores an expiration timestamp in the user's session. This value will be used if
|
||||
|
@ -4,24 +4,37 @@ from django.shortcuts import render
|
||||
from django.utils.translation import gettext as _
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.permissions import IsAuthenticated, AllowAny
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from authentication.mixins import AuthMixin
|
||||
from common.api import JMSModelViewSet
|
||||
from .fido import register_begin, register_complete, auth_begin, auth_complete
|
||||
from .models import Passkey
|
||||
from .serializer import PasskeySerializer
|
||||
from ...const import ConfirmType
|
||||
from ...permissions import UserConfirmation
|
||||
from ...views import FlashMessageMixin
|
||||
|
||||
|
||||
class PasskeyViewSet(AuthMixin, FlashMessageMixin, ModelViewSet):
|
||||
class PasskeyViewSet(AuthMixin, FlashMessageMixin, JMSModelViewSet):
|
||||
serializer_class = PasskeySerializer
|
||||
permission_classes = (IsAuthenticated,)
|
||||
|
||||
def get_permissions(self):
|
||||
if self.is_swagger_request():
|
||||
return super().get_permissions()
|
||||
if self.action == 'register':
|
||||
self.permission_classes = [
|
||||
IsAuthenticated, UserConfirmation.require(ConfirmType.PASSWORD)
|
||||
]
|
||||
return super().get_permissions()
|
||||
|
||||
def get_queryset(self):
|
||||
return Passkey.objects.filter(user=self.request.user)
|
||||
|
||||
@action(methods=['get', 'post'], detail=False, url_path='register')
|
||||
def register(self, request):
|
||||
if request.user.source != 'local':
|
||||
return JsonResponse({'error': _('Only register passkey for local user')}, status=400)
|
||||
if request.method == 'GET':
|
||||
register_data, state = register_begin(request)
|
||||
return JsonResponse(dict(register_data))
|
||||
|
@ -7,3 +7,6 @@ class PasskeyAuthBackend(JMSModelBackend):
|
||||
@staticmethod
|
||||
def is_enabled():
|
||||
return settings.AUTH_PASSKEY
|
||||
|
||||
def user_can_authenticate(self, user):
|
||||
return user.source == 'local'
|
||||
|
@ -1,8 +1,9 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.conf import settings
|
||||
from django.contrib.auth import get_user_model
|
||||
|
||||
from common.permissions import ServiceAccountSignaturePermission
|
||||
from .base import JMSBaseAuthBackend
|
||||
|
||||
UserModel = get_user_model()
|
||||
@ -18,6 +19,10 @@ class PublicKeyAuthBackend(JMSBaseAuthBackend):
|
||||
def authenticate(self, request, username=None, public_key=None, **kwargs):
|
||||
if not public_key:
|
||||
return None
|
||||
|
||||
permission = ServiceAccountSignaturePermission()
|
||||
if not permission.has_permission(request, None):
|
||||
return None
|
||||
if username is None:
|
||||
username = kwargs.get(UserModel.USERNAME_FIELD)
|
||||
try:
|
||||
@ -26,7 +31,7 @@ class PublicKeyAuthBackend(JMSBaseAuthBackend):
|
||||
return None
|
||||
else:
|
||||
if user.check_public_key(public_key) and \
|
||||
self.user_can_authenticate(user):
|
||||
self.user_can_authenticate(user):
|
||||
return user
|
||||
|
||||
def get_user(self, user_id):
|
||||
|
@ -2,7 +2,6 @@ import abc
|
||||
|
||||
|
||||
class BaseConfirm(abc.ABC):
|
||||
|
||||
def __init__(self, user, request):
|
||||
self.user = user
|
||||
self.request = request
|
||||
@ -23,7 +22,7 @@ class BaseConfirm(abc.ABC):
|
||||
|
||||
@property
|
||||
def content(self):
|
||||
return ''
|
||||
return []
|
||||
|
||||
@abc.abstractmethod
|
||||
def authenticate(self, secret_key, mfa_type) -> tuple:
|
||||
|
@ -15,3 +15,14 @@ class ConfirmPassword(BaseConfirm):
|
||||
ok = authenticate(self.request, username=self.user.username, password=secret_key)
|
||||
msg = '' if ok else _('Authentication failed password incorrect')
|
||||
return ok, msg
|
||||
|
||||
@property
|
||||
def content(self):
|
||||
return [
|
||||
{
|
||||
'name': 'password',
|
||||
'display_name': _('Password'),
|
||||
'disabled': False,
|
||||
'placeholder': _('Password'),
|
||||
}
|
||||
]
|
||||
|
@ -11,7 +11,7 @@ CONFIRM_BACKEND_MAP = {backend.name: backend for backend in CONFIRM_BACKENDS}
|
||||
|
||||
|
||||
class ConfirmType(TextChoices):
|
||||
ReLogin = ConfirmReLogin.name, ConfirmReLogin.display_name
|
||||
RELOGIN = ConfirmReLogin.name, ConfirmReLogin.display_name
|
||||
PASSWORD = ConfirmPassword.name, ConfirmPassword.display_name
|
||||
MFA = ConfirmMFA.name, ConfirmMFA.display_name
|
||||
|
||||
@ -23,10 +23,11 @@ class ConfirmType(TextChoices):
|
||||
return types
|
||||
|
||||
@classmethod
|
||||
def get_can_confirm_backend_classes(cls, confirm_type):
|
||||
def get_prop_backends(cls, confirm_type):
|
||||
types = cls.get_can_confirm_types(confirm_type)
|
||||
backend_classes = [
|
||||
CONFIRM_BACKEND_MAP[tp] for tp in types if tp in CONFIRM_BACKEND_MAP
|
||||
CONFIRM_BACKEND_MAP[tp]
|
||||
for tp in types if tp in CONFIRM_BACKEND_MAP
|
||||
]
|
||||
return backend_classes
|
||||
|
||||
|
@ -36,3 +36,11 @@ class FeiShuNotBound(JMSException):
|
||||
class PasswordInvalid(JMSException):
|
||||
default_code = 'passwd_invalid'
|
||||
default_detail = _('Your password is invalid')
|
||||
|
||||
|
||||
class IntervalTooShort(JMSException):
|
||||
default_code = 'interval_too_short'
|
||||
default_detail = _('Please wait for %s seconds before retry')
|
||||
|
||||
def __init__(self, interval, *args, **kwargs):
|
||||
super().__init__(detail=self.default_detail % interval, *args, **kwargs)
|
||||
|
@ -10,7 +10,7 @@ logger = get_logger(__file__)
|
||||
mfa_custom_method = None
|
||||
|
||||
if settings.MFA_CUSTOM:
|
||||
""" 保证自定义认证方法在服务运行时不能被更改,只在第一次调用时加载一次 """
|
||||
""" 保证自定义的方法在服务运行时不能被更改,只在第一次调用时加载一次 """
|
||||
try:
|
||||
mfa_custom_method_path = 'data.mfa.main.check_code'
|
||||
mfa_custom_method = import_string(mfa_custom_method_path)
|
||||
|
@ -8,6 +8,7 @@ from django.utils.deprecation import MiddlewareMixin
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
from apps.authentication import mixins
|
||||
from audits.signal_handlers import send_login_info_to_reviewers
|
||||
from authentication.signals import post_auth_failed
|
||||
from common.utils import gen_key_pair
|
||||
from common.utils import get_request_ip
|
||||
@ -92,12 +93,12 @@ class ThirdPartyLoginMiddleware(mixins.AuthMixin):
|
||||
'title': _('Authentication failed'),
|
||||
'message': _('Authentication failed (before login check failed): {}').format(e),
|
||||
'interval': 10,
|
||||
'redirect_url': reverse('authentication:login'),
|
||||
'redirect_url': reverse('authentication:login') + '?admin=1',
|
||||
'auto_redirect': True,
|
||||
}
|
||||
response = render(request, 'authentication/auth_fail_flash_message_standalone.html', context)
|
||||
else:
|
||||
if not self.request.session['auth_confirm_required']:
|
||||
if not self.request.session.get('auth_confirm_required'):
|
||||
return response
|
||||
guard_url = reverse('authentication:login-guard')
|
||||
args = request.META.get('QUERY_STRING', '')
|
||||
@ -105,6 +106,12 @@ class ThirdPartyLoginMiddleware(mixins.AuthMixin):
|
||||
guard_url = "%s?%s" % (guard_url, args)
|
||||
response = redirect(guard_url)
|
||||
finally:
|
||||
if request.session.get('can_send_notifications') and \
|
||||
self.request.session.get('auth_notice_required'):
|
||||
request.session['can_send_notifications'] = False
|
||||
user_log_id = self.request.session.get('user_log_id')
|
||||
auth_acl_id = self.request.session.get('auth_acl_id')
|
||||
send_login_info_to_reviewers(user_log_id, auth_acl_id)
|
||||
return response
|
||||
|
||||
|
||||
|
57
apps/authentication/migrations/0023_auto_20231010_1101.py
Normal file
57
apps/authentication/migrations/0023_auto_20231010_1101.py
Normal file
@ -0,0 +1,57 @@
|
||||
# Generated by Django 4.1.10 on 2023-10-10 02:47
|
||||
|
||||
import uuid
|
||||
import authentication.models.access_key
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
def migrate_access_key_secret(apps, schema_editor):
|
||||
access_key_model = apps.get_model('authentication', 'AccessKey')
|
||||
db_alias = schema_editor.connection.alias
|
||||
|
||||
batch_size = 100
|
||||
count = 0
|
||||
|
||||
while True:
|
||||
access_keys = access_key_model.objects.using(db_alias).all()[count:count + batch_size]
|
||||
if not access_keys:
|
||||
break
|
||||
|
||||
count += len(access_keys)
|
||||
access_keys_updated = []
|
||||
for access_key in access_keys:
|
||||
s = access_key.secret
|
||||
if len(s) != 32 or not s.islower():
|
||||
continue
|
||||
try:
|
||||
access_key.secret = '%s-%s-%s-%s-%s' % (s[:8], s[8:12], s[12:16], s[16:20], s[20:])
|
||||
access_keys_updated.append(access_key)
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
access_key_model.objects.bulk_update(access_keys_updated, fields=['secret'])
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('authentication', '0022_passkey'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name='accesskey',
|
||||
name='date_last_used',
|
||||
field=models.DateTimeField(blank=True, null=True, verbose_name='Date last used'),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name='privatetoken',
|
||||
name='date_last_used',
|
||||
field=models.DateTimeField(blank=True, null=True, verbose_name='Date last used'),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='accesskey',
|
||||
name='secret',
|
||||
field=models.CharField(default=authentication.models.access_key.default_secret, max_length=36, verbose_name='AccessKeySecret'),
|
||||
),
|
||||
migrations.RunPython(migrate_access_key_secret),
|
||||
]
|
@ -19,7 +19,7 @@ from django.utils.translation import gettext as _
|
||||
from rest_framework.request import Request
|
||||
|
||||
from acls.models import LoginACL
|
||||
from common.utils import get_request_ip, get_logger, bulk_get, FlashMessageUtil
|
||||
from common.utils import get_request_ip_or_data, get_request_ip, get_logger, bulk_get, FlashMessageUtil
|
||||
from users.models import User
|
||||
from users.utils import LoginBlockUtil, MFABlockUtils, LoginIpBlockUtil
|
||||
from . import errors
|
||||
@ -76,6 +76,12 @@ def authenticate(request=None, **credentials):
|
||||
if user is None:
|
||||
continue
|
||||
|
||||
if not user.is_valid:
|
||||
temp_user = user
|
||||
temp_user.backend = backend_path
|
||||
request.error_message = _('User is invalid')
|
||||
return temp_user
|
||||
|
||||
# 检查用户是否允许认证
|
||||
if not backend.user_allow_authenticate(user):
|
||||
temp_user = user
|
||||
@ -101,13 +107,12 @@ auth.authenticate = authenticate
|
||||
|
||||
class CommonMixin:
|
||||
request: Request
|
||||
_ip = ''
|
||||
|
||||
def get_request_ip(self):
|
||||
ip = ''
|
||||
if hasattr(self.request, 'data'):
|
||||
ip = self.request.data.get('remote_addr', '')
|
||||
ip = ip or get_request_ip(self.request)
|
||||
return ip
|
||||
if not self._ip:
|
||||
self._ip = get_request_ip_or_data(self.request)
|
||||
return self._ip
|
||||
|
||||
def raise_credential_error(self, error):
|
||||
raise self.partial_credential_error(error=error)
|
||||
@ -355,6 +360,11 @@ class AuthACLMixin:
|
||||
self.request.session['auth_acl_id'] = str(acl.id)
|
||||
return
|
||||
|
||||
if acl.is_action(acl.ActionChoices.notice):
|
||||
self.request.session['auth_notice_required'] = '1'
|
||||
self.request.session['auth_acl_id'] = str(acl.id)
|
||||
return
|
||||
|
||||
def _check_third_party_login_acl(self):
|
||||
request = self.request
|
||||
error_message = getattr(request, 'error_message', None)
|
||||
@ -513,7 +523,8 @@ class AuthMixin(CommonMixin, AuthPreCheckMixin, AuthACLMixin, MFAMixin, AuthPost
|
||||
def clear_auth_mark(self):
|
||||
keys = [
|
||||
'auth_password', 'user_id', 'auth_confirm_required',
|
||||
'auth_ticket_id', 'auth_acl_id'
|
||||
'auth_notice_required', 'auth_ticket_id', 'auth_acl_id',
|
||||
'user_session_id', 'user_log_id', 'can_send_notifications'
|
||||
]
|
||||
for k in keys:
|
||||
self.request.session.pop(k, '')
|
||||
|
@ -5,16 +5,20 @@ from django.db import models
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
import common.db.models
|
||||
from common.utils.random import random_string
|
||||
|
||||
|
||||
def default_secret():
|
||||
return random_string(36)
|
||||
|
||||
|
||||
class AccessKey(models.Model):
|
||||
id = models.UUIDField(verbose_name='AccessKeyID', primary_key=True,
|
||||
default=uuid.uuid4, editable=False)
|
||||
secret = models.UUIDField(verbose_name='AccessKeySecret',
|
||||
default=uuid.uuid4, editable=False)
|
||||
id = models.UUIDField(verbose_name='AccessKeyID', primary_key=True, default=uuid.uuid4, editable=False)
|
||||
secret = models.CharField(verbose_name='AccessKeySecret', default=default_secret, max_length=36)
|
||||
user = models.ForeignKey(settings.AUTH_USER_MODEL, verbose_name='User',
|
||||
on_delete=common.db.models.CASCADE_SIGNAL_SKIP, related_name='access_keys')
|
||||
is_active = models.BooleanField(default=True, verbose_name=_('Active'))
|
||||
date_last_used = models.DateTimeField(null=True, blank=True, verbose_name=_('Date last used'))
|
||||
date_created = models.DateTimeField(auto_now_add=True)
|
||||
|
||||
def get_id(self):
|
||||
|
@ -1,9 +1,11 @@
|
||||
from django.db import models
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework.authtoken.models import Token
|
||||
|
||||
|
||||
class PrivateToken(Token):
|
||||
"""Inherit from auth token, otherwise migration is boring"""
|
||||
date_last_used = models.DateTimeField(null=True, blank=True, verbose_name=_('Date last used'))
|
||||
|
||||
class Meta:
|
||||
verbose_name = _('Private Token')
|
||||
|
58
apps/authentication/permissions.py
Normal file
58
apps/authentication/permissions.py
Normal file
@ -0,0 +1,58 @@
|
||||
import time
|
||||
|
||||
from django.conf import settings
|
||||
from rest_framework import permissions
|
||||
|
||||
from authentication.const import ConfirmType
|
||||
from authentication.models import ConnectionToken
|
||||
from common.exceptions import UserConfirmRequired
|
||||
from common.permissions import IsValidUser
|
||||
from common.utils import get_object_or_none
|
||||
from orgs.utils import tmp_to_root_org
|
||||
|
||||
|
||||
class UserConfirmation(permissions.BasePermission):
|
||||
ttl = 60 * 5
|
||||
min_level = 1
|
||||
confirm_type = 'relogin'
|
||||
|
||||
def has_permission(self, request, view):
|
||||
if not settings.SECURITY_VIEW_AUTH_NEED_MFA:
|
||||
return True
|
||||
|
||||
confirm_level = request.session.get('CONFIRM_LEVEL')
|
||||
confirm_time = request.session.get('CONFIRM_TIME')
|
||||
ttl = self.get_ttl()
|
||||
if not confirm_level or not confirm_time or \
|
||||
confirm_level < self.min_level or \
|
||||
confirm_time < time.time() - ttl:
|
||||
raise UserConfirmRequired(code=self.confirm_type)
|
||||
return True
|
||||
|
||||
def get_ttl(self):
|
||||
if self.confirm_type == ConfirmType.MFA:
|
||||
ttl = settings.SECURITY_MFA_VERIFY_TTL
|
||||
else:
|
||||
ttl = self.ttl
|
||||
return ttl
|
||||
|
||||
@classmethod
|
||||
def require(cls, confirm_type=ConfirmType.RELOGIN, ttl=60 * 5):
|
||||
min_level = ConfirmType.values.index(confirm_type) + 1
|
||||
name = 'UserConfirmationLevel{}TTL{}'.format(min_level, ttl)
|
||||
return type(name, (cls,), {'min_level': min_level, 'ttl': ttl, 'confirm_type': confirm_type})
|
||||
|
||||
|
||||
class IsValidUserOrConnectionToken(IsValidUser):
|
||||
def has_permission(self, request, view):
|
||||
return super().has_permission(request, view) \
|
||||
or self.is_valid_connection_token(request)
|
||||
|
||||
@staticmethod
|
||||
def is_valid_connection_token(request):
|
||||
token_id = request.query_params.get('token')
|
||||
if not token_id:
|
||||
return False
|
||||
with tmp_to_root_org():
|
||||
token = get_object_or_none(ConnectionToken, id=token_id)
|
||||
return token and token.is_valid
|
@ -7,4 +7,4 @@ from ..const import ConfirmType, MFAType
|
||||
class ConfirmSerializer(serializers.Serializer):
|
||||
confirm_type = serializers.ChoiceField(required=True, allow_blank=True, choices=ConfirmType.choices)
|
||||
mfa_type = serializers.ChoiceField(required=False, allow_blank=True, choices=MFAType.choices)
|
||||
secret_key = EncryptedField(allow_blank=True)
|
||||
secret_key = EncryptedField(allow_blank=True, required=False)
|
||||
|
@ -10,16 +10,22 @@ from users.serializers import UserProfileSerializer
|
||||
from ..models import AccessKey, TempToken
|
||||
|
||||
__all__ = [
|
||||
'AccessKeySerializer', 'BearerTokenSerializer',
|
||||
'AccessKeySerializer', 'BearerTokenSerializer',
|
||||
'SSOTokenSerializer', 'TempTokenSerializer',
|
||||
'AccessKeyCreateSerializer'
|
||||
]
|
||||
|
||||
|
||||
class AccessKeySerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = AccessKey
|
||||
fields = ['id', 'secret', 'is_active', 'date_created']
|
||||
read_only_fields = ['id', 'secret', 'date_created']
|
||||
fields = ['id', 'is_active', 'date_created', 'date_last_used']
|
||||
read_only_fields = ['id', 'date_created', 'date_last_used']
|
||||
|
||||
|
||||
class AccessKeyCreateSerializer(AccessKeySerializer):
|
||||
class Meta(AccessKeySerializer.Meta):
|
||||
fields = AccessKeySerializer.Meta.fields + ['secret']
|
||||
|
||||
|
||||
class BearerTokenSerializer(serializers.Serializer):
|
||||
@ -37,7 +43,8 @@ class BearerTokenSerializer(serializers.Serializer):
|
||||
def get_keyword(obj):
|
||||
return 'Bearer'
|
||||
|
||||
def update_last_login(self, user):
|
||||
@staticmethod
|
||||
def update_last_login(user):
|
||||
user.last_login = timezone.now()
|
||||
user.save(update_fields=['last_login'])
|
||||
|
||||
@ -96,7 +103,7 @@ class TempTokenSerializer(serializers.ModelSerializer):
|
||||
username = request.user.username
|
||||
kwargs = {
|
||||
'username': username, 'secret': secret,
|
||||
'date_expired': timezone.now() + timezone.timedelta(seconds=5*60),
|
||||
'date_expired': timezone.now() + timezone.timedelta(seconds=5 * 60),
|
||||
}
|
||||
token = TempToken(**kwargs)
|
||||
token.save()
|
||||
|
0
apps/authentication/tests/__init__.py
Normal file
0
apps/authentication/tests/__init__.py
Normal file
34
apps/authentication/tests/access_key.py
Normal file
34
apps/authentication/tests/access_key.py
Normal file
@ -0,0 +1,34 @@
|
||||
# Python 示例
|
||||
# pip install requests drf-httpsig
|
||||
import datetime
|
||||
import json
|
||||
|
||||
import requests
|
||||
from httpsig.requests_auth import HTTPSignatureAuth
|
||||
|
||||
|
||||
def get_auth(KeyID, SecretID):
|
||||
signature_headers = ['(request-target)', 'accept', 'date']
|
||||
auth = HTTPSignatureAuth(key_id=KeyID, secret=SecretID, algorithm='hmac-sha256', headers=signature_headers)
|
||||
return auth
|
||||
|
||||
|
||||
def get_user_info(jms_url, auth):
|
||||
url = jms_url + '/api/v1/users/users/?limit=1'
|
||||
gmt_form = '%a, %d %b %Y %H:%M:%S GMT'
|
||||
headers = {
|
||||
'Accept': 'application/json',
|
||||
'X-JMS-ORG': '00000000-0000-0000-0000-000000000002',
|
||||
'Date': datetime.datetime.utcnow().strftime(gmt_form)
|
||||
}
|
||||
|
||||
response = requests.get(url, auth=auth, headers=headers)
|
||||
print(json.loads(response.text))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
jms_url = 'http://localhost:8080'
|
||||
KeyID = '0753098d-810c-45fb-b42c-b27077147933'
|
||||
SecretID = 'a58d2530-d7ee-4390-a204-3492e44dde84'
|
||||
auth = get_auth(KeyID, SecretID)
|
||||
get_user_info(jms_url, auth)
|
@ -13,6 +13,7 @@ router.register('sso', api.SSOViewSet, 'sso')
|
||||
router.register('temp-tokens', api.TempTokenViewSet, 'temp-token')
|
||||
router.register('connection-token', api.ConnectionTokenViewSet, 'connection-token')
|
||||
router.register('super-connection-token', api.SuperConnectionTokenViewSet, 'super-connection-token')
|
||||
router.register('confirm', api.UserConfirmationViewSet, 'confirm')
|
||||
|
||||
urlpatterns = [
|
||||
path('wecom/qr/unbind/', api.WeComQRUnBindForUserApi.as_view(), name='wecom-qr-unbind'),
|
||||
@ -29,7 +30,6 @@ urlpatterns = [
|
||||
name='feishu-event-subscription-callback'),
|
||||
|
||||
path('auth/', api.TokenCreateApi.as_view(), name='user-auth'),
|
||||
path('confirm/', api.ConfirmApi.as_view(), name='user-confirm'),
|
||||
path('confirm-oauth/', api.ConfirmBindORUNBindOAuth.as_view(), name='confirm-oauth'),
|
||||
path('tokens/', api.TokenCreateApi.as_view(), name='auth-token'),
|
||||
path('mfa/verify/', api.MFAChallengeVerifyApi.as_view(), name='mfa-verify'),
|
||||
|
@ -20,19 +20,22 @@ def check_different_city_login_if_need(user, request):
|
||||
return
|
||||
|
||||
ip = get_request_ip(request) or '0.0.0.0'
|
||||
if not (ip and validate_ip(ip)):
|
||||
city = DEFAULT_CITY
|
||||
else:
|
||||
city = get_ip_city(ip) or DEFAULT_CITY
|
||||
|
||||
city_white = [_('LAN'), 'LAN']
|
||||
is_private = ipaddress.ip_address(ip).is_private
|
||||
if not is_private:
|
||||
last_user_login = UserLoginLog.objects.exclude(city__in=city_white) \
|
||||
.filter(username=user.username, status=True).first()
|
||||
if is_private:
|
||||
return
|
||||
last_user_login = UserLoginLog.objects.exclude(
|
||||
city__in=city_white
|
||||
).filter(username=user.username, status=True).first()
|
||||
if not last_user_login:
|
||||
return
|
||||
|
||||
if last_user_login and last_user_login.city != city:
|
||||
DifferentCityLoginMessage(user, ip, city).publish_async()
|
||||
city = get_ip_city(ip)
|
||||
last_city = get_ip_city(last_user_login.ip)
|
||||
if city == last_city:
|
||||
return
|
||||
|
||||
DifferentCityLoginMessage(user, ip, city).publish_async()
|
||||
|
||||
|
||||
def build_absolute_uri(request, path=None):
|
||||
|
@ -1,6 +1,7 @@
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.auth import logout as auth_logout
|
||||
from django.db.utils import IntegrityError
|
||||
from django.http.request import HttpRequest
|
||||
from django.http.response import HttpResponseRedirect
|
||||
@ -13,7 +14,7 @@ from authentication import errors
|
||||
from authentication.const import ConfirmType
|
||||
from authentication.mixins import AuthMixin
|
||||
from authentication.notifications import OAuthBindMessage
|
||||
from common.permissions import UserConfirmation
|
||||
from authentication.permissions import UserConfirmation
|
||||
from common.sdk.im.dingtalk import URL, DingTalk
|
||||
from common.utils import get_logger
|
||||
from common.utils.common import get_request_ip
|
||||
@ -99,7 +100,7 @@ class DingTalkOAuthMixin(DingTalkBaseMixin, View):
|
||||
|
||||
|
||||
class DingTalkQRBindView(DingTalkQRMixin, View):
|
||||
permission_classes = (IsAuthenticated, UserConfirmation.require(ConfirmType.ReLogin))
|
||||
permission_classes = (IsAuthenticated, UserConfirmation.require(ConfirmType.RELOGIN))
|
||||
|
||||
def get(self, request: HttpRequest):
|
||||
user = request.user
|
||||
@ -158,6 +159,7 @@ class DingTalkQRBindCallbackView(DingTalkQRMixin, View):
|
||||
ip = get_request_ip(request)
|
||||
OAuthBindMessage(user, ip, _('DingTalk'), user_id).publish_async()
|
||||
msg = _('Binding DingTalk successfully')
|
||||
auth_logout(request)
|
||||
response = self.get_success_response(redirect_url, msg, msg)
|
||||
return response
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.auth import logout as auth_logout
|
||||
from django.db.utils import IntegrityError
|
||||
from django.http.request import HttpRequest
|
||||
from django.http.response import HttpResponseRedirect
|
||||
@ -11,7 +12,7 @@ from rest_framework.permissions import AllowAny, IsAuthenticated
|
||||
|
||||
from authentication.const import ConfirmType
|
||||
from authentication.notifications import OAuthBindMessage
|
||||
from common.permissions import UserConfirmation
|
||||
from authentication.permissions import UserConfirmation
|
||||
from common.sdk.im.feishu import URL, FeiShu
|
||||
from common.utils import get_logger
|
||||
from common.utils.common import get_request_ip
|
||||
@ -69,7 +70,7 @@ class FeiShuQRMixin(UserConfirmRequiredExceptionMixin, PermissionsMixin, FlashMe
|
||||
|
||||
|
||||
class FeiShuQRBindView(FeiShuQRMixin, View):
|
||||
permission_classes = (IsAuthenticated, UserConfirmation.require(ConfirmType.ReLogin))
|
||||
permission_classes = (IsAuthenticated, UserConfirmation.require(ConfirmType.RELOGIN))
|
||||
|
||||
def get(self, request: HttpRequest):
|
||||
redirect_url = request.GET.get('redirect_url')
|
||||
@ -121,6 +122,7 @@ class FeiShuQRBindCallbackView(FeiShuQRMixin, View):
|
||||
ip = get_request_ip(request)
|
||||
OAuthBindMessage(user, ip, _('FeiShu'), user_id).publish_async()
|
||||
msg = _('Binding FeiShu successfully')
|
||||
auth_logout(request)
|
||||
response = self.get_success_response(redirect_url, msg, msg)
|
||||
return response
|
||||
|
||||
|
@ -310,12 +310,6 @@ class UserLoginGuardView(mixins.AuthMixin, RedirectView):
|
||||
age = self.request.session.get_expiry_age()
|
||||
self.request.session.set_expiry(age)
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
response = super().get(request, *args, **kwargs)
|
||||
if request.user.is_authenticated:
|
||||
response.set_cookie('jms_username', request.user.username)
|
||||
return response
|
||||
|
||||
def get_redirect_url(self, *args, **kwargs):
|
||||
try:
|
||||
user = self.get_user_from_session()
|
||||
|
@ -1,6 +1,7 @@
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.auth import logout as auth_logout
|
||||
from django.db.utils import IntegrityError
|
||||
from django.http.request import HttpRequest
|
||||
from django.http.response import HttpResponseRedirect
|
||||
@ -13,7 +14,7 @@ from authentication import errors
|
||||
from authentication.const import ConfirmType
|
||||
from authentication.mixins import AuthMixin
|
||||
from authentication.notifications import OAuthBindMessage
|
||||
from common.permissions import UserConfirmation
|
||||
from authentication.permissions import UserConfirmation
|
||||
from common.sdk.im.wecom import URL
|
||||
from common.sdk.im.wecom import WeCom
|
||||
from common.utils import get_logger
|
||||
@ -100,7 +101,7 @@ class WeComOAuthMixin(WeComBaseMixin, View):
|
||||
|
||||
|
||||
class WeComQRBindView(WeComQRMixin, View):
|
||||
permission_classes = (IsAuthenticated, UserConfirmation.require(ConfirmType.ReLogin))
|
||||
permission_classes = (IsAuthenticated, UserConfirmation.require(ConfirmType.RELOGIN))
|
||||
|
||||
def get(self, request: HttpRequest):
|
||||
user = request.user
|
||||
@ -158,6 +159,7 @@ class WeComQRBindCallbackView(WeComQRMixin, View):
|
||||
ip = get_request_ip(request)
|
||||
OAuthBindMessage(user, ip, _('WeCom'), wecom_userid).publish_async()
|
||||
msg = _('Binding WeCom successfully')
|
||||
auth_logout(request)
|
||||
response = self.get_success_response(redirect_url, msg, msg)
|
||||
return response
|
||||
|
||||
|
@ -13,7 +13,7 @@ from common.drf.filters import (
|
||||
IDSpmFilterBackend, CustomFilterBackend, IDInFilterBackend,
|
||||
IDNotFilterBackend, NotOrRelFilterBackend
|
||||
)
|
||||
from common.utils import get_logger
|
||||
from common.utils import get_logger, lazyproperty
|
||||
from .action import RenderToJsonMixin
|
||||
from .serializer import SerializerMixin
|
||||
|
||||
@ -150,9 +150,9 @@ class OrderingFielderFieldsMixin:
|
||||
ordering_fields = None
|
||||
extra_ordering_fields = []
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.ordering_fields = self._get_ordering_fields()
|
||||
@lazyproperty
|
||||
def ordering_fields(self):
|
||||
return self._get_ordering_fields()
|
||||
|
||||
def _get_ordering_fields(self):
|
||||
if isinstance(self.__class__.ordering_fields, (list, tuple)):
|
||||
@ -179,7 +179,10 @@ class OrderingFielderFieldsMixin:
|
||||
model = self.queryset.model
|
||||
else:
|
||||
queryset = self.get_queryset()
|
||||
model = queryset.model
|
||||
if isinstance(queryset, list):
|
||||
model = None
|
||||
else:
|
||||
model = queryset.model
|
||||
|
||||
if not model:
|
||||
return []
|
||||
@ -201,4 +204,6 @@ class CommonApiMixin(
|
||||
SerializerMixin, ExtraFilterFieldsMixin, OrderingFielderFieldsMixin,
|
||||
QuerySetMixin, RenderToJsonMixin, PaginatedResponseMixin
|
||||
):
|
||||
pass
|
||||
def is_swagger_request(self):
|
||||
return getattr(self, 'swagger_fake_view', False) or \
|
||||
getattr(self, 'raw_action', '') == 'metadata'
|
||||
|
@ -1,5 +1,5 @@
|
||||
|
||||
CRONTAB_AT_AM_TWO = '0 14 * * *'
|
||||
CRONTAB_AT_AM_TWO = '0 2 * * *'
|
||||
CRONTAB_AT_AM_TEN = '0 10 * * *'
|
||||
CRONTAB_AT_PM_TWO = '0 2 * * *'
|
||||
CRONTAB_AT_PM_TWO = '0 14 * * *'
|
||||
|
||||
|
@ -1,8 +1,8 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework.exceptions import APIException
|
||||
from rest_framework import status
|
||||
from rest_framework.exceptions import APIException
|
||||
|
||||
|
||||
class JMSException(APIException):
|
||||
@ -42,8 +42,11 @@ class ReferencedByOthers(JMSException):
|
||||
|
||||
|
||||
class UserConfirmRequired(JMSException):
|
||||
status_code = status.HTTP_412_PRECONDITION_FAILED
|
||||
|
||||
def __init__(self, code=None):
|
||||
detail = {
|
||||
'type': 'user_confirm_required',
|
||||
'code': code,
|
||||
'detail': _('This action require confirm current user')
|
||||
}
|
||||
|
@ -5,12 +5,6 @@ import time
|
||||
from django.conf import settings
|
||||
from rest_framework import permissions
|
||||
|
||||
from authentication.const import ConfirmType
|
||||
from authentication.models import ConnectionToken
|
||||
from common.exceptions import UserConfirmRequired
|
||||
from common.utils import get_object_or_none
|
||||
from orgs.utils import tmp_to_root_org
|
||||
|
||||
|
||||
class IsValidUser(permissions.IsAuthenticated):
|
||||
"""Allows access to valid user, is active and not expired"""
|
||||
@ -20,21 +14,6 @@ class IsValidUser(permissions.IsAuthenticated):
|
||||
and request.user.is_valid
|
||||
|
||||
|
||||
class IsValidUserOrConnectionToken(IsValidUser):
|
||||
def has_permission(self, request, view):
|
||||
return super().has_permission(request, view) \
|
||||
or self.is_valid_connection_token(request)
|
||||
|
||||
@staticmethod
|
||||
def is_valid_connection_token(request):
|
||||
token_id = request.query_params.get('token')
|
||||
if not token_id:
|
||||
return False
|
||||
with tmp_to_root_org():
|
||||
token = get_object_or_none(ConnectionToken, id=token_id)
|
||||
return token and token.is_valid
|
||||
|
||||
|
||||
class OnlySuperUser(IsValidUser):
|
||||
def has_permission(self, request, view):
|
||||
return super().has_permission(request, view) \
|
||||
@ -56,33 +35,36 @@ class WithBootstrapToken(permissions.BasePermission):
|
||||
return settings.BOOTSTRAP_TOKEN == request_bootstrap_token
|
||||
|
||||
|
||||
class UserConfirmation(permissions.BasePermission):
|
||||
ttl = 60 * 5
|
||||
min_level = 1
|
||||
confirm_type = ConfirmType.ReLogin
|
||||
|
||||
class ServiceAccountSignaturePermission(permissions.BasePermission):
|
||||
def has_permission(self, request, view):
|
||||
if not settings.SECURITY_VIEW_AUTH_NEED_MFA:
|
||||
from authentication.models import AccessKey
|
||||
from common.utils.crypto import get_aes_crypto
|
||||
signature = request.META.get('HTTP_X_JMS_SVC', '')
|
||||
if not signature or not signature.startswith('Sign'):
|
||||
return False
|
||||
data = signature[4:].strip()
|
||||
if not data or ':' not in data:
|
||||
return False
|
||||
ak_id, time_sign = data.split(':', 1)
|
||||
if not ak_id or not time_sign:
|
||||
return False
|
||||
ak = AccessKey.objects.filter(id=ak_id).first()
|
||||
if not ak or not ak.is_active:
|
||||
return False
|
||||
if not ak.user or not ak.user.is_active or not ak.user.is_service_account:
|
||||
return False
|
||||
aes = get_aes_crypto(str(ak.secret).replace('-', ''), mode='ECB')
|
||||
try:
|
||||
timestamp = aes.decrypt(time_sign)
|
||||
if not timestamp or not timestamp.isdigit():
|
||||
return False
|
||||
timestamp = int(timestamp)
|
||||
interval = abs(int(time.time()) - timestamp)
|
||||
if interval > 30:
|
||||
return False
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
confirm_level = request.session.get('CONFIRM_LEVEL')
|
||||
confirm_time = request.session.get('CONFIRM_TIME')
|
||||
ttl = self.get_ttl()
|
||||
if not confirm_level or not confirm_time or \
|
||||
confirm_level < self.min_level or \
|
||||
confirm_time < time.time() - ttl:
|
||||
raise UserConfirmRequired(code=self.confirm_type)
|
||||
return True
|
||||
|
||||
def get_ttl(self):
|
||||
if self.confirm_type == ConfirmType.MFA:
|
||||
ttl = settings.SECURITY_MFA_VERIFY_TTL
|
||||
else:
|
||||
ttl = self.ttl
|
||||
return ttl
|
||||
|
||||
@classmethod
|
||||
def require(cls, confirm_type=ConfirmType.ReLogin, ttl=60 * 5):
|
||||
min_level = ConfirmType.values.index(confirm_type) + 1
|
||||
name = 'UserConfirmationLevel{}TTL{}'.format(min_level, ttl)
|
||||
return type(name, (cls,), {'min_level': min_level, 'ttl': ttl, 'confirm_type': confirm_type})
|
||||
def has_object_permission(self, request, view, obj):
|
||||
return False
|
||||
|
@ -30,7 +30,7 @@ class CustomSMS(BaseSMSClient):
|
||||
code=template_param.get('code'), phone_numbers=phone_numbers_str
|
||||
)
|
||||
|
||||
logger.info(f'Custom sms send: phone_numbers={phone_numbers}param={params}')
|
||||
logger.info(f'Custom sms send: phone_numbers={phone_numbers}, param={params}')
|
||||
if settings.CUSTOM_SMS_REQUEST_METHOD == 'post':
|
||||
action = requests.post
|
||||
kwargs = {'json': params}
|
||||
|
50
apps/common/sdk/sms/custom_file.py
Normal file
50
apps/common/sdk/sms/custom_file.py
Normal file
@ -0,0 +1,50 @@
|
||||
import os
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
from django.conf import settings
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django.utils.module_loading import import_string
|
||||
|
||||
from common.utils import get_logger
|
||||
from common.exceptions import JMSException
|
||||
from jumpserver.settings import get_file_md5
|
||||
|
||||
from .base import BaseSMSClient
|
||||
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
custom_sms_method = None
|
||||
SMS_CUSTOM_FILE_MD5 = settings.SMS_CUSTOM_FILE_MD5
|
||||
SMS_CUSTOM_FILE_PATH = os.path.join(settings.PROJECT_DIR, 'data', 'sms', 'main.py')
|
||||
if SMS_CUSTOM_FILE_MD5 == get_file_md5(SMS_CUSTOM_FILE_PATH):
|
||||
try:
|
||||
custom_sms_method_path = 'data.sms.main.send_sms'
|
||||
custom_sms_method = import_string(custom_sms_method_path)
|
||||
except Exception as e:
|
||||
logger.warning('Import custom sms method failed: {}, Maybe not enabled'.format(e))
|
||||
|
||||
|
||||
class CustomFileSMS(BaseSMSClient):
|
||||
@classmethod
|
||||
def new_from_settings(cls):
|
||||
return cls()
|
||||
|
||||
@staticmethod
|
||||
def need_pre_check():
|
||||
return False
|
||||
|
||||
def send_sms(self, phone_numbers: list, template_param: OrderedDict, **kwargs):
|
||||
if not callable(custom_sms_method):
|
||||
raise JMSException(_('The custom sms file is invalid'))
|
||||
|
||||
try:
|
||||
logger.info(f'Custom file sms send: phone_numbers={phone_numbers}, param={template_param}')
|
||||
custom_sms_method(phone_numbers, template_param, **kwargs)
|
||||
except Exception as err:
|
||||
raise JMSException(_('SMS sending failed[%s]: %s') % (f"{_('Custom type')}({_('File')})", err))
|
||||
|
||||
|
||||
client = CustomFileSMS
|
@ -17,7 +17,8 @@ class BACKENDS(TextChoices):
|
||||
TENCENT = 'tencent', _('Tencent cloud')
|
||||
HUAWEI = 'huawei', _('Huawei Cloud')
|
||||
CMPP2 = 'cmpp2', _('CMPP v2.0')
|
||||
Custom = 'custom', _('Custom type')
|
||||
CUSTOM = 'custom', _('Custom type')
|
||||
CUSTOM_FILE = 'custom_file', f"{_('Custom type')}({_('File')})"
|
||||
|
||||
|
||||
class SMS:
|
||||
|
@ -218,12 +218,13 @@ class PhoneField(serializers.CharField):
|
||||
code = data.get('code')
|
||||
phone = data.get('phone', '')
|
||||
if code and phone:
|
||||
data = '{}{}'.format(code, phone)
|
||||
code = code.replace('+', '')
|
||||
data = '+{}{}'.format(code, phone)
|
||||
else:
|
||||
data = phone
|
||||
try:
|
||||
phone = phonenumbers.parse(data, 'CN')
|
||||
data = '{}{}'.format(phone.country_code, phone.national_number)
|
||||
data = '+{}{}'.format(phone.country_code, phone.national_number)
|
||||
except phonenumbers.NumberParseException:
|
||||
data = '+86{}'.format(data)
|
||||
|
||||
|
@ -36,8 +36,8 @@ def send_mail_async(*args, **kwargs):
|
||||
args[0] = (settings.EMAIL_SUBJECT_PREFIX or '') + args[0]
|
||||
from_email = settings.EMAIL_FROM or settings.EMAIL_HOST_USER
|
||||
args.insert(2, from_email)
|
||||
args = tuple(args)
|
||||
|
||||
args = tuple(args)
|
||||
try:
|
||||
return send_mail(*args, **kwargs)
|
||||
except Exception as e:
|
||||
|
@ -17,6 +17,8 @@ import psutil
|
||||
from django.conf import settings
|
||||
from django.templatetags.static import static
|
||||
|
||||
from common.permissions import ServiceAccountSignaturePermission
|
||||
|
||||
UUID_PATTERN = re.compile(r'\w{8}(-\w{4}){3}-\w{12}')
|
||||
ipip_db = None
|
||||
|
||||
@ -153,19 +155,26 @@ def is_uuid(seq):
|
||||
|
||||
|
||||
def get_request_ip(request):
|
||||
x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR', '').split(',')
|
||||
x_real_ip = request.META.get('HTTP_X_REAL_IP', '')
|
||||
if x_real_ip:
|
||||
return x_real_ip
|
||||
|
||||
x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR', '').split(',')
|
||||
if x_forwarded_for and x_forwarded_for[0]:
|
||||
login_ip = x_forwarded_for[0]
|
||||
else:
|
||||
login_ip = request.META.get('REMOTE_ADDR', '')
|
||||
return login_ip
|
||||
|
||||
login_ip = request.META.get('REMOTE_ADDR', '')
|
||||
return login_ip
|
||||
|
||||
|
||||
def get_request_ip_or_data(request):
|
||||
ip = ''
|
||||
if hasattr(request, 'data'):
|
||||
ip = request.data.get('remote_addr', '')
|
||||
|
||||
if hasattr(request, 'data') and request.data.get('remote_addr', ''):
|
||||
permission = ServiceAccountSignaturePermission()
|
||||
if permission.has_permission(request, None):
|
||||
ip = request.data.get('remote_addr', '')
|
||||
ip = ip or get_request_ip(request)
|
||||
return ip
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
import random
|
||||
import secrets
|
||||
import socket
|
||||
import string
|
||||
import struct
|
||||
@ -17,32 +18,37 @@ def random_ip():
|
||||
return socket.inet_ntoa(struct.pack('>I', random.randint(1, 0xffffffff)))
|
||||
|
||||
|
||||
def random_replace_char(s, chars, length):
|
||||
using_index = set()
|
||||
seq = list(s)
|
||||
|
||||
while length > 0:
|
||||
index = secrets.randbelow(len(seq) - 1)
|
||||
if index in using_index or index == 0:
|
||||
continue
|
||||
seq[index] = secrets.choice(chars)
|
||||
using_index.add(index)
|
||||
length -= 1
|
||||
return ''.join(seq)
|
||||
|
||||
|
||||
def random_string(length: int, lower=True, upper=True, digit=True, special_char=False, symbols=string_punctuation):
|
||||
random.seed()
|
||||
args_names = ['lower', 'upper', 'digit']
|
||||
args_values = [lower, upper, digit]
|
||||
args_string = [string.ascii_lowercase, string.ascii_uppercase, string.digits]
|
||||
args_string_map = dict(zip(args_names, args_string))
|
||||
kwargs = dict(zip(args_names, args_values))
|
||||
kwargs_keys = list(kwargs.keys())
|
||||
kwargs_values = list(kwargs.values())
|
||||
args_true_count = len([i for i in kwargs_values if i])
|
||||
if not any([lower, upper, digit]):
|
||||
raise ValueError('At least one of `lower`, `upper`, `digit` must be `True`')
|
||||
if length < 4:
|
||||
raise ValueError('The length of the string must be greater than 3')
|
||||
|
||||
assert any(kwargs_values), f'Parameters {kwargs_keys} must have at least one `True`'
|
||||
assert length >= args_true_count, f'Expected length >= {args_true_count}, bug got {length}'
|
||||
|
||||
chars = ''.join([args_string_map[k] for k, v in kwargs.items() if v])
|
||||
password = list(random.choice(chars) for i in range(length))
|
||||
chars_map = (
|
||||
(lower, string.ascii_lowercase),
|
||||
(upper, string.ascii_uppercase),
|
||||
(digit, string.digits),
|
||||
)
|
||||
chars = ''.join([i[1] for i in chars_map if i[0]])
|
||||
texts = list(secrets.choice(chars) for __ in range(length))
|
||||
texts = ''.join(texts)
|
||||
|
||||
# 控制一下特殊字符的数量, 别随机出来太多
|
||||
if special_char:
|
||||
special_num = length // 16 + 1
|
||||
special_index = []
|
||||
for i in range(special_num):
|
||||
index = random.randint(1, length - 1)
|
||||
if index not in special_index:
|
||||
special_index.append(index)
|
||||
for i in special_index:
|
||||
password[i] = random.choice(symbols)
|
||||
|
||||
password = ''.join(password)
|
||||
return password
|
||||
symbol_num = length // 16 + 1
|
||||
texts = random_replace_char(texts, symbols, symbol_num)
|
||||
return texts
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user