Merge pull request #12565 from jumpserver/dev

v3.10.2
This commit is contained in:
Bryan 2024-01-17 07:23:30 -04:00 committed by GitHub
commit 8a9f0436b8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
85 changed files with 2305 additions and 1409 deletions

View File

@ -113,7 +113,7 @@ JumpServer是一款安全产品请参考 [基本安全建议](https://docs.ju
## License & Copyright ## License & Copyright
Copyright (c) 2014-2023 飞致云 FIT2CLOUD, All rights reserved. Copyright (c) 2014-2024 飞致云 FIT2CLOUD, All rights reserved.
Licensed under The GNU General Public License version 3 (GPLv3) (the "License"); you may not use this file except in Licensed under The GNU General Public License version 3 (GPLv3) (the "License"); you may not use this file except in
compliance with the License. You may obtain a copy of the License at compliance with the License. You may obtain a copy of the License at

View File

@ -145,9 +145,9 @@ class AccountBackupHandler:
wb = Workbook(filename) wb = Workbook(filename)
for sheet, data in data_map.items(): for sheet, data in data_map.items():
ws = wb.add_worksheet(str(sheet)) ws = wb.add_worksheet(str(sheet))
for row in data: for row_index, row_data in enumerate(data):
for col, _data in enumerate(row): for col_index, col_data in enumerate(row_data):
ws.write_string(0, col, _data) ws.write_string(row_index, col_index, col_data)
wb.close() wb.close()
files.append(filename) files.append(filename)
timedelta = round((time.time() - time_start), 2) timedelta = round((time.time() - time_start), 2)

View File

@ -161,7 +161,8 @@ class ChangeSecretManager(AccountBasePlaybookManager):
print("Account not found, deleted ?") print("Account not found, deleted ?")
return return
account.secret = recorder.new_secret account.secret = recorder.new_secret
account.save(update_fields=['secret']) account.date_updated = timezone.now()
account.save(update_fields=['secret', 'date_updated'])
def on_host_error(self, host, error, result): def on_host_error(self, host, error, result):
recorder = self.name_recorder_mapper.get(host) recorder = self.name_recorder_mapper.get(host)
@ -228,8 +229,8 @@ class ChangeSecretManager(AccountBasePlaybookManager):
rows.insert(0, header) rows.insert(0, header)
wb = Workbook(filename) wb = Workbook(filename)
ws = wb.add_worksheet('Sheet1') ws = wb.add_worksheet('Sheet1')
for row in rows: for row_index, row_data in enumerate(rows):
for col, data in enumerate(row): for col_index, col_data in enumerate(row_data):
ws.write_string(0, col, data) ws.write_string(row_index, col_index, col_data)
wb.close() wb.close()
return True return True

View File

@ -21,7 +21,8 @@ def on_account_pre_save(sender, instance, **kwargs):
if instance.version == 0: if instance.version == 0:
instance.version = 1 instance.version = 1
else: else:
instance.version = instance.history.count() history_account = instance.history.first()
instance.version = history_account.version + 1 if history_account else 0
@merge_delay_run(ttl=5) @merge_delay_run(ttl=5)

View File

@ -1,9 +1,19 @@
from celery import shared_task import uuid
from collections import defaultdict
from celery import shared_task, current_task
from django.conf import settings
from django.db.models import Count
from django.utils.translation import gettext_noop, gettext_lazy as _ from django.utils.translation import gettext_noop, gettext_lazy as _
from accounts.const import AutomationTypes from accounts.const import AutomationTypes
from accounts.models import Account
from accounts.tasks.common import quickstart_automation_by_snapshot from accounts.tasks.common import quickstart_automation_by_snapshot
from audits.const import ActivityChoices
from common.const.crontab import CRONTAB_AT_AM_TWO
from common.utils import get_logger from common.utils import get_logger
from ops.celery.decorator import register_as_period_task
from orgs.utils import tmp_to_root_org
logger = get_logger(__file__) logger = get_logger(__file__)
@ -29,3 +39,39 @@ def remove_accounts_task(gather_account_ids):
tp = AutomationTypes.remove_account tp = AutomationTypes.remove_account
quickstart_automation_by_snapshot(task_name, tp, task_snapshot) quickstart_automation_by_snapshot(task_name, tp, task_snapshot)
@shared_task(verbose_name=_('Clean historical accounts'))
@register_as_period_task(crontab=CRONTAB_AT_AM_TWO)
@tmp_to_root_org()
def clean_historical_accounts():
from audits.signal_handlers import create_activities
print("Clean historical accounts start.")
if settings.HISTORY_ACCOUNT_CLEAN_LIMIT >= 999:
return
limit = settings.HISTORY_ACCOUNT_CLEAN_LIMIT
history_ids_to_be_deleted = []
history_model = Account.history.model
history_id_mapper = defaultdict(list)
ids = history_model.objects.values('id').annotate(count=Count('id')) \
.filter(count__gte=limit).values_list('id', flat=True)
if not ids:
return
for i in history_model.objects.filter(id__in=ids):
_id = str(i.id)
history_id_mapper[_id].append(i.history_id)
for history_ids in history_id_mapper.values():
history_ids_to_be_deleted.extend(history_ids[limit:])
history_qs = history_model.objects.filter(history_id__in=history_ids_to_be_deleted)
resource_ids = list(history_qs.values_list('history_id', flat=True))
history_qs.delete()
task_id = current_task.request.id if current_task else str(uuid.uuid4())
detail = gettext_noop('Remove historical accounts that are out of range.')
create_activities(resource_ids, detail, task_id, action=ActivityChoices.task, org_id='')

View File

@ -21,7 +21,6 @@ from common.drf.filters import BaseFilterSet, AttrRulesFilterBackend
from common.utils import get_logger, is_uuid from common.utils import get_logger, is_uuid
from orgs.mixins import generics from orgs.mixins import generics
from orgs.mixins.api import OrgBulkModelViewSet from orgs.mixins.api import OrgBulkModelViewSet
from ..mixin import NodeFilterMixin
from ...notifications import BulkUpdatePlatformSkipAssetUserMsg from ...notifications import BulkUpdatePlatformSkipAssetUserMsg
logger = get_logger(__file__) logger = get_logger(__file__)
@ -86,7 +85,7 @@ class AssetFilterSet(BaseFilterSet):
return queryset.filter(protocols__name__in=value).distinct() return queryset.filter(protocols__name__in=value).distinct()
class AssetViewSet(SuggestionMixin, NodeFilterMixin, OrgBulkModelViewSet): class AssetViewSet(SuggestionMixin, OrgBulkModelViewSet):
""" """
API endpoint that allows Asset to be viewed or edited. API endpoint that allows Asset to be viewed or edited.
""" """
@ -114,9 +113,7 @@ class AssetViewSet(SuggestionMixin, NodeFilterMixin, OrgBulkModelViewSet):
] ]
def get_queryset(self): def get_queryset(self):
queryset = super().get_queryset() \ queryset = super().get_queryset()
.prefetch_related('nodes', 'protocols') \
.select_related('platform', 'domain')
if queryset.model is not Asset: if queryset.model is not Asset:
queryset = queryset.select_related('asset_ptr') queryset = queryset.select_related('asset_ptr')
return queryset return queryset

View File

@ -20,14 +20,15 @@ class DomainViewSet(OrgBulkModelViewSet):
filterset_fields = ("name",) filterset_fields = ("name",)
search_fields = filterset_fields search_fields = filterset_fields
ordering = ('name',) ordering = ('name',)
serializer_classes = {
'default': serializers.DomainSerializer,
'list': serializers.DomainListSerializer,
}
def get_serializer_class(self): def get_serializer_class(self):
if self.request.query_params.get('gateway'): if self.request.query_params.get('gateway'):
return serializers.DomainWithGatewaySerializer return serializers.DomainWithGatewaySerializer
return serializers.DomainSerializer return super().get_serializer_class()
def get_queryset(self):
return super().get_queryset().prefetch_related('assets')
class GatewayViewSet(HostViewSet): class GatewayViewSet(HostViewSet):

View File

@ -2,7 +2,7 @@ from typing import List
from rest_framework.request import Request from rest_framework.request import Request
from assets.models import Node, Protocol from assets.models import Node, Platform, Protocol
from assets.utils import get_node_from_request, is_query_node_all_assets from assets.utils import get_node_from_request, is_query_node_all_assets
from common.utils import lazyproperty, timeit from common.utils import lazyproperty, timeit
@ -71,37 +71,49 @@ class SerializeToTreeNodeMixin:
return 'file' return 'file'
@timeit @timeit
def serialize_assets(self, assets, node_key=None, pid=None): def serialize_assets(self, assets, node_key=None, get_pid=None):
if node_key is None: if not get_pid and not node_key:
get_pid = lambda asset: getattr(asset, 'parent_key', '') get_pid = lambda asset, platform: getattr(asset, 'parent_key', '')
else:
get_pid = lambda asset: node_key
sftp_asset_ids = Protocol.objects.filter(name='sftp') \ sftp_asset_ids = Protocol.objects.filter(name='sftp') \
.values_list('asset_id', flat=True) .values_list('asset_id', flat=True)
sftp_asset_ids = list(sftp_asset_ids) sftp_asset_ids = set(sftp_asset_ids)
data = [ platform_map = {p.id: p for p in Platform.objects.all()}
{
data = []
root_assets_count = 0
for asset in assets:
platform = platform_map.get(asset.platform_id)
if not platform:
continue
pid = node_key or get_pid(asset, platform)
if not pid:
continue
# 根节点最多显示 1000 个资产
if pid.isdigit():
if root_assets_count > 1000:
continue
root_assets_count += 1
data.append({
'id': str(asset.id), 'id': str(asset.id),
'name': asset.name, 'name': asset.name,
'title': f'{asset.address}\n{asset.comment}', 'title': f'{asset.address}\n{asset.comment}'.strip(),
'pId': pid or get_pid(asset), 'pId': pid,
'isParent': False, 'isParent': False,
'open': False, 'open': False,
'iconSkin': self.get_icon(asset), 'iconSkin': self.get_icon(platform),
'chkDisabled': not asset.is_active, 'chkDisabled': not asset.is_active,
'meta': { 'meta': {
'type': 'asset', 'type': 'asset',
'data': { 'data': {
'platform_type': asset.platform.type, 'platform_type': platform.type,
'org_name': asset.org_name, 'org_name': asset.org_name,
'sftp': asset.id in sftp_asset_ids, 'sftp': asset.id in sftp_asset_ids,
'name': asset.name, 'name': asset.name,
'address': asset.address 'address': asset.address
}, },
} }
} })
for asset in assets
]
return data return data

View File

@ -29,7 +29,10 @@ class AssetPlatformViewSet(JMSModelViewSet):
} }
def get_queryset(self): def get_queryset(self):
queryset = super().get_queryset() # 因为没有走分页逻辑,所以需要这里 prefetch
queryset = super().get_queryset().prefetch_related(
'protocols', 'automation', 'labels', 'labels__label',
)
queryset = queryset.filter(type__in=AllTypes.get_types_values()) queryset = queryset.filter(type__in=AllTypes.get_types_values())
return queryset return queryset

View File

@ -126,6 +126,8 @@ class NodeChildrenAsTreeApi(SerializeToTreeNodeMixin, NodeChildrenApi):
include_assets = self.request.query_params.get('assets', '0') == '1' include_assets = self.request.query_params.get('assets', '0') == '1'
if not self.instance or not include_assets: if not self.instance or not include_assets:
return Asset.objects.none() return Asset.objects.none()
if self.instance.is_org_root():
return Asset.objects.none()
if query_all: if query_all:
assets = self.instance.get_all_assets() assets = self.instance.get_all_assets()
else: else:

View File

@ -268,7 +268,7 @@ class AllTypes(ChoicesMixin):
meta = {'type': 'category', 'category': category.value, '_type': category.value} meta = {'type': 'category', 'category': category.value, '_type': category.value}
category_node = cls.choice_to_node(category, 'ROOT', meta=meta) category_node = cls.choice_to_node(category, 'ROOT', meta=meta)
category_count = category_type_mapper.get(category, 0) category_count = category_type_mapper.get(category, 0)
category_node['name'] += f'({category_count})' category_node['name'] += f' ({category_count})'
nodes.append(category_node) nodes.append(category_node)
# Type 格式化 # Type 格式化
@ -277,7 +277,7 @@ class AllTypes(ChoicesMixin):
meta = {'type': 'type', 'category': category.value, '_type': tp.value} meta = {'type': 'type', 'category': category.value, '_type': tp.value}
tp_node = cls.choice_to_node(tp, category_node['id'], opened=False, meta=meta) tp_node = cls.choice_to_node(tp, category_node['id'], opened=False, meta=meta)
tp_count = category_type_mapper.get(category + '_' + tp, 0) tp_count = category_type_mapper.get(category + '_' + tp, 0)
tp_node['name'] += f'({tp_count})' tp_node['name'] += f' ({tp_count})'
platforms = tp_platforms.get(category + '_' + tp, []) platforms = tp_platforms.get(category + '_' + tp, [])
if not platforms: if not platforms:
tp_node['isParent'] = False tp_node['isParent'] = False
@ -286,7 +286,7 @@ class AllTypes(ChoicesMixin):
# Platform 格式化 # Platform 格式化
for p in platforms: for p in platforms:
platform_node = cls.platform_to_node(p, tp_node['id'], include_asset) platform_node = cls.platform_to_node(p, tp_node['id'], include_asset)
platform_node['name'] += f'({platform_count.get(p.id, 0)})' platform_node['name'] += f' ({platform_count.get(p.id, 0)})'
nodes.append(platform_node) nodes.append(platform_node)
return nodes return nodes

View File

@ -63,11 +63,10 @@ class NodeFilterBackend(filters.BaseFilterBackend):
query_all = is_query_node_all_assets(request) query_all = is_query_node_all_assets(request)
if query_all: if query_all:
return queryset.filter( return queryset.filter(
Q(nodes__key__istartswith=f'{node.key}:') | Q(nodes__key__startswith=f'{node.key}:') |
Q(nodes__key=node.key) Q(nodes__key=node.key)
).distinct() ).distinct()
else: else:
print("Query query origin: ", queryset.count())
return queryset.filter(nodes__key=node.key).distinct() return queryset.filter(nodes__key=node.key).distinct()

View File

@ -13,7 +13,7 @@ from django.db.transaction import atomic
from django.utils.translation import gettext_lazy as _, gettext from django.utils.translation import gettext_lazy as _, gettext
from common.db.models import output_as_string from common.db.models import output_as_string
from common.utils import get_logger from common.utils import get_logger, timeit
from common.utils.lock import DistributedLock from common.utils.lock import DistributedLock
from orgs.mixins.models import OrgManager, JMSOrgBaseModel from orgs.mixins.models import OrgManager, JMSOrgBaseModel
from orgs.models import Organization from orgs.models import Organization
@ -195,11 +195,6 @@ class FamilyMixin:
ancestor_keys = self.get_ancestor_keys(with_self=with_self) ancestor_keys = self.get_ancestor_keys(with_self=with_self)
return self.__class__.objects.filter(key__in=ancestor_keys) return self.__class__.objects.filter(key__in=ancestor_keys)
# @property
# def parent_key(self):
# parent_key = ":".join(self.key.split(":")[:-1])
# return parent_key
def compute_parent_key(self): def compute_parent_key(self):
return compute_parent_key(self.key) return compute_parent_key(self.key)
@ -349,29 +344,26 @@ class NodeAllAssetsMappingMixin:
return 'ASSETS_ORG_NODE_ALL_ASSET_ids_MAPPING_{}'.format(org_id) return 'ASSETS_ORG_NODE_ALL_ASSET_ids_MAPPING_{}'.format(org_id)
@classmethod @classmethod
@timeit
def generate_node_all_asset_ids_mapping(cls, org_id): def generate_node_all_asset_ids_mapping(cls, org_id):
from .asset import Asset logger.info(f'Generate node asset mapping: org_id={org_id}')
logger.info(f'Generate node asset mapping: '
f'thread={threading.get_ident()} '
f'org_id={org_id}')
t1 = time.time() t1 = time.time()
with tmp_to_org(org_id): with tmp_to_org(org_id):
node_ids_key = Node.objects.annotate( node_ids_key = Node.objects.annotate(
char_id=output_as_string('id') char_id=output_as_string('id')
).values_list('char_id', 'key') ).values_list('char_id', 'key')
# * 直接取出全部. filter(node__org_id=org_id)(大规模下会更慢)
nodes_asset_ids = Asset.nodes.through.objects.all() \
.annotate(char_node_id=output_as_string('node_id')) \
.annotate(char_asset_id=output_as_string('asset_id')) \
.values_list('char_node_id', 'char_asset_id')
node_id_ancestor_keys_mapping = { node_id_ancestor_keys_mapping = {
node_id: cls.get_node_ancestor_keys(node_key, with_self=True) node_id: cls.get_node_ancestor_keys(node_key, with_self=True)
for node_id, node_key in node_ids_key for node_id, node_key in node_ids_key
} }
# * 直接取出全部. filter(node__org_id=org_id)(大规模下会更慢)
nodes_asset_ids = cls.assets.through.objects.all() \
.annotate(char_node_id=output_as_string('node_id')) \
.annotate(char_asset_id=output_as_string('asset_id')) \
.values_list('char_node_id', 'char_asset_id')
nodeid_assetsid_mapping = defaultdict(set) nodeid_assetsid_mapping = defaultdict(set)
for node_id, asset_id in nodes_asset_ids: for node_id, asset_id in nodes_asset_ids:
nodeid_assetsid_mapping[node_id].add(asset_id) nodeid_assetsid_mapping[node_id].add(asset_id)
@ -386,7 +378,7 @@ class NodeAllAssetsMappingMixin:
mapping[ancestor_key].update(asset_ids) mapping[ancestor_key].update(asset_ids)
t3 = time.time() t3 = time.time()
logger.info('t1-t2(DB Query): {} s, t3-t2(Generate mapping): {} s'.format(t2 - t1, t3 - t2)) logger.info('Generate asset nodes mapping, DB query: {:.2f}s, mapping: {:.2f}s'.format(t2 - t1, t3 - t2))
return mapping return mapping
@ -436,6 +428,7 @@ class NodeAssetsMixin(NodeAllAssetsMappingMixin):
return asset_ids return asset_ids
@classmethod @classmethod
@timeit
def get_nodes_all_assets(cls, *nodes): def get_nodes_all_assets(cls, *nodes):
from .asset import Asset from .asset import Asset
node_ids = set() node_ids = set()
@ -559,11 +552,6 @@ class Node(JMSOrgBaseModel, SomeNodesMixin, FamilyMixin, NodeAssetsMixin):
def __str__(self): def __str__(self):
return self.full_value return self.full_value
# def __eq__(self, other):
# if not other:
# return False
# return self.id == other.id
#
def __gt__(self, other): def __gt__(self, other):
self_key = [int(k) for k in self.key.split(':')] self_key = [int(k) for k in self.key.split(':')]
other_key = [int(k) for k in other.key.split(':')] other_key = [int(k) for k in other.key.split(':')]

View File

@ -1,8 +1,8 @@
from rest_framework.pagination import LimitOffsetPagination from rest_framework.pagination import LimitOffsetPagination
from rest_framework.request import Request from rest_framework.request import Request
from common.utils import get_logger
from assets.models import Node from assets.models import Node
from common.utils import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)
@ -28,6 +28,7 @@ class AssetPaginationBase(LimitOffsetPagination):
'key', 'all', 'show_current_asset', 'key', 'all', 'show_current_asset',
'cache_policy', 'display', 'draw', 'cache_policy', 'display', 'draw',
'order', 'node', 'node_id', 'fields_size', 'order', 'node', 'node_id', 'fields_size',
'asset'
} }
for k, v in self._request.query_params.items(): for k, v in self._request.query_params.items():
if k not in exclude_query_params and v is not None: if k not in exclude_query_params and v is not None:

View File

@ -100,7 +100,10 @@ class AssetAccountSerializer(AccountSerializer):
class Meta(AccountSerializer.Meta): class Meta(AccountSerializer.Meta):
fields = [ fields = [
f for f in AccountSerializer.Meta.fields f for f in AccountSerializer.Meta.fields
if f not in ['spec_info'] if f not in [
'spec_info', 'connectivity', 'labels', 'created_by',
'date_update', 'date_created'
]
] ]
extra_kwargs = { extra_kwargs = {
**AccountSerializer.Meta.extra_kwargs, **AccountSerializer.Meta.extra_kwargs,
@ -203,9 +206,12 @@ class AssetSerializer(BulkOrgResourceModelSerializer, ResourceLabelsMixin, Writa
""" Perform necessary eager loading of data. """ """ Perform necessary eager loading of data. """
queryset = queryset.prefetch_related('domain', 'nodes', 'protocols', ) \ queryset = queryset.prefetch_related('domain', 'nodes', 'protocols', ) \
.prefetch_related('platform', 'platform__automation') \ .prefetch_related('platform', 'platform__automation') \
.prefetch_related('labels', 'labels__label') \
.annotate(category=F("platform__category")) \ .annotate(category=F("platform__category")) \
.annotate(type=F("platform__type")) .annotate(type=F("platform__type"))
if queryset.model is Asset:
queryset = queryset.prefetch_related('labels__label', 'labels')
else:
queryset = queryset.prefetch_related('asset_ptr__labels__label', 'asset_ptr__labels')
return queryset return queryset
@staticmethod @staticmethod
@ -375,7 +381,6 @@ class AssetSerializer(BulkOrgResourceModelSerializer, ResourceLabelsMixin, Writa
class DetailMixin(serializers.Serializer): class DetailMixin(serializers.Serializer):
accounts = AssetAccountSerializer(many=True, required=False, label=_('Accounts'))
spec_info = MethodSerializer(label=_('Spec info'), read_only=True) spec_info = MethodSerializer(label=_('Spec info'), read_only=True)
gathered_info = MethodSerializer(label=_('Gathered info'), read_only=True) gathered_info = MethodSerializer(label=_('Gathered info'), read_only=True)
auto_config = serializers.DictField(read_only=True, label=_('Auto info')) auto_config = serializers.DictField(read_only=True, label=_('Auto info'))
@ -390,8 +395,7 @@ class DetailMixin(serializers.Serializer):
def get_field_names(self, declared_fields, info): def get_field_names(self, declared_fields, info):
names = super().get_field_names(declared_fields, info) names = super().get_field_names(declared_fields, info)
names.extend([ names.extend([
'accounts', 'gathered_info', 'spec_info', 'gathered_info', 'spec_info', 'auto_config',
'auto_config',
]) ])
return names return names

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
from django.db.models import Count
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from rest_framework import serializers from rest_framework import serializers
@ -7,18 +8,15 @@ from common.serializers import ResourceLabelsMixin
from common.serializers.fields import ObjectRelatedField from common.serializers.fields import ObjectRelatedField
from orgs.mixins.serializers import BulkOrgResourceModelSerializer from orgs.mixins.serializers import BulkOrgResourceModelSerializer
from .gateway import GatewayWithAccountSecretSerializer from .gateway import GatewayWithAccountSecretSerializer
from ..models import Domain, Asset from ..models import Domain
__all__ = ['DomainSerializer', 'DomainWithGatewaySerializer'] __all__ = ['DomainSerializer', 'DomainWithGatewaySerializer', 'DomainListSerializer']
class DomainSerializer(ResourceLabelsMixin, BulkOrgResourceModelSerializer): class DomainSerializer(ResourceLabelsMixin, BulkOrgResourceModelSerializer):
gateways = ObjectRelatedField( gateways = ObjectRelatedField(
many=True, required=False, label=_('Gateway'), read_only=True, many=True, required=False, label=_('Gateway'), read_only=True,
) )
assets = ObjectRelatedField(
many=True, required=False, queryset=Asset.objects, label=_('Asset')
)
class Meta: class Meta:
model = Domain model = Domain
@ -30,7 +28,9 @@ class DomainSerializer(ResourceLabelsMixin, BulkOrgResourceModelSerializer):
def to_representation(self, instance): def to_representation(self, instance):
data = super().to_representation(instance) data = super().to_representation(instance)
assets = data['assets'] assets = data.get('assets')
if assets is None:
return data
gateway_ids = [str(i['id']) for i in data['gateways']] gateway_ids = [str(i['id']) for i in data['gateways']]
data['assets'] = [i for i in assets if str(i['id']) not in gateway_ids] data['assets'] = [i for i in assets if str(i['id']) not in gateway_ids]
return data return data
@ -49,6 +49,20 @@ class DomainSerializer(ResourceLabelsMixin, BulkOrgResourceModelSerializer):
return queryset return queryset
class DomainListSerializer(DomainSerializer):
assets_amount = serializers.IntegerField(label=_('Assets amount'), read_only=True)
class Meta(DomainSerializer.Meta):
fields = list(set(DomainSerializer.Meta.fields + ['assets_amount']) - {'assets'})
@classmethod
def setup_eager_loading(cls, queryset):
queryset = queryset.annotate(
assets_amount=Count('assets'),
)
return queryset
class DomainWithGatewaySerializer(serializers.ModelSerializer): class DomainWithGatewaySerializer(serializers.ModelSerializer):
gateways = GatewayWithAccountSecretSerializer(many=True, read_only=True) gateways = GatewayWithAccountSecretSerializer(many=True, read_only=True)

View File

@ -191,7 +191,6 @@ class PlatformSerializer(ResourceLabelsMixin, WritableNestedModelSerializer):
def add_type_choices(self, name, label): def add_type_choices(self, name, label):
tp = self.fields['type'] tp = self.fields['type']
tp.choices[name] = label tp.choices[name] = label
tp.choice_mapper[name] = label
tp.choice_strings_to_values[name] = label tp.choice_strings_to_values[name] = label
@lazyproperty @lazyproperty
@ -200,12 +199,6 @@ class PlatformSerializer(ResourceLabelsMixin, WritableNestedModelSerializer):
constraints = AllTypes.get_constraints(category, tp) constraints = AllTypes.get_constraints(category, tp)
return constraints return constraints
@classmethod
def setup_eager_loading(cls, queryset):
queryset = queryset.prefetch_related('protocols', 'automation') \
.prefetch_related('labels', 'labels__label')
return queryset
def validate_protocols(self, protocols): def validate_protocols(self, protocols):
if not protocols: if not protocols:
raise serializers.ValidationError(_("Protocols is required")) raise serializers.ValidationError(_("Protocols is required"))

View File

@ -80,10 +80,11 @@ RELATED_NODE_IDS = '_related_node_ids'
@receiver(pre_delete, sender=Asset) @receiver(pre_delete, sender=Asset)
def on_asset_delete(instance: Asset, using, **kwargs): def on_asset_delete(instance: Asset, using, **kwargs):
logger.debug("Asset pre delete signal recv: {}".format(instance))
node_ids = Node.objects.filter(assets=instance) \ node_ids = Node.objects.filter(assets=instance) \
.distinct().values_list('id', flat=True) .distinct().values_list('id', flat=True)
setattr(instance, RELATED_NODE_IDS, node_ids) node_ids = list(node_ids)
logger.debug("Asset pre delete signal recv: {}, node_ids: {}".format(instance, node_ids))
setattr(instance, RELATED_NODE_IDS, list(node_ids))
m2m_changed.send( m2m_changed.send(
sender=Asset.nodes.through, instance=instance, sender=Asset.nodes.through, instance=instance,
reverse=False, model=Node, pk_set=node_ids, reverse=False, model=Node, pk_set=node_ids,
@ -93,8 +94,8 @@ def on_asset_delete(instance: Asset, using, **kwargs):
@receiver(post_delete, sender=Asset) @receiver(post_delete, sender=Asset)
def on_asset_post_delete(instance: Asset, using, **kwargs): def on_asset_post_delete(instance: Asset, using, **kwargs):
logger.debug("Asset post delete signal recv: {}".format(instance))
node_ids = getattr(instance, RELATED_NODE_IDS, []) node_ids = getattr(instance, RELATED_NODE_IDS, [])
logger.debug("Asset post delete signal recv: {}, node_ids: {}".format(instance, node_ids))
if node_ids: if node_ids:
m2m_changed.send( m2m_changed.send(
sender=Asset.nodes.through, instance=instance, reverse=False, sender=Asset.nodes.through, instance=instance, reverse=False,

View File

@ -15,8 +15,8 @@ from ..tasks import check_node_assets_amount_task
logger = get_logger(__file__) logger = get_logger(__file__)
@on_transaction_commit
@receiver(m2m_changed, sender=Asset.nodes.through) @receiver(m2m_changed, sender=Asset.nodes.through)
@on_transaction_commit
def on_node_asset_change(sender, action, instance, reverse, pk_set, **kwargs): def on_node_asset_change(sender, action, instance, reverse, pk_set, **kwargs):
# 不允许 `pre_clear` ,因为该信号没有 `pk_set` # 不允许 `pre_clear` ,因为该信号没有 `pk_set`
# [官网](https://docs.djangoproject.com/en/3.1/ref/signals/#m2m-changed) # [官网](https://docs.djangoproject.com/en/3.1/ref/signals/#m2m-changed)
@ -37,7 +37,7 @@ def on_node_asset_change(sender, action, instance, reverse, pk_set, **kwargs):
update_nodes_assets_amount(node_ids=node_ids) update_nodes_assets_amount(node_ids=node_ids)
@merge_delay_run(ttl=5) @merge_delay_run(ttl=30)
def update_nodes_assets_amount(node_ids=()): def update_nodes_assets_amount(node_ids=()):
nodes = Node.objects.filter(id__in=node_ids) nodes = Node.objects.filter(id__in=node_ids)
nodes = Node.get_ancestor_queryset(nodes) nodes = Node.get_ancestor_queryset(nodes)

View File

@ -21,7 +21,7 @@ logger = get_logger(__name__)
node_assets_mapping_pub_sub = lazy(lambda: RedisPubSub('fm.node_asset_mapping'), RedisPubSub)() node_assets_mapping_pub_sub = lazy(lambda: RedisPubSub('fm.node_asset_mapping'), RedisPubSub)()
@merge_delay_run(ttl=5) @merge_delay_run(ttl=30)
def expire_node_assets_mapping(org_ids=()): def expire_node_assets_mapping(org_ids=()):
logger.debug("Recv asset nodes changed signal, expire memery node asset mapping") logger.debug("Recv asset nodes changed signal, expire memery node asset mapping")
# 所有进程清除(自己的 memory 数据) # 所有进程清除(自己的 memory 数据)
@ -53,8 +53,9 @@ def on_node_post_delete(sender, instance, **kwargs):
@receiver(m2m_changed, sender=Asset.nodes.through) @receiver(m2m_changed, sender=Asset.nodes.through)
def on_node_asset_change(sender, instance, **kwargs): def on_node_asset_change(sender, instance, action='pre_remove', **kwargs):
expire_node_assets_mapping(org_ids=(instance.org_id,)) if action.startswith('post'):
expire_node_assets_mapping(org_ids=(instance.org_id,))
@receiver(django_ready) @receiver(django_ready)

View File

@ -2,6 +2,7 @@
from django.urls import path from django.urls import path
from rest_framework_bulk.routes import BulkRouter from rest_framework_bulk.routes import BulkRouter
from labels.api import LabelViewSet
from .. import api from .. import api
app_name = 'assets' app_name = 'assets'
@ -22,6 +23,7 @@ router.register(r'domains', api.DomainViewSet, 'domain')
router.register(r'gateways', api.GatewayViewSet, 'gateway') router.register(r'gateways', api.GatewayViewSet, 'gateway')
router.register(r'favorite-assets', api.FavoriteAssetViewSet, 'favorite-asset') router.register(r'favorite-assets', api.FavoriteAssetViewSet, 'favorite-asset')
router.register(r'protocol-settings', api.PlatformProtocolViewSet, 'protocol-setting') router.register(r'protocol-settings', api.PlatformProtocolViewSet, 'protocol-setting')
router.register(r'labels', LabelViewSet, 'label')
urlpatterns = [ urlpatterns = [
# path('assets/<uuid:pk>/gateways/', api.AssetGatewayListApi.as_view(), name='asset-gateway-list'), # path('assets/<uuid:pk>/gateways/', api.AssetGatewayListApi.as_view(), name='asset-gateway-list'),

View File

@ -4,7 +4,6 @@ from urllib.parse import urlencode, urlparse
from kubernetes import client from kubernetes import client
from kubernetes.client import api_client from kubernetes.client import api_client
from kubernetes.client.api import core_v1_api from kubernetes.client.api import core_v1_api
from kubernetes.client.exceptions import ApiException
from sshtunnel import SSHTunnelForwarder, BaseSSHTunnelForwarderError from sshtunnel import SSHTunnelForwarder, BaseSSHTunnelForwarderError
from common.utils import get_logger from common.utils import get_logger
@ -88,8 +87,9 @@ class KubernetesClient:
if hasattr(self, func_name): if hasattr(self, func_name):
try: try:
data = getattr(self, func_name)(*args) data = getattr(self, func_name)(*args)
except ApiException as e: except Exception as e:
logger.error(e.reason) logger.error(e)
raise e
if self.server: if self.server:
self.server.stop() self.server.stop()

View File

@ -5,6 +5,7 @@ from importlib import import_module
from django.conf import settings from django.conf import settings
from django.db.models import F, Value, CharField, Q from django.db.models import F, Value, CharField, Q
from django.db.models.functions import Cast
from django.http import HttpResponse, FileResponse from django.http import HttpResponse, FileResponse
from django.utils.encoding import escape_uri_path from django.utils.encoding import escape_uri_path
from rest_framework import generics from rest_framework import generics
@ -40,6 +41,7 @@ from .serializers import (
PasswordChangeLogSerializer, ActivityUnionLogSerializer, PasswordChangeLogSerializer, ActivityUnionLogSerializer,
FileSerializer, UserSessionSerializer FileSerializer, UserSessionSerializer
) )
from .utils import construct_userlogin_usernames
logger = get_logger(__name__) logger = get_logger(__name__)
@ -125,15 +127,16 @@ class UserLoginCommonMixin:
class UserLoginLogViewSet(UserLoginCommonMixin, OrgReadonlyModelViewSet): class UserLoginLogViewSet(UserLoginCommonMixin, OrgReadonlyModelViewSet):
@staticmethod @staticmethod
def get_org_members(): def get_org_member_usernames():
users = current_org.get_members().values_list('username', flat=True) user_queryset = current_org.get_members()
users = construct_userlogin_usernames(user_queryset)
return users return users
def get_queryset(self): def get_queryset(self):
queryset = super().get_queryset() queryset = super().get_queryset()
if current_org.is_root(): if current_org.is_root():
return queryset return queryset
users = self.get_org_members() users = self.get_org_member_usernames()
queryset = queryset.filter(username__in=users) queryset = queryset.filter(username__in=users)
return queryset return queryset
@ -163,7 +166,7 @@ class ResourceActivityAPIView(generics.ListAPIView):
q |= Q(user=str(user)) q |= Q(user=str(user))
queryset = OperateLog.objects.filter(q, org_q).annotate( queryset = OperateLog.objects.filter(q, org_q).annotate(
r_type=Value(ActivityChoices.operate_log, CharField()), r_type=Value(ActivityChoices.operate_log, CharField()),
r_detail_id=F('id'), r_detail=Value(None, CharField()), r_detail_id=Cast(F('id'), CharField()), r_detail=Value(None, CharField()),
r_user=F('user'), r_action=F('action'), r_user=F('user'), r_action=F('action'),
).values(*fields)[:limit] ).values(*fields)[:limit]
return queryset return queryset

View File

@ -4,6 +4,8 @@ from itertools import chain
from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation
from django.db import models from django.db import models
from django.db.models import F, Value, CharField
from django.db.models.functions import Concat
from common.db.fields import RelatedManager from common.db.fields import RelatedManager
from common.utils import validate_ip, get_ip_city, get_logger from common.utils import validate_ip, get_ip_city, get_logger
@ -115,3 +117,12 @@ def model_to_dict_for_operate_log(
get_related_values(f) get_related_values(f)
return data return data
def construct_userlogin_usernames(user_queryset):
usernames_original = user_queryset.values_list('username', flat=True)
usernames_combined = user_queryset.annotate(
usernames_combined_field=Concat(F('name'), Value('('), F('username'), Value(')'), output_field=CharField())
).values_list("usernames_combined_field", flat=True)
usernames = list(chain(usernames_original, usernames_combined))
return usernames

View File

@ -90,6 +90,6 @@ class MFAChallengeVerifyApi(AuthMixin, CreateAPIView):
return Response({'msg': 'ok'}) return Response({'msg': 'ok'})
except errors.AuthFailedError as e: except errors.AuthFailedError as e:
data = {"error": e.error, "msg": e.msg} data = {"error": e.error, "msg": e.msg}
raise ValidationError(data) return Response(data, status=401)
except errors.NeedMoreInfoError as e: except errors.NeedMoreInfoError as e:
return Response(e.as_data(), status=200) return Response(e.as_data(), status=200)

View File

@ -15,12 +15,11 @@ from authentication.mixins import authenticate
from authentication.serializers import ( from authentication.serializers import (
PasswordVerifySerializer, ResetPasswordCodeSerializer PasswordVerifySerializer, ResetPasswordCodeSerializer
) )
from authentication.utils import check_user_property_is_correct
from common.permissions import IsValidUser from common.permissions import IsValidUser
from common.utils import get_object_or_none
from common.utils.random import random_string from common.utils.random import random_string
from common.utils.verify_code import SendAndVerifyCodeUtil from common.utils.verify_code import SendAndVerifyCodeUtil
from settings.utils import get_login_title from settings.utils import get_login_title
from users.models import User
class UserResetPasswordSendCodeApi(CreateAPIView): class UserResetPasswordSendCodeApi(CreateAPIView):
@ -28,8 +27,8 @@ class UserResetPasswordSendCodeApi(CreateAPIView):
serializer_class = ResetPasswordCodeSerializer serializer_class = ResetPasswordCodeSerializer
@staticmethod @staticmethod
def is_valid_user(**kwargs): def is_valid_user(username, **properties):
user = get_object_or_none(User, **kwargs) user = check_user_property_is_correct(username, **properties)
if not user: if not user:
err_msg = _('User does not exist: {}').format(_("No user matched")) err_msg = _('User does not exist: {}').format(_("No user matched"))
return None, err_msg return None, err_msg
@ -56,7 +55,6 @@ class UserResetPasswordSendCodeApi(CreateAPIView):
target = serializer.validated_data[form_type] target = serializer.validated_data[form_type]
if form_type == 'sms': if form_type == 'sms':
query_key = 'phone' query_key = 'phone'
target = target.lstrip('+')
else: else:
query_key = form_type query_key = form_type
user, err = self.is_valid_user(username=username, **{query_key: target}) user, err = self.is_valid_user(username=username, **{query_key: target})

View File

@ -7,8 +7,9 @@ from django.conf import settings
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from audits.const import DEFAULT_CITY from audits.const import DEFAULT_CITY
from users.models import User
from audits.models import UserLoginLog from audits.models import UserLoginLog
from common.utils import get_logger from common.utils import get_logger, get_object_or_none
from common.utils import validate_ip, get_ip_city, get_request_ip from common.utils import validate_ip, get_ip_city, get_request_ip
from .notifications import DifferentCityLoginMessage from .notifications import DifferentCityLoginMessage
@ -24,9 +25,10 @@ def check_different_city_login_if_need(user, request):
is_private = ipaddress.ip_address(ip).is_private is_private = ipaddress.ip_address(ip).is_private
if is_private: if is_private:
return return
usernames = [user.username, f"{user.name}({user.username})"]
last_user_login = UserLoginLog.objects.exclude( last_user_login = UserLoginLog.objects.exclude(
city__in=city_white city__in=city_white
).filter(username=user.username, status=True).first() ).filter(username__in=usernames, status=True).first()
if not last_user_login: if not last_user_login:
return return
@ -59,3 +61,12 @@ def build_absolute_uri_for_oidc(request, path=None):
redirect_uri = urljoin(settings.BASE_SITE_URL, path) redirect_uri = urljoin(settings.BASE_SITE_URL, path)
return redirect_uri return redirect_uri
return build_absolute_uri(request, path=path) return build_absolute_uri(request, path=path)
def check_user_property_is_correct(username, **properties):
user = get_object_or_none(User, username=username)
for attr, value in properties.items():
if getattr(user, attr, None) != value:
user = None
break
return user

View File

@ -98,12 +98,19 @@ class QuerySetMixin:
return queryset return queryset
if self.action == 'metadata': if self.action == 'metadata':
queryset = queryset.none() queryset = queryset.none()
if self.action in ['list', 'metadata']:
serializer_class = self.get_serializer_class()
if serializer_class and hasattr(serializer_class, 'setup_eager_loading'):
queryset = serializer_class.setup_eager_loading(queryset)
return queryset return queryset
def paginate_queryset(self, queryset):
page = super().paginate_queryset(queryset)
serializer_class = self.get_serializer_class()
if page and serializer_class and hasattr(serializer_class, 'setup_eager_loading'):
ids = [str(obj.id) for obj in page]
page = self.get_queryset().filter(id__in=ids)
page = serializer_class.setup_eager_loading(page)
page_mapper = {str(obj.id): obj for obj in page}
page = [page_mapper.get(_id) for _id in ids if _id in page_mapper]
return page
class ExtraFilterFieldsMixin: class ExtraFilterFieldsMixin:
""" """

View File

@ -65,7 +65,7 @@ class EventLoopThread(threading.Thread):
_loop_thread = EventLoopThread() _loop_thread = EventLoopThread()
_loop_thread.setDaemon(True) _loop_thread.daemon = True
_loop_thread.start() _loop_thread.start()
executor = ThreadPoolExecutor( executor = ThreadPoolExecutor(
max_workers=10, max_workers=10,

View File

@ -219,11 +219,11 @@ class LabelFilterBackend(filters.BaseFilterBackend):
if not hasattr(queryset, 'model'): if not hasattr(queryset, 'model'):
return queryset return queryset
if not hasattr(queryset.model, 'labels'): if not hasattr(queryset.model, 'label_model'):
return queryset return queryset
model = queryset.model model = queryset.model.label_model()
labeled_resource_cls = model._labels.field.related_model labeled_resource_cls = model.labels.field.related_model
app_label = model._meta.app_label app_label = model._meta.app_label
model_name = model._meta.model_name model_name = model._meta.model_name

View File

@ -14,6 +14,7 @@ class CeleryBaseService(BaseService):
print('\n- Start Celery as Distributed Task Queue: {}'.format(self.queue.capitalize())) print('\n- Start Celery as Distributed Task Queue: {}'.format(self.queue.capitalize()))
ansible_config_path = os.path.join(settings.APPS_DIR, 'ops', 'ansible', 'ansible.cfg') ansible_config_path = os.path.join(settings.APPS_DIR, 'ops', 'ansible', 'ansible.cfg')
ansible_modules_path = os.path.join(settings.APPS_DIR, 'ops', 'ansible', 'modules') ansible_modules_path = os.path.join(settings.APPS_DIR, 'ops', 'ansible', 'modules')
os.environ.setdefault('LC_ALL', 'C.UTF-8')
os.environ.setdefault('PYTHONOPTIMIZE', '1') os.environ.setdefault('PYTHONOPTIMIZE', '1')
os.environ.setdefault('ANSIBLE_FORCE_COLOR', 'True') os.environ.setdefault('ANSIBLE_FORCE_COLOR', 'True')
os.environ.setdefault('ANSIBLE_CONFIG', ansible_config_path) os.environ.setdefault('ANSIBLE_CONFIG', ansible_config_path)

View File

@ -1,16 +1,14 @@
import requests
import mistune import mistune
import requests
from rest_framework.exceptions import APIException
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from rest_framework.exceptions import APIException
from users.utils import construct_user_email
from common.utils.common import get_logger from common.utils.common import get_logger
from jumpserver.utils import get_current_request from jumpserver.utils import get_current_request
from users.utils import construct_user_email
logger = get_logger(__name__) logger = get_logger(__name__)
SLACK_REDIRECT_URI_SESSION_KEY = '_slack_redirect_uri' SLACK_REDIRECT_URI_SESSION_KEY = '_slack_redirect_uri'
@ -22,15 +20,15 @@ class URL:
AUTH_TEST = 'https://slack.com/api/auth.test' AUTH_TEST = 'https://slack.com/api/auth.test'
class SlackRenderer(mistune.Renderer): class SlackRenderer(mistune.HTMLRenderer):
def header(self, text, level, raw=None): def heading(self, text, level):
return '*' + text + '*\n' return '*' + text + '*\n'
def double_emphasis(self, text): def strong(self, text):
return '*' + text + '*' return '*' + text + '*'
def list(self, body, ordered=True): def list(self, text, **kwargs):
lines = body.split('\n') lines = text.split('\n')
for i, line in enumerate(lines): for i, line in enumerate(lines):
if not line: if not line:
continue continue
@ -41,9 +39,9 @@ class SlackRenderer(mistune.Renderer):
def block_code(self, code, lang=None): def block_code(self, code, lang=None):
return f'`{code}`' return f'`{code}`'
def link(self, link, title, content): def link(self, link, text=None, title=None):
if title or content: if title or text:
label = str(title or content).strip() label = str(title or text).strip()
return f'<{link}|{label}>' return f'<{link}|{label}>'
return f'<{link}>' return f'<{link}>'

View File

@ -394,20 +394,20 @@ class CommonBulkModelSerializer(CommonBulkSerializerMixin, serializers.ModelSeri
class ResourceLabelsMixin(serializers.Serializer): class ResourceLabelsMixin(serializers.Serializer):
labels = LabelRelatedField(many=True, label=_('Labels'), required=False, allow_null=True) labels = LabelRelatedField(many=True, label=_('Labels'), required=False, allow_null=True, source='res_labels')
def update(self, instance, validated_data): def update(self, instance, validated_data):
labels = validated_data.pop('labels', None) labels = validated_data.pop('res_labels', None)
res = super().update(instance, validated_data) res = super().update(instance, validated_data)
if labels is not None: if labels is not None:
instance.labels.set(labels, bulk=False) instance.res_labels.set(labels, bulk=False)
return res return res
def create(self, validated_data): def create(self, validated_data):
labels = validated_data.pop('labels', None) labels = validated_data.pop('res_labels', None)
instance = super().create(validated_data) instance = super().create(validated_data)
if labels is not None: if labels is not None:
instance.labels.set(labels, bulk=False) instance.res_labels.set(labels, bulk=False)
return instance return instance
@classmethod @classmethod

View File

@ -62,14 +62,14 @@ def digest_sql_query():
method = current_request.method method = current_request.method
path = current_request.get_full_path() path = current_request.get_full_path()
print(">>> [{}] {}".format(method, path)) print(">>>. [{}] {}".format(method, path))
for table_name, queries in table_queries.items(): for table_name, queries in table_queries.items():
if table_name.startswith('rbac_') or table_name.startswith('auth_permission'): if table_name.startswith('rbac_') or table_name.startswith('auth_permission'):
continue continue
for query in queries: for query in queries:
sql = query['sql'] sql = query['sql']
print(" # {}: {}".format(query['time'], sql)) print(" # {}: {}".format(query['time'], sql[:1000]))
if len(queries) < 3: if len(queries) < 3:
continue continue
print("- Table: {}".format(table_name)) print("- Table: {}".format(table_name))
@ -77,9 +77,9 @@ def digest_sql_query():
sql = query['sql'] sql = query['sql']
if not sql or not sql.startswith('SELECT'): if not sql or not sql.startswith('SELECT'):
continue continue
print('\t{}. {}'.format(i, sql)) print('\t{}.[{}s] {}'.format(i, round(float(query['time']), 2), sql[:1000]))
logger.debug(">>> [{}] {}".format(method, path)) # logger.debug(">>> [{}] {}".format(method, path))
for name, counter in counters: for name, counter in counters:
logger.debug("Query {:3} times using {:.2f}s {}".format( logger.debug("Query {:3} times using {:.2f}s {}".format(
counter.counter, counter.time, name) counter.counter, counter.time, name)

View File

@ -2,7 +2,7 @@ import os
from celery import shared_task from celery import shared_task
from django.conf import settings from django.conf import settings
from django.core.mail import send_mail, EmailMultiAlternatives from django.core.mail import send_mail, EmailMultiAlternatives, get_connection
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
import jms_storage import jms_storage
@ -11,6 +11,16 @@ from .utils import get_logger
logger = get_logger(__file__) logger = get_logger(__file__)
def get_email_connection(**kwargs):
email_backend_map = {
'smtp': 'django.core.mail.backends.smtp.EmailBackend',
'exchange': 'jumpserver.rewriting.exchange.EmailBackend'
}
return get_connection(
backend=email_backend_map.get(settings.EMAIL_PROTOCOL), **kwargs
)
def task_activity_callback(self, subject, message, recipient_list, *args, **kwargs): def task_activity_callback(self, subject, message, recipient_list, *args, **kwargs):
from users.models import User from users.models import User
email_list = recipient_list email_list = recipient_list
@ -40,7 +50,7 @@ def send_mail_async(*args, **kwargs):
args = tuple(args) args = tuple(args)
try: try:
return send_mail(*args, **kwargs) return send_mail(connection=get_email_connection(), *args, **kwargs)
except Exception as e: except Exception as e:
logger.error("Sending mail error: {}".format(e)) logger.error("Sending mail error: {}".format(e))
@ -55,7 +65,8 @@ def send_mail_attachment_async(subject, message, recipient_list, attachment_list
subject=subject, subject=subject,
body=message, body=message,
from_email=from_email, from_email=from_email,
to=recipient_list to=recipient_list,
connection=get_email_connection(),
) )
for attachment in attachment_list: for attachment in attachment_list:
email.attach_file(attachment) email.attach_file(attachment)

View File

@ -220,7 +220,7 @@ def timeit(func):
now = time.time() now = time.time()
result = func(*args, **kwargs) result = func(*args, **kwargs)
using = (time.time() - now) * 1000 using = (time.time() - now) * 1000
msg = "End call {}, using: {:.1f}ms".format(name, using) msg = "Ends call: {}, using: {:.1f}ms".format(name, using)
logger.debug(msg) logger.debug(msg)
return result return result

View File

@ -1,18 +1,16 @@
from functools import wraps
import threading import threading
from functools import wraps
from django.db import transaction
from redis_lock import ( from redis_lock import (
Lock as RedisLock, NotAcquired, UNLOCK_SCRIPT, Lock as RedisLock, NotAcquired, UNLOCK_SCRIPT,
EXTEND_SCRIPT, RESET_SCRIPT, RESET_ALL_SCRIPT EXTEND_SCRIPT, RESET_SCRIPT, RESET_ALL_SCRIPT
) )
from redis import Redis
from django.db import transaction
from common.utils import get_logger
from common.utils.inspect import copy_function_args
from common.utils.connection import get_redis_client
from jumpserver.const import CONFIG
from common.local import thread_local from common.local import thread_local
from common.utils import get_logger
from common.utils.connection import get_redis_client
from common.utils.inspect import copy_function_args
logger = get_logger(__file__) logger = get_logger(__file__)
@ -76,6 +74,7 @@ class DistributedLock(RedisLock):
# 要创建一个新的锁对象 # 要创建一个新的锁对象
with self.__class__(**self.kwargs_copy): with self.__class__(**self.kwargs_copy):
return func(*args, **kwds) return func(*args, **kwds)
return inner return inner
@classmethod @classmethod
@ -95,7 +94,6 @@ class DistributedLock(RedisLock):
if self.locked(): if self.locked():
owner_id = self.get_owner_id() owner_id = self.get_owner_id()
local_owner_id = getattr(thread_local, self.name, None) local_owner_id = getattr(thread_local, self.name, None)
if local_owner_id and owner_id == local_owner_id: if local_owner_id and owner_id == local_owner_id:
return True return True
return False return False
@ -140,14 +138,16 @@ class DistributedLock(RedisLock):
logger.debug(f'Released reentrant-lock: lock_id={self.id} owner_id={self.get_owner_id()} lock={self.name}') logger.debug(f'Released reentrant-lock: lock_id={self.id} owner_id={self.get_owner_id()} lock={self.name}')
return return
else: else:
self._raise_exc_with_log(f'Reentrant-lock is not acquired: lock_id={self.id} owner_id={self.get_owner_id()} lock={self.name}') self._raise_exc_with_log(
f'Reentrant-lock is not acquired: lock_id={self.id} owner_id={self.get_owner_id()} lock={self.name}')
def _release_on_reentrant_locked_by_me(self): def _release_on_reentrant_locked_by_me(self):
logger.debug(f'Release reentrant-lock locked by me: lock_id={self.id} lock={self.name}') logger.debug(f'Release reentrant-lock locked by me: lock_id={self.id} lock={self.name}')
id = getattr(thread_local, self.name, None) id = getattr(thread_local, self.name, None)
if id != self.id: if id != self.id:
raise PermissionError(f'Reentrant-lock is not locked by me: lock_id={self.id} owner_id={self.get_owner_id()} lock={self.name}') raise PermissionError(
f'Reentrant-lock is not locked by me: lock_id={self.id} owner_id={self.get_owner_id()} lock={self.name}')
try: try:
# 这里要保证先删除 thread_local 的标记, # 这里要保证先删除 thread_local 的标记,
delattr(thread_local, self.name) delattr(thread_local, self.name)
@ -191,7 +191,7 @@ class DistributedLock(RedisLock):
# 处理是否在事务提交时才释放锁 # 处理是否在事务提交时才释放锁
if self._release_on_transaction_commit: if self._release_on_transaction_commit:
logger.debug( logger.debug(
f'Release lock on transaction commit ... :lock_id={self.id} lock={self.name}') f'Release lock on transaction commit:lock_id={self.id} lock={self.name}')
transaction.on_commit(_release) transaction.on_commit(_release)
else: else:
_release() _release()

View File

@ -17,6 +17,7 @@ from assets.models import Asset
from audits.api import OperateLogViewSet from audits.api import OperateLogViewSet
from audits.const import LoginStatusChoices from audits.const import LoginStatusChoices
from audits.models import UserLoginLog, PasswordChangeLog, OperateLog, FTPLog, JobLog from audits.models import UserLoginLog, PasswordChangeLog, OperateLog, FTPLog, JobLog
from audits.utils import construct_userlogin_usernames
from common.utils import lazyproperty from common.utils import lazyproperty
from common.utils.timezone import local_now, local_zero_hour from common.utils.timezone import local_now, local_zero_hour
from ops.const import JobStatus from ops.const import JobStatus
@ -79,7 +80,7 @@ class DateTimeMixin:
if not self.org.is_root(): if not self.org.is_root():
if query_params == 'username': if query_params == 'username':
query = { query = {
f'{query_params}__in': users.values_list('username', flat=True) f'{query_params}__in': construct_userlogin_usernames(users)
} }
else: else:
query = { query = {

View File

@ -17,7 +17,7 @@ import re
import sys import sys
import types import types
from importlib import import_module from importlib import import_module
from urllib.parse import urljoin, urlparse from urllib.parse import urljoin, urlparse, quote
import yaml import yaml
from django.urls import reverse_lazy from django.urls import reverse_lazy
@ -261,6 +261,8 @@ class Config(dict):
'VAULT_HCP_TOKEN': '', 'VAULT_HCP_TOKEN': '',
'VAULT_HCP_MOUNT_POINT': 'jumpserver', 'VAULT_HCP_MOUNT_POINT': 'jumpserver',
'HISTORY_ACCOUNT_CLEAN_LIMIT': 999,
# Cache login password # Cache login password
'CACHE_LOGIN_PASSWORD_ENABLED': False, 'CACHE_LOGIN_PASSWORD_ENABLED': False,
'CACHE_LOGIN_PASSWORD_TTL': 60 * 60 * 24, 'CACHE_LOGIN_PASSWORD_TTL': 60 * 60 * 24,
@ -280,6 +282,7 @@ class Config(dict):
'AUTH_LDAP_SYNC_INTERVAL': None, 'AUTH_LDAP_SYNC_INTERVAL': None,
'AUTH_LDAP_SYNC_CRONTAB': None, 'AUTH_LDAP_SYNC_CRONTAB': None,
'AUTH_LDAP_SYNC_ORG_IDS': ['00000000-0000-0000-0000-000000000002'], 'AUTH_LDAP_SYNC_ORG_IDS': ['00000000-0000-0000-0000-000000000002'],
'AUTH_LDAP_SYNC_RECEIVERS': [],
'AUTH_LDAP_USER_LOGIN_ONLY_IN_USERS': False, 'AUTH_LDAP_USER_LOGIN_ONLY_IN_USERS': False,
'AUTH_LDAP_OPTIONS_OPT_REFERRALS': -1, 'AUTH_LDAP_OPTIONS_OPT_REFERRALS': -1,
@ -325,6 +328,7 @@ class Config(dict):
'RADIUS_SERVER': 'localhost', 'RADIUS_SERVER': 'localhost',
'RADIUS_PORT': 1812, 'RADIUS_PORT': 1812,
'RADIUS_SECRET': '', 'RADIUS_SECRET': '',
'RADIUS_ATTRIBUTES': {},
'RADIUS_ENCRYPT_PASSWORD': True, 'RADIUS_ENCRYPT_PASSWORD': True,
'OTP_IN_RADIUS': False, 'OTP_IN_RADIUS': False,
@ -451,6 +455,7 @@ class Config(dict):
'CUSTOM_SMS_REQUEST_METHOD': 'get', 'CUSTOM_SMS_REQUEST_METHOD': 'get',
# Email # Email
'EMAIL_PROTOCOL': 'smtp',
'EMAIL_CUSTOM_USER_CREATED_SUBJECT': _('Create account successfully'), 'EMAIL_CUSTOM_USER_CREATED_SUBJECT': _('Create account successfully'),
'EMAIL_CUSTOM_USER_CREATED_HONORIFIC': _('Hello'), 'EMAIL_CUSTOM_USER_CREATED_HONORIFIC': _('Hello'),
'EMAIL_CUSTOM_USER_CREATED_BODY': _('Your account has been created successfully'), 'EMAIL_CUSTOM_USER_CREATED_BODY': _('Your account has been created successfully'),
@ -531,6 +536,7 @@ class Config(dict):
'SYSLOG_SOCKTYPE': 2, 'SYSLOG_SOCKTYPE': 2,
'PERM_EXPIRED_CHECK_PERIODIC': 60 * 60, 'PERM_EXPIRED_CHECK_PERIODIC': 60 * 60,
'PERM_TREE_REGEN_INTERVAL': 1,
'FLOWER_URL': "127.0.0.1:5555", 'FLOWER_URL': "127.0.0.1:5555",
'LANGUAGE_CODE': 'zh', 'LANGUAGE_CODE': 'zh',
'TIME_ZONE': 'Asia/Shanghai', 'TIME_ZONE': 'Asia/Shanghai',
@ -693,6 +699,13 @@ class Config(dict):
if openid_config: if openid_config:
self.set_openid_config(openid_config) self.set_openid_config(openid_config)
def compatible_redis(self):
redis_config = {
'REDIS_PASSWORD': quote(str(self.REDIS_PASSWORD)),
}
for key, value in redis_config.items():
self[key] = value
def compatible(self): def compatible(self):
""" """
对配置做兼容处理 对配置做兼容处理
@ -704,6 +717,8 @@ class Config(dict):
""" """
# 兼容 OpenID 配置 # 兼容 OpenID 配置
self.compatible_auth_openid() self.compatible_auth_openid()
# 兼容 Redis 配置
self.compatible_redis()
def convert_type(self, k, v): def convert_type(self, k, v):
default_value = self.defaults.get(k) default_value = self.defaults.get(k)

View File

@ -0,0 +1,104 @@
import urllib3
from urllib3.exceptions import InsecureRequestWarning
from django.core.mail.backends.base import BaseEmailBackend
from django.core.mail.message import sanitize_address
from django.conf import settings
from exchangelib import Account, Credentials, Configuration, DELEGATE
from exchangelib import Mailbox, Message, HTMLBody, FileAttachment
from exchangelib import BaseProtocol, NoVerifyHTTPAdapter
from exchangelib.errors import TransportError
urllib3.disable_warnings(InsecureRequestWarning)
BaseProtocol.HTTP_ADAPTER_CLS = NoVerifyHTTPAdapter
class EmailBackend(BaseEmailBackend):
def __init__(
self,
service_endpoint=None,
username=None,
password=None,
fail_silently=False,
**kwargs,
):
super().__init__(fail_silently=fail_silently)
self.service_endpoint = service_endpoint or settings.EMAIL_HOST
self.username = settings.EMAIL_HOST_USER if username is None else username
self.password = settings.EMAIL_HOST_PASSWORD if password is None else password
self._connection = None
def open(self):
if self._connection:
return False
try:
config = Configuration(
service_endpoint=self.service_endpoint, credentials=Credentials(
username=self.username, password=self.password
)
)
self._connection = Account(self.username, config=config, access_type=DELEGATE)
return True
except TransportError:
if not self.fail_silently:
raise
def close(self):
self._connection = None
def send_messages(self, email_messages):
if not email_messages:
return 0
new_conn_created = self.open()
if not self._connection or new_conn_created is None:
return 0
num_sent = 0
for message in email_messages:
sent = self._send(message)
if sent:
num_sent += 1
if new_conn_created:
self.close()
return num_sent
def _send(self, email_message):
if not email_message.recipients():
return False
encoding = settings.DEFAULT_CHARSET
from_email = sanitize_address(email_message.from_email, encoding)
recipients = [
Mailbox(email_address=sanitize_address(addr, encoding)) for addr in email_message.recipients()
]
try:
message_body = email_message.body
alternatives = email_message.alternatives or []
attachments = []
for attachment in email_message.attachments or []:
name, content, mimetype = attachment
if isinstance(content, str):
content = content.encode(encoding)
attachments.append(
FileAttachment(name=name, content=content, content_type=mimetype)
)
for alternative in alternatives:
if alternative[1] == 'text/html':
message_body = HTMLBody(alternative[0])
break
email_message = Message(
account=self._connection, subject=email_message.subject,
body=message_body, to_recipients=recipients, sender=from_email,
attachments=[]
)
email_message.attach(attachments)
email_message.send_and_save()
except Exception as error:
if not self.fail_silently:
raise error
return False
return True

View File

@ -0,0 +1,14 @@
from private_storage.servers import NginxXAccelRedirectServer, DjangoServer
class StaticFileServer(object):
@staticmethod
def serve(private_file):
full_path = private_file.full_path
# todo: gzip 文件录像 nginx 处理后,浏览器无法正常解析内容
# 造成在线播放失败,暂时仅使用 nginx 处理 mp4 录像文件
if full_path.endswith('.mp4'):
return NginxXAccelRedirectServer.serve(private_file)
else:
return DjangoServer.serve(private_file)

View File

@ -50,6 +50,7 @@ AUTH_LDAP_SYNC_IS_PERIODIC = CONFIG.AUTH_LDAP_SYNC_IS_PERIODIC
AUTH_LDAP_SYNC_INTERVAL = CONFIG.AUTH_LDAP_SYNC_INTERVAL AUTH_LDAP_SYNC_INTERVAL = CONFIG.AUTH_LDAP_SYNC_INTERVAL
AUTH_LDAP_SYNC_CRONTAB = CONFIG.AUTH_LDAP_SYNC_CRONTAB AUTH_LDAP_SYNC_CRONTAB = CONFIG.AUTH_LDAP_SYNC_CRONTAB
AUTH_LDAP_SYNC_ORG_IDS = CONFIG.AUTH_LDAP_SYNC_ORG_IDS AUTH_LDAP_SYNC_ORG_IDS = CONFIG.AUTH_LDAP_SYNC_ORG_IDS
AUTH_LDAP_SYNC_RECEIVERS = CONFIG.AUTH_LDAP_SYNC_RECEIVERS
AUTH_LDAP_USER_LOGIN_ONLY_IN_USERS = CONFIG.AUTH_LDAP_USER_LOGIN_ONLY_IN_USERS AUTH_LDAP_USER_LOGIN_ONLY_IN_USERS = CONFIG.AUTH_LDAP_USER_LOGIN_ONLY_IN_USERS
# ============================================================================== # ==============================================================================
@ -99,6 +100,8 @@ AUTH_RADIUS_BACKEND = 'authentication.backends.radius.RadiusBackend'
RADIUS_SERVER = CONFIG.RADIUS_SERVER RADIUS_SERVER = CONFIG.RADIUS_SERVER
RADIUS_PORT = CONFIG.RADIUS_PORT RADIUS_PORT = CONFIG.RADIUS_PORT
RADIUS_SECRET = CONFIG.RADIUS_SECRET RADIUS_SECRET = CONFIG.RADIUS_SECRET
# https://github.com/robgolding/django-radius/blob/develop/radiusauth/backends/radius.py#L15-L52
RADIUS_ATTRIBUTES = CONFIG.RADIUS_ATTRIBUTES
# CAS Auth # CAS Auth
AUTH_CAS = CONFIG.AUTH_CAS AUTH_CAS = CONFIG.AUTH_CAS
@ -190,6 +193,8 @@ VAULT_HCP_HOST = CONFIG.VAULT_HCP_HOST
VAULT_HCP_TOKEN = CONFIG.VAULT_HCP_TOKEN VAULT_HCP_TOKEN = CONFIG.VAULT_HCP_TOKEN
VAULT_HCP_MOUNT_POINT = CONFIG.VAULT_HCP_MOUNT_POINT VAULT_HCP_MOUNT_POINT = CONFIG.VAULT_HCP_MOUNT_POINT
HISTORY_ACCOUNT_CLEAN_LIMIT = CONFIG.HISTORY_ACCOUNT_CLEAN_LIMIT
# Other setting # Other setting
# 这个是 User Login Private Token # 这个是 User Login Private Token
TOKEN_EXPIRATION = CONFIG.TOKEN_EXPIRATION TOKEN_EXPIRATION = CONFIG.TOKEN_EXPIRATION

View File

@ -312,12 +312,15 @@ STATICFILES_DIRS = (
os.path.join(BASE_DIR, "static"), os.path.join(BASE_DIR, "static"),
) )
# Media files (File, ImageField) will be save these # Media files (File, ImageField) will be safe these
MEDIA_URL = '/media/' MEDIA_URL = '/media/'
MEDIA_ROOT = os.path.join(PROJECT_DIR, 'data', 'media').replace('\\', '/') + '/' MEDIA_ROOT = os.path.join(PROJECT_DIR, 'data', 'media').replace('\\', '/') + '/'
PRIVATE_STORAGE_ROOT = MEDIA_ROOT PRIVATE_STORAGE_ROOT = MEDIA_ROOT
PRIVATE_STORAGE_AUTH_FUNCTION = 'jumpserver.rewriting.storage.permissions.allow_access' PRIVATE_STORAGE_AUTH_FUNCTION = 'jumpserver.rewriting.storage.permissions.allow_access'
PRIVATE_STORAGE_INTERNAL_URL = '/private-media/'
PRIVATE_STORAGE_SERVER = 'jumpserver.rewriting.storage.servers.StaticFileServer'
# Use django-bootstrap-form to format template, input max width arg # Use django-bootstrap-form to format template, input max width arg
# BOOTSTRAP_COLUMN_COUNT = 11 # BOOTSTRAP_COLUMN_COUNT = 11
@ -326,6 +329,7 @@ PRIVATE_STORAGE_AUTH_FUNCTION = 'jumpserver.rewriting.storage.permissions.allow_
FIXTURE_DIRS = [os.path.join(BASE_DIR, 'fixtures'), ] FIXTURE_DIRS = [os.path.join(BASE_DIR, 'fixtures'), ]
# Email config # Email config
EMAIL_PROTOCOL = CONFIG.EMAIL_PROTOCOL
EMAIL_HOST = CONFIG.EMAIL_HOST EMAIL_HOST = CONFIG.EMAIL_HOST
EMAIL_PORT = CONFIG.EMAIL_PORT EMAIL_PORT = CONFIG.EMAIL_PORT
EMAIL_HOST_USER = CONFIG.EMAIL_HOST_USER EMAIL_HOST_USER = CONFIG.EMAIL_HOST_USER

View File

@ -208,6 +208,7 @@ OPERATE_LOG_ELASTICSEARCH_CONFIG = CONFIG.OPERATE_LOG_ELASTICSEARCH_CONFIG
MAX_LIMIT_PER_PAGE = CONFIG.MAX_LIMIT_PER_PAGE MAX_LIMIT_PER_PAGE = CONFIG.MAX_LIMIT_PER_PAGE
DEFAULT_PAGE_SIZE = CONFIG.DEFAULT_PAGE_SIZE DEFAULT_PAGE_SIZE = CONFIG.DEFAULT_PAGE_SIZE
PERM_TREE_REGEN_INTERVAL = CONFIG.PERM_TREE_REGEN_INTERVAL
# Magnus DB Port # Magnus DB Port
MAGNUS_ORACLE_PORTS = CONFIG.MAGNUS_ORACLE_PORTS MAGNUS_ORACLE_PORTS = CONFIG.MAGNUS_ORACLE_PORTS

View File

@ -21,7 +21,7 @@ LOGGING = {
}, },
'main': { 'main': {
'datefmt': '%Y-%m-%d %H:%M:%S', 'datefmt': '%Y-%m-%d %H:%M:%S',
'format': '%(asctime)s [%(module)s %(levelname)s] %(message)s', 'format': '%(asctime)s [%(levelname).4s] %(message)s',
}, },
'exception': { 'exception': {
'datefmt': '%Y-%m-%d %H:%M:%S', 'datefmt': '%Y-%m-%d %H:%M:%S',

View File

@ -73,7 +73,7 @@ class LabelContentTypeResourceViewSet(JMSModelViewSet):
queryset = model.objects.all() queryset = model.objects.all()
if bound == '1': if bound == '1':
queryset = queryset.filter(id__in=list(res_ids)) queryset = queryset.filter(id__in=list(res_ids))
elif bound == '0': else:
queryset = queryset.exclude(id__in=list(res_ids)) queryset = queryset.exclude(id__in=list(res_ids))
keyword = self.request.query_params.get('search') keyword = self.request.query_params.get('search')
if keyword: if keyword:
@ -90,9 +90,10 @@ class LabelContentTypeResourceViewSet(JMSModelViewSet):
LabeledResource.objects \ LabeledResource.objects \
.filter(res_type=content_type, label=label) \ .filter(res_type=content_type, label=label) \
.exclude(res_id__in=res_ids).delete() .exclude(res_id__in=res_ids).delete()
resources = [] resources = [
for res_id in res_ids: LabeledResource(res_type=content_type, res_id=res_id, label=label, org_id=current_org.id)
resources.append(LabeledResource(res_type=content_type, res_id=res_id, label=label, org_id=current_org.id)) for res_id in res_ids
]
LabeledResource.objects.bulk_create(resources, ignore_conflicts=True) LabeledResource.objects.bulk_create(resources, ignore_conflicts=True)
return Response({"total": len(res_ids)}) return Response({"total": len(res_ids)})
@ -129,15 +130,22 @@ class LabeledResourceViewSet(OrgBulkModelViewSet):
} }
ordering_fields = ('res_type', 'date_created') ordering_fields = ('res_type', 'date_created')
# Todo: 这里需要优化,查询 sql 太多
def filter_search(self, queryset): def filter_search(self, queryset):
keyword = self.request.query_params.get('search') keyword = self.request.query_params.get('search')
if not keyword: if not keyword:
return queryset return queryset
keyword = keyword.strip().lower()
matched = [] matched = []
for instance in queryset: offset = 0
if keyword.lower() in str(instance.resource).lower(): limit = 10000
matched.append(instance.id) while True:
page = queryset[offset:offset + limit]
if not page:
break
offset += limit
for instance in page:
if keyword in str(instance.resource).lower():
matched.append(instance.id)
return queryset.filter(id__in=matched) return queryset.filter(id__in=matched)
def get_queryset(self): def get_queryset(self):

View File

@ -1,21 +1,38 @@
from django.contrib.contenttypes.fields import GenericRelation from django.contrib.contenttypes.fields import GenericRelation
from django.db import models from django.db import models
from django.db.models import OneToOneField
from common.utils import lazyproperty
from .models import LabeledResource from .models import LabeledResource
__all__ = ['LabeledMixin'] __all__ = ['LabeledMixin']
class LabeledMixin(models.Model): class LabeledMixin(models.Model):
_labels = GenericRelation(LabeledResource, object_id_field='res_id', content_type_field='res_type') labels = GenericRelation(LabeledResource, object_id_field='res_id', content_type_field='res_type')
class Meta: class Meta:
abstract = True abstract = True
@property @classmethod
def labels(self): def label_model(cls):
return self._labels pk_field = cls._meta.pk
model = cls
if isinstance(pk_field, OneToOneField):
model = pk_field.related_model
return model
@labels.setter @lazyproperty
def labels(self, value): def real(self):
self._labels.set(value, bulk=False) pk_field = self._meta.pk
if isinstance(pk_field, OneToOneField):
return getattr(self, pk_field.name)
return self
@property
def res_labels(self):
return self.real.labels
@res_labels.setter
def res_labels(self, value):
self.real.labels.set(value, bulk=False)

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:71d292647cf751c002b459449c7bebf4d2bf5a3933748387e7c2f80a7111302e oid sha256:7879f4eeb499e920ad6c4bfdb0b1f334936ca344c275be056f12fcf7485f2bf6
size 169602 size 170948

File diff suppressed because it is too large Load Diff

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:80dd11dde678e4f9b64df18906175125218fd9f719bfe9aaa667ad6e2d055d40 oid sha256:19d3a111cc245f9a9d36b860fd95447df916ad66c918bef672bacdad6bc77a8f
size 139012 size 140119

File diff suppressed because it is too large Load Diff

View File

@ -4,6 +4,21 @@ import time
import paramiko import paramiko
from sshtunnel import SSHTunnelForwarder from sshtunnel import SSHTunnelForwarder
from packaging import version
if version.parse(paramiko.__version__) > version.parse("2.8.1"):
_preferred_pubkeys = (
"ssh-ed25519",
"ecdsa-sha2-nistp256",
"ecdsa-sha2-nistp384",
"ecdsa-sha2-nistp521",
"ssh-rsa",
"rsa-sha2-256",
"rsa-sha2-512",
"ssh-dss",
)
paramiko.transport.Transport._preferred_pubkeys = _preferred_pubkeys
def common_argument_spec(): def common_argument_spec():
options = dict( options = dict(

View File

@ -75,7 +75,7 @@ model_cache_field_mapper = {
class OrgResourceStatisticsRefreshUtil: class OrgResourceStatisticsRefreshUtil:
@staticmethod @staticmethod
@merge_delay_run(ttl=5) @merge_delay_run(ttl=30)
def refresh_org_fields(org_fields=()): def refresh_org_fields(org_fields=()):
for org, cache_field_name in org_fields: for org, cache_field_name in org_fields:
OrgResourceStatisticsCache(org).expire(*cache_field_name) OrgResourceStatisticsCache(org).expire(*cache_field_name)
@ -104,7 +104,7 @@ def on_post_delete_refresh_org_resource_statistics_cache(sender, instance, **kwa
def _refresh_session_org_resource_statistics_cache(instance: Session): def _refresh_session_org_resource_statistics_cache(instance: Session):
cache_field_name = [ cache_field_name = [
'total_count_online_users', 'total_count_online_sessions', 'total_count_online_users', 'total_count_online_sessions',
'total_count_today_active_assets','total_count_today_failed_sessions' 'total_count_today_active_assets', 'total_count_today_failed_sessions'
] ]
org_cache = OrgResourceStatisticsCache(instance.org) org_cache = OrgResourceStatisticsCache(instance.org)

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
from orgs.mixins.api import OrgBulkModelViewSet from orgs.mixins.api import OrgBulkModelViewSet
from perms import serializers from perms import serializers
from perms.filters import AssetPermissionFilter from perms.filters import AssetPermissionFilter
@ -13,7 +14,10 @@ class AssetPermissionViewSet(OrgBulkModelViewSet):
资产授权列表的增删改查api 资产授权列表的增删改查api
""" """
model = AssetPermission model = AssetPermission
serializer_class = serializers.AssetPermissionSerializer serializer_classes = {
'default': serializers.AssetPermissionSerializer,
'list': serializers.AssetPermissionListSerializer,
}
filterset_class = AssetPermissionFilter filterset_class = AssetPermissionFilter
search_fields = ('name',) search_fields = ('name',)
ordering = ('name',) ordering = ('name',)

View File

@ -7,8 +7,7 @@ from assets.models import Asset, Node
from common.utils import get_logger, lazyproperty, is_uuid from common.utils import get_logger, lazyproperty, is_uuid
from orgs.utils import tmp_to_root_org from orgs.utils import tmp_to_root_org
from perms import serializers from perms import serializers
from perms.pagination import AllPermedAssetPagination from perms.pagination import NodePermedAssetPagination, AllPermedAssetPagination
from perms.pagination import NodePermedAssetPagination
from perms.utils import UserPermAssetUtil, PermAssetDetailUtil from perms.utils import UserPermAssetUtil, PermAssetDetailUtil
from .mixin import ( from .mixin import (
SelfOrPKUserMixin SelfOrPKUserMixin

View File

@ -1,16 +1,14 @@
from django.conf import settings from django.conf import settings
from rest_framework.response import Response from rest_framework.response import Response
from assets.models import Asset
from assets.api import SerializeToTreeNodeMixin from assets.api import SerializeToTreeNodeMixin
from assets.models import Asset
from common.utils import get_logger from common.utils import get_logger
from ..assets import UserAllPermedAssetsApi
from .mixin import RebuildTreeMixin from .mixin import RebuildTreeMixin
from ..assets import UserAllPermedAssetsApi
logger = get_logger(__name__) logger = get_logger(__name__)
__all__ = [ __all__ = [
'UserAllPermedAssetsAsTreeApi', 'UserAllPermedAssetsAsTreeApi',
'UserUngroupAssetsAsTreeApi', 'UserUngroupAssetsAsTreeApi',
@ -31,7 +29,7 @@ class AssetTreeMixin(RebuildTreeMixin, SerializeToTreeNodeMixin):
if request.query_params.get('search'): if request.query_params.get('search'):
""" 限制返回数量, 搜索的条件不精准时,会返回大量的无意义数据 """ """ 限制返回数量, 搜索的条件不精准时,会返回大量的无意义数据 """
assets = assets[:999] assets = assets[:999]
data = self.serialize_assets(assets, None) data = self.serialize_assets(assets, 'root')
return Response(data=data) return Response(data=data)
@ -42,6 +40,7 @@ class UserAllPermedAssetsAsTreeApi(AssetTreeMixin, UserAllPermedAssetsApi):
class UserUngroupAssetsAsTreeApi(UserAllPermedAssetsAsTreeApi): class UserUngroupAssetsAsTreeApi(UserAllPermedAssetsAsTreeApi):
""" 用户 '未分组节点的资产(直接授权的资产)' 作为树 """ """ 用户 '未分组节点的资产(直接授权的资产)' 作为树 """
def get_assets(self): def get_assets(self):
if settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE: if settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE:
return super().get_assets() return super().get_assets()

View File

@ -1,6 +1,4 @@
import abc import abc
import re
from collections import defaultdict
from urllib.parse import parse_qsl from urllib.parse import parse_qsl
from django.conf import settings from django.conf import settings
@ -13,10 +11,10 @@ from rest_framework.response import Response
from accounts.const import AliasAccount from accounts.const import AliasAccount
from assets.api import SerializeToTreeNodeMixin from assets.api import SerializeToTreeNodeMixin
from assets.const import AllTypes
from assets.models import Asset from assets.models import Asset
from assets.utils import KubernetesTree from assets.utils import KubernetesTree
from authentication.models import ConnectionToken from authentication.models import ConnectionToken
from common.exceptions import JMSException
from common.utils import get_object_or_none, lazyproperty from common.utils import get_object_or_none, lazyproperty
from common.utils.common import timeit from common.utils.common import timeit
from perms.hands import Node from perms.hands import Node
@ -38,21 +36,36 @@ class BaseUserNodeWithAssetAsTreeApi(
SelfOrPKUserMixin, RebuildTreeMixin, SelfOrPKUserMixin, RebuildTreeMixin,
SerializeToTreeNodeMixin, ListAPIView SerializeToTreeNodeMixin, ListAPIView
): ):
page_limit = 10000
def list(self, request, *args, **kwargs): def list(self, request, *args, **kwargs):
nodes, assets = self.get_nodes_assets() offset = int(request.query_params.get('offset', 0))
tree_nodes = self.serialize_nodes(nodes, with_asset_amount=True) page_assets = self.get_page_assets()
tree_assets = self.serialize_assets(assets, node_key=self.node_key_for_serialize_assets)
data = list(tree_nodes) + list(tree_assets) if not offset:
return Response(data=data) nodes, assets = self.get_nodes_assets()
page = page_assets[:self.page_limit]
assets = [*assets, *page]
tree_nodes = self.serialize_nodes(nodes, with_asset_amount=True)
tree_assets = self.serialize_assets(assets, **self.serialize_asset_kwargs)
data = list(tree_nodes) + list(tree_assets)
else:
page = page_assets[offset:(offset + self.page_limit)]
data = self.serialize_assets(page, **self.serialize_asset_kwargs) if page else []
offset += len(page)
headers = {'X-JMS-TREE-OFFSET': offset} if offset else {}
return Response(data=data, headers=headers)
@abc.abstractmethod @abc.abstractmethod
def get_nodes_assets(self): def get_nodes_assets(self):
return [], [] return [], []
@lazyproperty def get_page_assets(self):
def node_key_for_serialize_assets(self): return []
return None
@property
def serialize_asset_kwargs(self):
return {}
class UserPermedNodesWithAssetsAsTreeApi(BaseUserNodeWithAssetAsTreeApi): class UserPermedNodesWithAssetsAsTreeApi(BaseUserNodeWithAssetAsTreeApi):
@ -61,7 +74,6 @@ class UserPermedNodesWithAssetsAsTreeApi(BaseUserNodeWithAssetAsTreeApi):
def get_nodes_assets(self): def get_nodes_assets(self):
self.query_node_util = UserPermNodeUtil(self.request.user) self.query_node_util = UserPermNodeUtil(self.request.user)
self.query_asset_util = UserPermAssetUtil(self.request.user)
ung_nodes, ung_assets = self._get_nodes_assets_for_ungrouped() ung_nodes, ung_assets = self._get_nodes_assets_for_ungrouped()
fav_nodes, fav_assets = self._get_nodes_assets_for_favorite() fav_nodes, fav_assets = self._get_nodes_assets_for_favorite()
all_nodes, all_assets = self._get_nodes_assets_for_all() all_nodes, all_assets = self._get_nodes_assets_for_all()
@ -69,31 +81,37 @@ class UserPermedNodesWithAssetsAsTreeApi(BaseUserNodeWithAssetAsTreeApi):
assets = list(ung_assets) + list(fav_assets) + list(all_assets) assets = list(ung_assets) + list(fav_assets) + list(all_assets)
return nodes, assets return nodes, assets
def get_page_assets(self):
return self.query_asset_util.get_all_assets().annotate(parent_key=F('nodes__key'))
@timeit @timeit
def _get_nodes_assets_for_ungrouped(self): def _get_nodes_assets_for_ungrouped(self):
if not settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE: if not settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE:
return [], [] return [], []
node = self.query_node_util.get_ungrouped_node() node = self.query_node_util.get_ungrouped_node()
assets = self.query_asset_util.get_ungroup_assets() assets = self.query_asset_util.get_ungroup_assets()
assets = assets.annotate(parent_key=Value(node.key, output_field=CharField())) \ assets = assets.annotate(parent_key=Value(node.key, output_field=CharField()))
.prefetch_related('platform')
return [node], assets return [node], assets
@lazyproperty
def query_asset_util(self):
return UserPermAssetUtil(self.user)
@timeit @timeit
def _get_nodes_assets_for_favorite(self): def _get_nodes_assets_for_favorite(self):
node = self.query_node_util.get_favorite_node() node = self.query_node_util.get_favorite_node()
assets = self.query_asset_util.get_favorite_assets() assets = self.query_asset_util.get_favorite_assets()
assets = assets.annotate(parent_key=Value(node.key, output_field=CharField())) \ assets = assets.annotate(parent_key=Value(node.key, output_field=CharField()))
.prefetch_related('platform')
return [node], assets return [node], assets
@timeit
def _get_nodes_assets_for_all(self): def _get_nodes_assets_for_all(self):
nodes = self.query_node_util.get_whole_tree_nodes(with_special=False) nodes = self.query_node_util.get_whole_tree_nodes(with_special=False)
if settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE: if settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE:
assets = self.query_asset_util.get_perm_nodes_assets() assets = self.query_asset_util.get_perm_nodes_assets()
else: else:
assets = self.query_asset_util.get_all_assets() assets = Asset.objects.none()
assets = assets.annotate(parent_key=F('nodes__key')).prefetch_related('platform') assets = assets.annotate(parent_key=F('nodes__key'))
return nodes, assets return nodes, assets
@ -103,6 +121,7 @@ class UserPermedNodeChildrenWithAssetsAsTreeApi(BaseUserNodeWithAssetAsTreeApi):
# 默认展开的节点key # 默认展开的节点key
default_unfolded_node_key = None default_unfolded_node_key = None
@timeit
def get_nodes_assets(self): def get_nodes_assets(self):
query_node_util = UserPermNodeUtil(self.user) query_node_util = UserPermNodeUtil(self.user)
query_asset_util = UserPermAssetUtil(self.user) query_asset_util = UserPermAssetUtil(self.user)
@ -136,14 +155,14 @@ class UserPermedNodeChildrenWithAssetsAsTreeApi(BaseUserNodeWithAssetAsTreeApi):
node_key = getattr(node, 'key', None) node_key = getattr(node, 'key', None)
return node_key return node_key
@lazyproperty @property
def node_key_for_serialize_assets(self): def serialize_asset_kwargs(self):
return self.query_node_key or self.default_unfolded_node_key return {
'node_key': self.query_node_key or self.default_unfolded_node_key
}
class UserPermedNodeChildrenWithAssetsAsCategoryTreeApi( class UserPermedNodeChildrenWithAssetsAsCategoryTreeApi(BaseUserNodeWithAssetAsTreeApi):
SelfOrPKUserMixin, SerializeToTreeNodeMixin, ListAPIView
):
@property @property
def is_sync(self): def is_sync(self):
sync = self.request.query_params.get('sync', 0) sync = self.request.query_params.get('sync', 0)
@ -151,66 +170,54 @@ class UserPermedNodeChildrenWithAssetsAsCategoryTreeApi(
@property @property
def tp(self): def tp(self):
return self.request.query_params.get('type')
def get_assets(self):
query_asset_util = UserPermAssetUtil(self.user)
node = PermNode.objects.filter(
granted_node_rels__user=self.user, parent_key='').first()
if node:
__, assets = query_asset_util.get_node_all_assets(node.id)
else:
assets = Asset.objects.none()
return assets
def to_tree_nodes(self, assets):
if not assets:
return []
assets = assets.annotate(tp=F('platform__type'))
asset_type_map = defaultdict(list)
for asset in assets:
asset_type_map[asset.tp].append(asset)
tp = self.tp
if tp:
assets = asset_type_map.get(tp, [])
if not assets:
return []
pid = f'ROOT_{str(assets[0].category).upper()}_{tp}'
return self.serialize_assets(assets, pid=pid)
params = self.request.query_params params = self.request.query_params
get_root = not list(filter(lambda x: params.get(x), ('type', 'n'))) return [params.get('category'), params.get('type')]
resource_platforms = assets.order_by('id').values_list('platform_id', flat=True)
node_all = AllTypes.get_tree_nodes(resource_platforms, get_root=get_root)
pattern = re.compile(r'\(0\)?')
nodes = []
for node in node_all:
meta = node.get('meta', {})
if pattern.search(node['name']) or meta.get('type') == 'platform':
continue
_type = meta.get('_type')
if _type:
node['type'] = _type
meta.setdefault('data', {})
node['meta'] = meta
nodes.append(node)
if not self.is_sync: @lazyproperty
return nodes def query_asset_util(self):
return UserPermAssetUtil(self.user)
asset_nodes = [] @timeit
for node in nodes: def get_assets(self):
node['open'] = True return self.query_asset_util.get_all_assets()
tp = node.get('meta', {}).get('_type')
if not tp:
continue
assets = asset_type_map.get(tp, [])
asset_nodes += self.serialize_assets(assets, pid=node['id'])
return nodes + asset_nodes
def list(self, request, *args, **kwargs): def _get_tree_nodes_async(self):
assets = self.get_assets() if self.request.query_params.get('lv') == '0':
nodes = self.to_tree_nodes(assets) return [], []
return Response(data=nodes) if not self.tp or not all(self.tp):
nodes = UserPermAssetUtil.get_type_nodes_tree_or_cached(self.user)
return nodes, []
category, tp = self.tp
assets = self.get_assets().filter(platform__type=tp, platform__category=category)
return [], assets
def _get_tree_nodes_sync(self):
if self.request.query_params.get('lv'):
return []
nodes = self.query_asset_util.get_type_nodes_tree()
return nodes, []
@property
def serialize_asset_kwargs(self):
return {
'get_pid': lambda asset, platform: 'ROOT_{}_{}'.format(platform.category.upper(), platform.type),
}
def serialize_nodes(self, nodes, with_asset_amount=False):
return nodes
def get_nodes_assets(self):
if self.is_sync:
return self._get_tree_nodes_sync()
else:
return self._get_tree_nodes_async()
def get_page_assets(self):
if self.is_sync:
return self.get_assets()
else:
return []
class UserGrantedK8sAsTreeApi(SelfOrPKUserMixin, ListAPIView): class UserGrantedK8sAsTreeApi(SelfOrPKUserMixin, ListAPIView):
@ -258,5 +265,8 @@ class UserGrantedK8sAsTreeApi(SelfOrPKUserMixin, ListAPIView):
if not any([namespace, pod]) and not key: if not any([namespace, pod]) and not key:
asset_node = k8s_tree_instance.as_asset_tree_node() asset_node = k8s_tree_instance.as_asset_tree_node()
tree.append(asset_node) tree.append(asset_node)
tree.extend(k8s_tree_instance.async_tree_node(namespace, pod)) try:
return Response(data=tree) tree.extend(k8s_tree_instance.async_tree_node(namespace, pod))
return Response(data=tree)
except Exception as e:
raise JMSException(e)

View File

@ -8,7 +8,7 @@ from django.utils.translation import gettext_lazy as _
from accounts.const import AliasAccount from accounts.const import AliasAccount
from accounts.models import Account from accounts.models import Account
from assets.models import Asset from assets.models import Asset
from common.utils import date_expired_default from common.utils import date_expired_default, lazyproperty
from common.utils.timezone import local_now from common.utils.timezone import local_now
from labels.mixins import LabeledMixin from labels.mixins import LabeledMixin
from orgs.mixins.models import JMSOrgBaseModel from orgs.mixins.models import JMSOrgBaseModel
@ -105,6 +105,22 @@ class AssetPermission(LabeledMixin, JMSOrgBaseModel):
return True return True
return False return False
@lazyproperty
def users_amount(self):
return self.users.count()
@lazyproperty
def user_groups_amount(self):
return self.user_groups.count()
@lazyproperty
def assets_amount(self):
return self.assets.count()
@lazyproperty
def nodes_amount(self):
return self.nodes.count()
def get_all_users(self): def get_all_users(self):
from users.models import User from users.models import User
user_ids = self.users.all().values_list('id', flat=True) user_ids = self.users.all().values_list('id', flat=True)
@ -114,7 +130,7 @@ class AssetPermission(LabeledMixin, JMSOrgBaseModel):
qs1_ids = User.objects.filter(id__in=user_ids).distinct().values_list('id', flat=True) qs1_ids = User.objects.filter(id__in=user_ids).distinct().values_list('id', flat=True)
qs2_ids = User.objects.filter(groups__id__in=group_ids).distinct().values_list('id', flat=True) qs2_ids = User.objects.filter(groups__id__in=group_ids).distinct().values_list('id', flat=True)
qs_ids = list(qs1_ids) + list(qs2_ids) qs_ids = list(qs1_ids) + list(qs2_ids)
qs = User.objects.filter(id__in=qs_ids) qs = User.objects.filter(id__in=qs_ids, is_service_account=False)
return qs return qs
def get_all_assets(self, flat=False): def get_all_assets(self, flat=False):
@ -143,11 +159,14 @@ class AssetPermission(LabeledMixin, JMSOrgBaseModel):
@classmethod @classmethod
def get_all_users_for_perms(cls, perm_ids, flat=False): def get_all_users_for_perms(cls, perm_ids, flat=False):
user_ids = cls.users.through.objects.filter(assetpermission_id__in=perm_ids) \ user_ids = cls.users.through.objects \
.filter(assetpermission_id__in=perm_ids) \
.values_list('user_id', flat=True).distinct() .values_list('user_id', flat=True).distinct()
group_ids = cls.user_groups.through.objects.filter(assetpermission_id__in=perm_ids) \ group_ids = cls.user_groups.through.objects \
.filter(assetpermission_id__in=perm_ids) \
.values_list('usergroup_id', flat=True).distinct() .values_list('usergroup_id', flat=True).distinct()
group_user_ids = User.groups.through.objects.filter(usergroup_id__in=group_ids) \ group_user_ids = User.groups.through.objects \
.filter(usergroup_id__in=group_ids) \
.values_list('user_id', flat=True).distinct() .values_list('user_id', flat=True).distinct()
user_ids = set(user_ids) | set(group_user_ids) user_ids = set(user_ids) | set(group_user_ids)
if flat: if flat:

View File

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
from django.db.models import Q from django.db.models import Q, Count
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from rest_framework import serializers from rest_framework import serializers
@ -14,7 +14,7 @@ from orgs.mixins.serializers import BulkOrgResourceModelSerializer
from perms.models import ActionChoices, AssetPermission from perms.models import ActionChoices, AssetPermission
from users.models import User, UserGroup from users.models import User, UserGroup
__all__ = ["AssetPermissionSerializer", "ActionChoicesField"] __all__ = ["AssetPermissionSerializer", "ActionChoicesField", "AssetPermissionListSerializer"]
class ActionChoicesField(BitChoicesField): class ActionChoicesField(BitChoicesField):
@ -142,8 +142,8 @@ class AssetPermissionSerializer(ResourceLabelsMixin, BulkOrgResourceModelSeriali
def perform_display_create(instance, **kwargs): def perform_display_create(instance, **kwargs):
# 用户 # 用户
users_to_set = User.objects.filter( users_to_set = User.objects.filter(
Q(name__in=kwargs.get("users_display")) Q(name__in=kwargs.get("users_display")) |
| Q(username__in=kwargs.get("users_display")) Q(username__in=kwargs.get("users_display"))
).distinct() ).distinct()
instance.users.add(*users_to_set) instance.users.add(*users_to_set)
# 用户组 # 用户组
@ -153,8 +153,8 @@ class AssetPermissionSerializer(ResourceLabelsMixin, BulkOrgResourceModelSeriali
instance.user_groups.add(*user_groups_to_set) instance.user_groups.add(*user_groups_to_set)
# 资产 # 资产
assets_to_set = Asset.objects.filter( assets_to_set = Asset.objects.filter(
Q(address__in=kwargs.get("assets_display")) Q(address__in=kwargs.get("assets_display")) |
| Q(name__in=kwargs.get("assets_display")) Q(name__in=kwargs.get("assets_display"))
).distinct() ).distinct()
instance.assets.add(*assets_to_set) instance.assets.add(*assets_to_set)
# 节点 # 节点
@ -180,3 +180,27 @@ class AssetPermissionSerializer(ResourceLabelsMixin, BulkOrgResourceModelSeriali
instance = super().create(validated_data) instance = super().create(validated_data)
self.perform_display_create(instance, **display) self.perform_display_create(instance, **display)
return instance return instance
class AssetPermissionListSerializer(AssetPermissionSerializer):
users_amount = serializers.IntegerField(read_only=True, label=_("Users amount"))
user_groups_amount = serializers.IntegerField(read_only=True, label=_("User groups amount"))
assets_amount = serializers.IntegerField(read_only=True, label=_("Assets amount"))
nodes_amount = serializers.IntegerField(read_only=True, label=_("Nodes amount"))
class Meta(AssetPermissionSerializer.Meta):
amount_fields = ["users_amount", "user_groups_amount", "assets_amount", "nodes_amount"]
remove_fields = {"users", "assets", "nodes", "user_groups"}
fields = list(set(AssetPermissionSerializer.Meta.fields + amount_fields) - remove_fields)
@classmethod
def setup_eager_loading(cls, queryset):
"""Perform necessary eager loading of data."""
queryset = queryset \
.prefetch_related('labels', 'labels__label') \
.annotate(users_amount=Count("users"),
user_groups_amount=Count("user_groups"),
assets_amount=Count("assets"),
nodes_amount=Count("nodes"),
)
return queryset

View File

@ -3,15 +3,13 @@
from django.db.models.signals import m2m_changed, pre_delete, pre_save, post_save from django.db.models.signals import m2m_changed, pre_delete, pre_save, post_save
from django.dispatch import receiver from django.dispatch import receiver
from users.models import User, UserGroup
from assets.models import Asset from assets.models import Asset
from common.utils import get_logger, get_object_or_none
from common.exceptions import M2MReverseNotAllowed
from common.const.signals import POST_ADD, POST_REMOVE, POST_CLEAR from common.const.signals import POST_ADD, POST_REMOVE, POST_CLEAR
from common.exceptions import M2MReverseNotAllowed
from common.utils import get_logger, get_object_or_none
from perms.models import AssetPermission from perms.models import AssetPermission
from perms.utils import UserPermTreeExpireUtil from perms.utils import UserPermTreeExpireUtil
from users.models import User, UserGroup
logger = get_logger(__file__) logger = get_logger(__file__)
@ -38,7 +36,7 @@ def on_user_groups_change(sender, instance, action, reverse, pk_set, **kwargs):
group = UserGroup.objects.get(id=list(group_ids)[0]) group = UserGroup.objects.get(id=list(group_ids)[0])
org_id = group.org_id org_id = group.org_id
has_group_perm = AssetPermission.user_groups.through.objects\ has_group_perm = AssetPermission.user_groups.through.objects \
.filter(usergroup_id__in=group_ids).exists() .filter(usergroup_id__in=group_ids).exists()
if not has_group_perm: if not has_group_perm:
return return
@ -115,6 +113,7 @@ def on_asset_permission_user_groups_changed(sender, instance, action, pk_set, re
def on_node_asset_change(action, instance, reverse, pk_set, **kwargs): def on_node_asset_change(action, instance, reverse, pk_set, **kwargs):
if not need_rebuild_mapping_node(action): if not need_rebuild_mapping_node(action):
return return
print("Asset node changed: ", action)
if reverse: if reverse:
asset_ids = pk_set asset_ids = pk_set
node_ids = [instance.id] node_ids = [instance.id]

View File

@ -1,8 +1,7 @@
from django.db.models import QuerySet from django.db.models import QuerySet
from assets.models import Node, Asset from assets.models import Node, Asset
from common.utils import get_logger from common.utils import get_logger, timeit
from perms.models import AssetPermission from perms.models import AssetPermission
logger = get_logger(__file__) logger = get_logger(__file__)
@ -13,6 +12,7 @@ __all__ = ['AssetPermissionUtil']
class AssetPermissionUtil(object): class AssetPermissionUtil(object):
""" 资产授权相关的方法工具 """ """ 资产授权相关的方法工具 """
@timeit
def get_permissions_for_user(self, user, with_group=True, flat=False): def get_permissions_for_user(self, user, with_group=True, flat=False):
""" 获取用户的授权规则 """ """ 获取用户的授权规则 """
perm_ids = set() perm_ids = set()

View File

@ -1,13 +1,22 @@
from django.conf import settings import json
from django.db.models import Q import re
from django.conf import settings
from django.core.cache import cache
from django.db.models import Q
from rest_framework.utils.encoders import JSONEncoder
from assets.const import AllTypes
from assets.models import FavoriteAsset, Asset from assets.models import FavoriteAsset, Asset
from common.utils.common import timeit from common.utils.common import timeit, get_logger
from orgs.utils import current_org, tmp_to_root_org
from perms.models import PermNode, UserAssetGrantedTreeNodeRelation from perms.models import PermNode, UserAssetGrantedTreeNodeRelation
from .permission import AssetPermissionUtil from .permission import AssetPermissionUtil
__all__ = ['AssetPermissionPermAssetUtil', 'UserPermAssetUtil', 'UserPermNodeUtil'] __all__ = ['AssetPermissionPermAssetUtil', 'UserPermAssetUtil', 'UserPermNodeUtil']
logger = get_logger(__name__)
class AssetPermissionPermAssetUtil: class AssetPermissionPermAssetUtil:
@ -15,30 +24,35 @@ class AssetPermissionPermAssetUtil:
self.perm_ids = perm_ids self.perm_ids = perm_ids
def get_all_assets(self): def get_all_assets(self):
""" 获取所有授权的资产 """ node_assets = self.get_perm_nodes_assets()
node_asset_ids = self.get_perm_nodes_assets(flat=True) direct_assets = self.get_direct_assets()
direct_asset_ids = self.get_direct_assets(flat=True) # 比原来的查到所有 asset id 再搜索块很多,因为当资产量大的时候,搜索会很慢
asset_ids = list(node_asset_ids) + list(direct_asset_ids) return (node_assets | direct_assets).distinct()
assets = Asset.objects.filter(id__in=asset_ids)
return assets
@timeit
def get_perm_nodes_assets(self, flat=False): def get_perm_nodes_assets(self, flat=False):
""" 获取所有授权节点下的资产 """ """ 获取所有授权节点下的资产 """
from assets.models import Node from assets.models import Node
nodes = Node.objects.prefetch_related('granted_by_permissions').filter( from ..models import AssetPermission
granted_by_permissions__in=self.perm_ids).only('id', 'key') nodes_ids = AssetPermission.objects \
.filter(id__in=self.perm_ids) \
.values_list('nodes', flat=True)
nodes = Node.objects.filter(id__in=nodes_ids).only('id', 'key')
assets = PermNode.get_nodes_all_assets(*nodes) assets = PermNode.get_nodes_all_assets(*nodes)
if flat: if flat:
return assets.values_list('id', flat=True) return set(assets.values_list('id', flat=True))
return assets return assets
@timeit
def get_direct_assets(self, flat=False): def get_direct_assets(self, flat=False):
""" 获取直接授权的资产 """ """ 获取直接授权的资产 """
assets = Asset.objects.order_by() \ from ..models import AssetPermission
.filter(granted_by_permissions__id__in=self.perm_ids) \ asset_ids = AssetPermission.objects \
.distinct() .filter(id__in=self.perm_ids) \
.values_list('assets', flat=True)
assets = Asset.objects.filter(id__in=asset_ids).distinct()
if flat: if flat:
return assets.values_list('id', flat=True) return set(assets.values_list('id', flat=True))
return assets return assets
@ -52,12 +66,62 @@ class UserPermAssetUtil(AssetPermissionPermAssetUtil):
def get_ungroup_assets(self): def get_ungroup_assets(self):
return self.get_direct_assets() return self.get_direct_assets()
@timeit
def get_favorite_assets(self): def get_favorite_assets(self):
assets = self.get_all_assets() assets = Asset.objects.all().valid()
asset_ids = FavoriteAsset.objects.filter(user=self.user).values_list('asset_id', flat=True) asset_ids = FavoriteAsset.objects.filter(user=self.user).values_list('asset_id', flat=True)
assets = assets.filter(id__in=list(asset_ids)) assets = assets.filter(id__in=list(asset_ids))
return assets return assets
def get_type_nodes_tree(self):
assets = self.get_all_assets()
resource_platforms = assets.order_by('id').values_list('platform_id', flat=True)
node_all = AllTypes.get_tree_nodes(resource_platforms, get_root=True)
pattern = re.compile(r'\(0\)?')
nodes = []
for node in node_all:
meta = node.get('meta', {})
if pattern.search(node['name']) or meta.get('type') == 'platform':
continue
_type = meta.get('_type')
if _type:
node['type'] = _type
node['category'] = meta.get('category')
meta.setdefault('data', {})
node['meta'] = meta
nodes.append(node)
return nodes
@classmethod
def get_type_nodes_tree_or_cached(cls, user):
key = f'perms:type-nodes-tree:{user.id}:{current_org.id}'
nodes = cache.get(key)
if nodes is None:
nodes = cls(user).get_type_nodes_tree()
nodes_json = json.dumps(nodes, cls=JSONEncoder)
cache.set(key, nodes_json, 60 * 60 * 24)
else:
nodes = json.loads(nodes)
return nodes
def refresh_type_nodes_tree_cache(self):
logger.debug("Refresh type nodes tree cache")
key = f'perms:type-nodes-tree:{self.user.id}:{current_org.id}'
cache.delete(key)
def refresh_favorite_assets(self):
favor_ids = FavoriteAsset.objects.filter(user=self.user).values_list('asset_id', flat=True)
favor_ids = set(favor_ids)
with tmp_to_root_org():
valid_ids = self.get_all_assets() \
.filter(id__in=favor_ids) \
.values_list('id', flat=True)
valid_ids = set(valid_ids)
invalid_ids = favor_ids - valid_ids
FavoriteAsset.objects.filter(user=self.user, asset_id__in=invalid_ids).delete()
def get_node_assets(self, key): def get_node_assets(self, key):
node = PermNode.objects.get(key=key) node = PermNode.objects.get(key=key)
node.compute_node_from_and_assets_amount(self.user) node.compute_node_from_and_assets_amount(self.user)
@ -90,6 +154,7 @@ class UserPermAssetUtil(AssetPermissionPermAssetUtil):
assets = assets.filter(nodes__id=node.id).order_by().distinct() assets = assets.filter(nodes__id=node.id).order_by().distinct()
return assets return assets
@timeit
def _get_indirect_perm_node_all_assets(self, node): def _get_indirect_perm_node_all_assets(self, node):
""" 获取间接授权节点下的所有资产 """ 获取间接授权节点下的所有资产
此算法依据 `UserAssetGrantedTreeNodeRelation` 的数据查询 此算法依据 `UserAssetGrantedTreeNodeRelation` 的数据查询
@ -134,7 +199,11 @@ class UserPermNodeUtil:
self.perm_ids = AssetPermissionUtil().get_permissions_for_user(self.user, flat=True) self.perm_ids = AssetPermissionUtil().get_permissions_for_user(self.user, flat=True)
def get_favorite_node(self): def get_favorite_node(self):
assets_amount = UserPermAssetUtil(self.user).get_favorite_assets().count() favor_ids = FavoriteAsset.objects \
.filter(user=self.user) \
.values_list('asset_id') \
.distinct()
assets_amount = Asset.objects.all().valid().filter(id__in=favor_ids).count()
return PermNode.get_favorite_node(assets_amount) return PermNode.get_favorite_node(assets_amount)
def get_ungrouped_node(self): def get_ungrouped_node(self):

View File

@ -3,11 +3,12 @@ from collections import defaultdict
from django.conf import settings from django.conf import settings
from django.core.cache import cache from django.core.cache import cache
from django.db import transaction
from assets.models import Asset from assets.models import Asset
from assets.utils import NodeAssetsUtil from assets.utils import NodeAssetsUtil
from common.db.models import output_as_string from common.db.models import output_as_string
from common.decorators import on_transaction_commit from common.decorators import on_transaction_commit, merge_delay_run
from common.utils import get_logger from common.utils import get_logger
from common.utils.common import lazyproperty, timeit from common.utils.common import lazyproperty, timeit
from orgs.models import Organization from orgs.models import Organization
@ -23,6 +24,7 @@ from perms.models import (
PermNode PermNode
) )
from users.models import User from users.models import User
from . import UserPermAssetUtil
from .permission import AssetPermissionUtil from .permission import AssetPermissionUtil
logger = get_logger(__name__) logger = get_logger(__name__)
@ -50,24 +52,74 @@ class UserPermTreeRefreshUtil(_UserPermTreeCacheMixin):
def __init__(self, user): def __init__(self, user):
self.user = user self.user = user
self.orgs = self.user.orgs.distinct()
self.org_ids = [str(o.id) for o in self.orgs] @lazyproperty
def orgs(self):
return self.user.orgs.distinct()
@lazyproperty
def org_ids(self):
return [str(o.id) for o in self.orgs]
@lazyproperty @lazyproperty
def cache_key_user(self): def cache_key_user(self):
return self.get_cache_key(self.user.id) return self.get_cache_key(self.user.id)
@lazyproperty
def cache_key_time(self):
key = 'perms.user.node_tree.built_time.{}'.format(self.user.id)
return key
@timeit @timeit
def refresh_if_need(self, force=False): def refresh_if_need(self, force=False):
self._clean_user_perm_tree_for_legacy_org() built_just_now = cache.get(self.cache_key_time)
if built_just_now:
logger.info('Refresh user perm tree just now, pass: {}'.format(built_just_now))
return
to_refresh_orgs = self.orgs if force else self._get_user_need_refresh_orgs() to_refresh_orgs = self.orgs if force else self._get_user_need_refresh_orgs()
if not to_refresh_orgs: if not to_refresh_orgs:
logger.info('Not have to refresh orgs') logger.info('Not have to refresh orgs')
return return
with UserGrantedTreeRebuildLock(self.user.id): logger.info("Delay refresh user orgs: {} {}".format(self.user, [o.name for o in to_refresh_orgs]))
refresh_user_orgs_perm_tree(user_orgs=((self.user, tuple(to_refresh_orgs)),))
refresh_user_favorite_assets(users=(self.user,))
@timeit
def refresh_tree_manual(self):
built_just_now = cache.get(self.cache_key_time)
if built_just_now:
logger.info('Refresh just now, pass: {}'.format(built_just_now))
return
to_refresh_orgs = self._get_user_need_refresh_orgs()
if not to_refresh_orgs:
logger.info('Not have to refresh orgs for user: {}'.format(self.user))
return
self.perform_refresh_user_tree(to_refresh_orgs)
@timeit
def perform_refresh_user_tree(self, to_refresh_orgs):
# 再判断一次,毕竟构建树比较慢
built_just_now = cache.get(self.cache_key_time)
if built_just_now:
logger.info('Refresh user perm tree just now, pass: {}'.format(built_just_now))
return
self._clean_user_perm_tree_for_legacy_org()
ttl = settings.PERM_TREE_REGEN_INTERVAL
cache.set(self.cache_key_time, int(time.time()), ttl)
lock = UserGrantedTreeRebuildLock(self.user.id)
got = lock.acquire(blocking=False)
if not got:
logger.info('User perm tree rebuild lock not acquired, pass')
return
try:
for org in to_refresh_orgs: for org in to_refresh_orgs:
self._rebuild_user_perm_tree_for_org(org) self._rebuild_user_perm_tree_for_org(org)
self._mark_user_orgs_refresh_finished(to_refresh_orgs) self._mark_user_orgs_refresh_finished(to_refresh_orgs)
finally:
lock.release()
def _rebuild_user_perm_tree_for_org(self, org): def _rebuild_user_perm_tree_for_org(self, org):
with tmp_to_org(org): with tmp_to_org(org):
@ -75,7 +127,7 @@ class UserPermTreeRefreshUtil(_UserPermTreeCacheMixin):
UserPermTreeBuildUtil(self.user).rebuild_user_perm_tree() UserPermTreeBuildUtil(self.user).rebuild_user_perm_tree()
end = time.time() end = time.time()
logger.info( logger.info(
'Refresh user [{user}] org [{org}] perm tree, user {use_time:.2f}s' 'Refresh user perm tree: [{user}] org [{org}] {use_time:.2f}s'
''.format(user=self.user, org=org, use_time=end - start) ''.format(user=self.user, org=org, use_time=end - start)
) )
@ -90,7 +142,7 @@ class UserPermTreeRefreshUtil(_UserPermTreeCacheMixin):
cached_org_ids = self.client.smembers(self.cache_key_user) cached_org_ids = self.client.smembers(self.cache_key_user)
cached_org_ids = {oid.decode() for oid in cached_org_ids} cached_org_ids = {oid.decode() for oid in cached_org_ids}
to_refresh_org_ids = set(self.org_ids) - cached_org_ids to_refresh_org_ids = set(self.org_ids) - cached_org_ids
to_refresh_orgs = Organization.objects.filter(id__in=to_refresh_org_ids) to_refresh_orgs = list(Organization.objects.filter(id__in=to_refresh_org_ids))
logger.info(f'Need to refresh orgs: {to_refresh_orgs}') logger.info(f'Need to refresh orgs: {to_refresh_orgs}')
return to_refresh_orgs return to_refresh_orgs
@ -128,7 +180,8 @@ class UserPermTreeExpireUtil(_UserPermTreeCacheMixin):
self.expire_perm_tree_for_user_groups_orgs(group_ids, org_ids) self.expire_perm_tree_for_user_groups_orgs(group_ids, org_ids)
def expire_perm_tree_for_user_groups_orgs(self, group_ids, org_ids): def expire_perm_tree_for_user_groups_orgs(self, group_ids, org_ids):
user_ids = User.groups.through.objects.filter(usergroup_id__in=group_ids) \ user_ids = User.groups.through.objects \
.filter(usergroup_id__in=group_ids) \
.values_list('user_id', flat=True).distinct() .values_list('user_id', flat=True).distinct()
self.expire_perm_tree_for_users_orgs(user_ids, org_ids) self.expire_perm_tree_for_users_orgs(user_ids, org_ids)
@ -151,6 +204,21 @@ class UserPermTreeExpireUtil(_UserPermTreeCacheMixin):
logger.info('Expire all user perm tree') logger.info('Expire all user perm tree')
@merge_delay_run(ttl=20)
def refresh_user_orgs_perm_tree(user_orgs=()):
for user, orgs in user_orgs:
util = UserPermTreeRefreshUtil(user)
util.perform_refresh_user_tree(orgs)
@merge_delay_run(ttl=20)
def refresh_user_favorite_assets(users=()):
for user in users:
util = UserPermAssetUtil(user)
util.refresh_favorite_assets()
util.refresh_type_nodes_tree_cache()
class UserPermTreeBuildUtil(object): class UserPermTreeBuildUtil(object):
node_only_fields = ('id', 'key', 'parent_key', 'org_id') node_only_fields = ('id', 'key', 'parent_key', 'org_id')
@ -161,13 +229,14 @@ class UserPermTreeBuildUtil(object):
self._perm_nodes_key_node_mapper = {} self._perm_nodes_key_node_mapper = {}
def rebuild_user_perm_tree(self): def rebuild_user_perm_tree(self):
self.clean_user_perm_tree() with transaction.atomic():
if not self.user_perm_ids: self.clean_user_perm_tree()
logger.info('User({}) not have permissions'.format(self.user)) if not self.user_perm_ids:
return logger.info('User({}) not have permissions'.format(self.user))
self.compute_perm_nodes() return
self.compute_perm_nodes_asset_amount() self.compute_perm_nodes()
self.create_mapping_nodes() self.compute_perm_nodes_asset_amount()
self.create_mapping_nodes()
def clean_user_perm_tree(self): def clean_user_perm_tree(self):
UserAssetGrantedTreeNodeRelation.objects.filter(user=self.user).delete() UserAssetGrantedTreeNodeRelation.objects.filter(user=self.user).delete()

View File

@ -139,7 +139,7 @@ class RBACPermission(permissions.DjangoModelPermissions):
if isinstance(perms, str): if isinstance(perms, str):
perms = [perms] perms = [perms]
has = request.user.has_perms(perms) has = request.user.has_perms(perms)
logger.debug('View require perms: {}, result: {}'.format(perms, has)) logger.debug('Api require perms: {}, result: {}'.format(perms, has))
return has return has
def has_object_permission(self, request, view, obj): def has_object_permission(self, request, view, obj):

View File

@ -4,11 +4,12 @@
from smtplib import SMTPSenderRefused from smtplib import SMTPSenderRefused
from django.conf import settings from django.conf import settings
from django.core.mail import send_mail, get_connection from django.core.mail import send_mail
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from rest_framework.views import Response, APIView from rest_framework.views import Response, APIView
from common.utils import get_logger from common.utils import get_logger
from common.tasks import get_email_connection as get_connection
from .. import serializers from .. import serializers
logger = get_logger(__file__) logger = get_logger(__file__)

View File

@ -137,7 +137,7 @@ class LDAPUserImportAPI(APIView):
return Response({'msg': _('Get ldap users is None')}, status=400) return Response({'msg': _('Get ldap users is None')}, status=400)
orgs = self.get_orgs() orgs = self.get_orgs()
errors = LDAPImportUtil().perform_import(users, orgs) new_users, errors = LDAPImportUtil().perform_import(users, orgs)
if errors: if errors:
return Response({'errors': errors}, status=400) return Response({'errors': errors}, status=400)

View File

@ -0,0 +1,36 @@
from django.template.loader import render_to_string
from django.utils.translation import gettext as _
from common.utils import get_logger
from common.utils.timezone import local_now_display
from notifications.notifications import UserMessage
logger = get_logger(__file__)
class LDAPImportMessage(UserMessage):
def __init__(self, user, extra_kwargs):
super().__init__(user)
self.orgs = extra_kwargs.pop('orgs', [])
self.end_time = extra_kwargs.pop('end_time', '')
self.start_time = extra_kwargs.pop('start_time', '')
self.time_start_display = extra_kwargs.pop('time_start_display', '')
self.new_users = extra_kwargs.pop('new_users', [])
self.errors = extra_kwargs.pop('errors', [])
self.cost_time = extra_kwargs.pop('cost_time', '')
def get_html_msg(self) -> dict:
subject = _('Notification of Synchronized LDAP User Task Results')
context = {
'orgs': self.orgs,
'start_time': self.time_start_display,
'end_time': local_now_display(),
'cost_time': self.cost_time,
'users': self.new_users,
'errors': self.errors
}
message = render_to_string('ldap/_msg_import_ldap_user.html', context)
return {
'subject': subject,
'message': message
}

View File

@ -77,6 +77,9 @@ class LDAPSettingSerializer(serializers.Serializer):
required=False, label=_('Connect timeout (s)'), required=False, label=_('Connect timeout (s)'),
) )
AUTH_LDAP_SEARCH_PAGED_SIZE = serializers.IntegerField(required=False, label=_('Search paged size (piece)')) AUTH_LDAP_SEARCH_PAGED_SIZE = serializers.IntegerField(required=False, label=_('Search paged size (piece)'))
AUTH_LDAP_SYNC_RECEIVERS = serializers.ListField(
required=False, label=_('Recipient'), max_length=36
)
AUTH_LDAP = serializers.BooleanField(required=False, label=_('Enable LDAP auth')) AUTH_LDAP = serializers.BooleanField(required=False, label=_('Enable LDAP auth'))

View File

@ -55,6 +55,17 @@ class VaultSettingSerializer(serializers.Serializer):
max_length=256, allow_blank=True, required=False, label=_('Mount Point') max_length=256, allow_blank=True, required=False, label=_('Mount Point')
) )
HISTORY_ACCOUNT_CLEAN_LIMIT = serializers.IntegerField(
default=999, max_value=999, min_value=1,
required=False, label=_('Historical accounts retained count'),
help_text=_(
'If the specific value is less than 999, '
'the system will automatically perform a task every night: '
'check and delete historical accounts that exceed the predetermined number. '
'If the value reaches or exceeds 999, no historical account deletion will be performed.'
)
)
class ChatAISettingSerializer(serializers.Serializer): class ChatAISettingSerializer(serializers.Serializer):
PREFIX_TITLE = _('Chat AI') PREFIX_TITLE = _('Chat AI')

View File

@ -1,11 +1,12 @@
# coding: utf-8 # coding: utf-8
# #
from django.db import models
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from rest_framework import serializers from rest_framework import serializers
from common.serializers.fields import EncryptedField from common.serializers.fields import EncryptedField
__all__ = [ __all__ = [
'MailTestSerializer', 'EmailSettingSerializer', 'MailTestSerializer', 'EmailSettingSerializer',
'EmailContentSettingSerializer', 'SMSBackendSerializer', 'EmailContentSettingSerializer', 'SMSBackendSerializer',
@ -18,14 +19,20 @@ class MailTestSerializer(serializers.Serializer):
class EmailSettingSerializer(serializers.Serializer): class EmailSettingSerializer(serializers.Serializer):
# encrypt_fields 现在使用 write_only 来判断了
PREFIX_TITLE = _('Email') PREFIX_TITLE = _('Email')
EMAIL_HOST = serializers.CharField(max_length=1024, required=True, label=_("SMTP host")) class EmailProtocol(models.TextChoices):
EMAIL_PORT = serializers.CharField(max_length=5, required=True, label=_("SMTP port")) smtp = 'smtp', _('SMTP')
EMAIL_HOST_USER = serializers.CharField(max_length=128, required=True, label=_("SMTP account")) exchange = 'exchange', _('EXCHANGE')
EMAIL_PROTOCOL = serializers.ChoiceField(
choices=EmailProtocol.choices, label=_("Protocol"), default=EmailProtocol.smtp
)
EMAIL_HOST = serializers.CharField(max_length=1024, required=True, label=_("Host"))
EMAIL_PORT = serializers.CharField(max_length=5, required=True, label=_("Port"))
EMAIL_HOST_USER = serializers.CharField(max_length=128, required=True, label=_("Account"))
EMAIL_HOST_PASSWORD = EncryptedField( EMAIL_HOST_PASSWORD = EncryptedField(
max_length=1024, required=False, label=_("SMTP password"), max_length=1024, required=False, label=_("Password"),
help_text=_("Tips: Some provider use token except password") help_text=_("Tips: Some provider use token except password")
) )
EMAIL_FROM = serializers.CharField( EMAIL_FROM = serializers.CharField(

View File

@ -1,15 +1,19 @@
# coding: utf-8 # coding: utf-8
# #
import time
from celery import shared_task from celery import shared_task
from django.conf import settings from django.conf import settings
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from common.utils import get_logger from common.utils import get_logger
from common.utils.timezone import local_now_display
from ops.celery.decorator import after_app_ready_start from ops.celery.decorator import after_app_ready_start
from ops.celery.utils import ( from ops.celery.utils import (
create_or_update_celery_periodic_tasks, disable_celery_periodic_task create_or_update_celery_periodic_tasks, disable_celery_periodic_task
) )
from orgs.models import Organization from orgs.models import Organization
from settings.notifications import LDAPImportMessage
from users.models import User
from ..utils import LDAPSyncUtil, LDAPServerUtil, LDAPImportUtil from ..utils import LDAPSyncUtil, LDAPServerUtil, LDAPImportUtil
__all__ = ['sync_ldap_user', 'import_ldap_user_periodic', 'import_ldap_user'] __all__ = ['sync_ldap_user', 'import_ldap_user_periodic', 'import_ldap_user']
@ -23,6 +27,8 @@ def sync_ldap_user():
@shared_task(verbose_name=_('Periodic import ldap user')) @shared_task(verbose_name=_('Periodic import ldap user'))
def import_ldap_user(): def import_ldap_user():
start_time = time.time()
time_start_display = local_now_display()
logger.info("Start import ldap user task") logger.info("Start import ldap user task")
util_server = LDAPServerUtil() util_server = LDAPServerUtil()
util_import = LDAPImportUtil() util_import = LDAPImportUtil()
@ -35,11 +41,26 @@ def import_ldap_user():
org_ids = [Organization.DEFAULT_ID] org_ids = [Organization.DEFAULT_ID]
default_org = Organization.default() default_org = Organization.default()
orgs = list(set([Organization.get_instance(org_id, default=default_org) for org_id in org_ids])) orgs = list(set([Organization.get_instance(org_id, default=default_org) for org_id in org_ids]))
errors = util_import.perform_import(users, orgs) new_users, errors = util_import.perform_import(users, orgs)
if errors: if errors:
logger.error("Imported LDAP users errors: {}".format(errors)) logger.error("Imported LDAP users errors: {}".format(errors))
else: else:
logger.info('Imported {} users successfully'.format(len(users))) logger.info('Imported {} users successfully'.format(len(users)))
if settings.AUTH_LDAP_SYNC_RECEIVERS:
user_ids = settings.AUTH_LDAP_SYNC_RECEIVERS
recipient_list = User.objects.filter(id__in=list(user_ids))
end_time = time.time()
extra_kwargs = {
'orgs': orgs,
'end_time': end_time,
'start_time': start_time,
'time_start_display': time_start_display,
'new_users': new_users,
'errors': errors,
'cost_time': end_time - start_time,
}
for user in recipient_list:
LDAPImportMessage(user, extra_kwargs).publish()
@shared_task(verbose_name=_('Registration periodic import ldap user task')) @shared_task(verbose_name=_('Registration periodic import ldap user task'))

View File

@ -0,0 +1,30 @@
{% load i18n %}
<p>{% trans "Sync task Finish" %}</p>
<b>{% trans 'Time' %}:</b>
<ul>
<li>{% trans 'Date start' %}: {{ start_time }}</li>
<li>{% trans 'Date end' %}: {{ end_time }}</li>
<li>{% trans 'Time cost' %}: {{ cost_time| floatformat:0 }}s</li>
</ul>
<b>{% trans "Synced Organization" %}:</b>
<ul>
{% for org in orgs %}
<li>{{ org }}</li>
{% endfor %}
</ul>
<b>{% trans "Synced User" %}:</b>
<ul>
{% for user in users %}
<li>{{ user }}</li>
{% endfor %}
</ul>
{% if errors %}
<b>{% trans 'Error' %}:</b>
<ul>
{% for error in errors %}
<li>{{ error }}</li>
{% endfor %}
</ul>
{% endif %}

View File

@ -400,11 +400,14 @@ class LDAPImportUtil(object):
logger.info('Start perform import ldap users, count: {}'.format(len(users))) logger.info('Start perform import ldap users, count: {}'.format(len(users)))
errors = [] errors = []
objs = [] objs = []
new_users = []
group_users_mapper = defaultdict(set) group_users_mapper = defaultdict(set)
for user in users: for user in users:
groups = user.pop('groups', []) groups = user.pop('groups', [])
try: try:
obj, created = self.update_or_create(user) obj, created = self.update_or_create(user)
if created:
new_users.append(obj)
objs.append(obj) objs.append(obj)
except Exception as e: except Exception as e:
errors.append({user['username']: str(e)}) errors.append({user['username']: str(e)})
@ -421,14 +424,13 @@ class LDAPImportUtil(object):
for org in orgs: for org in orgs:
self.bind_org(org, objs, group_users_mapper) self.bind_org(org, objs, group_users_mapper)
logger.info('End perform import ldap users') logger.info('End perform import ldap users')
return errors return new_users, errors
@staticmethod def exit_user_group(self, user_groups_mapper):
def exit_user_group(user_groups_mapper):
# 通过对比查询本次导入用户需要移除的用户组 # 通过对比查询本次导入用户需要移除的用户组
group_remove_users_mapper = defaultdict(set) group_remove_users_mapper = defaultdict(set)
for user, current_groups in user_groups_mapper.items(): for user, current_groups in user_groups_mapper.items():
old_groups = set(user.groups.all()) old_groups = set(user.groups.filter(name__startswith=self.user_group_name_prefix))
exit_groups = old_groups - current_groups exit_groups = old_groups - current_groups
logger.debug(f'Ldap user {user} exits user groups {exit_groups}') logger.debug(f'Ldap user {user} exits user groups {exit_groups}')
for g in exit_groups: for g in exit_groups:

View File

@ -4,6 +4,7 @@ from django.utils.translation import gettext_lazy as _
from rest_framework import viewsets from rest_framework import viewsets
from rest_framework.decorators import action from rest_framework.decorators import action
from rest_framework.exceptions import MethodNotAllowed from rest_framework.exceptions import MethodNotAllowed
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response from rest_framework.response import Response
from audits.handler import create_or_update_operate_log from audits.handler import create_or_update_operate_log
@ -41,7 +42,6 @@ class TicketViewSet(CommonApiMixin, viewsets.ModelViewSet):
ordering = ('-date_created',) ordering = ('-date_created',)
rbac_perms = { rbac_perms = {
'open': 'tickets.view_ticket', 'open': 'tickets.view_ticket',
'bulk': 'tickets.change_ticket',
} }
def retrieve(self, request, *args, **kwargs): def retrieve(self, request, *args, **kwargs):
@ -122,7 +122,7 @@ class TicketViewSet(CommonApiMixin, viewsets.ModelViewSet):
self._record_operate_log(instance, TicketAction.close) self._record_operate_log(instance, TicketAction.close)
return Response('ok') return Response('ok')
@action(detail=False, methods=[PUT], permission_classes=[RBACPermission, ]) @action(detail=False, methods=[PUT], permission_classes=[IsAuthenticated, ])
def bulk(self, request, *args, **kwargs): def bulk(self, request, *args, **kwargs):
self.ticket_not_allowed() self.ticket_not_allowed()

View File

@ -6,7 +6,7 @@ from rest_framework.response import Response
from orgs.mixins.api import OrgBulkModelViewSet from orgs.mixins.api import OrgBulkModelViewSet
from ..models import UserGroup, User from ..models import UserGroup, User
from ..serializers import UserGroupSerializer from ..serializers import UserGroupSerializer, UserGroupListSerializer
__all__ = ['UserGroupViewSet'] __all__ = ['UserGroupViewSet']
@ -15,7 +15,10 @@ class UserGroupViewSet(OrgBulkModelViewSet):
model = UserGroup model = UserGroup
filterset_fields = ("name",) filterset_fields = ("name",)
search_fields = filterset_fields search_fields = filterset_fields
serializer_class = UserGroupSerializer serializer_classes = {
'default': UserGroupSerializer,
'list': UserGroupListSerializer,
}
ordering = ('name',) ordering = ('name',)
rbac_perms = ( rbac_perms = (
("add_all_users", "users.add_usergroup"), ("add_all_users", "users.add_usergroup"),

View File

@ -1,7 +1,8 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
from django.db.models import Count from django.db.models import Count, Q
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from rest_framework import serializers
from common.serializers.fields import ObjectRelatedField from common.serializers.fields import ObjectRelatedField
from common.serializers.mixin import ResourceLabelsMixin from common.serializers.mixin import ResourceLabelsMixin
@ -10,7 +11,7 @@ from .. import utils
from ..models import User, UserGroup from ..models import User, UserGroup
__all__ = [ __all__ = [
'UserGroupSerializer', 'UserGroupSerializer', 'UserGroupListSerializer',
] ]
@ -29,7 +30,6 @@ class UserGroupSerializer(ResourceLabelsMixin, BulkOrgResourceModelSerializer):
fields = fields_mini + fields_small + ['users', 'labels'] fields = fields_mini + fields_small + ['users', 'labels']
extra_kwargs = { extra_kwargs = {
'created_by': {'label': _('Created by'), 'read_only': True}, 'created_by': {'label': _('Created by'), 'read_only': True},
'users_amount': {'label': _('Users amount')},
'id': {'label': _('ID')}, 'id': {'label': _('ID')},
} }
@ -45,6 +45,17 @@ class UserGroupSerializer(ResourceLabelsMixin, BulkOrgResourceModelSerializer):
@classmethod @classmethod
def setup_eager_loading(cls, queryset): def setup_eager_loading(cls, queryset):
""" Perform necessary eager loading of data. """ """ Perform necessary eager loading of data. """
queryset = queryset.prefetch_related('users', 'labels', 'labels__label') \ queryset = queryset.prefetch_related('labels', 'labels__label') \
.annotate(users_amount=Count('users')) .annotate(users_amount=Count('users', filter=Q(users__is_service_account=False)))
return queryset return queryset
class UserGroupListSerializer(UserGroupSerializer):
users_amount = serializers.IntegerField(label=_('Users amount'), read_only=True)
class Meta(UserGroupSerializer.Meta):
fields = list(set(UserGroupSerializer.Meta.fields + ['users_amount']) - {'users'})
extra_kwargs = {
**UserGroupSerializer.Meta.extra_kwargs,
'users_amount': {'label': _('Users amount')},
}

View File

@ -163,9 +163,9 @@ def on_openid_create_or_update_user(sender, request, user, created, name, userna
user.save() user.save()
@shared_task(verbose_name=_('Clean audits session task log')) @shared_task(verbose_name=_('Clean up expired user sessions'))
@register_as_period_task(crontab=CRONTAB_AT_PM_TWO) @register_as_period_task(crontab=CRONTAB_AT_PM_TWO)
def clean_audits_log_period(): def clean_expired_user_session_period():
UserSession.clear_expired_sessions() UserSession.clear_expired_sessions()

View File

@ -12,6 +12,7 @@ from django.utils.translation import gettext as _
from django.views.generic import FormView, RedirectView from django.views.generic import FormView, RedirectView
from authentication.errors import IntervalTooShort from authentication.errors import IntervalTooShort
from authentication.utils import check_user_property_is_correct
from common.utils import FlashMessageUtil, get_object_or_none, random_string from common.utils import FlashMessageUtil, get_object_or_none, random_string
from common.utils.verify_code import SendAndVerifyCodeUtil from common.utils.verify_code import SendAndVerifyCodeUtil
from users.notifications import ResetPasswordSuccessMsg from users.notifications import ResetPasswordSuccessMsg
@ -148,7 +149,6 @@ class UserForgotPasswordView(FormView):
query_key = form_type query_key = form_type
if form_type == 'sms': if form_type == 'sms':
query_key = 'phone' query_key = 'phone'
target = target.lstrip('+')
try: try:
self.safe_verify_code(token, target, form_type, code) self.safe_verify_code(token, target, form_type, code)
@ -158,7 +158,7 @@ class UserForgotPasswordView(FormView):
form.add_error('code', str(e)) form.add_error('code', str(e))
return super().form_invalid(form) return super().form_invalid(form)
user = get_object_or_none(User, **{'username': username, query_key: target}) user = check_user_property_is_correct(username, **{query_key: target})
if not user: if not user:
form.add_error('code', _('No user matched')) form.add_error('code', _('No user matched'))
return super().form_invalid(form) return super().form_invalid(form)

1501
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -78,7 +78,7 @@ geoip2 = "4.7.0"
ipip-ipdb = "1.6.1" ipip-ipdb = "1.6.1"
pywinrm = "0.4.3" pywinrm = "0.4.3"
python-nmap = "0.7.1" python-nmap = "0.7.1"
django = "4.1.10" django = "4.1.13"
django-bootstrap3 = "23.4" django-bootstrap3 = "23.4"
django-filter = "23.2" django-filter = "23.2"
django-formtools = "2.4.1" django-formtools = "2.4.1"
@ -97,7 +97,7 @@ drf-yasg = "1.21.7"
coreapi = "2.3.3" coreapi = "2.3.3"
coreschema = "0.0.4" coreschema = "0.0.4"
openapi-codec = "1.3.2" openapi-codec = "1.3.2"
pillow = "10.0.0" pillow = "10.0.1"
pytz = "2023.3" pytz = "2023.3"
django-proxy = "1.2.2" django-proxy = "1.2.2"
python-daemon = "3.0.1" python-daemon = "3.0.1"
@ -127,7 +127,7 @@ python-redis-lock = "4.0.0"
pyopenssl = "23.2.0" pyopenssl = "23.2.0"
redis = "4.6.0" redis = "4.6.0"
pymongo = "4.4.1" pymongo = "4.4.1"
pyfreerdp = "0.0.1" pyfreerdp = "0.0.2"
ipython = "8.14.0" ipython = "8.14.0"
forgerypy3 = "0.3.1" forgerypy3 = "0.3.1"
django-debug-toolbar = "4.1.0" django-debug-toolbar = "4.1.0"
@ -143,9 +143,10 @@ fido2 = "^1.1.2"
ua-parser = "^0.18.0" ua-parser = "^0.18.0"
user-agents = "^2.2.0" user-agents = "^2.2.0"
django-cors-headers = "^4.3.0" django-cors-headers = "^4.3.0"
mistune = "0.8.4" mistune = "2.0.3"
openai = "^1.3.7" openai = "^1.3.7"
xlsxwriter = "^3.1.9" xlsxwriter = "^3.1.9"
exchangelib = "^5.1.0"
[tool.poetry.group.xpack.dependencies] [tool.poetry.group.xpack.dependencies]
@ -154,8 +155,7 @@ azure-mgmt-subscription = "3.1.1"
azure-identity = "1.13.0" azure-identity = "1.13.0"
azure-mgmt-compute = "30.0.0" azure-mgmt-compute = "30.0.0"
azure-mgmt-network = "23.1.0" azure-mgmt-network = "23.1.0"
google-cloud-compute = "1.13.0" google-cloud-compute = "1.15.0"
grpcio = "1.56.2"
alibabacloud-dysmsapi20170525 = "2.0.24" alibabacloud-dysmsapi20170525 = "2.0.24"
python-novaclient = "18.3.0" python-novaclient = "18.3.0"
python-keystoneclient = "5.1.0" python-keystoneclient = "5.1.0"

View File

@ -17,6 +17,7 @@ from resources.assets import AssetsGenerator, NodesGenerator, PlatformGenerator
from resources.users import UserGroupGenerator, UserGenerator from resources.users import UserGroupGenerator, UserGenerator
from resources.perms import AssetPermissionGenerator from resources.perms import AssetPermissionGenerator
from resources.terminal import CommandGenerator, SessionGenerator from resources.terminal import CommandGenerator, SessionGenerator
from resources.accounts import AccountGenerator
resource_generator_mapper = { resource_generator_mapper = {
'asset': AssetsGenerator, 'asset': AssetsGenerator,
@ -27,6 +28,7 @@ resource_generator_mapper = {
'asset_permission': AssetPermissionGenerator, 'asset_permission': AssetPermissionGenerator,
'command': CommandGenerator, 'command': CommandGenerator,
'session': SessionGenerator, 'session': SessionGenerator,
'account': AccountGenerator,
'all': None 'all': None
# 'stat': StatGenerator # 'stat': StatGenerator
} }
@ -45,6 +47,7 @@ def main():
parser.add_argument('-o', '--org', type=str, default='') parser.add_argument('-o', '--org', type=str, default='')
args = parser.parse_args() args = parser.parse_args()
resource, count, batch_size, org_id = args.resource, args.count, args.batch_size, args.org resource, count, batch_size, org_id = args.resource, args.count, args.batch_size, args.org
resource = resource.lower().rstrip('s')
generator_cls = [] generator_cls = []
if resource == 'all': if resource == 'all':

View File

@ -0,0 +1,32 @@
import random
import forgery_py
from accounts.models import Account
from assets.models import Asset
from .base import FakeDataGenerator
class AccountGenerator(FakeDataGenerator):
resource = 'account'
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.assets = list(list(Asset.objects.all()[:5000]))
def do_generate(self, batch, batch_size):
accounts = []
for i in batch:
asset = random.choice(self.assets)
name = forgery_py.internet.user_name(True) + '-' + str(i)
d = {
'username': name,
'name': name,
'asset': asset,
'secret': name,
'secret_type': 'password',
'is_active': True,
'privileged': False,
}
accounts.append(Account(**d))
Account.objects.bulk_create(accounts, ignore_conflicts=True)

View File

@ -48,7 +48,7 @@ class AssetsGenerator(FakeDataGenerator):
def pre_generate(self): def pre_generate(self):
self.node_ids = list(Node.objects.all().values_list('id', flat=True)) self.node_ids = list(Node.objects.all().values_list('id', flat=True))
self.platform_ids = list(Platform.objects.all().values_list('id', flat=True)) self.platform_ids = list(Platform.objects.filter(category='host').values_list('id', flat=True))
def set_assets_nodes(self, assets): def set_assets_nodes(self, assets):
for asset in assets: for asset in assets:
@ -72,6 +72,17 @@ class AssetsGenerator(FakeDataGenerator):
assets.append(Asset(**data)) assets.append(Asset(**data))
creates = Asset.objects.bulk_create(assets, ignore_conflicts=True) creates = Asset.objects.bulk_create(assets, ignore_conflicts=True)
self.set_assets_nodes(creates) self.set_assets_nodes(creates)
self.set_asset_platform(creates)
@staticmethod
def set_asset_platform(assets):
protocol = random.choice(['ssh', 'rdp', 'telnet', 'vnc'])
protocols = []
for asset in assets:
port = 22 if protocol == 'ssh' else 3389
protocols.append(Protocol(asset=asset, name=protocol, port=port))
Protocol.objects.bulk_create(protocols, ignore_conflicts=True)
def after_generate(self): def after_generate(self):
pass pass

View File

@ -41,7 +41,7 @@ class FakeDataGenerator:
start = time.time() start = time.time()
self.do_generate(batch, self.batch_size) self.do_generate(batch, self.batch_size)
end = time.time() end = time.time()
using = end - start using = round(end - start, 3)
from_size = created from_size = created
created += len(batch) created += len(batch)
print('Generate %s: %s-%s [%s]' % (self.resource, from_size, created, using)) print('Generate %s: %s-%s [%s]' % (self.resource, from_size, created, using))

View File

@ -1,9 +1,11 @@
from random import choice, sample from random import sample
import forgery_py import forgery_py
from .base import FakeDataGenerator from orgs.utils import current_org
from rbac.models import RoleBinding, Role
from users.models import * from users.models import *
from .base import FakeDataGenerator
class UserGroupGenerator(FakeDataGenerator): class UserGroupGenerator(FakeDataGenerator):
@ -47,3 +49,12 @@ class UserGenerator(FakeDataGenerator):
users.append(u) users.append(u)
users = User.objects.bulk_create(users, ignore_conflicts=True) users = User.objects.bulk_create(users, ignore_conflicts=True)
self.set_groups(users) self.set_groups(users)
self.set_to_org(users)
def set_to_org(self, users):
bindings = []
role = Role.objects.get(name='OrgUser')
for u in users:
b = RoleBinding(user=u, role=role, org_id=current_org.id, scope='org')
bindings.append(b)
RoleBinding.objects.bulk_create(bindings, ignore_conflicts=True)