From 1a0ff422fe4e1957d9515639b4ee48af15cf56aa Mon Sep 17 00:00:00 2001 From: ibuler Date: Thu, 27 Jun 2019 21:43:10 +0800 Subject: [PATCH] =?UTF-8?q?[Update]=20=E4=BC=98=E5=8C=96=E6=A0=91=E7=BB=93?= =?UTF-8?q?=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/assets/api/asset_user.py | 1 + apps/assets/api/node.py | 25 +- apps/assets/models/asset.py | 17 +- apps/assets/models/node.py | 391 +++++++++++++++++---------- apps/assets/utils.py | 166 +++++++++++- apps/common/struct.py | 25 ++ apps/jumpserver/urls.py | 3 +- apps/jumpserver/views.py | 9 + apps/ops/inventory.py | 1 - apps/orgs/mixins.py | 13 +- apps/perms/api/user_permission.py | 2 + apps/perms/forms/asset_permission.py | 2 +- apps/perms/utils/asset_permission.py | 346 ++++++++++++------------ apps/settings/api.py | 19 -- apps/settings/urls/api_urls.py | 1 - 15 files changed, 644 insertions(+), 377 deletions(-) create mode 100644 apps/common/struct.py diff --git a/apps/assets/api/asset_user.py b/apps/assets/api/asset_user.py index 7bbca679d..237517790 100644 --- a/apps/assets/api/asset_user.py +++ b/apps/assets/api/asset_user.py @@ -148,6 +148,7 @@ class AssetUserTestConnectiveApi(generics.RetrieveAPIView): Test asset users connective """ permission_classes = (IsOrgAdminOrAppUser,) + serializer_class = serializers.TaskIDSerializer def get_asset_users(self): username = self.request.GET.get('username') diff --git a/apps/assets/api/node.py b/apps/assets/api/node.py index f772a2ace..32b6c33bb 100644 --- a/apps/assets/api/node.py +++ b/apps/assets/api/node.py @@ -26,6 +26,7 @@ from ..hands import IsOrgAdmin from ..models import Node from ..tasks import update_assets_hardware_info_util, test_asset_connectivity_util from .. import serializers +from ..utils import NodeUtil logger = get_logger(__file__) @@ -79,12 +80,10 @@ class NodeListAsTreeApi(generics.ListAPIView): serializer_class = TreeNodeSerializer def get_queryset(self): - queryset = [node.as_tree_node() for node in Node.objects.all()] - return queryset - - def filter_queryset(self, queryset): - if self.request.query_params.get('refresh', '0') == '1': - queryset = self.refresh_nodes(queryset) + queryset = Node.objects.all() + util = NodeUtil() + nodes = util.get_nodes_by_queryset(queryset) + queryset = [node.as_tree_node() for node in nodes] return queryset @staticmethod @@ -114,15 +113,11 @@ class NodeChildrenAsTreeApi(generics.ListAPIView): def get_queryset(self): node_key = self.request.query_params.get('key') - if node_key: - self.node = Node.objects.get(key=node_key) - queryset = self.node.get_children(with_self=False) - else: - self.is_root = True - self.node = Node.root() - queryset = list(self.node.get_children(with_self=True)) - nodes_invalid = Node.objects.exclude(key__startswith=self.node.key) - queryset.extend(list(nodes_invalid)) + util = NodeUtil() + if not node_key: + node_key = Node.root().key + self.node = util.get_node_by_key(node_key) + queryset = self.node.get_children(with_self=True) queryset = [node.as_tree_node() for node in queryset] queryset = sorted(queryset) return queryset diff --git a/apps/assets/models/asset.py b/apps/assets/models/asset.py index 5029245cf..704abf404 100644 --- a/apps/assets/models/asset.py +++ b/apps/assets/models/asset.py @@ -46,12 +46,6 @@ class AssetQuerySet(models.QuerySet): return self.active() -class AssetManager(OrgManager): - def get_queryset(self): - queryset = super().get_queryset().prefetch_related("nodes", "protocols") - return queryset - - class Protocol(models.Model): PROTOCOL_SSH = 'ssh' PROTOCOL_RDP = 'rdp' @@ -131,7 +125,7 @@ class Asset(OrgModelMixin): date_created = models.DateTimeField(auto_now_add=True, null=True, blank=True, verbose_name=_('Date created')) comment = models.TextField(max_length=128, default='', blank=True, verbose_name=_('Comment')) - objects = AssetManager.from_queryset(AssetQuerySet)() + objects = OrgManager.from_queryset(AssetQuerySet)() def __str__(self): return '{0.hostname}({0.ip})'.format(self) @@ -300,15 +294,20 @@ class Asset(OrgModelMixin): @classmethod def generate_fake(cls, count=100): from random import seed, choice - import forgery_py from django.db import IntegrityError from .node import Node + from orgs.utils import get_current_org + from orgs.models import Organization + org = get_current_org() + if not org or not org.is_real(): + Organization.default().change_to() + nodes = list(Node.objects.all()) seed() for i in range(count): ip = [str(i) for i in random.sample(range(255), 4)] asset = cls(ip='.'.join(ip), - hostname=forgery_py.internet.user_name(True), + hostname='.'.join(ip), admin_user=choice(AdminUser.objects.all()), created_by='Fake') try: diff --git a/apps/assets/models/node.py b/apps/assets/models/node.py index 082273165..5ef664c85 100644 --- a/apps/assets/models/node.py +++ b/apps/assets/models/node.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # import uuid +import re from django.db import models, transaction from django.db.models import Q @@ -15,54 +16,185 @@ from orgs.models import Organization __all__ = ['Node'] -class Node(OrgModelMixin): - id = models.UUIDField(default=uuid.uuid4, primary_key=True) - key = models.CharField(unique=True, max_length=64, verbose_name=_("Key")) # '1:1:1:1' - value = models.CharField(max_length=128, verbose_name=_("Value")) - child_mark = models.IntegerField(default=0) - date_create = models.DateTimeField(auto_now_add=True) - +class FamilyMixin: + _parents = None + _children = None + _all_children = None is_node = True - _assets_amount = None - _full_value_cache_key = '_NODE_VALUE_{}' - _assets_amount_cache_key = '_NODE_ASSETS_AMOUNT_{}' - - class Meta: - verbose_name = _("Node") - ordering = ['key'] - - def __str__(self): - return self.full_value - - def __eq__(self, other): - if not other: - return False - return self.id == other.id - - def __gt__(self, other): - if self.is_root() and not other.is_root(): - return True - elif not self.is_root() and other.is_root(): - return False - self_key = [int(k) for k in self.key.split(':')] - other_key = [int(k) for k in other.key.split(':')] - self_parent_key = self_key[:-1] - other_parent_key = other_key[:-1] - - if self_parent_key == other_parent_key: - return self.name > other.name - if len(self_parent_key) < len(other_parent_key): - return True - elif len(self_parent_key) > len(other_parent_key): - return False - return self_key > other_key - - def __lt__(self, other): - return not self.__gt__(other) @property - def name(self): - return self.value + def children(self): + if self._children: + return self._children + pattern = r'^{0}:[0-9]+$'.format(self.key) + return Node.objects.filter(key__regex=pattern) + + @children.setter + def children(self, value): + self._children = value + + @property + def all_children(self): + if self._all_children: + return self._all_children + pattern = r'^{0}:'.format(self.key) + return Node.objects.filter( + key__regex=pattern + ) + + def get_children(self, with_self=False): + children = list(self.children) + if with_self: + children.append(self) + return children + + def get_all_children(self, with_self=False): + children = self.all_children + if with_self: + children = list(children) + children.append(self) + return children + + @property + def parents(self): + if self._parents: + return self._parents + ancestor_keys = self.get_ancestor_keys() + ancestor = Node.objects.filter( + key__in=ancestor_keys + ).order_by('key') + return ancestor + + @parents.setter + def parents(self, value): + self._parents = value + + def get_ancestor(self, with_self=False): + parents = self.parents + if with_self: + parents = list(parents) + parents.append(self) + return parents + + @property + def parent(self): + if self._parents: + return self._parents[0] + if self.is_root(): + return self + try: + parent = Node.objects.get(key=self.parent_key) + return parent + except Node.DoesNotExist: + return Node.root() + + @parent.setter + def parent(self, parent): + if not self.is_node: + self.key = parent.key + ':fake' + return + children = self.get_all_children() + old_key = self.key + with transaction.atomic(): + self.key = parent.get_next_child_key() + for child in children: + child.key = child.key.replace(old_key, self.key, 1) + child.save() + self.save() + + def get_sibling(self, with_self=False): + key = ':'.join(self.key.split(':')[:-1]) + pattern = r'^{}:[0-9]+$'.format(key) + sibling = Node.objects.filter( + key__regex=pattern.format(self.key) + ) + if not with_self: + sibling = sibling.exclude(key=self.key) + return sibling + + def get_family(self): + ancestor = self.get_ancestor() + children = self.get_all_children() + return [*tuple(ancestor), self, *tuple(children)] + + def get_ancestor_keys(self, with_self=False): + parent_keys = [] + key_list = self.key.split(":") + if not with_self: + key_list.pop() + for i in range(len(key_list)): + parent_keys.append(":".join(key_list)) + key_list.pop() + return parent_keys + + def is_children(self, other): + pattern = re.compile(r'^{0}:[0-9]+$'.format(self.key)) + return pattern.match(other.key) + + def is_parent(self, other): + pattern = re.compile(r'^{0}:[0-9]+$'.format(other.key)) + return pattern.match(self.key) + + @property + def parent_key(self): + parent_key = ":".join(self.key.split(":")[:-1]) + return parent_key + + @property + def parents_keys(self, with_self=False): + keys = [] + key_list = self.key.split(":") + if not with_self: + key_list.pop() + for i in range(len(key_list)): + keys.append(':'.join(key_list)) + key_list.pop() + return keys + + +class FullValueMixin: + _full_value_cache_key = '_NODE_VALUE_{}' + _full_value = '' + key = '' + + @property + def full_value(self): + if self._full_value: + return self._full_value + key = self._full_value_cache_key.format(self.key) + cached = cache.get(key) + if cached: + return cached + if self.is_root(): + return self.value + parent_full_value = self.parent.full_value + value = parent_full_value + ' / ' + self.value + self.full_value = value + return value + + @full_value.setter + def full_value(self, value): + self._full_value = value + key = self._full_value_cache_key.format(self.key) + cache.set(key, value, 3600*24) + + def expire_full_value(self): + key = self._full_value_cache_key.format(self.key) + cache.delete_pattern(key+'*') + + @classmethod + def expire_nodes_full_value(cls, nodes=None): + key = cls._full_value_cache_key.format('*') + cache.delete_pattern(key+'*') + from ..utils import NodeUtil + util = NodeUtil() + util.set_full_value() + + +class AssetsAmountMixin: + _assets_amount_cache_key = '_NODE_ASSETS_AMOUNT_{}' + _assets_amount = None + key = '' @property def assets_amount(self): @@ -77,53 +209,77 @@ class Node(OrgModelMixin): if cached is not None: return cached assets_amount = self.get_all_assets().count() - cache.set(cache_key, assets_amount, 3600) + self.assets_amount = assets_amount return assets_amount @assets_amount.setter def assets_amount(self, value): self._assets_amount = value + cache_key = self._assets_amount_cache_key.format(self.key) + cache.set(cache_key, value, 3600 * 24) def expire_assets_amount(self): ancestor_keys = self.get_ancestor_keys(with_self=True) - cache_keys = [self._assets_amount_cache_key.format(k) for k in ancestor_keys] + cache_keys = [self._assets_amount_cache_key.format(k) for k in + ancestor_keys] cache.delete_many(cache_keys) @classmethod def expire_nodes_assets_amount(cls, nodes=None): - if nodes: - for node in nodes: - node.expire_assets_amount() - return + from ..utils import NodeUtil key = cls._assets_amount_cache_key.format('*') cache.delete_pattern(key) + util = NodeUtil(with_assets_amount=True) + util.set_assets_amount() + + +class Node(OrgModelMixin, FamilyMixin, FullValueMixin, AssetsAmountMixin): + id = models.UUIDField(default=uuid.uuid4, primary_key=True) + key = models.CharField(unique=True, max_length=64, verbose_name=_("Key")) # '1:1:1:1' + value = models.CharField(max_length=128, verbose_name=_("Value")) + child_mark = models.IntegerField(default=0) + date_create = models.DateTimeField(auto_now_add=True) + + is_node = True + _parents = None + + class Meta: + verbose_name = _("Node") + ordering = ['key'] + + def __str__(self): + return self.full_value + + def __eq__(self, other): + if not other: + return False + return self.id == other.id + + def __gt__(self, other): + # if self.is_root() and not other.is_root(): + # return False + # elif not self.is_root() and other.is_root(): + # return True + self_key = [int(k) for k in self.key.split(':')] + other_key = [int(k) for k in other.key.split(':')] + self_parent_key = self_key[:-1] + other_parent_key = other_key[:-1] + + if self_parent_key and other_parent_key and \ + self_parent_key == other_parent_key: + return self.value > other.value + # if len(self_parent_key) < len(other_parent_key): + # return True + # elif len(self_parent_key) > len(other_parent_key): + # return False + return self_key > other_key + + def __lt__(self, other): + return not self.__gt__(other) @property - def full_value(self): - key = self._full_value_cache_key.format(self.key) - cached = cache.get(key) - if cached: - return cached - if self.is_root(): - return self.value - parent_full_value = self.parent.full_value - value = parent_full_value + ' / ' + self.value - key = self._full_value_cache_key.format(self.key) - cache.set(key, value, 3600) - return value - - def expire_full_value(self): - key = self._full_value_cache_key.format(self.key) - cache.delete_pattern(key+'*') - - @classmethod - def expire_nodes_full_value(cls, nodes=None): - if nodes: - for node in nodes: - node.expire_full_value() - return - key = cls._full_value_cache_key.format('*') - cache.delete_pattern(key+'*') + def name(self): + return self.value @property def level(self): @@ -152,33 +308,6 @@ class Node(OrgModelMixin): child = self.__class__.objects.create(id=_id, key=child_key, value=value) return child - def get_children(self, with_self=False): - pattern = r'^{0}$|^{0}:[0-9]+$' if with_self else r'^{0}:[0-9]+$' - return self.__class__.objects.filter( - key__regex=pattern.format(self.key) - ) - - def get_all_children(self, with_self=False): - pattern = r'^{0}$|^{0}:' if with_self else r'^{0}:' - return self.__class__.objects.filter( - key__regex=pattern.format(self.key) - ) - - def get_sibling(self, with_self=False): - key = ':'.join(self.key.split(':')[:-1]) - pattern = r'^{}:[0-9]+$'.format(key) - sibling = self.__class__.objects.filter( - key__regex=pattern.format(self.key) - ) - if not with_self: - sibling = sibling.exclude(key=self.key) - return sibling - - def get_family(self): - ancestor = self.get_ancestor() - children = self.get_all_children() - return [*tuple(ancestor), self, *tuple(children)] - def get_assets(self): from .asset import Asset if self.is_default_node(): @@ -214,52 +343,6 @@ class Node(OrgModelMixin): else: return False - @property - def parent_key(self): - parent_key = ":".join(self.key.split(":")[:-1]) - return parent_key - - @property - def parent(self): - if self.is_root(): - return self - try: - parent = self.__class__.objects.get(key=self.parent_key) - return parent - except Node.DoesNotExist: - return self.__class__.root() - - @parent.setter - def parent(self, parent): - if not self.is_node: - self.key = parent.key + ':fake' - return - children = self.get_all_children() - old_key = self.key - with transaction.atomic(): - self.key = parent.get_next_child_key() - for child in children: - child.key = child.key.replace(old_key, self.key, 1) - child.save() - self.save() - - def get_ancestor_keys(self, with_self=False): - parent_keys = [] - key_list = self.key.split(":") - if not with_self: - key_list.pop() - for i in range(len(key_list)): - parent_keys.append(":".join(key_list)) - key_list.pop() - return parent_keys - - def get_ancestor(self, with_self=False): - ancestor_keys = self.get_ancestor_keys(with_self=with_self) - ancestor = self.__class__.objects.filter( - key__in=ancestor_keys - ).order_by('key') - return ancestor - @classmethod def create_root_node(cls): # 如果使用current_org 在set_current_org时会死循环 @@ -310,9 +393,19 @@ class Node(OrgModelMixin): tree_node = TreeNode(**data) return tree_node + @classmethod + def get_queryset(cls): + from ..utils import NodeUtil + util = NodeUtil() + return util.nodes + @classmethod def generate_fake(cls, count=100): import random + org = get_current_org() + if not org or not org.is_real(): + Organization.default().change_to() + for i in range(count): node = random.choice(cls.objects.all()) node.create_child('Node {}'.format(i)) diff --git a/apps/assets/utils.py b/apps/assets/utils.py index a796e210e..7ad6cd8e8 100644 --- a/apps/assets/utils.py +++ b/apps/assets/utils.py @@ -1,19 +1,13 @@ # ~*~ coding: utf-8 ~*~ # -from django.utils.translation import ugettext_lazy as _ -from django.core.cache import cache -from django.utils import timezone +from django.db.models import Prefetch -from common.utils import get_object_or_none -from .models import SystemUser, Label +from common.utils import get_object_or_none, get_logger +from common.struct import Stack +from .models import SystemUser, Label, Node, Asset -def get_assets_by_id_list(id_list): - return Asset.objects.filter(id__in=id_list).filter(is_active=True) - - -def get_system_users_by_id_list(id_list): - return SystemUser.objects.filter(id__in=id_list) +logger = get_logger(__file__) def get_system_user_by_name(name): @@ -47,4 +41,154 @@ class LabelFilter: return queryset +class NodeUtil: + def __init__(self, with_assets_amount=False, debug=False): + self.stack = Stack() + self._nodes = {} + self.with_assets_amount = with_assets_amount + self._debug = debug + self.init() + + @staticmethod + def sorted_by(node): + return [int(i) for i in node.key.split(':')] + + def get_all_nodes(self): + all_nodes = Node.objects.all() + if self.with_assets_amount: + all_nodes = all_nodes.prefetch_related( + Prefetch('assets', queryset=Asset.objects.all().only('id')) + ) + for node in all_nodes: + node._assets = set(node.assets.all()) + all_nodes = sorted(all_nodes, key=self.sorted_by) + + guarder = Node(key='', value='Guarder') + guarder._assets = [] + all_nodes.append(guarder) + return all_nodes + + def push_to_stack(self, node): + # 入栈之前检查 + # 如果栈是空的,证明是一颗树的根部 + if self.stack.is_empty(): + node._full_value = node.value + node._parents = [] + else: + # 如果不是根节点, + # 该节点的祖先应该是父节点的祖先加上父节点 + # 该节点的名字是父节点的名字+自己的名字 + node._parents = [self.stack.top] + self.stack.top._parents + node._full_value = ' / '.join( + [self.stack.top._full_value, node.value] + ) + node._children = [] + node._all_children = [] + self.debug("入栈: {}".format(node.key)) + self.stack.push(node) + + # 出栈 + def pop_from_stack(self): + _node = self.stack.pop() + self.debug("出栈: {} 栈顶: {}".format(_node.key, self.stack.top.key if self.stack.top else None)) + self._nodes[_node.key] = _node + if not self.stack.top: + return + if self.with_assets_amount: + self.stack.top._assets.update(_node._assets) + _node._assets_amount = len(_node._assets) + delattr(_node, '_assets') + self.stack.top._children.append(_node) + self.stack.top._all_children.extend([_node] + _node._children) + + def init(self): + all_nodes = self.get_all_nodes() + for node in all_nodes: + self.debug("准备: {} 栈顶: {}".format(node.key, self.stack.top.key if self.stack.top else None)) + # 入栈之前检查,该节点是不是栈顶节点的子节点 + # 如果不是,则栈顶出栈 + while self.stack.top and not self.stack.top.is_children(node): + self.pop_from_stack() + self.push_to_stack(node) + # 出栈最后一个 + self.debug("剩余: {}".format(', '.join([n.key for n in self.stack]))) + + def get_nodes_by_queryset(self, queryset): + nodes = [] + for n in queryset: + node = self._nodes.get(n.key) + if not node: + continue + nodes.append(nodes) + return [self] + + def get_node_by_key(self, key): + return self._nodes.get(key) + + def debug(self, msg): + self._debug and logger.debug(msg) + + def set_assets_amount(self): + for node in self._nodes.values(): + node.assets_amount = node._assets_amount + + def set_full_value(self): + for node in self._nodes.values(): + node.full_value = node._full_value + + @property + def nodes(self): + return list(self._nodes.values()) + + # 使用给定节点生成一颗树 + # 找到他们的祖先节点 + # 可选找到他们的子孙节点 + def get_family(self, nodes, with_children=False): + tree_nodes = set() + for n in nodes: + node = self.get_node_by_key(n.key) + if not node: + continue + tree_nodes.update(node._parents) + tree_nodes.add(node) + if with_children: + tree_nodes.update(node._children) + for n in tree_nodes: + delattr(n, '_children') + delattr(n, '_parents') + return list(tree_nodes) + + +def test_node_tree(): + tree = NodeUtil() + for node in tree._nodes.values(): + print("Check {}".format(node.key)) + children_wanted = node.get_all_children().count() + children = len(node._children) + if children != children_wanted: + print("{} children not equal: {} != {}".format(node.key, children, children_wanted)) + + assets_amount_wanted = node.get_all_assets().count() + if node._assets_amount != assets_amount_wanted: + print("{} assets amount not equal: {} != {}".format( + node.key, node._assets_amount, assets_amount_wanted) + ) + + full_value_wanted = node.full_value + if node._full_value != full_value_wanted: + print("{} full value not equal: {} != {}".format( + node.key, node._full_value, full_value_wanted) + ) + + parents_wanted = node.get_ancestor().count() + parents = len(node._parents) + if parents != parents_wanted: + print("{} parents count not equal: {} != {}".format( + node.key, parents, parents_wanted) + ) + + + + + diff --git a/apps/common/struct.py b/apps/common/struct.py new file mode 100644 index 000000000..88bace4a7 --- /dev/null +++ b/apps/common/struct.py @@ -0,0 +1,25 @@ +# -*- coding: utf-8 -*- +# + + +class Stack(list): + def is_empty(self): + return len(self) == 0 + + @property + def top(self): + if self.is_empty(): + return None + return self[-1] + + @property + def bottom(self): + if self.is_empty(): + return None + return self[0] + + def size(self): + return len(self) + + def push(self, item): + self.append(item) diff --git a/apps/jumpserver/urls.py b/apps/jumpserver/urls.py index 37b30172b..10e2d2547 100644 --- a/apps/jumpserver/urls.py +++ b/apps/jumpserver/urls.py @@ -7,7 +7,7 @@ from django.conf.urls.static import static from django.conf.urls.i18n import i18n_patterns from django.views.i18n import JavaScriptCatalog -from .views import IndexView, LunaView, I18NView +from .views import IndexView, LunaView, I18NView, HealthCheckView from .swagger import get_swagger_view api_v1 = [ @@ -63,6 +63,7 @@ urlpatterns = [ path('', IndexView.as_view(), name='index'), path('', include(api_v2_patterns)), path('', include(api_v1_patterns)), + path('api/health/', HealthCheckView.as_view(), name="health"), path('luna/', LunaView.as_view(), name='luna-view'), path('i18n//', I18NView.as_view(), name='i18n-switch'), path('settings/', include('settings.urls.view_urls', namespace='settings')), diff --git a/apps/jumpserver/views.py b/apps/jumpserver/views.py index f3272fb17..7f7662add 100644 --- a/apps/jumpserver/views.py +++ b/apps/jumpserver/views.py @@ -1,5 +1,6 @@ import datetime import re +import time from django.http import HttpResponse, HttpResponseRedirect from django.conf import settings @@ -9,6 +10,7 @@ from django.utils.translation import ugettext_lazy as _ from django.db.models import Count from django.shortcuts import redirect from rest_framework.response import Response +from rest_framework.views import APIView from django.views.decorators.csrf import csrf_exempt from django.http import HttpResponse from django.utils.encoding import iri_to_uri @@ -222,3 +224,10 @@ def redirect_format_api(request, *args, **kwargs): return HttpResponseTemporaryRedirect(_path) else: return Response({"msg": "Redirect url failed: {}".format(_path)}, status=404) + + +class HealthCheckView(APIView): + permission_classes = () + + def get(self, request): + return Response({"status": 1, "time": int(time.time())}) diff --git a/apps/ops/inventory.py b/apps/ops/inventory.py index ed3b9a057..9b6d3e183 100644 --- a/apps/ops/inventory.py +++ b/apps/ops/inventory.py @@ -2,7 +2,6 @@ # from .ansible.inventory import BaseInventory -from assets.utils import get_assets_by_id_list, get_system_user_by_id from common.utils import get_logger diff --git a/apps/orgs/mixins.py b/apps/orgs/mixins.py index ad15977a1..6f2d48c24 100644 --- a/apps/orgs/mixins.py +++ b/apps/orgs/mixins.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # - +import traceback from django.db import models from django.utils.translation import ugettext_lazy as _ from django.shortcuts import redirect, get_object_or_404 @@ -33,8 +33,8 @@ class OrgManager(models.Manager): def get_queryset(self): queryset = super(OrgManager, self).get_queryset() kwargs = {} - _current_org = get_current_org() + _current_org = get_current_org() if _current_org is None: kwargs['id'] = None elif _current_org.is_real(): @@ -42,12 +42,17 @@ class OrgManager(models.Manager): elif _current_org.is_default(): queryset = queryset.filter(org_id="") + # lines = traceback.format_stack() + # print(">>>>>>>>>>>>>>>>>>>>>>>>>>>>") + # for line in lines[-10:-5]: + # print(line) + # print("<<<<<<<<<<<<<<<<<<<<<<<<<<<<") + queryset = queryset.filter(**kwargs) return queryset def all(self): - _current_org = get_current_org() - if _current_org is None: + if not current_org: msg = 'You can `objects.set_current_org(org).all()` then run it' return self else: diff --git a/apps/perms/api/user_permission.py b/apps/perms/api/user_permission.py index 152e17936..49475e122 100644 --- a/apps/perms/api/user_permission.py +++ b/apps/perms/api/user_permission.py @@ -258,7 +258,9 @@ class UserGrantedNodesWithAssetsAsTreeApi(UserPermissionCacheMixin, ListAPIView) util.filter_permissions( system_users=self.system_user_id ) + print("111111111111") nodes = util.get_nodes_with_assets() + print("22222222222222") for node, assets in nodes.items(): data = parse_node_to_tree_node(node) queryset.append(data) diff --git a/apps/perms/forms/asset_permission.py b/apps/perms/forms/asset_permission.py index 81845fb77..c1feea5b3 100644 --- a/apps/perms/forms/asset_permission.py +++ b/apps/perms/forms/asset_permission.py @@ -7,7 +7,7 @@ from django.utils.translation import ugettext_lazy as _ from orgs.mixins import OrgModelForm from orgs.utils import current_org from perms.models import AssetPermission -from assets.models import Asset +from assets.models import Asset, Node __all__ = [ 'AssetPermissionForm', diff --git a/apps/perms/utils/asset_permission.py b/apps/perms/utils/asset_permission.py index 4fbb136d7..b799314ec 100644 --- a/apps/perms/utils/asset_permission.py +++ b/apps/perms/utils/asset_permission.py @@ -4,6 +4,7 @@ import uuid from collections import defaultdict import json from hashlib import md5 +import time from django.utils import timezone from django.db.models import Q @@ -17,6 +18,7 @@ from common.tree import TreeNode from .. import const from ..models import AssetPermission, Action from ..hands import Node +from assets.utils import NodeUtil logger = get_logger(__file__) @@ -35,9 +37,8 @@ class GenerateTree: "asset_instance": set("system_user") } """ - self.__all_nodes = list(Node.objects.all()) + self.node_util = NodeUtil() self.nodes = defaultdict(dict) - self.direct_nodes = [] self._root_node = None self._ungroup_node = None @@ -48,10 +49,8 @@ class GenerateTree: all_nodes = self.nodes.keys() # 如果没有授权节点,就放到默认的根节点下 if not all_nodes: - root_node = Node.root() - self.add_node(root_node) - else: - root_node = max(all_nodes) + return None + root_node = min(all_nodes) self._root_node = root_node return root_node @@ -60,7 +59,10 @@ class GenerateTree: if self._ungroup_node: return self._ungroup_node node_id = const.UNGROUPED_NODE_ID - node_key = self.root_node.get_next_child_key() + if self.root_node: + node_key = self.root_node.get_next_child_key() + else: + node_key = '0:0' node_value = _("Default") node = Node(id=node_id, key=node_key, value=node_value) self.add_node(node) @@ -69,11 +71,11 @@ class GenerateTree: def add_asset(self, asset, system_users): nodes = asset.nodes.all() - in_nodes = set(self.direct_nodes) & set(nodes) - for node in in_nodes: - self.nodes[node][asset].update(system_users) - if not in_nodes: - self.nodes[self.ungrouped_node][asset].update(system_users) + for node in nodes: + if node in self.nodes: + self.nodes[node][asset].update(system_users) + else: + self.nodes[self.ungrouped_node][asset].update(system_users) def get_nodes(self): for node in self.nodes: @@ -84,26 +86,14 @@ class GenerateTree: node.assets_amount = len(assets) return self.nodes - # 添加节点时,追溯到根节点 def add_node(self, node): - if node in self.nodes: - return - else: - self.nodes[node] = defaultdict(set) - if node.is_root(): - return - for n in self.__all_nodes: - if n.key == node.parent_key: - self.add_node(n) - break + self.nodes[node] = defaultdict(set) # 添加树节点 def add_nodes(self, nodes): - for node in nodes: + need_nodes = self.node_util.get_family(nodes, with_children=True) + for node in need_nodes: self.add_node(node) - self.add_nodes(node.get_all_children(with_self=False)) - # 如果是直接授权的节点,则放到direct_nodes中 - self.direct_nodes.append(node) def get_user_permissions(user, include_group=True): @@ -140,35 +130,28 @@ def get_system_user_permissions(system_user): ) -class AssetPermissionUtil: - get_permissions_map = { - "User": get_user_permissions, - "UserGroup": get_user_group_permissions, - "Asset": get_asset_permissions, - "Node": get_node_permissions, - "SystemUser": get_system_user_permissions, - } +def timeit(func): + def wrapper(*args, **kwargs): + logger.debug("Start call: {}".format(func.__name__)) + now = time.time() + result = func(*args, **kwargs) + using = time.time() - now + logger.debug("Call {} end, using: {:.2}".format(func.__name__, using)) + return result + return wrapper + +class AssetGranted: + def __init__(self): + self.system_users = {} + + +class AssetPermissionCacheMixin: CACHE_KEY_PREFIX = '_ASSET_PERM_CACHE_' CACHE_META_KEY_PREFIX = '_ASSET_PERM_META_KEY_' CACHE_TIME = settings.ASSETS_PERM_CACHE_TIME CACHE_POLICY_MAP = (('0', 'never'), ('1', 'using'), ('2', 'refresh')) - def __init__(self, obj, cache_policy='0'): - self.object = obj - self.obj_id = str(obj.id) - self._permissions = None - self._permissions_id = None # 标记_permission的唯一值 - self._assets = None - self._filter_id = 'None' # 当通过filter更改 permission是标记 - self.cache_policy = cache_policy - self.tree = GenerateTree() - self.change_org_if_need() - - @staticmethod - def change_org_if_need(): - set_to_root_org() - @classmethod def is_not_using_cache(cls, cache_policy): return cls.CACHE_TIME == 0 or cache_policy in cls.CACHE_POLICY_MAP[0] @@ -190,94 +173,7 @@ class AssetPermissionUtil: def _is_refresh_cache(self): return self.is_refresh_cache(self.cache_policy) - @property - def permissions(self): - if self._permissions: - return self._permissions - object_cls = self.object.__class__.__name__ - func = self.get_permissions_map[object_cls] - permissions = func(self.object) - self._permissions = permissions - return permissions - - def filter_permissions(self, **filters): - filters_json = json.dumps(filters, sort_keys=True) - self._permissions = self.permissions.filter(**filters) - self._filter_id = md5(filters_json.encode()).hexdigest() - - @staticmethod - def _structured_system_user(system_users, actions): - """ - 结构化系统用户 - :param system_users: - :param actions: - :return: {system_user1: {'actions': set(), }, } - """ - _attr = {'actions': set(actions)} - _system_users = {system_user: _attr for system_user in system_users} - return _system_users - - def get_nodes_direct(self): - """ - 返回用户/组授权规则直接关联的节点 - :return: {asset1: {system_user1: {'actions': set()},}} - """ - nodes = defaultdict(dict) - permissions = self.permissions.prefetch_related('nodes', 'system_users') - for perm in permissions: - actions = perm.actions.all() - self.tree.add_nodes(perm.nodes.all()) - for node in perm.nodes.all(): - system_users = perm.system_users.all() - system_users = self._structured_system_user(system_users, actions) - nodes[node].update(system_users) - return nodes - - def get_assets_direct(self): - """ - - 返回用户授权规则直接关联的资产 - :return: {asset1: {system_user1: {'actions': set()},}} - """ - assets = defaultdict(dict) - permissions = self.permissions.prefetch_related('assets', 'system_users') - for perm in permissions: - actions = perm.actions.all() - for asset in perm.assets.all().valid().prefetch_related('nodes'): - system_users = perm.system_users.filter(protocol__in=asset.protocols_name) - system_users = self._structured_system_user(system_users, actions) - assets[asset].update(system_users) - return assets - - def get_assets_without_cache(self): - """ - :return: {asset1: set(system_user1,)} - """ - if self._assets: - return self._assets - assets = self.get_assets_direct() - nodes = self.get_nodes_direct() - for node, system_users in nodes.items(): - _assets = node.get_all_assets().valid().prefetch_related('nodes') - for asset in _assets: - for system_user, attr_dict in system_users.items(): - if not asset.has_protocol(system_user.protocol): - continue - if system_user in assets[asset]: - actions = assets[asset][system_user]['actions'] - attr_dict['actions'].update(actions) - system_users.update({system_user: attr_dict}) - assets[asset].update(system_users) - - __assets = defaultdict(set) - for asset, system_users in assets.items(): - for system_user, attr_dict in system_users.items(): - setattr(system_user, 'actions', attr_dict['actions']) - __assets[asset] = set(system_users.keys()) - - self._assets = __assets - return self._assets - + @timeit def get_cache_key(self, resource): cache_key = self.CACHE_KEY_PREFIX + '{obj_id}_{filter_id}_{resource}' return cache_key.format( @@ -301,27 +197,6 @@ class AssetPermissionUtil: cached = cache.get(self.asset_key) return cached - def get_assets(self): - if self._is_not_using_cache(): - return self.get_assets_from_cache() - elif self._is_refresh_cache(): - self.expire_cache() - return self.get_assets_from_cache() - else: - self.expire_cache() - return self.get_assets_without_cache() - - def get_nodes_with_assets_without_cache(self): - """ - 返回节点并且包含资产 - {"node": {"assets": set("system_user")}} - :return: - """ - assets = self.get_assets_without_cache() - for asset, system_users in assets.items(): - self.tree.add_asset(asset, system_users) - return self.tree.get_nodes() - def get_nodes_with_assets_from_cache(self): cached = cache.get(self.node_key) if not cached: @@ -338,13 +213,6 @@ class AssetPermissionUtil: else: return self.get_nodes_with_assets_without_cache() - def get_system_user_without_cache(self): - system_users = set() - permissions = self.permissions.prefetch_related('system_users') - for perm in permissions: - system_users.update(perm.system_users.all()) - return system_users - def get_system_user_from_cache(self): cached = cache.get(self.system_key) if not cached: @@ -418,6 +286,152 @@ class AssetPermissionUtil: cache.delete_pattern(key) +class AssetPermissionUtil(AssetPermissionCacheMixin): + get_permissions_map = { + "User": get_user_permissions, + "UserGroup": get_user_group_permissions, + "Asset": get_asset_permissions, + "Node": get_node_permissions, + "SystemUser": get_system_user_permissions, + } + + def __init__(self, obj, cache_policy='0'): + self.object = obj + self.obj_id = str(obj.id) + self._permissions = None + self._permissions_id = None # 标记_permission的唯一值 + self._assets = None + self._filter_id = 'None' # 当通过filter更改 permission是标记 + self.cache_policy = cache_policy + self.tree = GenerateTree() + self.change_org_if_need() + self.nodes = None + + @staticmethod + def change_org_if_need(): + set_to_root_org() + + @property + def permissions(self): + if self._permissions: + return self._permissions + object_cls = self.object.__class__.__name__ + func = self.get_permissions_map[object_cls] + permissions = func(self.object) + self._permissions = permissions + return permissions + + @timeit + def filter_permissions(self, **filters): + filters_json = json.dumps(filters, sort_keys=True) + self._permissions = self.permissions.filter(**filters) + self._filter_id = md5(filters_json.encode()).hexdigest() + + @staticmethod + @timeit + def _structured_system_user(system_users, actions): + """ + 结构化系统用户 + :param system_users: + :param actions: + :return: {system_user1: {'actions': set(), }, } + """ + _attr = {'actions': set(actions)} + _system_users = {system_user: _attr for system_user in system_users} + return _system_users + + @timeit + def get_nodes_direct(self): + """ + 返回用户/组授权规则直接关联的节点 + :return: {asset1: {system_user1: {'actions': set()},}} + """ + nodes = defaultdict(dict) + permissions = self.permissions.prefetch_related('nodes', 'system_users', 'actions') + for perm in permissions: + actions = perm.actions.all() + for node in perm.nodes.all(): + system_users = perm.system_users.all() + system_users = self._structured_system_user(system_users, actions) + nodes[node].update(system_users) + self.tree.add_nodes(nodes.keys()) + # 替换成优化过的node + nodes = {self.tree.node_util.get_node_by_key(k.key): v for k, v in nodes.items()} + return nodes + + @timeit + def get_assets_direct(self): + """ + 返回用户授权规则直接关联的资产 + :return: {asset1: {system_user1: {'actions': set()},}} + """ + assets = defaultdict(dict) + permissions = self.permissions.prefetch_related('assets', 'system_users') + for perm in permissions: + actions = perm.actions.all() + for asset in perm.assets.all().valid().prefetch_related('nodes'): + system_users = perm.system_users.filter(protocol__in=asset.protocols_name) + system_users = self._structured_system_user(system_users, actions) + assets[asset].update(system_users) + return assets + + @timeit + def get_assets_without_cache(self): + """ + :return: {asset1: set(system_user1,)} + """ + if self._assets: + return self._assets + assets = self.get_assets_direct() + nodes = self.get_nodes_direct() + # for node, system_users in nodes.items(): + # print(9999, node) + # _assets = node.get_all_valid_assets() + # print(".......... end .......") + # for asset in _assets: + # print(">>asset") + # for system_user, attr_dict in system_users.items(): + # print(">>>system user") + # if not asset.has_protocol(system_user.protocol): + # continue + # if system_user in assets[asset]: + # actions = assets[asset][system_user]['actions'] + # attr_dict['actions'].update(actions) + # system_users.update({system_user: attr_dict}) + # print("<<>>>>>") + # + __assets = defaultdict(set) + for asset, system_users in assets.items(): + for system_user, attr_dict in system_users.items(): + setattr(system_user, 'actions', attr_dict['actions']) + __assets[asset] = set(system_users.keys()) + + self._assets = __assets + return self._assets + + @timeit + def get_nodes_with_assets_without_cache(self): + """ + 返回节点并且包含资产 + {"node": {"assets": set("system_user")}} + :return: + """ + assets = self.get_assets_without_cache() + for asset, system_users in assets.items(): + self.tree.add_asset(asset, system_users) + return self.tree.get_nodes() + + def get_system_user_without_cache(self): + system_users = set() + permissions = self.permissions.prefetch_related('system_users') + for perm in permissions: + system_users.update(perm.system_users.all()) + return system_users + + def is_obj_attr_has(obj, val, attrs=("hostname", "ip", "comment")): if not attrs: vals = [val for val in obj.__dict__.values() if isinstance(val, (str, int))] diff --git a/apps/settings/api.py b/apps/settings/api.py index 37d63bef1..9a10f6fcc 100644 --- a/apps/settings/api.py +++ b/apps/settings/api.py @@ -242,22 +242,3 @@ class CommandStorageDeleteAPI(APIView): storage_name = str(request.data.get('name')) Setting.delete_storage('TERMINAL_COMMAND_STORAGE', storage_name) return Response({"msg": _('Delete succeed')}, status=200) - - -class DjangoSettingsAPI(APIView): - def get(self, request): - if not settings.DEBUG: - return Response("Not in debug mode") - - data = {} - for i in [settings, getattr(settings, '_wrapped')]: - if not i: - continue - for k, v in i.__dict__.items(): - if k and k.isupper(): - try: - json.dumps(v) - data[k] = v - except (json.JSONDecodeError, TypeError): - data[k] = str(v) - return Response(data) \ No newline at end of file diff --git a/apps/settings/urls/api_urls.py b/apps/settings/urls/api_urls.py index df8eede8a..bc2e4731f 100644 --- a/apps/settings/urls/api_urls.py +++ b/apps/settings/urls/api_urls.py @@ -15,5 +15,4 @@ urlpatterns = [ path('terminal/replay-storage/delete/', api.ReplayStorageDeleteAPI.as_view(), name='replay-storage-delete'), path('terminal/command-storage/create/', api.CommandStorageCreateAPI.as_view(), name='command-storage-create'), path('terminal/command-storage/delete/', api.CommandStorageDeleteAPI.as_view(), name='command-storage-delete'), - path('django-settings/', api.DjangoSettingsAPI.as_view(), name='django-settings'), ]