Files
jumpserver/apps/assets/tree/asset_tree.py
2025-12-19 19:53:08 +08:00

163 lines
5.1 KiB
Python

from collections import defaultdict
from django.db.models import Count, Q
from orgs.utils import current_org
from orgs.models import Organization
from assets.models import Asset, Node, Platform
from assets.const.category import Category
from common.utils import get_logger, timeit
from .tree import TreeNode, Tree
logger = get_logger(__name__)
__all__ = ['AssetTree', 'AssetSearchTree']
class AssetTreeNode(TreeNode):
def __init__(self, _id, key: str, value: str, assets_count: int=0):
super().__init__(_id, key, value)
self.assets_count = assets_count
self.assets_count_total = 0
def as_dict(self, simple=True):
base_dict = super().as_dict(simple=simple)
base_dict.update({
'assets_count_total': self.assets_count_total,
'assets_count': self.assets_count,
})
return base_dict
class AssetTree(Tree):
TreeNode = AssetTreeNode
def __init__(self, org=None):
super().__init__()
self._org: Organization = org or current_org()
self._nodes_attr_mapper = defaultdict(dict)
self._nodes_assets_count_mapper = defaultdict(int)
@timeit
def build(self):
self._pre_build()
self._load_nodes_attr_mapper()
self._load_nodes_assets_count()
self._init_tree()
self._compute_assets_count_total()
self._after_build()
def _pre_build(self):
""" 预处理操作 """
pass
def _after_build(self):
""" 构建完成后的操作 """
pass
@timeit
def _load_nodes_attr_mapper(self):
nodes = Node.objects.filter(org_id=self._org.id).values('id', 'key', 'value')
# 保证节点按 key 顺序加载,以便后续构建树时父节点总在子节点前面
nodes = sorted(nodes, key=lambda n: [int(i) for i in n['key'].split(':')])
for node in list(nodes):
node['id'] = str(node['id'])
self._nodes_attr_mapper[node['id']] = node
@timeit
def _load_nodes_assets_count(self):
q_ = self._make_assets_q_object()
nodes_count = Asset.objects.filter(q_).values('node_id').annotate(
count=Count('id')
).values('node_id', 'count')
for nc in list(nodes_count):
nc['node_id'] = str(nc['node_id'])
self._nodes_assets_count_mapper[nc['node_id']] = nc['count']
@timeit
def _make_assets_q_object(self) -> Q:
q_org = Q(org_id=self._org.id)
return q_org
@timeit
def _init_tree(self):
for nid in self._nodes_attr_mapper.keys():
data = self._get_tree_node_data(nid)
node = self.TreeNode(**data)
self.add_node(node)
def _get_tree_node_data(self, node_id):
attr = self._nodes_attr_mapper[node_id]
assets_count = self._nodes_assets_count_mapper.get(node_id, 0)
data = {
'_id': node_id,
'key': attr['key'],
'value': attr['value'],
'assets_count': assets_count,
}
return data
@timeit
def _compute_assets_count_total(self):
for node in reversed(list(self.nodes.values())):
total = node.assets_count
for child in node.children:
child: AssetTreeNode
total += child.assets_count_total
node: AssetTreeNode
node.assets_count_total = total
class AssetSearchTree(AssetTree):
def __init__(self, assets_q_object: Q = None, category=None, org=None):
super().__init__(org)
self._q_assets: Q = assets_q_object or Q()
self._category = self._check_category(category)
self._platform_ids = set()
def _check_category(self, category):
if category is None:
return None
if category in Category.values:
return category
logger.warning(f"Invalid category '{category}' for AssetSearchTree.")
return None
def _after_build(self):
super()._after_build()
# 搜索树一般需要移除掉资产数为 0 的节点,只保留有资产的节点
self._remove_nodes_with_zero_assets()
def _make_assets_q_object(self) -> Q:
q_org = super()._make_assets_q_object()
self._load_category_platforms_if_needed()
q_platform = Q(platform_id__in=self._platform_ids) if self._platform_ids else Q()
q = q_org & q_platform & self._q_assets
return q
@timeit
def _load_category_platforms_if_needed(self):
if self._category is None:
return
ids = Platform.objects.filter(category=self._category).values_list('id', flat=True)
ids = self._uuids_to_string(ids)
self._platform_ids = ids
@timeit
def _remove_nodes_with_zero_assets(self):
nodes: list[AssetTreeNode] = list(self.nodes.values())
nodes_to_remove = [
node for node in nodes
# 不移除根节点
if not node.is_root and node.assets_count_total == 0
]
for node in nodes_to_remove:
self.remove_node(node)
def _uuids_to_string(self, uuids):
return [ str(u) for u in uuids ]