mirror of
https://github.com/jumpserver/jumpserver.git
synced 2025-12-20 19:12:54 +00:00
163 lines
5.1 KiB
Python
163 lines
5.1 KiB
Python
from collections import defaultdict
|
|
from django.db.models import Count, Q
|
|
|
|
from orgs.utils import current_org
|
|
from orgs.models import Organization
|
|
from assets.models import Asset, Node, Platform
|
|
from assets.const.category import Category
|
|
from common.utils import get_logger, timeit
|
|
|
|
from .tree import TreeNode, Tree
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
__all__ = ['AssetTree', 'AssetSearchTree']
|
|
|
|
|
|
class AssetTreeNode(TreeNode):
|
|
|
|
def __init__(self, _id, key: str, value: str, assets_count: int=0):
|
|
super().__init__(_id, key, value)
|
|
self.assets_count = assets_count
|
|
self.assets_count_total = 0
|
|
|
|
def as_dict(self, simple=True):
|
|
base_dict = super().as_dict(simple=simple)
|
|
base_dict.update({
|
|
'assets_count_total': self.assets_count_total,
|
|
'assets_count': self.assets_count,
|
|
})
|
|
return base_dict
|
|
|
|
|
|
class AssetTree(Tree):
|
|
|
|
TreeNode = AssetTreeNode
|
|
|
|
def __init__(self, org=None):
|
|
super().__init__()
|
|
self._org: Organization = org or current_org()
|
|
self._nodes_attr_mapper = defaultdict(dict)
|
|
self._nodes_assets_count_mapper = defaultdict(int)
|
|
|
|
@timeit
|
|
def build(self):
|
|
self._pre_build()
|
|
self._load_nodes_attr_mapper()
|
|
self._load_nodes_assets_count()
|
|
self._init_tree()
|
|
self._compute_assets_count_total()
|
|
self._after_build()
|
|
|
|
def _pre_build(self):
|
|
""" 预处理操作 """
|
|
pass
|
|
|
|
def _after_build(self):
|
|
""" 构建完成后的操作 """
|
|
pass
|
|
|
|
@timeit
|
|
def _load_nodes_attr_mapper(self):
|
|
nodes = Node.objects.filter(org_id=self._org.id).values('id', 'key', 'value')
|
|
# 保证节点按 key 顺序加载,以便后续构建树时父节点总在子节点前面
|
|
nodes = sorted(nodes, key=lambda n: [int(i) for i in n['key'].split(':')])
|
|
for node in list(nodes):
|
|
node['id'] = str(node['id'])
|
|
self._nodes_attr_mapper[node['id']] = node
|
|
|
|
@timeit
|
|
def _load_nodes_assets_count(self):
|
|
q_ = self._make_assets_q_object()
|
|
nodes_count = Asset.objects.filter(q_).values('node_id').annotate(
|
|
count=Count('id')
|
|
).values('node_id', 'count')
|
|
for nc in list(nodes_count):
|
|
nc['node_id'] = str(nc['node_id'])
|
|
self._nodes_assets_count_mapper[nc['node_id']] = nc['count']
|
|
|
|
@timeit
|
|
def _make_assets_q_object(self) -> Q:
|
|
q_org = Q(org_id=self._org.id)
|
|
return q_org
|
|
|
|
@timeit
|
|
def _init_tree(self):
|
|
for nid in self._nodes_attr_mapper.keys():
|
|
data = self._get_tree_node_data(nid)
|
|
node = self.TreeNode(**data)
|
|
self.add_node(node)
|
|
|
|
def _get_tree_node_data(self, node_id):
|
|
attr = self._nodes_attr_mapper[node_id]
|
|
assets_count = self._nodes_assets_count_mapper.get(node_id, 0)
|
|
data = {
|
|
'_id': node_id,
|
|
'key': attr['key'],
|
|
'value': attr['value'],
|
|
'assets_count': assets_count,
|
|
}
|
|
return data
|
|
|
|
@timeit
|
|
def _compute_assets_count_total(self):
|
|
for node in reversed(list(self.nodes.values())):
|
|
total = node.assets_count
|
|
for child in node.children:
|
|
child: AssetTreeNode
|
|
total += child.assets_count_total
|
|
node: AssetTreeNode
|
|
node.assets_count_total = total
|
|
|
|
|
|
|
|
class AssetSearchTree(AssetTree):
|
|
|
|
def __init__(self, assets_q_object: Q = None, category=None, org=None):
|
|
super().__init__(org)
|
|
self._q_assets: Q = assets_q_object or Q()
|
|
self._category = self._check_category(category)
|
|
self._platform_ids = set()
|
|
|
|
def _check_category(self, category):
|
|
if category is None:
|
|
return None
|
|
if category in Category.values:
|
|
return category
|
|
logger.warning(f"Invalid category '{category}' for AssetSearchTree.")
|
|
return None
|
|
|
|
def _after_build(self):
|
|
super()._after_build()
|
|
# 搜索树一般需要移除掉资产数为 0 的节点,只保留有资产的节点
|
|
self._remove_nodes_with_zero_assets()
|
|
|
|
def _make_assets_q_object(self) -> Q:
|
|
q_org = super()._make_assets_q_object()
|
|
self._load_category_platforms_if_needed()
|
|
q_platform = Q(platform_id__in=self._platform_ids) if self._platform_ids else Q()
|
|
q = q_org & q_platform & self._q_assets
|
|
return q
|
|
|
|
@timeit
|
|
def _load_category_platforms_if_needed(self):
|
|
if self._category is None:
|
|
return
|
|
ids = Platform.objects.filter(category=self._category).values_list('id', flat=True)
|
|
ids = self._uuids_to_string(ids)
|
|
self._platform_ids = ids
|
|
|
|
@timeit
|
|
def _remove_nodes_with_zero_assets(self):
|
|
nodes: list[AssetTreeNode] = list(self.nodes.values())
|
|
nodes_to_remove = [
|
|
node for node in nodes
|
|
# 不移除根节点
|
|
if not node.is_root and node.assets_count_total == 0
|
|
]
|
|
for node in nodes_to_remove:
|
|
self.remove_node(node)
|
|
|
|
def _uuids_to_string(self, uuids):
|
|
return [ str(u) for u in uuids ] |