diff --git a/.isort.cfg b/.isort.cfg new file mode 100644 index 000000000..e59d309dc --- /dev/null +++ b/.isort.cfg @@ -0,0 +1,3 @@ +[settings] +line_length=120 +known_first_party=common,users,assets,perms,authentication,jumpserver,notification,ops,orgs,rbac,settings,terminal,tickets diff --git a/apps/assets/api/asset/asset.py b/apps/assets/api/asset/asset.py index 4e1176d17..ad04966e3 100644 --- a/apps/assets/api/asset/asset.py +++ b/apps/assets/api/asset/asset.py @@ -1,89 +1,91 @@ # -*- coding: utf-8 -*- # + import django_filters from rest_framework.decorators import action from rest_framework.response import Response -from common.utils import get_logger -from common.drf.filters import BaseFilterSet -from common.mixins.api import SuggestionMixin -from orgs.mixins.api import OrgBulkModelViewSet -from orgs.mixins import generics from assets import serializers +from assets.filters import IpInFilterBackend, LabelFilterBackend, NodeFilterBackend from assets.models import Asset, Gateway from assets.tasks import ( push_accounts_to_assets, - verify_accounts_connectivity, test_assets_connectivity_manual, update_assets_hardware_info_manual, + verify_accounts_connectivity, ) -from assets.filters import NodeFilterBackend, LabelFilterBackend, IpInFilterBackend +from common.drf.filters import BaseFilterSet +from common.mixins.api import SuggestionMixin +from common.utils import get_logger +from orgs.mixins import generics +from orgs.mixins.api import OrgBulkModelViewSet from ..mixin import NodeFilterMixin logger = get_logger(__file__) __all__ = [ - 'AssetViewSet', 'AssetTaskCreateApi', 'AssetsTaskCreateApi', + "AssetViewSet", + "AssetTaskCreateApi", + "AssetsTaskCreateApi", ] class AssetFilterSet(BaseFilterSet): - type = django_filters.CharFilter(field_name='platform__type', lookup_expr='exact') - category = django_filters.CharFilter(field_name='platform__category', lookup_expr='exact') - hostname = django_filters.CharFilter(field_name='name', lookup_expr='exact') + type = django_filters.CharFilter(field_name="platform__type", lookup_expr="exact") + category = django_filters.CharFilter( + field_name="platform__category", lookup_expr="exact" + ) + hostname = django_filters.CharFilter(field_name="name", lookup_expr="exact") class Meta: model = Asset - fields = ['name', 'address', 'is_active', 'type', 'category', 'hostname'] + fields = ["name", "address", "is_active", "type", "category", "hostname"] class AssetViewSet(SuggestionMixin, NodeFilterMixin, OrgBulkModelViewSet): """ API endpoint that allows Asset to be viewed or edited. """ + model = Asset filterset_class = AssetFilterSet search_fields = ("name", "address") ordering_fields = ("name", "address") - ordering = ('name',) + ordering = ("name",) serializer_classes = ( - ('default', serializers.AssetSerializer), - ('suggestion', serializers.MiniAssetSerializer), - ('platform', serializers.PlatformSerializer), - ('gateways', serializers.GatewayWithAuthSerializer) + ("default", serializers.AssetSerializer), + ("suggestion", serializers.MiniAssetSerializer), + ("platform", serializers.PlatformSerializer), + ("gateways", serializers.GatewayWithAuthSerializer), ) rbac_perms = ( - ('match', 'assets.match_asset'), - ('platform', 'assets.view_platform'), - ('gateways', 'assets.view_gateway') + ("match", "assets.match_asset"), + ("platform", "assets.view_platform"), + ("gateways", "assets.view_gateway"), ) - extra_filter_backends = [ - LabelFilterBackend, - IpInFilterBackend, - NodeFilterBackend - ] + extra_filter_backends = [LabelFilterBackend, IpInFilterBackend, NodeFilterBackend] - @action(methods=['GET'], detail=True, url_path='platform') + @action(methods=["GET"], detail=True, url_path="platform") def platform(self, *args, **kwargs): asset = self.get_object() serializer = self.get_serializer(asset.platform) return Response(serializer.data) - @action(methods=['GET'], detail=True, url_path='gateways') + @action(methods=["GET"], detail=True, url_path="gateways") def gateways(self, *args, **kwargs): asset = self.get_object() if not asset.domain: gateways = Gateway.objects.none() else: - gateways = asset.domain.gateways.filter(protocol='ssh') + gateways = asset.domain.gateways.filter(protocol="ssh") return self.get_paginated_response_from_queryset(gateways) class AssetsTaskMixin: def perform_assets_task(self, serializer): data = serializer.validated_data - assets = data.get('assets', []) + assets = data.get("assets", []) asset_ids = [asset.id for asset in assets] - if data['action'] == "refresh": + if data["action"] == "refresh": task = update_assets_hardware_info_manual.delay(asset_ids) else: task = test_assets_connectivity_manual.delay(asset_ids) @@ -94,9 +96,9 @@ class AssetsTaskMixin: self.set_task_to_serializer_data(serializer, task) def set_task_to_serializer_data(self, serializer, task): - data = getattr(serializer, '_data', {}) + data = getattr(serializer, "_data", {}) data["task"] = task.id - setattr(serializer, '_data', data) + setattr(serializer, "_data", data) class AssetTaskCreateApi(AssetsTaskMixin, generics.CreateAPIView): @@ -104,18 +106,18 @@ class AssetTaskCreateApi(AssetsTaskMixin, generics.CreateAPIView): serializer_class = serializers.AssetTaskSerializer def create(self, request, *args, **kwargs): - pk = self.kwargs.get('pk') - request.data['asset'] = pk - request.data['assets'] = [pk] + pk = self.kwargs.get("pk") + request.data["asset"] = pk + request.data["assets"] = [pk] return super().create(request, *args, **kwargs) def check_permissions(self, request): - action = request.data.get('action') + action = request.data.get("action") action_perm_require = { - 'refresh': 'assets.refresh_assethardwareinfo', - 'push_account': 'assets.push_assetsystemuser', - 'test': 'assets.test_assetconnectivity', - 'test_account': 'assets.test_assetconnectivity' + "refresh": "assets.refresh_assethardwareinfo", + "push_account": "assets.push_assetsystemuser", + "test": "assets.test_assetconnectivity", + "test_account": "assets.test_assetconnectivity", } perm_required = action_perm_require.get(action) has = self.request.user.has_perm(perm_required) @@ -126,19 +128,19 @@ class AssetTaskCreateApi(AssetsTaskMixin, generics.CreateAPIView): @staticmethod def perform_asset_task(serializer): data = serializer.validated_data - if data['action'] not in ['push_system_user', 'test_system_user']: + if data["action"] not in ["push_system_user", "test_system_user"]: return - asset = data['asset'] - accounts = data.get('accounts') + asset = data["asset"] + accounts = data.get("accounts") if not accounts: accounts = asset.accounts.all() asset_ids = [asset.id] - account_ids = accounts.values_list('id', flat=True) - if action == 'push_account': + account_ids = accounts.values_list("id", flat=True) + if action == "push_account": task = push_accounts_to_assets.delay(account_ids, asset_ids) - elif action == 'test_account': + elif action == "test_account": task = verify_accounts_connectivity.delay(account_ids, asset_ids) else: task = None @@ -156,9 +158,9 @@ class AssetsTaskCreateApi(AssetsTaskMixin, generics.CreateAPIView): serializer_class = serializers.AssetsTaskSerializer def check_permissions(self, request): - action = request.data.get('action') + action = request.data.get("action") action_perm_require = { - 'refresh': 'assets.refresh_assethardwareinfo', + "refresh": "assets.refresh_assethardwareinfo", } perm_required = action_perm_require.get(action) has = self.request.user.has_perm(perm_required) diff --git a/apps/assets/filters.py b/apps/assets/filters.py index de2550ceb..f1b869805 100644 --- a/apps/assets/filters.py +++ b/apps/assets/filters.py @@ -1,13 +1,14 @@ # -*- coding: utf-8 -*- # from django.db.models import Q +from django_filters import rest_framework as drf_filters from rest_framework import filters from rest_framework.compat import coreapi, coreschema -from django_filters import rest_framework as drf_filters +from assets.utils import get_node_from_request, is_query_node_all_assets from common.drf.filters import BaseFilterSet -from assets.utils import is_query_node_all_assets, get_node_from_request -from .models import Label, Node, Account + +from .models import Account, Label, Node class AssetByNodeFilterBackend(filters.BaseFilterBackend): diff --git a/apps/assets/models/account.py b/apps/assets/models/account.py index 9aa007e53..cad5f9ded 100644 --- a/apps/assets/models/account.py +++ b/apps/assets/models/account.py @@ -3,7 +3,8 @@ from django.utils.translation import gettext_lazy as _ from simple_history.models import HistoricalRecords from common.utils import lazyproperty -from .base import BaseAccount, AbsConnectivity + +from .base import AbsConnectivity, BaseAccount __all__ = ['Account', 'AccountTemplate'] @@ -40,9 +41,10 @@ class AccountHistoricalRecords(HistoricalRecords): class Account(AbsConnectivity, BaseAccount): - class InnerAccount(models.TextChoices): - INPUT = '@INPUT', '@INPUT' - USER = '@USER', '@USER' + class AliasAccount(models.TextChoices): + ALL = '@ALL', _('All') + INPUT = '@INPUT', _('Manual input') + USER = '@USER', _('Dynamic user') asset = models.ForeignKey( 'assets.Asset', related_name='accounts', @@ -76,14 +78,14 @@ class Account(AbsConnectivity, BaseAccount): return '{}'.format(self.username) @classmethod - def get_input_account(cls): + def get_manual_account(cls): """ @INPUT 手动登录的账号(any) """ - return cls(name=cls.InnerAccount.INPUT.value, username='') + return cls(name=cls.AliasAccount.INPUT.label, username=cls.AliasAccount.INPUT.value, secret=None) @classmethod def get_user_account(cls, username): """ @USER 动态用户的账号(self) """ - return cls(name=cls.InnerAccount.USER.value, username=username) + return cls(name=cls.AliasAccount.USER.label, username=cls.AliasAccount.USER.value) class AccountTemplate(BaseAccount): diff --git a/apps/assets/serializers/platform.py b/apps/assets/serializers/platform.py index 3bc02732f..8f8dcb5a3 100644 --- a/apps/assets/serializers/platform.py +++ b/apps/assets/serializers/platform.py @@ -1,61 +1,75 @@ -from rest_framework import serializers from django.utils.translation import gettext_lazy as _ +from rest_framework import serializers from common.drf.fields import LabeledChoiceField from common.drf.serializers import WritableNestedModelSerializer -from ..models import Platform, PlatformProtocol, PlatformAutomation from ..const import Category, AllTypes +from ..models import Platform, PlatformProtocol, PlatformAutomation -__all__ = ['PlatformSerializer', 'PlatformOpsMethodSerializer'] +__all__ = ["PlatformSerializer", "PlatformOpsMethodSerializer"] class ProtocolSettingSerializer(serializers.Serializer): SECURITY_CHOICES = [ - ('any', 'Any'), - ('rdp', 'RDP'), - ('tls', 'TLS'), - ('nla', 'NLA'), + ("any", "Any"), + ("rdp", "RDP"), + ("tls", "TLS"), + ("nla", "NLA"), ] # RDP console = serializers.BooleanField(required=False) - security = serializers.ChoiceField(choices=SECURITY_CHOICES, default='any') + security = serializers.ChoiceField(choices=SECURITY_CHOICES, default="any") # SFTP sftp_enabled = serializers.BooleanField(default=True, label=_("SFTP enabled")) - sftp_home = serializers.CharField(default='/tmp', label=_("SFTP home")) + sftp_home = serializers.CharField(default="/tmp", label=_("SFTP home")) # HTTP auto_fill = serializers.BooleanField(default=False, label=_("Auto fill")) - username_selector = serializers.CharField(default='', allow_blank=True, label=_("Username selector")) - password_selector = serializers.CharField(default='', allow_blank=True, label=_("Password selector")) - submit_selector = serializers.CharField(default='', allow_blank=True, label=_("Submit selector")) + username_selector = serializers.CharField( + default="", allow_blank=True, label=_("Username selector") + ) + password_selector = serializers.CharField( + default="", allow_blank=True, label=_("Password selector") + ) + submit_selector = serializers.CharField( + default="", allow_blank=True, label=_("Submit selector") + ) class PlatformAutomationSerializer(serializers.ModelSerializer): class Meta: model = PlatformAutomation fields = [ - 'id', 'ansible_enabled', 'ansible_config', - 'ping_enabled', 'ping_method', - 'gather_facts_enabled', 'gather_facts_method', - 'push_account_enabled', 'push_account_method', - 'change_secret_enabled', 'change_secret_method', - 'verify_account_enabled', 'verify_account_method', - 'gather_accounts_enabled', 'gather_accounts_method', + "id", + "ansible_enabled", + "ansible_config", + "ping_enabled", + "ping_method", + "gather_facts_enabled", + "gather_facts_method", + "push_account_enabled", + "push_account_method", + "change_secret_enabled", + "change_secret_method", + "verify_account_enabled", + "verify_account_method", + "gather_accounts_enabled", + "gather_accounts_method", ] extra_kwargs = { - 'ping_enabled': {'label': '启用资产探测'}, - 'ping_method': {'label': '探测方式'}, - 'gather_facts_enabled': {'label': '启用收集信息'}, - 'gather_facts_method': {'label': '收集信息方式'}, - 'verify_account_enabled': {'label': '启用校验账号'}, - 'verify_account_method': {'label': '校验账号方式'}, - 'push_account_enabled': {'label': '启用推送账号'}, - 'push_account_method': {'label': '推送账号方式'}, - 'change_secret_enabled': {'label': '启用账号改密'}, - 'change_secret_method': {'label': '账号创建改密方式'}, - 'gather_accounts_enabled': {'label': '启用账号收集'}, - 'gather_accounts_method': {'label': '收集账号方式'}, + "ping_enabled": {"label": "启用资产探测"}, + "ping_method": {"label": "探测方式"}, + "gather_facts_enabled": {"label": "启用收集信息"}, + "gather_facts_method": {"label": "收集信息方式"}, + "verify_account_enabled": {"label": "启用校验账号"}, + "verify_account_method": {"label": "校验账号方式"}, + "push_account_enabled": {"label": "启用推送账号"}, + "push_account_method": {"label": "推送账号方式"}, + "change_secret_enabled": {"label": "启用账号改密"}, + "change_secret_method": {"label": "账号创建改密方式"}, + "gather_accounts_enabled": {"label": "启用账号收集"}, + "gather_accounts_method": {"label": "收集账号方式"}, } @@ -66,42 +80,62 @@ class PlatformProtocolsSerializer(serializers.ModelSerializer): class Meta: model = PlatformProtocol fields = [ - 'id', 'name', 'port', 'primary', 'default', - 'required', 'secret_types', 'setting', + "id", + "name", + "port", + "primary", + "default", + "required", + "secret_types", + "setting", ] class PlatformSerializer(WritableNestedModelSerializer): + charset = LabeledChoiceField( + choices=Platform.CharsetChoices.choices, label=_("Charset") + ) type = LabeledChoiceField(choices=AllTypes.choices(), label=_("Type")) category = LabeledChoiceField(choices=Category.choices, label=_("Category")) - protocols = PlatformProtocolsSerializer(label=_('Protocols'), many=True, required=False) - automation = PlatformAutomationSerializer(label=_('Automation'), required=False) + protocols = PlatformProtocolsSerializer( + label=_("Protocols"), many=True, required=False + ) + automation = PlatformAutomationSerializer(label=_("Automation"), required=False) su_method = LabeledChoiceField( - choices=[('sudo', 'sudo su -'), ('su', 'su - ')], - label='切换方式', required=False, default='sudo' + choices=[("sudo", "sudo su -"), ("su", "su - ")], + label="切换方式", + required=False, + default="sudo", ) class Meta: model = Platform - fields_mini = ['id', 'name', 'internal'] + fields_mini = ["id", "name", "internal"] fields_small = fields_mini + [ - 'category', 'type', 'charset', + "category", + "type", + "charset", ] fields = fields_small + [ - 'protocols_enabled', 'protocols', 'domain_enabled', - 'su_enabled', 'su_method', 'automation', 'comment', + "protocols_enabled", + "protocols", + "domain_enabled", + "su_enabled", + "su_method", + "automation", + "comment", ] extra_kwargs = { - 'su_enabled': {'label': '启用切换账号'}, - 'protocols_enabled': {'label': '启用协议'}, - 'domain_enabled': {'label': "启用网域"}, - 'domain_default': {'label': "默认网域"}, + "su_enabled": {"label": "启用切换账号"}, + "protocols_enabled": {"label": "启用协议"}, + "domain_enabled": {"label": "启用网域"}, + "domain_default": {"label": "默认网域"}, } class PlatformOpsMethodSerializer(serializers.Serializer): id = serializers.CharField(read_only=True) - name = serializers.CharField(max_length=50, label=_('Name')) - category = serializers.CharField(max_length=50, label=_('Category')) + name = serializers.CharField(max_length=50, label=_("Name")) + category = serializers.CharField(max_length=50, label=_("Category")) type = serializers.ListSerializer(child=serializers.CharField()) method = serializers.CharField() diff --git a/apps/authentication/api/connection_token.py b/apps/authentication/api/connection_token.py index 0c04531d5..0839229b8 100644 --- a/apps/authentication/api/connection_token.py +++ b/apps/authentication/api/connection_token.py @@ -16,7 +16,7 @@ from rest_framework.request import Request from common.drf.api import JMSModelViewSet from common.http import is_true from orgs.mixins.api import RootOrgViewMixin -from perms.models import Action +from perms.models import ActionChoices from terminal.models import EndpointRule from ..serializers import ( ConnectionTokenSerializer, ConnectionTokenSecretSerializer, @@ -70,8 +70,8 @@ class RDPFileClientProtocolURLMixin: # 设置磁盘挂载 drives_redirect = is_true(self.request.query_params.get('drives_redirect')) if drives_redirect: - actions = Action.choices_to_value(token.actions) - if actions & Action.UPDOWNLOAD == Action.UPDOWNLOAD: + actions = ActionChoices.choices_to_value(token.actions) + if actions & Action.TRANSFER == Action.TRANSFER: rdp_options['drivestoredirect:s'] = '*' # 设置全屏 diff --git a/apps/authentication/serializers/connection_token.py b/apps/authentication/serializers/connection_token.py index 6e1f19be1..256661882 100644 --- a/apps/authentication/serializers/connection_token.py +++ b/apps/authentication/serializers/connection_token.py @@ -7,7 +7,7 @@ from common.utils import pretty_string from common.utils.random import random_string from assets.models import Asset, Gateway, Domain, CommandFilterRule, Account from users.models import User -from perms.serializers.permission import ActionsField +from perms.serializers.permission import ActionChoicesField __all__ = [ @@ -158,14 +158,13 @@ class ConnectionTokenSecretSerializer(OrgResourceModelSerializerMixin): gateway = ConnectionTokenGatewaySerializer(read_only=True) domain = ConnectionTokenDomainSerializer(read_only=True) cmd_filter_rules = ConnectionTokenCmdFilterRuleSerializer(many=True) - actions = ActionsField() + actions = ActionChoicesField() expire_at = serializers.IntegerField() class Meta: model = ConnectionToken fields = [ - 'id', 'secret', - 'user', 'asset', 'account_username', 'account', 'protocol', - 'domain', 'gateway', 'cmd_filter_rules', - 'actions', 'expire_at', + 'id', 'secret', 'user', 'asset', 'account_username', + 'account', 'protocol', 'domain', 'gateway', + 'cmd_filter_rules', 'actions', 'expire_at', ] diff --git a/apps/authentication/views/dingtalk.py b/apps/authentication/views/dingtalk.py index 0d19d3fcd..340ece19c 100644 --- a/apps/authentication/views/dingtalk.py +++ b/apps/authentication/views/dingtalk.py @@ -1,27 +1,28 @@ +from urllib.parse import urlencode + +from django.conf import settings +from django.db.utils import IntegrityError +from django.http.request import HttpRequest from django.http.response import HttpResponseRedirect from django.utils.translation import ugettext_lazy as _ -from urllib.parse import urlencode from django.views import View -from django.conf import settings -from django.http.request import HttpRequest -from django.db.utils import IntegrityError -from rest_framework.permissions import IsAuthenticated, AllowAny from rest_framework.exceptions import APIException +from rest_framework.permissions import AllowAny, IsAuthenticated +from authentication import errors +from authentication.const import ConfirmType +from authentication.mixins import AuthMixin +from authentication.notifications import OAuthBindMessage +from common.mixins.views import PermissionsMixin, UserConfirmRequiredExceptionMixin +from common.permissions import UserConfirmation +from common.sdk.im.dingtalk import URL, DingTalk +from common.utils import FlashMessageUtil, get_logger +from common.utils.common import get_request_ip +from common.utils.django import get_object_or_none, reverse +from common.utils.random import random_string from users.models import User from users.views import UserVerifyPasswordView -from common.utils import get_logger, FlashMessageUtil -from common.utils.random import random_string -from common.utils.django import reverse, get_object_or_none -from common.sdk.im.dingtalk import URL -from common.mixins.views import UserConfirmRequiredExceptionMixin, PermissionsMixin -from common.permissions import UserConfirmation -from authentication import errors -from authentication.mixins import AuthMixin -from authentication.const import ConfirmType -from common.sdk.im.dingtalk import DingTalk -from common.utils.common import get_request_ip -from authentication.notifications import OAuthBindMessage + from .mixins import METAMixin logger = get_logger(__file__) diff --git a/apps/authentication/views/feishu.py b/apps/authentication/views/feishu.py index da7999b95..4fdf6f846 100644 --- a/apps/authentication/views/feishu.py +++ b/apps/authentication/views/feishu.py @@ -1,26 +1,27 @@ +from urllib.parse import urlencode + +from django.conf import settings +from django.db.utils import IntegrityError +from django.http.request import HttpRequest from django.http.response import HttpResponseRedirect from django.utils.translation import ugettext_lazy as _ -from urllib.parse import urlencode from django.views import View -from django.conf import settings -from django.http.request import HttpRequest -from django.db.utils import IntegrityError -from rest_framework.permissions import IsAuthenticated, AllowAny from rest_framework.exceptions import APIException +from rest_framework.permissions import AllowAny, IsAuthenticated -from users.models import User -from users.views import UserVerifyPasswordView -from common.utils import get_logger, FlashMessageUtil -from common.utils.random import random_string -from common.utils.django import reverse, get_object_or_none -from common.mixins.views import UserConfirmRequiredExceptionMixin, PermissionsMixin -from common.permissions import UserConfirmation -from common.sdk.im.feishu import FeiShu, URL -from common.utils.common import get_request_ip from authentication import errors from authentication.const import ConfirmType from authentication.mixins import AuthMixin from authentication.notifications import OAuthBindMessage +from common.mixins.views import PermissionsMixin, UserConfirmRequiredExceptionMixin +from common.permissions import UserConfirmation +from common.sdk.im.feishu import URL, FeiShu +from common.utils import FlashMessageUtil, get_logger +from common.utils.common import get_request_ip +from common.utils.django import get_object_or_none, reverse +from common.utils.random import random_string +from users.models import User +from users.views import UserVerifyPasswordView logger = get_logger(__file__) diff --git a/apps/common/db/fields.py b/apps/common/db/fields.py index 72c4df898..edca62d5b 100644 --- a/apps/common/db/fields.py +++ b/apps/common/db/fields.py @@ -1,10 +1,12 @@ # -*- coding: utf-8 -*- # import json + from django.db import models from django.utils.translation import ugettext_lazy as _ from django.utils.encoding import force_text from django.core.validators import MinValueValidator, MaxValueValidator + from common.utils import signer, crypto @@ -13,7 +15,7 @@ __all__ = [ 'JsonCharField', 'JsonTextField', 'JsonListCharField', 'JsonListTextField', 'JsonDictCharField', 'JsonDictTextField', 'EncryptCharField', 'EncryptTextField', 'EncryptMixin', 'EncryptJsonDictTextField', - 'EncryptJsonDictCharField', 'PortField' + 'EncryptJsonDictCharField', 'PortField', 'BitChoices', ] @@ -190,3 +192,37 @@ class PortField(models.IntegerField): }) super().__init__(*args, **kwargs) + +class BitChoices(models.IntegerChoices): + @classmethod + def branches(cls): + return [i for i in cls] + + @classmethod + def tree(cls): + root = [_('All'), cls.branches()] + return cls.render_node(root) + + @classmethod + def render_node(cls, node): + if isinstance(node, BitChoices): + return { + 'id': node.name, + 'label': node.label, + } + else: + name, children = node + return { + 'id': name, + 'label': name, + 'children': [cls.render_node(child) for child in children] + } + + @classmethod + def all(cls): + value = 0 + for c in cls: + value |= c.value + return value + + diff --git a/apps/common/drf/fields.py b/apps/common/drf/fields.py index 97f9785f5..1e68265ab 100644 --- a/apps/common/drf/fields.py +++ b/apps/common/drf/fields.py @@ -1,17 +1,20 @@ # -*- coding: utf-8 -*- # import six - -from rest_framework.fields import ChoiceField -from rest_framework import serializers -from django.utils.translation import gettext_lazy as _ from django.core.exceptions import ObjectDoesNotExist +from django.db.models import IntegerChoices +from django.utils.translation import gettext_lazy as _ +from rest_framework import serializers +from rest_framework.fields import ChoiceField from common.utils import decrypt_password __all__ = [ - 'ReadableHiddenField', 'EncryptedField', 'LabeledChoiceField', - 'ObjectRelatedField', + "ReadableHiddenField", + "EncryptedField", + "LabeledChoiceField", + "ObjectRelatedField", + "BitChoicesField", ] @@ -20,14 +23,15 @@ __all__ = [ class ReadableHiddenField(serializers.HiddenField): - """ 可读的 HiddenField """ + """可读的 HiddenField""" + def __init__(self, **kwargs): super().__init__(**kwargs) self.write_only = False def to_representation(self, value): - if hasattr(value, 'id'): - return getattr(value, 'id') + if hasattr(value, "id"): + return getattr(value, "id") return value @@ -35,7 +39,7 @@ class EncryptedField(serializers.CharField): def __init__(self, write_only=None, **kwargs): if write_only is None: write_only = True - kwargs['write_only'] = write_only + kwargs["write_only"] = write_only super().__init__(**kwargs) def to_internal_value(self, value): @@ -54,26 +58,26 @@ class LabeledChoiceField(ChoiceField): if value is None: return value return { - 'value': value, - 'label': self.choice_mapper.get(six.text_type(value), value), + "value": value, + "label": self.choice_mapper.get(six.text_type(value), value), } def to_internal_value(self, data): if isinstance(data, dict): - return data.get('value') + return data.get("value") return super(LabeledChoiceField, self).to_internal_value(data) class ObjectRelatedField(serializers.RelatedField): default_error_messages = { - 'required': _('This field is required.'), - 'does_not_exist': _('Invalid pk "{pk_value}" - object does not exist.'), - 'incorrect_type': _('Incorrect type. Expected pk value, received {data_type}.'), + "required": _("This field is required."), + "does_not_exist": _('Invalid pk "{pk_value}" - object does not exist.'), + "incorrect_type": _("Incorrect type. Expected pk value, received {data_type}."), } def __init__(self, **kwargs): - self.attrs = kwargs.pop('attrs', None) or ('id', 'name') - self.many = kwargs.get('many', False) + self.attrs = kwargs.pop("attrs", None) or ("id", "name") + self.many = kwargs.get("many", False) super().__init__(**kwargs) def to_representation(self, value): @@ -86,13 +90,53 @@ class ObjectRelatedField(serializers.RelatedField): if not isinstance(data, dict): pk = data else: - pk = data.get('id') or data.get('pk') or data.get(self.attrs[0]) + pk = data.get("id") or data.get("pk") or data.get(self.attrs[0]) queryset = self.get_queryset() try: if isinstance(data, bool): raise TypeError return queryset.get(pk=pk) except ObjectDoesNotExist: - self.fail('does_not_exist', pk_value=pk) + self.fail("does_not_exist", pk_value=pk) except (TypeError, ValueError): - self.fail('incorrect_type', data_type=type(pk).__name__) + self.fail("incorrect_type", data_type=type(pk).__name__) + + +class BitChoicesField(serializers.MultipleChoiceField): + """ + 位字段 + """ + + def __init__(self, choice_cls, **kwargs): + assert issubclass(choice_cls, IntegerChoices) + choices = [(c.name, c.label) for c in choice_cls] + self._choice_cls = choice_cls + super().__init__(choices=choices, **kwargs) + + def to_representation(self, value): + return [ + {"value": c.name, "label": c.label} + for c in self._choice_cls + if c.value & value == c.value + ] + + def to_internal_value(self, data): + if not isinstance(data, list): + raise serializers.ValidationError(_("Invalid data type, should be list")) + value = 0 + if not data: + return value + if isinstance(data[0], dict): + data = [d["value"] for d in data] + # 所有的 + if "all" in data: + for c in self._choice_cls: + value |= c.value + return value + + name_value_map = {c.name: c.value for c in self._choice_cls} + for name in data: + if name not in name_value_map: + raise serializers.ValidationError(_("Invalid choice: {}").format(name)) + value |= name_value_map[name] + return value diff --git a/apps/common/drf/metadata.py b/apps/common/drf/metadata.py index 939c1f314..d16ab2262 100644 --- a/apps/common/drf/metadata.py +++ b/apps/common/drf/metadata.py @@ -2,17 +2,15 @@ # from __future__ import unicode_literals -from collections import OrderedDict import datetime -from itertools import chain +from collections import OrderedDict from django.core.exceptions import PermissionDenied from django.http import Http404 from django.utils.encoding import force_text -from rest_framework.fields import empty - -from rest_framework.metadata import SimpleMetadata from rest_framework import exceptions, serializers +from rest_framework.fields import empty +from rest_framework.metadata import SimpleMetadata from rest_framework.request import clone_request @@ -21,9 +19,14 @@ class SimpleMetadataWithFilters(SimpleMetadata): methods = {"PUT", "POST", "GET", "PATCH"} attrs = [ - 'read_only', 'label', 'help_text', - 'min_length', 'max_length', - 'min_value', 'max_value', "write_only", + "read_only", + "label", + "help_text", + "min_length", + "max_length", + "min_value", + "max_value", + "write_only", ] def determine_actions(self, request, view): @@ -32,18 +35,18 @@ class SimpleMetadataWithFilters(SimpleMetadata): the fields that are accepted for 'PUT' and 'POST' methods. """ actions = {} - view.raw_action = getattr(view, 'action', None) + view.raw_action = getattr(view, "action", None) for method in self.methods & set(view.allowed_methods): - if hasattr(view, 'action_map'): + if hasattr(view, "action_map"): view.action = view.action_map.get(method.lower(), view.action) view.request = clone_request(request, method) try: # Test global permissions - if hasattr(view, 'check_permissions'): + if hasattr(view, "check_permissions"): view.check_permissions(view.request) # Test object permissions - if method == 'PUT' and hasattr(view, 'get_object'): + if method == "PUT" and hasattr(view, "get_object"): view.get_object() except (exceptions.APIException, PermissionDenied, Http404): pass @@ -62,64 +65,63 @@ class SimpleMetadataWithFilters(SimpleMetadata): of metadata about it. """ field_info = OrderedDict() - field_info['type'] = self.label_lookup[field] - field_info['required'] = getattr(field, 'required', False) + field_info["type"] = self.label_lookup[field] + field_info["required"] = getattr(field, "required", False) - default = getattr(field, 'default', None) + # Default value + default = getattr(field, "default", None) if default is not None and default != empty: if isinstance(default, (str, int, bool, float, datetime.datetime, list)): - field_info['default'] = default + field_info["default"] = default for attr in self.attrs: value = getattr(field, attr, None) - if value is not None and value != '': + if value is not None and value != "": field_info[attr] = force_text(value, strings_only=True) - if getattr(field, 'child', None): - field_info['child'] = self.get_field_info(field.child) - elif getattr(field, 'fields', None): - field_info['children'] = self.get_serializer_info(field) + if getattr(field, "child", None): + field_info["child"] = self.get_field_info(field.child) + elif getattr(field, "fields", None): + field_info["children"] = self.get_serializer_info(field) - is_related_field = isinstance(field, (serializers.RelatedField, serializers.ManyRelatedField)) - if not is_related_field and hasattr(field, 'choices'): - field_info['choices'] = [ + is_choice_field = isinstance(field, (serializers.ChoiceField,)) + if is_choice_field and hasattr(field, "choices"): + field_info["choices"] = [ { - 'value': choice_value, - 'label': force_text(choice_name, strings_only=True) + "value": choice_value, + "label": force_text(choice_label, strings_only=True), } - for choice_value, choice_name in dict(field.choices).items() + for choice_value, choice_label in dict(field.choices).items() ] class_name = field.__class__.__name__ - if class_name == 'LabeledChoiceField': - field_info['type'] = 'labeled_choice' - elif class_name == 'ObjectRelatedField': - field_info['type'] = 'object_related_field' - elif class_name == 'ManyRelatedField': + if class_name == "LabeledChoiceField": + field_info["type"] = "labeled_choice" + elif class_name == "ObjectRelatedField": + field_info["type"] = "object_related_field" + elif class_name == "ManyRelatedField": child_relation_class_name = field.child_relation.__class__.__name__ - if child_relation_class_name == 'ObjectRelatedField': - field_info['type'] = 'm2m_related_field' - - # if field.label == '系统平台': - # print("Field: ", class_name, field, field_info) - + if child_relation_class_name == "ObjectRelatedField": + field_info["type"] = "m2m_related_field" return field_info - def get_filters_fields(self, request, view): + @staticmethod + def get_filters_fields(request, view): fields = [] - if hasattr(view, 'get_filter_fields'): + if hasattr(view, "get_filter_fields"): fields = view.get_filter_fields(request) - elif hasattr(view, 'filter_fields'): + elif hasattr(view, "filter_fields"): fields = view.filter_fields - elif hasattr(view, 'filterset_fields'): + elif hasattr(view, "filterset_fields"): fields = view.filterset_fields - elif hasattr(view, 'get_filterset_fields'): + elif hasattr(view, "get_filterset_fields"): fields = view.get_filterset_fields(request) - elif hasattr(view, 'filterset_class'): - fields = list(view.filterset_class.Meta.fields) + \ - list(view.filterset_class.declared_filters.keys()) + elif hasattr(view, "filterset_class"): + fields = list(view.filterset_class.Meta.fields) + list( + view.filterset_class.declared_filters.keys() + ) - if hasattr(view, 'custom_filter_fields'): + if hasattr(view, "custom_filter_fields"): # 不能写 fields += view.custom_filter_fields # 会改变 view 的 filter_fields fields = list(fields) + list(view.custom_filter_fields) @@ -130,14 +132,16 @@ class SimpleMetadataWithFilters(SimpleMetadata): def get_ordering_fields(self, request, view): fields = [] - if hasattr(view, 'get_ordering_fields'): + if hasattr(view, "get_ordering_fields"): fields = view.get_ordering_fields(request) - elif hasattr(view, 'ordering_fields'): + elif hasattr(view, "ordering_fields"): fields = view.ordering_fields return fields def determine_metadata(self, request, view): - metadata = super(SimpleMetadataWithFilters, self).determine_metadata(request, view) + metadata = super(SimpleMetadataWithFilters, self).determine_metadata( + request, view + ) filterset_fields = self.get_filters_fields(request, view) order_fields = self.get_ordering_fields(request, view) diff --git a/apps/common/utils/integer.py b/apps/common/utils/integer.py new file mode 100644 index 000000000..73f4160c0 --- /dev/null +++ b/apps/common/utils/integer.py @@ -0,0 +1,3 @@ + +def bit(x): + return 2 ** (x - 1) diff --git a/apps/perms/api/asset_permission.py b/apps/perms/api/asset_permission.py index afadc456c..de15a6c6f 100644 --- a/apps/perms/api/asset_permission.py +++ b/apps/perms/api/asset_permission.py @@ -1,10 +1,9 @@ # -*- coding: utf-8 -*- # -from perms.filters import AssetPermissionFilter -from perms.models import AssetPermission from orgs.mixins.api import OrgBulkModelViewSet from perms import serializers - +from perms.filters import AssetPermissionFilter +from perms.models import AssetPermission __all__ = ['AssetPermissionViewSet'] @@ -18,4 +17,4 @@ class AssetPermissionViewSet(OrgBulkModelViewSet): filterset_class = AssetPermissionFilter search_fields = ('name',) ordering_fields = ('name',) - ordering = ('name', ) + ordering = ('name',) diff --git a/apps/perms/api/user_permission/accounts.py b/apps/perms/api/user_permission/accounts.py index 70973d988..692dac8c8 100644 --- a/apps/perms/api/user_permission/accounts.py +++ b/apps/perms/api/user_permission/accounts.py @@ -6,7 +6,6 @@ from common.utils import get_logger, lazyproperty from assets.serializers import AccountSerializer from perms.hands import User, Asset, Account from perms import serializers -from perms.models import Action from perms.utils import PermAccountUtil from .mixin import RoleAdminMixin, RoleUserMixin @@ -80,7 +79,7 @@ class UserGrantedAssetSpecialAccountsApi(ListAPIView): def get_queryset(self): # 构造默认包含的账号,如: @INPUT @USER accounts = [ - Account.get_input_account(), + Account.get_manual_account(), Account.get_user_account(self.user.username) ] for account in accounts: diff --git a/apps/perms/api/user_permission/mixin.py b/apps/perms/api/user_permission/mixin.py index da9691f38..2a7cbe221 100644 --- a/apps/perms/api/user_permission/mixin.py +++ b/apps/perms/api/user_permission/mixin.py @@ -3,11 +3,9 @@ from rest_framework.request import Request from common.http import is_true -from common.mixins.api import RoleAdminMixin -from common.mixins.api import RoleUserMixin -from orgs.utils import tmp_to_root_org -from users.models import User +from common.mixins.api import RoleAdminMixin, RoleUserMixin from perms.utils.user_permission import UserGrantedTreeRefreshController +from users.models import User class RebuildTreeMixin: diff --git a/apps/perms/const.py b/apps/perms/const.py index ec51c5a2b..3dd7aad6a 100644 --- a/apps/perms/const.py +++ b/apps/perms/const.py @@ -1,2 +1,71 @@ # -*- coding: utf-8 -*- # +from django.db import models +from django.utils.translation import ugettext_lazy as _ + +from common.utils.integer import bit +from common.db.fields import BitChoices + + +__all__ = ['SpecialAccount', 'ActionChoices'] + + +class ActionChoices(BitChoices): + connect = bit(0), _('Connect') + upload = bit(1), _('Upload') + download = bit(2), _('Download') + copy = bit(3), _('Copy') + paste = bit(4), _('Paste') + + @classmethod + def branches(cls): + return ( + (_('Transfer'), [cls.upload, cls.download]), + (_('Clipboard'), [cls.copy, cls.paste]), + ) + + +# class Action(BitOperationChoice): +# CONNECT = 0b1 +# UPLOAD = 0b1 << 1 +# DOWNLOAD = 0b1 << 2 +# COPY = 0b1 << 3 +# PASTE = 0b1 << 4 +# ALL = 0 << 8 +# TRANSFER = UPLOAD | DOWNLOAD +# CLIPBOARD = COPY | PASTE +# +# DB_CHOICES = ( +# (ALL, _('All')), +# (CONNECT, _('Connect')), +# (UPLOAD, _('Upload file')), +# (DOWNLOAD, _('Download file')), +# (TRANSFER, _("Upload download")), +# (COPY, _('Clipboard copy')), +# (PASTE, _('Clipboard paste')), +# (CLIPBOARD, _('Clipboard copy paste')) +# ) +# +# NAME_MAP = { +# ALL: "all", +# CONNECT: "connect", +# UPLOAD: "upload", +# DOWNLOAD: "download", +# TRANSFER: "transfer", +# COPY: 'copy', +# PASTE: 'paste', +# CLIPBOARD: 'clipboard' +# } +# +# NAME_MAP_REVERSE = {v: k for k, v in NAME_MAP.items()} +# CHOICES = [] +# for i, j in DB_CHOICES: +# CHOICES.append((NAME_MAP[i], j)) +# +# @classmethod +# def choices(cls): +# pass +# + +class SpecialAccount(models.TextChoices): + ALL = '@ALL', 'All' diff --git a/apps/perms/locks.py b/apps/perms/locks.py index a6ffa6b98..96c766fb8 100644 --- a/apps/perms/locks.py +++ b/apps/perms/locks.py @@ -5,7 +5,5 @@ class UserGrantedTreeRebuildLock(DistributedLock): name_template = 'perms.user.asset.node.tree.rebuid.' def __init__(self, user_id): - name = self.name_template.format( - user_id=user_id - ) + name = self.name_template.format(user_id=user_id) super().__init__(name=name, release_on_transaction_commit=True) diff --git a/apps/perms/migrations/0011_auto_20200721_1739.py b/apps/perms/migrations/0011_auto_20200721_1739.py index df8b46cde..1dcb33633 100644 --- a/apps/perms/migrations/0011_auto_20200721_1739.py +++ b/apps/perms/migrations/0011_auto_20200721_1739.py @@ -3,13 +3,12 @@ from django.db import migrations, models from django.db.models import F -from perms.models import Action def migrate_asset_permission(apps, schema_editor): # 已有的资产权限默认拥有剪切板复制粘贴动作 - AssetPermission = apps.get_model('perms', 'AssetPermission') - AssetPermission.objects.all().update(actions=F('actions').bitor(Action.CLIPBOARD_COPY_PASTE)) + asset_permission_model = apps.get_model('perms', 'AssetPermission') + asset_permission_model.objects.all().update(actions=F('actions').bitor(24)) class Migration(migrations.Migration): diff --git a/apps/perms/models/__init__.py b/apps/perms/models/__init__.py index 9cb0efc76..ee7787f7f 100644 --- a/apps/perms/models/__init__.py +++ b/apps/perms/models/__init__.py @@ -1,5 +1,5 @@ # coding: utf-8 # +from .permed_node import * from .asset_permission import * -from .const import * diff --git a/apps/perms/models/asset_permission.py b/apps/perms/models/asset_permission.py index 20186527d..47cb1e8e6 100644 --- a/apps/perms/models/asset_permission.py +++ b/apps/perms/models/asset_permission.py @@ -1,23 +1,19 @@ -import uuid import logging +import uuid +from django.db import models +from django.db.models import Q from django.utils import timezone from django.utils.translation import ugettext_lazy as _ -from django.db import models -from django.db.models import F, Q, TextChoices -from common.utils import lazyproperty, date_expired_default -from common.db.models import BaseCreateUpdateModel, UnionQuerySet -from assets.models import Asset, Node, FamilyMixin, Account -from orgs.mixins.models import OrgModelMixin +from assets.models import Asset, Account +from common.db.models import UnionQuerySet +from common.utils import date_expired_default from orgs.mixins.models import OrgManager -from .const import Action, SpecialAccount +from orgs.mixins.models import OrgModelMixin +from perms.const import ActionChoices, SpecialAccount -__all__ = [ - 'AssetPermission', 'PermNode', - 'UserAssetGrantedTreeNodeRelation', - 'Action' -] +__all__ = ['AssetPermission', 'ActionChoices'] # 使用场景 logger = logging.getLogger(__name__) @@ -67,9 +63,7 @@ class AssetPermission(OrgModelMixin): ) # 特殊的账号: @ALL, @INPUT @USER 默认包含,将来在全局设置中进行控制. accounts = models.JSONField(default=list, verbose_name=_("Accounts")) - actions = models.IntegerField( - choices=Action.DB_CHOICES, default=Action.ALL, verbose_name=_("Actions") - ) + actions = models.IntegerField(default=ActionChoices.connect, verbose_name=_("Actions")) is_active = models.BooleanField(default=True, verbose_name=_('Active')) date_start = models.DateTimeField( default=timezone.now, db_index=True, verbose_name=_("Date start") @@ -133,145 +127,9 @@ class AssetPermission(OrgModelMixin): """ asset_ids = self.get_all_assets(flat=True) q = Q(asset_id__in=asset_ids) - if not self.is_perm_all_accounts: + if SpecialAccount.ALL in self.accounts: q &= Q(username__in=self.accounts) accounts = Account.objects.filter(q).order_by('asset__name', 'name', 'username') if not flat: return accounts return accounts.values_list('id', flat=True) - - @property - def is_perm_all_accounts(self): - return SpecialAccount.ALL in self.accounts - - @lazyproperty - def users_amount(self): - return self.users.count() - - @lazyproperty - def user_groups_amount(self): - return self.user_groups.count() - - @lazyproperty - def assets_amount(self): - return self.assets.count() - - @lazyproperty - def nodes_amount(self): - return self.nodes.count() - - def users_display(self): - names = [user.username for user in self.users.all()] - return names - - def user_groups_display(self): - names = [group.name for group in self.user_groups.all()] - return names - - def assets_display(self): - names = [asset.name for asset in self.assets.all()] - return names - - def nodes_display(self): - names = [node.full_value for node in self.nodes.all()] - return names - - -class UserAssetGrantedTreeNodeRelation(OrgModelMixin, FamilyMixin, BaseCreateUpdateModel): - class NodeFrom(TextChoices): - granted = 'granted', 'Direct node granted' - child = 'child', 'Have children node' - asset = 'asset', 'Direct asset granted' - - user = models.ForeignKey('users.User', db_constraint=False, on_delete=models.CASCADE) - node = models.ForeignKey('assets.Node', default=None, on_delete=models.CASCADE, - db_constraint=False, null=False, related_name='granted_node_rels') - node_key = models.CharField(max_length=64, verbose_name=_("Key"), db_index=True) - node_parent_key = models.CharField(max_length=64, default='', verbose_name=_('Parent key'), - db_index=True) - node_from = models.CharField(choices=NodeFrom.choices, max_length=16, db_index=True) - node_assets_amount = models.IntegerField(default=0) - - @property - def key(self): - return self.node_key - - @property - def parent_key(self): - return self.node_parent_key - - @classmethod - def get_node_granted_status(cls, user, key): - ancestor_keys = set(cls.get_node_ancestor_keys(key, with_self=True)) - ancestor_rel_nodes = cls.objects.filter(user=user, node_key__in=ancestor_keys) - - for rel_node in ancestor_rel_nodes: - if rel_node.key == key: - return rel_node.node_from, rel_node - if rel_node.node_from == cls.NodeFrom.granted: - return cls.NodeFrom.granted, None - return '', None - - -class PermNode(Node): - class Meta: - proxy = True - ordering = [] - - # 特殊节点 - UNGROUPED_NODE_KEY = 'ungrouped' - UNGROUPED_NODE_VALUE = _('Ungrouped') - FAVORITE_NODE_KEY = 'favorite' - FAVORITE_NODE_VALUE = _('Favorite') - - node_from = '' - granted_assets_amount = 0 - - annotate_granted_node_rel_fields = { - 'granted_assets_amount': F('granted_node_rels__node_assets_amount'), - 'node_from': F('granted_node_rels__node_from') - } - - def use_granted_assets_amount(self): - self.assets_amount = self.granted_assets_amount - - @classmethod - def get_ungrouped_node(cls, assets_amount): - return cls( - id=cls.UNGROUPED_NODE_KEY, - key=cls.UNGROUPED_NODE_KEY, - value=cls.UNGROUPED_NODE_VALUE, - assets_amount=assets_amount - ) - - @classmethod - def get_favorite_node(cls, assets_amount): - node = cls( - id=cls.FAVORITE_NODE_KEY, - key=cls.FAVORITE_NODE_KEY, - value=cls.FAVORITE_NODE_VALUE, - ) - node.assets_amount = assets_amount - return node - - def get_granted_status(self, user): - status, rel_node = UserAssetGrantedTreeNodeRelation.get_node_granted_status(user, self.key) - self.node_from = status - if rel_node: - self.granted_assets_amount = rel_node.node_assets_amount - return status - - def save(self): - # 这是个只读 Model - raise NotImplementedError - - -class PermedAsset(Asset): - class Meta: - proxy = True - verbose_name = _('Permed asset') - permissions = [ - ('view_myassets', _('Can view my assets')), - ('view_userassets', _('Can view user assets')), - ('view_usergroupassets', _('Can view usergroup assets')), - ] diff --git a/apps/perms/models/const.py b/apps/perms/models/const.py deleted file mode 100644 index 6128418b0..000000000 --- a/apps/perms/models/const.py +++ /dev/null @@ -1,48 +0,0 @@ -from django.db import models -from django.utils.translation import ugettext_lazy as _ -from common.db.models import BitOperationChoice - - -__all__ = ['Action', 'SpecialAccount'] - - -class Action(BitOperationChoice): - ALL = 0xff - CONNECT = 0b1 - UPLOAD = 0b1 << 1 - DOWNLOAD = 0b1 << 2 - CLIPBOARD_COPY = 0b1 << 3 - CLIPBOARD_PASTE = 0b1 << 4 - UPDOWNLOAD = UPLOAD | DOWNLOAD - CLIPBOARD_COPY_PASTE = CLIPBOARD_COPY | CLIPBOARD_PASTE - - DB_CHOICES = ( - (ALL, _('All')), - (CONNECT, _('Connect')), - (UPLOAD, _('Upload file')), - (DOWNLOAD, _('Download file')), - (UPDOWNLOAD, _("Upload download")), - (CLIPBOARD_COPY, _('Clipboard copy')), - (CLIPBOARD_PASTE, _('Clipboard paste')), - (CLIPBOARD_COPY_PASTE, _('Clipboard copy paste')) - ) - - NAME_MAP = { - ALL: "all", - CONNECT: "connect", - UPLOAD: "upload_file", - DOWNLOAD: "download_file", - UPDOWNLOAD: "updownload", - CLIPBOARD_COPY: 'clipboard_copy', - CLIPBOARD_PASTE: 'clipboard_paste', - CLIPBOARD_COPY_PASTE: 'clipboard_copy_paste' - } - - NAME_MAP_REVERSE = {v: k for k, v in NAME_MAP.items()} - CHOICES = [] - for i, j in DB_CHOICES: - CHOICES.append((NAME_MAP[i], j)) - - -class SpecialAccount(models.TextChoices): - ALL = '@ALL', 'All' diff --git a/apps/perms/models/permed_node.py b/apps/perms/models/permed_node.py new file mode 100644 index 000000000..ce061297e --- /dev/null +++ b/apps/perms/models/permed_node.py @@ -0,0 +1,119 @@ + +from django.utils.translation import ugettext_lazy as _ +from django.db import models +from django.db.models import F, TextChoices + +from common.utils import lazyproperty +from common.db.models import BaseCreateUpdateModel +from assets.models import Asset, Node, FamilyMixin, Account +from orgs.mixins.models import OrgModelMixin + + +class UserAssetGrantedTreeNodeRelation(OrgModelMixin, FamilyMixin, BaseCreateUpdateModel): + class NodeFrom(TextChoices): + granted = 'granted', 'Direct node granted' + child = 'child', 'Have children node' + asset = 'asset', 'Direct asset granted' + + user = models.ForeignKey('users.User', db_constraint=False, on_delete=models.CASCADE) + node = models.ForeignKey('assets.Node', default=None, on_delete=models.CASCADE, + db_constraint=False, null=False, related_name='granted_node_rels') + node_key = models.CharField(max_length=64, verbose_name=_("Key"), db_index=True) + node_parent_key = models.CharField(max_length=64, default='', verbose_name=_('Parent key'), + db_index=True) + node_from = models.CharField(choices=NodeFrom.choices, max_length=16, db_index=True) + node_assets_amount = models.IntegerField(default=0) + + @property + def key(self): + return self.node_key + + @property + def parent_key(self): + return self.node_parent_key + + @classmethod + def get_node_granted_status(cls, user, key): + ancestor_keys = set(cls.get_node_ancestor_keys(key, with_self=True)) + ancestor_rel_nodes = cls.objects.filter(user=user, node_key__in=ancestor_keys) + + for rel_node in ancestor_rel_nodes: + if rel_node.key == key: + return rel_node.node_from, rel_node + if rel_node.node_from == cls.NodeFrom.granted: + return cls.NodeFrom.granted, None + return '', None + + +class PermNode(Node): + class Meta: + proxy = True + ordering = [] + + # 特殊节点 + UNGROUPED_NODE_KEY = 'ungrouped' + UNGROUPED_NODE_VALUE = _('Ungrouped') + FAVORITE_NODE_KEY = 'favorite' + FAVORITE_NODE_VALUE = _('Favorite') + + node_from = '' + granted_assets_amount = 0 + + annotate_granted_node_rel_fields = { + 'granted_assets_amount': F('granted_node_rels__node_assets_amount'), + 'node_from': F('granted_node_rels__node_from') + } + + def use_granted_assets_amount(self): + self.assets_amount = self.granted_assets_amount + + @classmethod + def get_ungrouped_node(cls, assets_amount): + return cls( + id=cls.UNGROUPED_NODE_KEY, + key=cls.UNGROUPED_NODE_KEY, + value=cls.UNGROUPED_NODE_VALUE, + assets_amount=assets_amount + ) + + @classmethod + def get_favorite_node(cls, assets_amount): + node = cls( + id=cls.FAVORITE_NODE_KEY, + key=cls.FAVORITE_NODE_KEY, + value=cls.FAVORITE_NODE_VALUE, + ) + node.assets_amount = assets_amount + return node + + def get_granted_status(self, user): + status, rel_node = UserAssetGrantedTreeNodeRelation.get_node_granted_status(user, self.key) + self.node_from = status + if rel_node: + self.granted_assets_amount = rel_node.node_assets_amount + return status + + def save(self): + # 这是个只读 Model + raise NotImplementedError + + +class PermedAsset(Asset): + class Meta: + proxy = True + verbose_name = _('Permed asset') + permissions = [ + ('view_myassets', _('Can view my assets')), + ('view_userassets', _('Can view user assets')), + ('view_usergroupassets', _('Can view usergroup assets')), + ] + + +class PermedAccount(Account): + @lazyproperty + def actions(self): + return 0 + + class Meta: + proxy = True + verbose_name = _('Permed account') diff --git a/apps/perms/serializers/permission.py b/apps/perms/serializers/permission.py index ff19b9dd6..9a31058f6 100644 --- a/apps/perms/serializers/permission.py +++ b/apps/perms/serializers/permission.py @@ -1,75 +1,64 @@ # -*- coding: utf-8 -*- # -from rest_framework import serializers -from rest_framework.fields import empty -from django.utils.translation import ugettext_lazy as _ from django.db.models import Q +from django.utils.translation import ugettext_lazy as _ +from rest_framework import serializers -from common.drf.fields import ObjectRelatedField -from orgs.mixins.serializers import BulkOrgResourceModelSerializer from assets.models import Asset, Node +from common.drf.fields import BitChoicesField, ObjectRelatedField +from orgs.mixins.serializers import BulkOrgResourceModelSerializer +from perms.models import ActionChoices, AssetPermission from users.models import User, UserGroup -from perms.models import AssetPermission, Action -__all__ = ['AssetPermissionSerializer', 'ActionsField'] +__all__ = ["AssetPermissionSerializer", "ActionChoicesField"] -class ActionsField(serializers.MultipleChoiceField): +class ActionChoicesField(BitChoicesField): def __init__(self, **kwargs): - kwargs['choices'] = Action.CHOICES - super().__init__(**kwargs) - - def run_validation(self, data=empty): - data = super(ActionsField, self).run_validation(data) - if isinstance(data, list): - data = Action.choices_to_value(value=data) - return data - - def to_representation(self, value): - return Action.value_to_choices(value) - - def to_internal_value(self, data): - if not self.allow_empty and not data: - self.fail('empty') - if not data: - return data - return Action.choices_to_value(data) - - -class ActionsDisplayField(ActionsField): - def to_representation(self, value): - values = super().to_representation(value) - choices = dict(Action.CHOICES) - return [choices.get(i) for i in values] + super().__init__(ActionChoices, **kwargs) class AssetPermissionSerializer(BulkOrgResourceModelSerializer): users = ObjectRelatedField(queryset=User.objects, many=True, required=False) - user_groups = ObjectRelatedField(queryset=UserGroup.objects, many=True, required=False) + user_groups = ObjectRelatedField( + queryset=UserGroup.objects, many=True, required=False + ) assets = ObjectRelatedField(queryset=Asset.objects, many=True, required=False) nodes = ObjectRelatedField(queryset=Node.objects, many=True, required=False) - actions = ActionsField(required=False, allow_null=True, label=_("Actions")) + actions = ActionChoicesField(required=False, allow_null=True, label=_("Actions")) is_valid = serializers.BooleanField(read_only=True, label=_("Is valid")) - is_expired = serializers.BooleanField(read_only=True, label=_('Is expired')) + is_expired = serializers.BooleanField(read_only=True, label=_("Is expired")) + accounts = serializers.ListField(label=_("Accounts"), required=False) class Meta: model = AssetPermission - fields_mini = ['id', 'name'] + fields_mini = ["id", "name"] fields_small = fields_mini + [ - 'accounts', 'is_active', 'is_expired', 'is_valid', - 'actions', 'created_by', 'date_created', 'date_expired', - 'date_start', 'comment', 'from_ticket' + "accounts", + "is_active", + "is_expired", + "is_valid", + "actions", + "created_by", + "date_created", + "date_expired", + "date_start", + "comment", + "from_ticket", ] fields_m2m = [ - 'users', 'user_groups', 'assets', 'nodes', + "users", + "user_groups", + "assets", + "nodes", ] fields = fields_small + fields_m2m - read_only_fields = ['created_by', 'date_created', 'from_ticket'] + read_only_fields = ["created_by", "date_created", "from_ticket"] extra_kwargs = { - 'actions': {'label': _('Actions')}, - 'is_expired': {'label': _('Is expired')}, - 'is_valid': {'label': _('Is valid')}, + "actions": {"label": _("Actions")}, + "is_expired": {"label": _("Is expired")}, + "is_valid": {"label": _("Is valid")}, } def __init__(self, *args, **kwargs): @@ -77,7 +66,7 @@ class AssetPermissionSerializer(BulkOrgResourceModelSerializer): self.set_actions_field() def set_actions_field(self): - actions = self.fields.get('actions') + actions = self.fields.get("actions") if not actions: return choices = actions._choices @@ -86,9 +75,12 @@ class AssetPermissionSerializer(BulkOrgResourceModelSerializer): @classmethod def setup_eager_loading(cls, queryset): - """ Perform necessary eager loading of data. """ + """Perform necessary eager loading of data.""" queryset = queryset.prefetch_related( - 'users', 'user_groups', 'assets', 'nodes', + "users", + "user_groups", + "assets", + "nodes", ) return queryset @@ -96,35 +88,34 @@ class AssetPermissionSerializer(BulkOrgResourceModelSerializer): def perform_display_create(instance, **kwargs): # 用户 users_to_set = User.objects.filter( - Q(name__in=kwargs.get('users_display')) | - Q(username__in=kwargs.get('users_display')) + Q(name__in=kwargs.get("users_display")) + | Q(username__in=kwargs.get("users_display")) ).distinct() instance.users.add(*users_to_set) # 用户组 user_groups_to_set = UserGroup.objects.filter( - name__in=kwargs.get('user_groups_display') + name__in=kwargs.get("user_groups_display") ).distinct() instance.user_groups.add(*user_groups_to_set) # 资产 assets_to_set = Asset.objects.filter( - Q(address__in=kwargs.get('assets_display')) | - Q(name__in=kwargs.get('assets_display')) + Q(address__in=kwargs.get("assets_display")) + | Q(name__in=kwargs.get("assets_display")) ).distinct() instance.assets.add(*assets_to_set) # 节点 nodes_to_set = Node.objects.filter( - full_value__in=kwargs.get('nodes_display') + full_value__in=kwargs.get("nodes_display") ).distinct() instance.nodes.add(*nodes_to_set) def create(self, validated_data): display = { - 'users_display': validated_data.pop('users_display', ''), - 'user_groups_display': validated_data.pop('user_groups_display', ''), - 'assets_display': validated_data.pop('assets_display', ''), - 'nodes_display': validated_data.pop('nodes_display', '') + "users_display": validated_data.pop("users_display", ""), + "user_groups_display": validated_data.pop("user_groups_display", ""), + "assets_display": validated_data.pop("assets_display", ""), + "nodes_display": validated_data.pop("nodes_display", ""), } instance = super().create(validated_data) self.perform_display_create(instance, **display) return instance - diff --git a/apps/perms/serializers/user_permission.py b/apps/perms/serializers/user_permission.py index 8784a8abb..09eb97428 100644 --- a/apps/perms/serializers/user_permission.py +++ b/apps/perms/serializers/user_permission.py @@ -7,7 +7,7 @@ from django.utils.translation import ugettext_lazy as _ from common.drf.fields import ObjectRelatedField, LabeledChoiceField from assets.models import Node, Asset, Platform, Account from assets.const import Category, AllTypes -from perms.serializers.permission import ActionsField +from perms.serializers.permission import ActionChoicesField __all__ = [ 'NodeGrantedSerializer', 'AssetGrantedSerializer', @@ -45,7 +45,7 @@ class NodeGrantedSerializer(serializers.ModelSerializer): class ActionsSerializer(serializers.Serializer): - actions = ActionsField(read_only=True) + actions = ActionChoicesField(read_only=True) class AccountsGrantedSerializer(serializers.ModelSerializer): @@ -53,7 +53,7 @@ class AccountsGrantedSerializer(serializers.ModelSerializer): # Todo: 添加前端登录逻辑中需要的一些字段,比如:是否需要手动输入密码 # need_manual = serializers.BooleanField(label=_('Need manual input')) - actions = ActionsField(read_only=True) + actions = ActionChoicesField(read_only=True) class Meta: model = Account diff --git a/apps/perms/utils/account.py b/apps/perms/utils/account.py index 8d8f5e743..167a3060b 100644 --- a/apps/perms/utils/account.py +++ b/apps/perms/utils/account.py @@ -1,5 +1,6 @@ import time from collections import defaultdict + from assets.models import Account from .permission import AssetPermissionUtil @@ -8,54 +9,78 @@ __all__ = ['PermAccountUtil'] class PermAccountUtil(AssetPermissionUtil): """ 资产授权账号相关的工具 """ + @staticmethod + def get_permed_accounts_from_perms(perms, user, asset): + alias_action_bit_mapper = defaultdict(int) + alias_expired_mapper = defaultdict(list) - def get_perm_accounts_for_user(self, user, with_actions=False): - """ 获取授权给用户的所有账号 """ - perms = self.get_permissions_for_user(user) - accounts = self.get_perm_accounts_for_permissions(perms, with_actions=with_actions) + for perm in perms: + for alias in perm.accounts: + alias_action_bit_mapper[alias] |= perm.actions + alias_expired_mapper[alias].append(perm.date_expired) + + asset_accounts = asset.accounts.all() + username_account_mapper = {account.username: account for account in asset_accounts} + cleaned_accounts_action_bit = defaultdict(int) + cleaned_accounts_expired = defaultdict(list) + + # @ALL 账号先处理,后面的每个最多映射一个账号 + all_action_bit = alias_action_bit_mapper.pop('@ALL', None) + if all_action_bit: + for account in asset_accounts: + cleaned_accounts_action_bit[account] |= all_action_bit + cleaned_accounts_expired[account].extend(alias_expired_mapper['@ALL']) + + for alias, action_bit in alias_action_bit_mapper.items(): + if alias == '@USER': + if user.username in username_account_mapper: + account = username_account_mapper[user.username] + else: + account = Account.get_user_account(user.username) + elif alias == '@INPUT': + account = Account.get_manual_account() + elif alias in username_account_mapper: + account = username_account_mapper[alias] + else: + account = None + + if account: + cleaned_accounts_action_bit[account] |= action_bit + cleaned_accounts_expired[account].extend(alias_expired_mapper[alias]) + + accounts = [] + for account, action_bit in cleaned_accounts_action_bit.items(): + account.actions = action_bit + account.date_expired = max(cleaned_accounts_expired[account]) + accounts.append(account) return accounts - def get_perm_accounts_for_user_asset(self, user, asset, with_actions=False, with_perms=False): + def get_permed_accounts_for_user(self, user, asset): """ 获取授权给用户某个资产的账号 """ perms = self.get_permissions_for_user_asset(user, asset) - accounts = self.get_perm_accounts_for_permissions(perms, with_actions=with_actions) - if with_perms: - return perms, accounts - return accounts - - def get_perm_accounts_for_user_group_asset(self, user_group, asset, with_actions=False): - """ 获取授权给用户组某个资产的账号 """ - perms = self.get_permissions_for_user_group_asset(user_group, asset) - accounts = self.get_perm_accounts_for_permissions(perms, with_actions=with_actions) - return accounts + permed_accounts = self.get_permed_accounts_from_perms(perms, user, asset) + return permed_accounts @staticmethod - def get_perm_accounts_for_permissions(permissions, with_actions=False): + def get_accounts_for_permission(perm, with_actions=False): """ 获取授权规则包含的账号 """ aid_actions_map = defaultdict(int) - for perm in permissions: - account_ids = perm.get_all_accounts(flat=True) - actions = perm.actions - for aid in account_ids: - aid_actions_map[str(aid)] |= actions + # 这里不行,速度太慢, 别情有很多查询 + account_ids = perm.get_all_accounts(flat=True) + actions = perm.actions + for aid in account_ids: + aid_actions_map[str(aid)] |= actions account_ids = list(aid_actions_map.keys()) - accounts = Account.objects.filter(id__in=account_ids).order_by( - 'asset__name', 'name', 'username' - ) - if with_actions: - for account in accounts: - account.actions = aid_actions_map.get(str(account.id)) + accounts = Account.objects.filter(id__in=account_ids) return accounts def validate_permission(self, user, asset, account_username): """ 校验用户有某个资产下某个账号名的权限 """ - perms, accounts = self.get_perm_accounts_for_user_asset( - user, asset, with_actions=True, with_perms=True - ) - perm = perms.first() - actions = [] - for account in accounts: - if account.username == account_username: - actions = account.actions - expire_at = perm.date_expired.timestamp() if perm else time.time() - return actions, expire_at + permed_accounts = self.get_permed_accounts_for_user(user, asset) + accounts_mapper = {account.username: account for account in permed_accounts} + + account = accounts_mapper.get(account_username) + if not account: + return False, None + else: + return account.actions, account.date_expired diff --git a/apps/perms/utils/permission.py b/apps/perms/utils/permission.py index fd0ea593b..8e4cd9199 100644 --- a/apps/perms/utils/permission.py +++ b/apps/perms/utils/permission.py @@ -1,12 +1,6 @@ -import time -from collections import defaultdict - -from django.db.models import Q from common.utils import get_logger -from perms.models import AssetPermission, Action -from perms.hands import Asset, User, UserGroup, Node -from perms.utils.user_permission import get_user_all_asset_perm_ids +from perms.models import AssetPermission logger = get_logger(__file__) diff --git a/apps/perms/utils/user_permission.py b/apps/perms/utils/user_permission.py index 81a2b9a10..fe73913cf 100644 --- a/apps/perms/utils/user_permission.py +++ b/apps/perms/utils/user_permission.py @@ -19,13 +19,12 @@ from orgs.utils import ( from assets.models import ( Asset, FavoriteAsset, AssetQuerySet, NodeQuerySet ) -from orgs.models import Organization -from perms.models import ( - AssetPermission, PermNode, UserAssetGrantedTreeNodeRelation, -) from users.models import User +from orgs.models import Organization from perms.locks import UserGrantedTreeRebuildLock - +from perms.models import ( + AssetPermission, PermNode, UserAssetGrantedTreeNodeRelation +) NodeFrom = UserAssetGrantedTreeNodeRelation.NodeFrom NODE_ONLY_FIELDS = ('id', 'key', 'parent_key', 'org_id') diff --git a/apps/tickets/migrations/0017_auto_20220623_1027.py b/apps/tickets/migrations/0017_auto_20220623_1027.py index 87752a469..ba2351d65 100644 --- a/apps/tickets/migrations/0017_auto_20220623_1027.py +++ b/apps/tickets/migrations/0017_auto_20220623_1027.py @@ -7,7 +7,6 @@ from collections import defaultdict from django.utils import timezone as dj_timezone from django.db import migrations -from perms.models import Action from tickets.const import TicketType pt = re.compile(r'(\w+)\((\w+)\)') diff --git a/apps/tickets/models/ticket/apply_asset.py b/apps/tickets/models/ticket/apply_asset.py index c0ec46cc0..d5f11ee36 100644 --- a/apps/tickets/models/ticket/apply_asset.py +++ b/apps/tickets/models/ticket/apply_asset.py @@ -1,7 +1,6 @@ from django.db import models from django.utils.translation import gettext_lazy as _ -from perms.models import Action from .general import Ticket __all__ = ['ApplyAssetTicket'] @@ -15,15 +14,13 @@ class ApplyAssetTicket(Ticket): # 申请信息 apply_assets = models.ManyToManyField('assets.Asset', verbose_name=_('Apply assets')) apply_accounts = models.JSONField(default=list, verbose_name=_('Apply accounts')) - apply_actions = models.IntegerField( - choices=Action.DB_CHOICES, default=Action.ALL, verbose_name=_('Actions') - ) + apply_actions = models.IntegerField(default=1, verbose_name=_('Actions')) apply_date_start = models.DateTimeField(verbose_name=_('Date start'), null=True) apply_date_expired = models.DateTimeField(verbose_name=_('Date expired'), null=True) @property def apply_actions_display(self): - return Action.value_to_choices_display(self.apply_actions) + return 'Todo' def get_apply_actions_display(self): return ', '.join(self.apply_actions_display) diff --git a/apps/tickets/serializers/ticket/apply_asset.py b/apps/tickets/serializers/ticket/apply_asset.py index a2cb94179..26a1fe434 100644 --- a/apps/tickets/serializers/ticket/apply_asset.py +++ b/apps/tickets/serializers/ticket/apply_asset.py @@ -1,7 +1,7 @@ from django.utils.translation import ugettext_lazy as _ from rest_framework import serializers -from perms.serializers.permission import ActionsField +from perms.serializers.permission import ActionChoicesField from perms.models import AssetPermission from orgs.utils import tmp_to_org from assets.models import Asset, Node @@ -16,7 +16,7 @@ asset_or_node_help_text = _("Select at least one asset or node") class ApplyAssetSerializer(BaseApplyAssetApplicationSerializer, TicketApplySerializer): - apply_actions = ActionsField(required=True, allow_empty=False) + apply_actions = ActionChoicesField(required=True, allow_empty=False) permission_model = AssetPermission class Meta: diff --git a/apps/users/serializers/user.py b/apps/users/serializers/user.py index 833752727..d84a7baa0 100644 --- a/apps/users/serializers/user.py +++ b/apps/users/serializers/user.py @@ -15,8 +15,10 @@ from ..models import User from ..const import PasswordStrategy __all__ = [ - 'UserSerializer', 'MiniUserSerializer', - 'InviteSerializer', 'ServiceAccountSerializer', + "UserSerializer", + "MiniUserSerializer", + "InviteSerializer", + "ServiceAccountSerializer", ] logger = get_logger(__file__) @@ -25,15 +27,17 @@ logger = get_logger(__file__) class RolesSerializerMixin(serializers.Serializer): system_roles = serializers.ManyRelatedField( child_relation=serializers.PrimaryKeyRelatedField(queryset=Role.system_roles), - label=_('System roles'), + label=_("System roles"), ) org_roles = serializers.ManyRelatedField( required=False, child_relation=serializers.PrimaryKeyRelatedField(queryset=Role.org_roles), - label=_('Org roles'), + label=_("Org roles"), ) - system_roles_display = serializers.SerializerMethodField(label=_('System roles display')) - org_roles_display = serializers.SerializerMethodField(label=_('Org roles display')) + system_roles_display = serializers.SerializerMethodField( + label=_("System roles display") + ) + org_roles_display = serializers.SerializerMethodField(label=_("Org roles display")) @staticmethod def get_system_roles_display(user): @@ -44,20 +48,20 @@ class RolesSerializerMixin(serializers.Serializer): return user.org_roles.display def pop_roles_if_need(self, fields): - request = self.context.get('request') - view = self.context.get('view') + request = self.context.get("request") + view = self.context.get("view") - if not all([request, view, hasattr(view, 'action')]): + if not all([request, view, hasattr(view, "action")]): return fields if request.user.is_anonymous: return fields - action = view.action or 'list' - if action in ('partial_bulk_update', 'bulk_update', 'partial_update', 'update'): - action = 'create' + action = view.action or "list" + if action in ("partial_bulk_update", "bulk_update", "partial_update", "update"): + action = "create" model_cls_field_mapper = { - SystemRoleBinding: ['system_roles', 'system_roles_display'], - OrgRoleBinding: ['org_roles', 'system_roles_display'] + SystemRoleBinding: ["system_roles", "system_roles_display"], + OrgRoleBinding: ["org_roles", "system_roles_display"], } for model_cls, fields_names in model_cls_field_mapper.items(): @@ -75,97 +79,148 @@ class RolesSerializerMixin(serializers.Serializer): return fields -class UserSerializer(RolesSerializerMixin, CommonBulkSerializerMixin, serializers.ModelSerializer): +class UserSerializer( + RolesSerializerMixin, CommonBulkSerializerMixin, serializers.ModelSerializer +): password_strategy = serializers.ChoiceField( - choices=PasswordStrategy.choices, default=PasswordStrategy.email, required=False, - write_only=True, label=_('Password strategy') + choices=PasswordStrategy.choices, + default=PasswordStrategy.email, + required=False, + write_only=True, + label=_("Password strategy"), + ) + mfa_enabled = serializers.BooleanField(read_only=True, label=_("MFA enabled")) + mfa_force_enabled = serializers.BooleanField( + read_only=True, label=_("MFA force enabled") ) - mfa_enabled = serializers.BooleanField(read_only=True, label=_('MFA enabled')) - mfa_force_enabled = serializers.BooleanField(read_only=True, label=_('MFA force enabled')) mfa_level_display = serializers.ReadOnlyField( - source='get_mfa_level_display', label=_('MFA level display') + source="get_mfa_level_display", label=_("MFA level display") ) - login_blocked = serializers.BooleanField(read_only=True, label=_('Login blocked')) - is_expired = serializers.BooleanField(read_only=True, label=_('Is expired')) + login_blocked = serializers.BooleanField(read_only=True, label=_("Login blocked")) + is_expired = serializers.BooleanField(read_only=True, label=_("Is expired")) can_public_key_auth = serializers.ReadOnlyField( - source='can_use_ssh_key_login', label=_('Can public key authentication') + source="can_use_ssh_key_login", label=_("Can public key authentication") ) password = EncryptedField( - label=_('Password'), required=False, allow_blank=True, allow_null=True, max_length=1024 + label=_("Password"), + required=False, + allow_blank=True, + allow_null=True, + max_length=1024, ) # Todo: 这里看看该怎么搞 # can_update = serializers.SerializerMethodField(label=_('Can update')) # can_delete = serializers.SerializerMethodField(label=_('Can delete')) custom_m2m_fields = { - 'system_roles': [BuiltinRole.system_user], - 'org_roles': [BuiltinRole.org_user] + "system_roles": [BuiltinRole.system_user], + "org_roles": [BuiltinRole.org_user], } class Meta: model = User # mini 是指能识别对象的最小单元 - fields_mini = ['id', 'name', 'username'] + fields_mini = ["id", "name", "username"] # 只能写的字段, 这个虽然无法在框架上生效,但是更多对我们是提醒 fields_write_only = [ - 'password', 'public_key', + "password", + "public_key", ] # small 指的是 不需要计算的直接能从一张表中获取到的数据 - fields_small = fields_mini + fields_write_only + [ - 'email', 'wechat', 'phone', 'mfa_level', 'source', 'source_display', - 'can_public_key_auth', 'need_update_password', - 'mfa_enabled', 'is_service_account', 'is_valid', 'is_expired', 'is_active', # 布尔字段 - 'date_expired', 'date_joined', 'last_login', # 日期字段 - 'created_by', 'comment', # 通用字段 - 'is_wecom_bound', 'is_dingtalk_bound', 'is_feishu_bound', 'is_otp_secret_key_bound', - 'wecom_id', 'dingtalk_id', 'feishu_id' - ] + fields_small = ( + fields_mini + + fields_write_only + + [ + "email", + "wechat", + "phone", + "mfa_level", + "source", + "source_display", + "can_public_key_auth", + "need_update_password", + "mfa_enabled", + "is_service_account", + "is_valid", + "is_expired", + "is_active", # 布尔字段 + "date_expired", + "date_joined", + "last_login", # 日期字段 + "created_by", + "comment", # 通用字段 + "is_wecom_bound", + "is_dingtalk_bound", + "is_feishu_bound", + "is_otp_secret_key_bound", + "wecom_id", + "dingtalk_id", + "feishu_id", + ] + ) # 包含不太常用的字段,可以没有 fields_verbose = fields_small + [ - 'mfa_level_display', 'mfa_force_enabled', 'is_first_login', - 'date_password_last_updated', 'avatar_url', + "mfa_level_display", + "mfa_force_enabled", + "is_first_login", + "date_password_last_updated", + "avatar_url", ] # 外键的字段 fields_fk = [] # 多对多字段 fields_m2m = [ - 'groups', 'groups_display', 'system_roles', 'org_roles', - 'system_roles_display', 'org_roles_display' + "groups", + "groups_display", + "system_roles", + "org_roles", + "system_roles_display", + "org_roles_display", ] # 在serializer 上定义的字段 - fields_custom = ['login_blocked', 'password_strategy'] + fields_custom = ["login_blocked", "password_strategy"] fields = fields_verbose + fields_fk + fields_m2m + fields_custom read_only_fields = [ - 'date_joined', 'last_login', 'created_by', 'is_first_login', - 'wecom_id', 'dingtalk_id', 'feishu_id' + "date_joined", + "last_login", + "created_by", + "is_first_login", + "wecom_id", + "dingtalk_id", + "feishu_id", ] - disallow_self_update_fields = ['is_active'] + disallow_self_update_fields = ["is_active"] extra_kwargs = { - 'password': {'write_only': True, 'required': False, 'allow_null': True, 'allow_blank': True}, - 'public_key': {'write_only': True}, - 'is_first_login': {'label': _('Is first login'), 'read_only': True}, - 'is_active': {'label': _('Is active')}, - 'is_valid': {'label': _('Is valid')}, - 'is_service_account': {'label': _('Is service account')}, - 'is_expired': {'label': _('Is expired')}, - 'avatar_url': {'label': _('Avatar url')}, - 'created_by': {'read_only': True, 'allow_blank': True}, - 'groups_display': {'label': _('Groups name')}, - 'source_display': {'label': _('Source name')}, - 'org_role_display': {'label': _('Organization role name')}, - 'role_display': {'label': _('Super role name')}, - 'total_role_display': {'label': _('Total role name')}, - 'role': {'default': "User"}, - 'is_wecom_bound': {'label': _('Is wecom bound')}, - 'is_dingtalk_bound': {'label': _('Is dingtalk bound')}, - 'is_feishu_bound': {'label': _('Is feishu bound')}, - 'is_otp_secret_key_bound': {'label': _('Is OTP bound')}, - 'phone': {'validators': [PhoneValidator()]}, - 'system_role_display': {'label': _('System role name')}, + "password": { + "write_only": True, + "required": False, + "allow_null": True, + "allow_blank": True, + }, + "public_key": {"write_only": True}, + "is_first_login": {"label": _("Is first login"), "read_only": True}, + "is_active": {"label": _("Is active")}, + "is_valid": {"label": _("Is valid")}, + "is_service_account": {"label": _("Is service account")}, + "is_expired": {"label": _("Is expired")}, + "avatar_url": {"label": _("Avatar url")}, + "created_by": {"read_only": True, "allow_blank": True}, + "groups_display": {"label": _("Groups name")}, + "source_display": {"label": _("Source name")}, + "org_role_display": {"label": _("Organization role name")}, + "role_display": {"label": _("Super role name")}, + "total_role_display": {"label": _("Total role name")}, + "role": {"default": "User"}, + "is_wecom_bound": {"label": _("Is wecom bound")}, + "is_dingtalk_bound": {"label": _("Is dingtalk bound")}, + "is_feishu_bound": {"label": _("Is feishu bound")}, + "is_otp_secret_key_bound": {"label": _("Is OTP bound")}, + "phone": {"validators": [PhoneValidator()]}, + "system_role_display": {"label": _("System role name")}, } def validate_password(self, password): - password_strategy = self.initial_data.get('password_strategy') + password_strategy = self.initial_data.get("password_strategy") if self.instance is None and password_strategy != PasswordStrategy.custom: # 创建用户,使用邮件设置密码 return @@ -176,32 +231,34 @@ class UserSerializer(RolesSerializerMixin, CommonBulkSerializerMixin, serializer @staticmethod def change_password_to_raw(attrs): - password = attrs.pop('password', None) + password = attrs.pop("password", None) if password: - attrs['password_raw'] = password + attrs["password_raw"] = password return attrs @staticmethod def clean_auth_fields(attrs): - for field in ('password', 'public_key'): + for field in ("password", "public_key"): value = attrs.get(field) if not value: attrs.pop(field, None) return attrs def check_disallow_self_update_fields(self, attrs): - request = self.context.get('request') + request = self.context.get("request") if not request or not request.user.is_authenticated: return attrs if not self.instance: return attrs if request.user.id != self.instance.id: return attrs - disallow_fields = set(list(attrs.keys())) & set(self.Meta.disallow_self_update_fields) + disallow_fields = set(list(attrs.keys())) & set( + self.Meta.disallow_self_update_fields + ) if not disallow_fields: return attrs # 用户自己不能更新自己的一些字段 - logger.debug('Disallow update self fields: %s', disallow_fields) + logger.debug("Disallow update self fields: %s", disallow_fields) for field in disallow_fields: attrs.pop(field, None) return attrs @@ -210,7 +267,7 @@ class UserSerializer(RolesSerializerMixin, CommonBulkSerializerMixin, serializer attrs = self.check_disallow_self_update_fields(attrs) attrs = self.change_password_to_raw(attrs) attrs = self.clean_auth_fields(attrs) - attrs.pop('password_strategy', None) + attrs.pop("password_strategy", None) return attrs def save_and_set_custom_m2m_fields(self, validated_data, save_handler, created): @@ -219,8 +276,7 @@ class UserSerializer(RolesSerializerMixin, CommonBulkSerializerMixin, serializer roles = validated_data.pop(f, None) if created and not roles: roles = [ - Role.objects.filter(id=role.id).first() - for role in default_roles + Role.objects.filter(id=role.id).first() for role in default_roles ] m2m_values[f] = roles @@ -234,22 +290,26 @@ class UserSerializer(RolesSerializerMixin, CommonBulkSerializerMixin, serializer def update(self, instance, validated_data): save_handler = partial(super().update, instance) - instance = self.save_and_set_custom_m2m_fields(validated_data, save_handler, created=False) + instance = self.save_and_set_custom_m2m_fields( + validated_data, save_handler, created=False + ) return instance def create(self, validated_data): save_handler = super().create - instance = self.save_and_set_custom_m2m_fields(validated_data, save_handler, created=True) + instance = self.save_and_set_custom_m2m_fields( + validated_data, save_handler, created=True + ) return instance class UserRetrieveSerializer(UserSerializer): login_confirm_settings = serializers.PrimaryKeyRelatedField( - read_only=True, source='login_confirm_setting.reviewers', many=True + read_only=True, source="login_confirm_setting.reviewers", many=True ) class Meta(UserSerializer.Meta): - fields = UserSerializer.Meta.fields + ['login_confirm_settings'] + fields = UserSerializer.Meta.fields + ["login_confirm_settings"] class MiniUserSerializer(serializers.ModelSerializer): @@ -260,8 +320,10 @@ class MiniUserSerializer(serializers.ModelSerializer): class InviteSerializer(RolesSerializerMixin, serializers.Serializer): users = serializers.PrimaryKeyRelatedField( - queryset=User.get_nature_users(), many=True, label=_('Select users'), - help_text=_('For security, only list several users') + queryset=User.get_nature_users(), + many=True, + label=_("Select users"), + help_text=_("For security, only list several users"), ) system_roles = None system_roles_display = None @@ -271,22 +333,23 @@ class InviteSerializer(RolesSerializerMixin, serializers.Serializer): class ServiceAccountSerializer(serializers.ModelSerializer): class Meta: model = User - fields = ['id', 'name', 'access_key', 'comment'] - read_only_fields = ['access_key'] + fields = ["id", "name", "access_key", "comment"] + read_only_fields = ["access_key"] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) from authentication.serializers import AccessKeySerializer - self.fields['access_key'] = AccessKeySerializer(read_only=True) + + self.fields["access_key"] = AccessKeySerializer(read_only=True) def get_username(self): - return self.initial_data.get('name') + return self.initial_data.get("name") def get_email(self): - name = self.initial_data.get('name') + name = self.initial_data.get("name") name_max_length = 128 - len(User.service_account_email_suffix) - name = pretty_string(name, max_length=name_max_length, ellipsis_str='-') - return '{}{}'.format(name, User.service_account_email_suffix) + name = pretty_string(name, max_length=name_max_length, ellipsis_str="-") + return "{}{}".format(name, User.service_account_email_suffix) def validate_name(self, name): email = self.get_email() @@ -296,12 +359,12 @@ class ServiceAccountSerializer(serializers.ModelSerializer): else: users = User.objects.all() if users.filter(email=email) or users.filter(username=username): - raise serializers.ValidationError(_('name not unique'), code='unique') + raise serializers.ValidationError(_("name not unique"), code="unique") return name def create(self, validated_data): - name = validated_data['name'] + name = validated_data["name"] email = self.get_email() - comment = validated_data.get('comment', '') + comment = validated_data.get("comment", "") user, ak = User.create_service_account(name, email, comment) return user