refactor: Tree again, later ...

This commit is contained in:
Bai
2026-01-06 15:02:05 +08:00
parent 07bea445a5
commit 2b3942687e
20 changed files with 687 additions and 392 deletions

View File

@@ -5,25 +5,25 @@ from rest_framework.request import Request
from assets.models import Node, Platform, Protocol, MyAsset
from assets.utils import get_node_from_request, is_query_node_all_assets
from common.utils import lazyproperty, timeit
from assets.tree.asset_tree import AssetTreeNode
from assets.tree.node_tree import NodeTreeNode
class SerializeToTreeNodeMixin:
request: Request
@timeit
def serialize_nodes(self, nodes: List[AssetTreeNode], with_asset_amount=False, expand_level=1, tree_type='node'):
def serialize_nodes(self, nodes: List[NodeTreeNode], with_asset_amount=False, expand_level=1, tree_type='node'):
if not nodes:
return []
def _name(node: AssetTreeNode):
def _name(node: NodeTreeNode):
v = node.value
if not with_asset_amount:
return v
v = f'{v} ({node.assets_amount_total})'
return v
def is_parent(node: AssetTreeNode):
def is_parent(node: NodeTreeNode):
if tree_type == 'asset':
return node.assets_amount > 0 or not node.is_leaf
else: # tree_type == 'node'
@@ -45,7 +45,7 @@ class SerializeToTreeNodeMixin:
"value": node.value,
"assets_amount": node.assets_amount,
"assets_amount_total": node.assets_amount_total,
"children_count_total": node.children_count_total,
# "children_count_total": node.children_count_total,
},
}
}

View File

@@ -1,3 +1,4 @@
from abc import abstractmethod, abstractproperty
from django.db.models import Q
from rest_framework import generics
@@ -9,10 +10,10 @@ from common.utils import lazyproperty, timeit
from common.exceptions import APIException
from orgs.utils import current_org
from rbac.permissions import RBACPermission
from assets.tree.asset_tree import AssetTree
from assets.models import Node
from assets.tree.node_tree import AssetNodeTree, NodeTreeNode
from .mixin import SerializeToTreeNodeMixin
from .const import RenderTreeType, RenderTreeTypeChoices, RenderTreeView, RenderTreeViewChoices
from .const import RenderTreeView, RenderTreeViewChoices
__all__ = ['AbstractAssetTreeAPI']
@@ -20,13 +21,12 @@ __all__ = ['AbstractAssetTreeAPI']
class AbstractAssetTreeAPI(SerializeToTreeNodeMixin, generics.ListAPIView):
# TODO: 子类必须定义 rbac_perms 属性限制权限
# 子类必须指定权限 rbac_perms#
permission_classes = (RBACPermission,)
# query parameters keys
query_search_key = 'search'
query_search_key_value_sep = ':'
query_tree_type_key = 'tree_type'
query_tree_view_key = 'tree_view'
query_asset_category_key = 'category'
query_asset_type_key = 'type'
@@ -40,36 +40,22 @@ class AbstractAssetTreeAPI(SerializeToTreeNodeMixin, generics.ListAPIView):
search_assets_per_org_limit_max = 1000
search_assets_per_org_limit_min = 100
render_tree_type: RenderTreeType
@lazyproperty
def tree_with_assets(self):
with_assets = self.request.query_params.get('assets', '0') == '1'
return with_assets
tree_user: User
def initial(self, request, *args, **kwargs):
super().initial(request, *args, **kwargs)
self.render_tree_view = self.initial_render_tree_view()
self.render_tree_type = self.initial_render_tree_type()
self.tree_user = self.get_tree_user()
def initial_render_tree_view(self):
# 资产树视图
# 默认是节点视图
# 扩展支持 category 视图, label 视图等等
@lazyproperty
def tree_view(self):
# 资产树视图 # 默认是节点视图 # 扩展支持 category 视图, label 视图等等
tree_view = self.get_query_value(self.query_tree_view_key)
if not tree_view:
tree_view = RenderTreeViewChoices.node
tree_view = tree_view or RenderTreeViewChoices.node
return RenderTreeView(tree_view)
def initial_render_tree_type(self):
tree_type = self.get_query_value(self.query_tree_type_key)
if not tree_type:
# 兼容 assets=1 参数
with_assets = self.request.query_params.get('assets', '0') == '1'
if with_assets:
tree_type = RenderTreeTypeChoices.asset
else:
tree_type = RenderTreeTypeChoices.node
return RenderTreeType(tree_type)
@lazyproperty
def tree_user(self) -> User:
return self.get_tree_user()
def get_tree_user(self) -> User:
# 抽象方法: 获取为哪个用户渲染树 #
raise NotImplementedError
@@ -95,66 +81,137 @@ class AbstractAssetTreeAPI(SerializeToTreeNodeMixin, generics.ListAPIView):
query_value = _search.replace(f'{query_key}{sep}', '').strip()
return query_value
def get_org_asset_tree(self, **kwargs) -> AssetTree:
# 抽象方法: 获取组织的资产树 #
return self._get_org_asset_tree(tree_view=self.render_tree_view, **kwargs)
def _get_org_asset_tree(self, **kwargs) -> AssetTree:
@abstractmethod
def get_asset_tree(self, assets_scope_q=None, asset_category=None, asset_type=None, org=None):
raise NotImplementedError
@lazyproperty
def org_is_global(self):
return current_org.is_root()
def get_tree_user_orgs(self):
def tree_user_orgs(self):
# 重要: 获取用户有权限渲染树的组织列表 #
user = self.tree_user
if self.org_is_global:
# 如果是全局组织,返回用户所在的所有实体组织
orgs = user.orgs.all()
else:
# 如果时实体组织,从用户所在的实体组织中返回该实体组织
orgs = user.orgs.filter(id=current_org.id)
orgs = self.tree_user.orgs.all()
if not current_org.is_root():
orgs = orgs.filter(id=current_org.id)
if not orgs.exists():
raise APIException(
'No organization available for rendering the tree'
)
raise APIException('No organization available for rendering the tree')
return orgs
def get_search_asset_keyword(self):
search_asset = self.get_query_value(self.query_search_asset_key)
if self.tree_with_assets and not search_asset:
# 兼容 search 为搜索资产
search = self.get_query_value(self.query_search_key) or ''
search_asset = search if self.query_search_key_value_sep not in search else ''
return search_asset
@timeit
def list(self, request, *args, **kwargs):
# 渲染资产树 API 接口 #
# 支持渲染节点树和资产树两种类型
# 节点树: 只返回节点
# 资产树: 返回节点和节点下的资产
asset_category = self.get_query_value(self.query_asset_category_key)
asset_type = self.get_query_value(self.query_asset_type_key)
with_asset_amount = True
expand_node_key = self.get_query_value(self.query_expand_node_key)
search_node = self.get_query_value(self.query_search_node_key)
search_asset = self.get_query_value(self.query_search_asset_key)
if self.render_tree_type.is_asset_tree:
if not search_asset:
# 兼容 search 为搜索资产
search = self.get_query_value(self.query_search_key) or ''
sep = self.query_search_key_value_sep
if sep not in search:
search_asset = search
search_asset = self.get_search_asset_keyword()
data = self._list(
expand_node_key=expand_node_key,
search_node=search_node, search_asset=search_asset,
asset_category=asset_category, asset_type=asset_type,
with_asset_amount=with_asset_amount
)
data = []
for org in self.tree_user_orgs:
nodes, assets = self.get_org_nodes_assets_data(
org=org,
expand_node_key=expand_node_key,
search_node=search_node,
search_asset=search_asset,
asset_category=asset_category,
asset_type=asset_type,
with_asset_amount=with_asset_amount
)
data.extend(nodes)
data.extend(assets)
return Response(data=data)
def get_org_nodes_assets_data(
self, org, expand_node_key=None, search_node=None, search_asset=None,
asset_category=None, asset_type=None, with_asset_amount=True
):
tree = self.get_asset_tree(
search_asset=search_asset, asset_category=asset_category, asset_type=asset_type,
org=org, with_asset_amount=with_asset_amount
)
nodes = []
assets = []
if self.tree_with_assets:
if expand_node_key:
node = tree.get_node(key=expand_node_key)
if not node:
raise APIException(f'Node not found: {expand_node_key}')
nodes = node.children
assets = tree.get_tree_assets(nodes=[node])
elif search_node:
# 只展开父节点
pass
elif search_asset:
nodes = tree.get_nodes(with_empty_assets_branch=False)
assets = tree.get_tree_assets(limit=10)
for node in nodes:
if node.key == '1:0:1:1:0':
print('.........')
node: NodeTreeNode
setattr(node, 'is_parent', False)
setattr(node, 'open', True)
else:
if current_org.is_root():
nodes = [tree.root]
else:
tree_root = tree.root
setattr(tree_root, 'open', True)
nodes = [tree_root] + tree_root.children
assets = tree.get_tree_assets(nodes=[tree_root])
for node in nodes:
node: NodeTreeNode
is_parent = not node.is_leaf or node.assets_amount > 0
setattr(node, 'is_parent', is_parent)
else:
nodes = tree.get_nodes()
if not current_org.is_root():
setattr(tree.root, 'open', True)
for node in nodes:
node: NodeTreeNode
is_parent = not node.is_leaf
setattr(node, 'is_parent', is_parent)
if with_asset_amount:
for node in nodes:
node.name = f'{node.name} ({node.assets_amount_total})'
data_nodes = self.serialize_nodes(nodes=nodes)
data_assets = self.serialize_assets(assets)
return data_nodes, data_assets
def get_asset_tree(self, search_asset=None, asset_category=None, asset_type=None, org=None,
with_asset_amount=True):
assets_scope_q = None
if search_asset:
assets_scope_q = Q(name__icontains=search_asset) | Q(address__icontains=search_asset)
tree = AssetNodeTree(
assets_scope_q=assets_scope_q,
asset_category=asset_category,
asset_type=asset_type,
org=org
)
tree.init(with_assets_amount=with_asset_amount)
return tree
@timeit
def _list(self, expand_node_key=None, search_node=None, search_asset=None,
asset_category=None, asset_type=None, with_asset_amount=True):
if self.render_tree_type.is_node_tree:
if not self.tree_with_assets:
data = self.render_node_tree(
asset_category=asset_category, asset_type=asset_type,
with_asset_amount=with_asset_amount
@@ -228,7 +285,7 @@ class AbstractAssetTreeAPI(SerializeToTreeNodeMixin, generics.ListAPIView):
nodes = []
for org in orgs:
tree = self.get_org_asset_tree(
tree = self.get_asset_tree(
asset_category=asset_category, asset_type=asset_type, org=org
)
_nodes = tree.get_nodes()
@@ -254,7 +311,7 @@ class AbstractAssetTreeAPI(SerializeToTreeNodeMixin, generics.ListAPIView):
nodes = []
assets = []
for org in orgs:
tree: AssetTree = self.get_org_asset_tree(
tree = self.get_asset_tree(
asset_category=asset_category,
asset_type=asset_type,
org=org,
@@ -297,7 +354,7 @@ class AbstractAssetTreeAPI(SerializeToTreeNodeMixin, generics.ListAPIView):
node_id = node_key # 在类别视图中,节点 key 就是节点 id
with_assets_node_id = node_id
tree: AssetTree = self.get_org_asset_tree(
tree = self.get_asset_tree(
asset_category=asset_category,
asset_type=asset_type,
org=org,
@@ -332,7 +389,7 @@ class AbstractAssetTreeAPI(SerializeToTreeNodeMixin, generics.ListAPIView):
matched_nodes = []
matched_nodes_ancestors = []
for org in orgs:
tree: AssetTree = self.get_org_asset_tree(
tree = self.get_asset_tree(
asset_category=asset_category,
asset_type=asset_type,
org=org
@@ -397,7 +454,7 @@ class AbstractAssetTreeAPI(SerializeToTreeNodeMixin, generics.ListAPIView):
nodes = []
assets = []
for org in orgs:
tree: AssetTree = self.get_org_asset_tree(
tree = self.get_asset_tree(
assets_q_object=assets_q_object,
asset_category=asset_category,
asset_type=asset_type, org=org,

View File

@@ -2,21 +2,20 @@ from django.db.models import TextChoices
__all__ = [
'RenderTreeType', 'RenderTreeTypeChoices',
'RenderTreeView', 'RenderTreeViewChoices',
]
class RenderTreeViewChoices(TextChoices):
node = 'node', 'Node View'
category = 'category', 'Category View'
node = 'node', 'Node tree'
category = 'category', 'Category tree'
class RenderTreeView:
def __init__(self, view):
if view not in RenderTreeViewChoices.values:
raise ValueError(f'Invalid tree view: {view}')
view = RenderTreeViewChoices.node
self.view: RenderTreeViewChoices = view
@property
@@ -29,27 +28,3 @@ class RenderTreeView:
def __str__(self):
return self.view.value
class RenderTreeTypeChoices(TextChoices):
node = 'node', 'Node'
asset = 'asset', 'Asset'
class RenderTreeType:
def __init__(self, _type):
if _type not in RenderTreeTypeChoices.values:
raise ValueError(f'Invalid tree type: {_type}')
self._type: RenderTreeTypeChoices = _type
@property
def is_asset_tree(self):
return self._type == RenderTreeTypeChoices.asset
@property
def is_node_tree(self):
return self._type == RenderTreeTypeChoices.node
def __str__(self):
return self._type.value

View File

@@ -4,9 +4,7 @@ from rest_framework.request import Request
from assets.models import Platform, Protocol, MyAsset
from common.utils import lazyproperty, timeit
from assets.tree.asset_tree import AssetTreeNode, AssetTreeNodeAsset
from .const import RenderTreeType
from assets.tree.node_tree import NodeTreeNode, TreeAsset
__all__ = ['SerializeToTreeNodeMixin']
@@ -16,34 +14,18 @@ class SerializeToTreeNodeMixin:
request: Request
@timeit
def serialize_nodes(self, nodes: List[AssetTreeNode], tree_type: RenderTreeType,
with_asset_amount=False, expand_level=1, expand_all=False):
def serialize_nodes(self, nodes: List[NodeTreeNode]):
if not nodes:
return []
def _name(node: AssetTreeNode):
v = node.value
if not with_asset_amount:
return v
v = f'{v} ({node.assets_amount_total})'
return v
def is_parent(node: AssetTreeNode):
if tree_type.is_asset_tree:
return node.assets_amount > 0 or not node.is_leaf
elif tree_type.is_node_tree:
return not node.is_leaf
else:
return True
data = [
{
'id': node.key,
'name': _name(node),
'title': _name(node),
'name': node.name,
'title': node.name,
'pId': node.parent_key,
'isParent': is_parent(node),
'open': expand_all or node.level <= expand_level,
'isParent': node.is_parent,
'open': getattr(node, 'open', False),
'meta': {
'type': 'node',
'data': {
@@ -52,7 +34,7 @@ class SerializeToTreeNodeMixin:
"value": node.value,
"assets_amount": node.assets_amount,
"assets_amount_total": node.assets_amount_total,
"children_count_total": node.children_count_total,
# "children_count_total": node.children_count_total,
},
}
}
@@ -74,15 +56,11 @@ class SerializeToTreeNodeMixin:
return 'file'
@timeit
def serialize_assets(self, assets, node_key=None, get_pid=None):
def serialize_assets(self, assets):
if not assets:
return []
if not get_pid and not node_key:
get_pid = lambda asset, platform: getattr(asset, 'parent_key', '')
sftp_asset_ids = Protocol.objects.filter(name='sftp') \
.values_list('asset_id', flat=True)
sftp_asset_ids = Protocol.objects.filter(name='sftp').values_list('asset_id', flat=True)
sftp_asset_ids = set(sftp_asset_ids)
platform_map = {p.id: p for p in Platform.objects.all()}
@@ -90,15 +68,13 @@ class SerializeToTreeNodeMixin:
root_assets_count = 0
MyAsset.set_asset_custom_value(assets, self.request.user)
for asset in assets:
asset: AssetTreeNodeAsset
asset: TreeAsset
platform = platform_map.get(asset.platform_id)
if not platform:
continue
pid = node_key or get_pid(asset, platform)
if not pid:
continue
parent = asset.tree_node
# 根节点最多显示 1000 个资产
if pid.isdigit():
if parent.is_root:
if root_assets_count > 1000:
continue
root_assets_count += 1
@@ -106,7 +82,7 @@ class SerializeToTreeNodeMixin:
'id': str(asset.id),
'name': asset.name,
'title': f'{asset.address}\n{asset.comment}'.strip(),
'pId': pid,
'pId': parent.key,
'isParent': False,
'open': False,
'iconSkin': self.get_icon(platform),

View File

@@ -16,8 +16,8 @@ from .const import RenderTreeView
from ... import serializers
from ...const import AllTypes
from ...models import Node, Platform, Asset
from assets.tree.asset_tree import AssetTree
from assets.tree.category import AssetTreeCategoryView
from assets.tree.node_tree import AssetNodeTree
# from assets.tree.category import AssetTreeCategoryView
from .base import AbstractAssetTreeAPI
@@ -109,11 +109,12 @@ class AssetTreeAPI(AbstractAssetTreeAPI):
def get_tree_user(self):
return self.request.user
def _get_org_asset_tree(self, tree_view: RenderTreeView, **kwargs) -> AssetTree:
def _get_asset_tree(self, tree_view: RenderTreeView, **kwargs):
if tree_view.is_node_view:
tree = AssetTree(**kwargs)
tree = AssetNodeTree(**kwargs)
elif tree_view.is_category_view:
tree = AssetTreeCategoryView(**kwargs)
raise ValueError('Category tree view is not implemented yet')
# tree = AssetTreeCategoryView(**kwargs)
else:
raise ValueError(f'Unsupported tree view: {tree_view}')
return tree

View File

@@ -1 +1,2 @@
from .asset_tree import *
from .node_tree import *
from .tree import *

View File

@@ -0,0 +1,147 @@
from django.db.models import Q
from orgs.models import Organization
from orgs.utils import current_org
from assets.models import Node, Asset
from django.db.models import Count
from .tree import Tree, TreeNode
from common.utils import lazyproperty
__all__ = ['NodeTreeNode', 'AssetNodeTree', 'TreeAsset']
class NodeTreeNode(TreeNode):
model_only_fields = ['id', 'key', 'value']
def __init__(self, id, **kwargs):
super().__init__(**kwargs)
self.id = id
self.value = self.name
self.assets_amount = 0
self.assets = []
def set_assets_amount(self, amount):
self.assets_amount = amount
@lazyproperty
def assets_amount_total(self):
count = self.assets_amount
for child in self.children:
child: NodeTreeNode
count += child.assets_amount_total
return count
class TreeAsset:
model_only_fields = [
'id', 'name', 'address', 'platform__category', 'platform__type', 'node_id',
'platform_id', 'is_active', 'comment', 'org_id'
]
def __init__(self, tree_node: NodeTreeNode, **kwargs):
self.tree_node = tree_node
for k, v in kwargs.items():
setattr(self, k, v)
@property
def org_name(self):
Asset.org_name
org = Organization.get_instance(self.org_id)
return org.name if org else ''
class AssetNodeTree(Tree):
model_node_only_fields = NodeTreeNode.model_only_fields
model_asset_only_fields = TreeAsset.model_only_fields
def __init__(self, assets_scope_q=None, asset_category=None, asset_type=None, org=None):
self.assets_scope_q = assets_scope_q or Q()
self.asset_category = asset_category
self.asset_type = asset_type
self.org: Organization = org if org else current_org
super().__init__()
def construct_tree_nodes(self):
tree_nodes = []
nodes = Node.objects.filter(org_id=self.org.id).only(*self.model_node_only_fields)
for node in nodes:
key = node.key
if key.isdigit():
parent_key = None
else:
parent_key = ':'.join(key.split(':')[:-1])
tree_node = NodeTreeNode(
id=str(node.id),
key=key,
name=node.value,
parent_key=parent_key
)
tree_nodes.append(tree_node)
return tree_nodes
def init(self, with_assets_amount=True):
tree_nodes = self.construct_tree_nodes()
super().init(tree_nodes)
if with_assets_amount:
self.init_tree_nodes_assets_amount()
def assets_scope_queryset(self):
qs = Asset.objects.filter(org_id=self.org.id)
if self.assets_scope_q:
qs = qs.filter(self.assets_scope_q)
if self.asset_type:
qs = qs.filter(type=self.asset_type)
elif self.asset_category:
qs = qs.filter(category=self.asset_category)
return qs
def init_tree_nodes_assets_amount(self):
assets_amounts = self.assets_scope_queryset().values('node_id').annotate(
assets_amount=Count('id')
).values_list('node_id', 'assets_amount')
assets_amount_mapper = {str(node_id): amount for node_id, amount in assets_amounts}
for node in self.nodes.values():
assets_amount = assets_amount_mapper.get(str(node.id), 0)
node: NodeTreeNode
node.set_assets_amount(assets_amount)
def get_tree_assets(self, nodes=None, limit=None):
if nodes is None:
id_nodes_mapper = {node.id: node for node in self.get_nodes()}
filter_node_q = Q()
limit = limit or 1000
else:
id_nodes_mapper = {node.id: node for node in nodes}
node_ids = list(id_nodes_mapper.keys())
filter_node_q = Q(node_id__in=node_ids)
limit = None
tree_assets = []
assets = self.assets_scope_queryset().filter(filter_node_q).values(*self.model_asset_only_fields)
assets = assets[:limit] if limit else assets
for asset in assets:
kwargs = {k: asset[k] for k in self.model_asset_only_fields}
tree_node = id_nodes_mapper.get(str(asset['node_id']))
kwargs.update({'tree_node': tree_node})
tree_asset = TreeAsset(**kwargs)
tree_assets.append(tree_asset)
return tree_assets
def get_nodes(self, with_empty_assets_branch=True):
if not with_empty_assets_branch:
self.remove_zero_assets_amount_total_nodes()
return list(self.nodes.values())
def remove_zero_assets_amount_total_nodes(self):
for node in list(self.nodes.values()):
node: NodeTreeNode
if node.is_root:
continue
if node.assets_amount_total > 0:
continue
parent: NodeTreeNode = node.parent
parent.remove_child(node)
self.nodes.pop(node.key, None)
for descendant in node.descendants():
self.nodes.pop(descendant.key, None)

View File

@@ -1,268 +1,98 @@
from collections import deque
from common.utils import get_logger, lazyproperty, timeit
__all__ = ['TreeNode', 'Tree']
logger = get_logger(__name__)
class TreeNode:
class TreeNode(object):
def __init__(self, _id, key, value, **kwargs):
self.id = _id
def __init__(self, key, name, parent_key):
self.key = key
self.value = value
self.children = []
self.name = name
self.parent_key = parent_key
self.parent = None
self.children_count_total = 0
self.children = []
def match(self, keyword):
if not keyword:
return True
keyword = str(keyword).strip().lower()
node_value = str(self.value).strip().lower()
return keyword in node_value
def add_child(self, child: 'TreeNode'):
child.parent = self
self.children.append(child)
@lazyproperty
def parent_key(self):
if self.is_root:
return None
return ':'.join(self.key.split(':')[:-1])
def remove_child(self, child: 'TreeNode'):
self.children.remove(child)
child.parent = None
def descendants(self) -> list['TreeNode']:
nodes = []
for child in self.children:
child: TreeNode
nodes.append(child)
nodes.extend(child.descendants())
return nodes
def ancestors(self) -> list['TreeNode']:
node = self
ancestors = []
while node.parent:
ancestors.append(node.parent)
node = node.parent
ancestors.reverse()
return ancestors
@property
def is_root(self):
return self.key.isdigit()
def add_child(self, child_node: 'TreeNode'):
child_node.parent = self
self.children.append(child_node)
return self.parent is None
def remove_child(self, child_node: 'TreeNode'):
self.children.remove(child_node)
child_node.parent = None
@property
def level(self):
level = 1
node = self
while node.parent:
level += 1
node = node.parent
return level
@property
def is_leaf(self):
return len(self.children) == 0
@lazyproperty
def level(self):
return self.key.count(':') + 1
def get_ancestor_keys(self):
if self.is_root:
return []
ancestor_keys = []
parts = self.key.split(':')
for i in range(1, len(parts)):
ancestor_key = ':'.join(parts[:i])
ancestor_keys.append(ancestor_key)
return ancestor_keys
def get_descendants(self, node: 'TreeNode'):
"""
返回指定节点的所有子孙节点(不包含自身),非递归实现,按层级从近到远排序。
返回列表,空列表表示没有子孙或节点为 None。
"""
if not node:
return []
descendants = []
dq = deque(node.children)
while dq:
cur = dq.popleft()
descendants.append(cur)
# 复制 children 以防在遍历过程中被修改
for ch in list(cur.children):
dq.append(ch)
return descendants
@property
def children_count(self):
return len(self.children)
def as_dict(self, simple=True):
data = {
'key': self.key,
}
if simple:
return data
data.update({
'id': self.id,
'value': self.value,
'level': self.level,
'children_count': self.children_count,
'is_root': self.is_root,
'is_leaf': self.is_leaf,
})
return data
def print(self, simple=True, is_print_keys=False):
def info_as_string(_info):
return ' | '.join(s.ljust(25) for s in _info)
if is_print_keys:
info_keys = [k for k in self.as_dict(simple=simple).keys()]
info_keys_string = info_as_string(info_keys)
print('-' * len(info_keys_string))
print(info_keys_string)
print('-' * len(info_keys_string))
info_values = [str(v) for v in self.as_dict(simple=simple).values()]
info_values_as_string = info_as_string(info_values)
print(info_values_as_string)
print('-' * len(info_values_as_string))
class Tree(object):
class Tree:
def __init__(self):
self.root = None
# { key -> TreeNode }
self.nodes: dict[TreeNode] = {}
self.nodes = {}
@property
def size(self):
return len(self.nodes)
@property
def is_empty(self):
return self.size == 0
@property
def depth(self):
" 返回树的最大深度以及对应的节点key "
if self.is_empty:
return 0, 0
node = max(self.nodes.values(), key=lambda node: node.level)
node: TreeNode
print(f"max_depth_node_key: {node.key}")
return node.level
@property
def width(self):
" 返回树的最大宽度,以及对应的层级数 "
if self.is_empty:
return 0, 0
node = max(self.nodes.values(), key=lambda node: node.children_count)
node: TreeNode
print(f"max_width_level: {node.level + 1}")
return node.children_count
def add_node(self, node: TreeNode):
if node.is_root:
self.root = node
def init(self, nodes: list[TreeNode]) -> None:
for node in nodes:
self.nodes[node.key] = node
for node in nodes:
self.add_node(node)
def add_node(self, node: TreeNode) -> None:
if node.parent_key is None:
self.root = node
return
parent = self.get_node(node.parent_key)
if not parent:
error = f""" Cannot add node {node.key}: parent key {node.parent_key} not found.
Please ensure parent nodes are added before child nodes."""
raise ValueError(error)
parent.add_child(node)
self.nodes[node.key] = node
def get_node(self, key: str) -> TreeNode:
parent = self.nodes.get(node.parent_key)
if parent:
parent: TreeNode
parent.add_child(node)
else:
raise ValueError(f'Parent with key {node.parent_key} not found for node {node.key}')
def get_node(self, key) -> TreeNode | None:
return self.nodes.get(key)
def remove_node(self, node: TreeNode):
if node.is_root:
self.root = None
else:
parent: TreeNode = node.parent
parent.remove_child(node)
self.nodes.pop(node.key, None)
def get_nodes(self) -> list[TreeNode]:
return list(self.nodes.values())
def search_nodes(self, keyword, only_top_level=False):
if not keyword:
return []
keyword = keyword.strip().lower()
nodes = {}
for node in self.nodes.values():
node: TreeNode
if not node.match(keyword):
continue
nodes[node.key] = node
if not only_top_level:
return list(nodes.values())
# 如果匹配的节点中包含有父子关系的节点,只返回最上一级的父节点
# TODO: 优化性能
node_keys = list(nodes.keys())
children_keys = []
for node_key in node_keys:
_children_keys = [k for k in node_keys if k.startswith(f"{node_key}:")]
children_keys.extend(_children_keys)
for child_key in children_keys:
nodes.pop(child_key, None)
return list(nodes.values())
def remove_nodes_descendants(self, nodes: list[TreeNode]):
descendants = self.get_nodes_descendants(nodes)
for node in reversed(descendants):
self.remove_node(node)
def get_nodes_descendants(self, nodes: list[TreeNode]):
descendants = []
for node in nodes:
ds = node.get_descendants(node)
descendants.extend(ds)
return descendants
def get_nodes_ancestors(self, nodes: list[TreeNode]):
ancestors = set()
for node in nodes:
ancestor_keys = node.get_ancestor_keys()
_ancestors = self.get_nodes_by_keys(ancestor_keys)
ancestors.update(_ancestors)
return list(ancestors)
def get_nodes_by_keys(self, keys):
nodes = []
for key in keys:
node = self.get_node(key)
if node:
nodes.append(node)
return nodes
def get_nodes(self, levels=None):
nodes = list(self.nodes.values())
if levels:
nodes = [ n for n in nodes if n.level in levels ]
return nodes
def get_node_children(self, key, with_self=False):
node = self.get_node(key)
def pre_order_traversal(self, node: TreeNode = None) -> list[TreeNode]:
if node is None:
node = self.root
if not node:
return []
nodes = []
if with_self:
nodes.append(node)
nodes.extend(node.children)
return nodes
@timeit
def _compute_children_count_total(self):
for node in reversed(list(self.nodes.values())):
total = 0
for child in node.children:
child: TreeNode
total += child.children_count_total + 1
node: TreeNode
node.children_count_total = total
def print(self, count=10, simple=True):
print('tree_root_key: ', getattr(self.root, 'key', 'No-root'))
print('tree_size: ', self.size)
print('tree_depth: ', self.depth)
print('tree_width: ', self.width)
is_print_key = True
for n in list(self.nodes.values())[:count]:
n: TreeNode
n.print(simple=simple, is_print_keys=is_print_key)
is_print_key = False
result = [node]
for child in node.children:
result.extend(self.pre_order_traversal(child))
return result

View File

@@ -0,0 +1 @@
from .node_tree import *

View File

@@ -6,7 +6,7 @@ from assets.api import node
# 类别视图的资产树
from .tree import Tree, TreeNode
from .asset_tree import AssetTree, AssetTreeNode, AssetTreeNodeAsset
from .node_tree import AssetTree, AssetTreeNode, AssetTreeNodeAsset
from assets.models import Platform, Asset
from common.utils import timeit, get_logger

268
apps/assets/tree1/tree.py Normal file
View File

@@ -0,0 +1,268 @@
from collections import deque
from common.utils import get_logger, lazyproperty, timeit
__all__ = ['TreeNode', 'Tree']
logger = get_logger(__name__)
class TreeNode(object):
def __init__(self, _id, key, value, **kwargs):
self.id = _id
self.key = key
self.value = value
self.children = []
self.parent = None
self.children_count_total = 0
def match(self, keyword):
if not keyword:
return True
keyword = str(keyword).strip().lower()
node_value = str(self.value).strip().lower()
return keyword in node_value
@lazyproperty
def parent_key(self):
if self.is_root:
return None
return ':'.join(self.key.split(':')[:-1])
@property
def is_root(self):
return self.key.isdigit()
def add_child(self, child_node: 'TreeNode'):
child_node.parent = self
self.children.append(child_node)
def remove_child(self, child_node: 'TreeNode'):
self.children.remove(child_node)
child_node.parent = None
@property
def is_leaf(self):
return len(self.children) == 0
@lazyproperty
def level(self):
return self.key.count(':') + 1
def get_ancestor_keys(self):
if self.is_root:
return []
ancestor_keys = []
parts = self.key.split(':')
for i in range(1, len(parts)):
ancestor_key = ':'.join(parts[:i])
ancestor_keys.append(ancestor_key)
return ancestor_keys
def get_descendants(self, node: 'TreeNode'):
"""
返回指定节点的所有子孙节点(不包含自身),非递归实现,按层级从近到远排序。
返回列表,空列表表示没有子孙或节点为 None。
"""
if not node:
return []
descendants = []
dq = deque(node.children)
while dq:
cur = dq.popleft()
descendants.append(cur)
# 复制 children 以防在遍历过程中被修改
for ch in list(cur.children):
dq.append(ch)
return descendants
@property
def children_count(self):
return len(self.children)
def as_dict(self, simple=True):
data = {
'key': self.key,
}
if simple:
return data
data.update({
'id': self.id,
'value': self.value,
'level': self.level,
'children_count': self.children_count,
'is_root': self.is_root,
'is_leaf': self.is_leaf,
})
return data
def print(self, simple=True, is_print_keys=False):
def info_as_string(_info):
return ' | '.join(s.ljust(25) for s in _info)
if is_print_keys:
info_keys = [k for k in self.as_dict(simple=simple).keys()]
info_keys_string = info_as_string(info_keys)
print('-' * len(info_keys_string))
print(info_keys_string)
print('-' * len(info_keys_string))
info_values = [str(v) for v in self.as_dict(simple=simple).values()]
info_values_as_string = info_as_string(info_values)
print(info_values_as_string)
print('-' * len(info_values_as_string))
class Tree(object):
def __init__(self):
self.root = None
# { key -> TreeNode }
self.nodes: dict[TreeNode] = {}
@property
def size(self):
return len(self.nodes)
@property
def is_empty(self):
return self.size == 0
@property
def depth(self):
" 返回树的最大深度以及对应的节点key "
if self.is_empty:
return 0, 0
node = max(self.nodes.values(), key=lambda node: node.level)
node: TreeNode
print(f"max_depth_node_key: {node.key}")
return node.level
@property
def width(self):
" 返回树的最大宽度,以及对应的层级数 "
if self.is_empty:
return 0, 0
node = max(self.nodes.values(), key=lambda node: node.children_count)
node: TreeNode
print(f"max_width_level: {node.level + 1}")
return node.children_count
def add_node(self, node: TreeNode):
if node.is_root:
self.root = node
self.nodes[node.key] = node
return
parent = self.get_node(node.parent_key)
if not parent:
error = f""" Cannot add node {node.key}: parent key {node.parent_key} not found.
Please ensure parent nodes are added before child nodes."""
raise ValueError(error)
parent.add_child(node)
self.nodes[node.key] = node
def get_node(self, key: str) -> TreeNode:
return self.nodes.get(key)
def remove_node(self, node: TreeNode):
if node.is_root:
self.root = None
else:
parent: TreeNode = node.parent
parent.remove_child(node)
self.nodes.pop(node.key, None)
def search_nodes(self, keyword, only_top_level=False):
if not keyword:
return []
keyword = keyword.strip().lower()
nodes = {}
for node in self.nodes.values():
node: TreeNode
if not node.match(keyword):
continue
nodes[node.key] = node
if not only_top_level:
return list(nodes.values())
# 如果匹配的节点中包含有父子关系的节点,只返回最上一级的父节点
# TODO: 优化性能
node_keys = list(nodes.keys())
children_keys = []
for node_key in node_keys:
_children_keys = [k for k in node_keys if k.startswith(f"{node_key}:")]
children_keys.extend(_children_keys)
for child_key in children_keys:
nodes.pop(child_key, None)
return list(nodes.values())
def remove_nodes_descendants(self, nodes: list[TreeNode]):
descendants = self.get_nodes_descendants(nodes)
for node in reversed(descendants):
self.remove_node(node)
def get_nodes_descendants(self, nodes: list[TreeNode]):
descendants = []
for node in nodes:
ds = node.get_descendants(node)
descendants.extend(ds)
return descendants
def get_nodes_ancestors(self, nodes: list[TreeNode]):
ancestors = set()
for node in nodes:
ancestor_keys = node.get_ancestor_keys()
_ancestors = self.get_nodes_by_keys(ancestor_keys)
ancestors.update(_ancestors)
return list(ancestors)
def get_nodes_by_keys(self, keys):
nodes = []
for key in keys:
node = self.get_node(key)
if node:
nodes.append(node)
return nodes
def get_nodes(self, levels=None):
nodes = list(self.nodes.values())
if levels:
nodes = [ n for n in nodes if n.level in levels ]
return nodes
def get_node_children(self, key, with_self=False):
node = self.get_node(key)
if not node:
return []
nodes = []
if with_self:
nodes.append(node)
nodes.extend(node.children)
return nodes
@timeit
def _compute_children_count_total(self):
for node in reversed(list(self.nodes.values())):
total = 0
for child in node.children:
child: TreeNode
total += child.children_count_total + 1
node: TreeNode
node.children_count_total = total
def print(self, count=10, simple=True):
print('tree_root_key: ', getattr(self.root, 'key', 'No-root'))
print('tree_size: ', self.size)
print('tree_depth: ', self.depth)
print('tree_width: ', self.width)
is_print_key = True
for n in list(self.nodes.values())[:count]:
n: TreeNode
n.print(simple=simple, is_print_keys=is_print_key)
is_print_key = False

39
apps/assets/xtree/tree.py Normal file
View File

@@ -0,0 +1,39 @@
__all__ = ['TreeNode', 'Tree']
class TreeNode:
def __init__(self, key, name, parent_key=None):
self.key = key
self.name = name
self.parent_key = parent_key
self.parent = None
self.children = []
def add_child(self, child: 'TreeNode') -> None:
child.parent = self
self.children.append(child)
class Tree:
def __init__(self):
self.root = None
self.nodes = {}
def init(self, nodes: list[TreeNode]) -> None:
for node in nodes:
self.nodes[node.key] = node
for node in nodes:
if node.parent_key is None:
self.root = node
continue
parent = self.nodes.get(node.parent_key)
if not parent:
raise ValueError(f'Parent with key {node.parent_key} not found for node {node.key}')
parent: TreeNode
parent.add_child(node)
def get_node(self, key) -> TreeNode | None:
return self.nodes.get(key)

View File

@@ -5,7 +5,7 @@ from django.conf import settings
from common.utils import get_logger, timeit
from assets.api.tree import AbstractAssetTreeAPI
from assets.tree.asset_tree import AssetTreeNodeAsset
from assets.tree.node_tree import TreeAsset
from perms.tree import UserPermAssetTree, UserPermAssetTreeNode
from perms.utils.utils import UserPermedAssetUtil
@@ -24,7 +24,7 @@ class UserPermedAssetTreeAPI(SelfOrPKUserMixin, AbstractAssetTreeAPI):
def get_tree_user(self):
return self.user
def get_org_asset_tree(self, **kwargs) -> UserPermAssetTree:
def get_asset_tree(self, **kwargs) -> UserPermAssetTree:
# 重写父类方法,返回用户授权的组织资产树
return self.get_user_org_asset_tree(**kwargs)
@@ -116,7 +116,7 @@ class UserPermedAssetTreeAPI(SelfOrPKUserMixin, AbstractAssetTreeAPI):
if not with_assets:
return node, []
assets = assets.values(*AssetTreeNodeAsset.model_values)
assets = assets.values(*TreeAsset.model_only_fields)
if search_asset:
assets = assets[:self.search_special_node_asset_limit_max]
@@ -158,7 +158,7 @@ class UserPermedAssetTreeAPI(SelfOrPKUserMixin, AbstractAssetTreeAPI):
if not with_assets:
return node, []
assets = assets.values(*AssetTreeNodeAsset.model_values)
assets = assets.values(*TreeAsset.model_only_fields)
if search_asset:
assets = assets[:self.search_special_node_asset_limit_max]

View File

@@ -6,7 +6,7 @@ from django.core.cache import cache
from common.utils import get_logger
from users.models import User
from assets.models import FavoriteAsset, Asset
from assets.tree.asset_tree import AssetTree, AssetTreeNode, AssetTreeNodeAsset
from assets.tree.node_tree import AssetNodeTree, NodeTreeNode, TreeAsset
from perms.utils.utils import UserPermedAssetUtil
@@ -16,7 +16,7 @@ __all__ = ['UserPermAssetTree']
logger = get_logger(__name__)
class UserPermAssetTreeNode(AssetTreeNode):
class UserPermAssetTreeNode(NodeTreeNode):
class Type(TextChoices):
DN = 'direct_node', 'Direct Node'
@@ -62,7 +62,7 @@ class UserPermAssetTreeNode(AssetTreeNode):
class UserPermAssetTree(AssetTree):
class UserPermAssetTree(AssetNodeTree):
TreeNode = UserPermAssetTreeNode