Compare commits

...

25 Commits

Author SHA1 Message Date
Bai
4b29928d9b refactor: build asset tree and perm tree, do not need add node_key into assets_asset_nodes table 2025-12-14 12:08:20 +08:00
Bai
21cf94493c refactor: build asset tree and perm tree, do not need add node_key into assets_asset_nodes table 2025-12-14 10:32:22 +08:00
Bai
39caee6a2b refactor: add build asset tree test script for x build add cache: 0.001ms 2025-12-14 03:18:15 +08:00
Bai
9c1a36c573 refactor: add build user perm tree test script for x build add cache 2025-12-14 02:46:58 +08:00
Bai
5ca5234274 refactor: add build user perm tree test script for x build 2025-12-14 02:14:19 +08:00
Bai
83b7ccf225 refactor: add build user perm tree test script and log: finished 2025-12-13 22:08:17 +08:00
Bai
8ee3f9935a refactor: add build user perm tree test script and log 2025-12-13 18:46:53 +08:00
Bai
6ec0bee77d refactor: add build user perm tree test script and log 2025-12-13 18:16:45 +08:00
Bai
afd1cd4542 refactor: add build user perm tree test script and log 2025-12-13 15:05:15 +08:00
Bai
7c39f9f43e refactor: add build user perm tree test script and log 2025-12-13 14:14:37 +08:00
Bai
ab9e10791b refactor: add build user perm tree test script 2025-12-13 13:21:22 +08:00
Bai
878974ffbd refactor: add build asset tree test script 2025-12-13 12:04:32 +08:00
Bai
3052c5055f refactor: add build asset tree test script 2025-12-13 12:02:20 +08:00
Bai
4be301c2dc refactor: add build asset tree test script 2025-12-13 11:57:14 +08:00
Bai
7f90027754 refactor: finished through migrations and add fake generate through data 2025-12-13 10:41:15 +08:00
Bai
db3cd0bcc7 refactor: finished through migrations and add fake generate through data 2025-12-13 10:40:27 +08:00
Bai
6995754fd9 refactor: migrate assets_asset_nodes table add node_key field 2025-12-12 21:32:41 +08:00
Bai
8bd116e955 refactor: use 1 sql query 1 node assets_amount(exactly) 2025-12-12 21:30:11 +08:00
Bai
41884d054d refactor: add query_3_result demo can select raw sql 2025-12-11 19:31:53 +08:00
Bai
0ef78fb545 refactor: add query_3_result demo 2025-12-11 19:18:41 +08:00
Bai
98218e814b refactor: support cache_tree by ttl 2025-12-10 18:05:45 +08:00
Bai
167267067f refactor: support method get_node_all_assets, get_node_children(with_assets) 2025-12-10 14:22:08 +08:00
Bai
8126d52a8b refactor: generate complete perm tree; refactor compute nodes assets amount algorithm; 2025-12-10 13:06:42 +08:00
Bai
8b53a21659 refactor: finished generate user perm tree (include comupte node assets amount) 2025-12-09 19:41:40 +08:00
Bai
3496a31e1f refactor: finished generate user perm tree (only nodes) 2025-12-09 18:18:18 +08:00
14 changed files with 2578 additions and 3 deletions

View File

@@ -0,0 +1,25 @@
# Generated by Django 4.1.13 on 2025-12-13 02:18
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('assets', '0019_alter_asset_connectivity'),
]
operations = [
migrations.CreateModel(
name='AssetNodeBackup',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('asset_id', models.CharField(max_length=1024, verbose_name='Asset ID')),
('node_id', models.CharField(max_length=1024, verbose_name='Node ID')),
],
options={
'verbose_name': 'Asset Node Backup',
'db_table': 'assets_asset_nodes_backup',
},
),
]

View File

@@ -0,0 +1,343 @@
# Generated by Django 4.1.13 on 2025-12-12 03:55
"""
【数据迁移流程】
本迁移将 Asset.nodes 从自动生成的 M2M through 表迁移到自定义的 AssetNode 模型,并添加 node_key 字段。
五阶段迁移流程:
【阶段1】读取 through 表并备份到 AssetNodeBackup 中
- 创建 AssetNodeBackup 表用于备份
- 读取原 Asset.nodes.through 表中所有 (asset_id, node_id) 数据
- 将数据保存到全局变量 through_old_data 中(内存缓存)
- 同时将数据备份到 AssetNodeBackup 表中(持久化,支持重试)
- 好处:即使迁移中断,再次执行时也可以从 backup 恢复
【阶段2】数据库表结构修改
- 删除 Asset.nodes 的 M2M 关系字段
- 创建新的 AssetNode 自定义 through 模型
- 重新添加 Asset.nodes M2M 字段,指向新的 AssetNode
- 创建优化后的索引和修改一个联合唯一索引:
* idx_node_key_asset_id: 支持按 node_key 范围查询资产
* idx_node_id_asset_id: 支持按 node_id 查询资产
* idx_asset_id_node_id_key: 支持按 asset_id 反向查询节点
* unique_together (asset, node): 保证每个资产和节点组合唯一
【阶段3】恢复数据并填充 node_key
- 优先使用内存缓存中的数据through_old_data
- 如果内存为空,说明之前可能迁移失败过,则从 AssetNodeBackup 表加载数据
- 预加载 Node.key 映射,为每条数据填充 node_key 字段
- 预加载已存在的 (asset_id, node_id) 对,避免重复插入
- 批量插入到 AssetNode 表中50k/batch
- 如果批插入失败,降级为单条插入
- 统计插入和跳过的记录数
【阶段4】清理 through 表中的重复数据
- 找出原 through 表中 node_key 为空的数据(这些是重复或无效的)
- 显示前 100 条要删除的数据的 (asset_id, node_id)
- 分批删除50k/batch
- 输出删除的总数
【阶段5】删除备份表 (或用户手动删除也可以)
- 验证 AssetNodeBackup 表中的数据(显示记录数)
- 清空备份表中的所有数据
- 删除 AssetNodeBackup 表
- 意义:迁移完成后,备份表已无用,清理数据库空间
【数据一致性保证】
- 备份表AssetNodeBackup 在阶段1中持久化所有原始数据支持恢复
- 去重阶段3 中使用 set 预检测避免重复
- 容错阶段3 批插入失败时自动降级到单条插入
- 清理阶段4 只删除 node_key 为空的无效数据
- 清理阶段5 删除已完成使命的备份表
【字段映射】
AssetNode.node_key 来自 Node.key
"""
import time
from datetime import datetime
from django.db import migrations, models, transaction
from django.db.models import Count, Q
import django.db.models.deletion
import assets.models.asset.common
# ============== 全局变量 ==============
through_old_data = []
migration_stats = {'backed_up': 0, 'restored': 0}
def log(msg):
print(f"[{datetime.now().strftime('%H:%M:%S')}] {msg}")
def load_data_from_backup(AssetNodeBackup):
"""从 backup 表加载数据到内存"""
global through_old_data
total = AssetNodeBackup.objects.count()
if total == 0:
log("⚠ backup 表为空,无数据可恢复")
return False
log(f"从 backup 表加载 {total:,} 条数据...")
batch_size = 50000
start = time.time()
for offset in range(0, total, batch_size):
batch = list(AssetNodeBackup.objects.all().values_list('asset_id', 'node_id')[offset:offset + batch_size])
through_old_data.extend(batch)
log(f" 已加载 {len(through_old_data):,}/{total:,}")
log(f"✓ 从 backup 加载完成! 耗时 {time.time()-start:.1f}s")
return True
def phase1_save_and_backup(apps, schema_editor):
"""阶段1: 读取 through 所有数据,加载到全局变量,并保存到 backup 中"""
global through_old_data, migration_stats
Asset = apps.get_model('assets', 'Asset')
AssetNodeBackup = apps.get_model('assets', 'AssetNodeBackup')
asset_node_through = Asset.nodes.through
total = asset_node_through.objects.count()
log(f"\n{'='*50}")
log("【阶段1】读取 through 数据并备份")
log(f"{'='*50}")
log(f"从 through 表读取 {total:,} 条数据...")
batch_size = 50000
start = time.time()
backup_batch = []
# 阶段1-1: 读取所有数据到内存
for offset in range(0, total, batch_size):
batch = list(asset_node_through.objects.all().values_list('asset_id', 'node_id')[offset:offset + batch_size])
through_old_data.extend(batch)
# 准备备份数据
backup_objs = [AssetNodeBackup(asset_id=aid, node_id=nid) for aid, nid in batch]
backup_batch.extend(backup_objs)
log(f" 已读取 {len(through_old_data):,}/{total:,} ({len(through_old_data)/total*100:.1f}%)")
# 阶段1-2: 写入备份到数据库并立即提交
log(f"\n写入 {len(backup_batch):,} 条备份数据到数据库...")
backup_start = time.time()
backup_batch_size = 50000
for i in range(0, len(backup_batch), backup_batch_size):
batch = backup_batch[i:i + backup_batch_size]
with transaction.atomic():
created = AssetNodeBackup.objects.bulk_create(batch, batch_size=backup_batch_size, ignore_conflicts=True)
migration_stats['backed_up'] += len(created)
log(f" 已备份 {min(i+backup_batch_size, len(backup_batch)):,}/{len(backup_batch):,}")
log(f"✓ 阶段1完成! 读取耗时 {time.time()-start:.1f}s, 备份耗时 {time.time()-backup_start:.1f}s")
log(f" 内存缓存: {len(through_old_data):,}")
log(f" 数据库备份: {migration_stats['backed_up']:,}\n")
def phase3_restore_data_and_set_node_key(apps, schema_editor):
"""阶段3: 恢复数据时先查看全局变量是否有如果没有从backup中加载"""
global through_old_data, migration_stats
Node = apps.get_model('assets', 'Node')
AssetNode = apps.get_model('assets', 'AssetNode')
AssetNodeBackup = apps.get_model('assets', 'AssetNodeBackup')
log(f"\n{'='*50}")
log("【阶段3】恢复数据并设置 node_key")
log(f"{'='*50}")
# 检查内存是否有数据,如果没有则从 backup 加载
if not through_old_data:
log("内存缓存为空,从 backup 表加载数据...")
if not load_data_from_backup(AssetNodeBackup):
log("✗ 无法恢复数据backup 表也为空")
return
log()
else:
log(f"使用内存缓存的 {len(through_old_data):,} 条数据\n")
total = len(through_old_data)
log(f"开始恢复 {total:,} 条数据到 AssetNode 表...")
# 预加载 node_key 映射
id_key_map = {str(item['id']): item['key'] for item in Node.objects.values('id', 'key')}
# 预加载已存在的数据,避免重复
existing = set(AssetNode.objects.values_list('asset_id', 'node_id'))
log(f"数据库中已存在 {len(existing):,} 条记录\n")
batch_size = 50000
start = time.time()
skipped = 0
for i in range(0, total, batch_size):
batch = through_old_data[i:i + batch_size]
# 去重:只保留不存在的记录
objs = []
for aid, nid in batch:
if (aid, nid) not in existing:
objs.append(AssetNode(asset_id=aid, node_id=nid, node_key=id_key_map.get(str(nid), '')))
existing.add((aid, nid))
else:
skipped += 1
# 批量插入
if objs:
try:
AssetNode.objects.bulk_create(objs, batch_size=batch_size, ignore_conflicts=True)
migration_stats['restored'] += len(objs)
except Exception as e:
log(f" ✗ 批插入失败: {str(e)}")
# 降级:逐条插入
for obj in objs:
try:
obj.save()
migration_stats['restored'] += 1
except Exception as save_err:
log(f" ✗ 跳过 asset_id={obj.asset_id}, node_id={obj.node_id}: {str(save_err)}")
skipped += 1
progress = min(i + batch_size, total)
log(f" 已恢复 {progress:,}/{total:,} (插入{migration_stats['restored']:,} 跳过{skipped:,})")
log(f"✓ 阶段3完成! 耗时 {time.time()-start:.1f}s")
log(f" 插入: {migration_stats['restored']:,}")
log(f" 跳过: {skipped:,}\n")
def phase4_cleanup_duplicates(apps, schema_editor):
"""阶段4: 删除 through 表中 node_key 为空的数据"""
Asset = apps.get_model('assets', 'Asset')
asset_node_through = Asset.nodes.through
log(f"\n{'='*50}")
log("【阶段4】清理 through 表中 node_key 为空的数据")
log(f"{'='*50}")
# 找出 node_key 为空的记录
empty_node_key = asset_node_through.objects.filter(Q(node_key='') | Q(node_key__isnull=True))
total = empty_node_key.count()
if total == 0:
log("✓ 没有 node_key 为空的数据,无需清理\n")
return
log(f"发现 {total:,} 条 node_key 为空的数据")
start = time.time()
batch_size = 50000
deleted = 0
# 获取要删除的数据信息
to_delete_records = list(
empty_node_key.values_list('asset_id', 'node_id', 'id')
)
log("删除详情:")
for aid, nid, record_id in to_delete_records[:100]: # 显示前100条
log(f" asset_id={aid}, node_id={nid}")
if len(to_delete_records) > 100:
log(f" ... 还有 {len(to_delete_records)-100:,}")
# 分批删除
for offset in range(0, len(to_delete_records), batch_size):
batch_ids = [record_id for _, _, record_id in to_delete_records[offset:offset + batch_size]]
if batch_ids:
delete_count, _ = asset_node_through.objects.filter(id__in=batch_ids).delete()
deleted += delete_count
log(f" 已删除 {deleted:,}/{total:,}")
log(f"✓ 阶段4完成! 耗时 {time.time()-start:.1f}s")
log(f" 删除: {deleted:,} 条 node_key 为空的数据\n")
def phase5_cleanup_backup_table(apps, schema_editor):
"""阶段5: 删除备份表"""
log(f"\n{'='*50}")
log("【阶段5】删除 AssetNodeBackup 备份表")
log(f"{'='*50}")
AssetNodeBackup = apps.get_model('assets', 'AssetNodeBackup')
total = AssetNodeBackup.objects.count()
log(f"备份表中有 {total:,} 条数据")
start = time.time()
# 删除所有备份数据
delete_count, _ = AssetNodeBackup.objects.all().delete()
log(f"✓ 删除 {delete_count:,} 条备份数据")
log(f"✓ 阶段5完成! 耗时 {time.time()-start:.1f}s\n")
class Migration(migrations.Migration):
dependencies = [
('assets', '0020_assetnodebackup'),
]
operations = [
# 阶段1保存和备份
migrations.RunPython(phase1_save_and_backup),
# 阶段2数据库表结构操作
migrations.RemoveField(
model_name='asset',
name='nodes',
),
migrations.CreateModel(
name='AssetNode',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('node_key', models.CharField(db_index=True, default='', max_length=64, verbose_name='Node key')),
('asset', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='assets.asset', verbose_name='Asset')),
('node', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='assets.node', verbose_name='Node')),
],
options={
'verbose_name': 'Asset Node',
'db_table': 'assets_asset_nodes',
},
),
migrations.AddField(
model_name='asset',
name='nodes',
field=models.ManyToManyField(default=assets.models.asset.common.default_node, related_name='assets', through='assets.AssetNode', to='assets.node', verbose_name='Nodes'),
),
migrations.AlterUniqueTogether(
name='assetnode',
unique_together={('asset', 'node_key'), ('asset', 'node')},
),
migrations.AddIndex(
model_name='assetnode',
index=models.Index(fields=['node_key', 'asset_id'], name='idx_node_key_asset_id'),
),
migrations.AddIndex(
model_name='assetnode',
index=models.Index(fields=['node_id', 'asset_id'], name='idx_node_id_asset_id'),
),
migrations.AddIndex(
model_name='assetnode',
index=models.Index(fields=['asset_id', 'node_id', 'node_key'], name='idx_asset_id_node_id_key'),
),
# 阶段3恢复数据
migrations.RunPython(phase3_restore_data_and_set_node_key),
# 阶段4清理重复数据
migrations.RunPython(phase4_cleanup_duplicates),
# 阶段5删除备份表 (或用户手动删除也可以)
# migrations.RunPython(phase5_cleanup_backup_table),
# migrations.DeleteModel(
# name='AssetNodeBackup',
# ),
]

View File

@@ -9,3 +9,4 @@ from .node import *
from .favorite_asset import *
from .automations import *
from .my_asset import *
from .asset_node import *

View File

@@ -173,7 +173,12 @@ class Asset(NodesRelationMixin, LabeledMixin, AbsConnectivity, JSONFilterMixin,
verbose_name=_("Zone"), on_delete=models.SET_NULL
)
nodes = models.ManyToManyField(
'assets.Node', default=default_node, related_name='assets', verbose_name=_("Nodes")
'assets.Node',
default=default_node,
related_name='assets',
verbose_name=_("Nodes"),
through='assets.AssetNode', # 使用自定义 through 表
through_fields=('asset', 'node')
)
directory_services = models.ManyToManyField(
'assets.DirectoryService', related_name='assets',

View File

@@ -0,0 +1,67 @@
from django.db import models
from django.db import models
from django.utils.translation import gettext_lazy as _
__all__ = ['AssetNode', 'AssetNodeBackup']
class AssetNode(models.Model):
asset = models.ForeignKey(
'assets.Asset',
on_delete=models.CASCADE,
verbose_name=_('Asset'),
db_index=True
)
node = models.ForeignKey(
'assets.Node',
on_delete=models.CASCADE,
verbose_name=_('Node'),
db_index=True
)
node_key = models.CharField(
max_length=64,
verbose_name=_('Node key'),
db_index=True,
default=''
)
class Meta:
db_table = 'assets_asset_nodes'
verbose_name = _('Asset Node')
# unique_together: 定义哪些字段组合必须唯一
unique_together = [
('asset', 'node'), # asset_id + node_id 必须唯一
('asset', 'node_key'), # 必须加上 node_key 保证唯一
]
# indexes: 定义查询优化的索引
indexes = [
# 索引 1: 按 node_key 查询所有资产(子孙节点)
# 查询: WHERE node_key LIKE '1.12:%'
models.Index(fields=['node_key', 'asset_id'], name='idx_node_key_asset_id'),
# 索引 2: 按 node_id 查询所有直接资产
# 查询: WHERE node_id = ?
models.Index(fields=['node_id', 'asset_id'], name='idx_node_id_asset_id'),
# 索引 3: 按 asset_id 查询 node_key 和 node_id
# 查询: WHERE asset_id = ? 获取 node_key 或 node_id
models.Index(fields=['asset_id', 'node_id', 'node_key'], name='idx_asset_id_node_id_key'),
]
class AssetNodeBackup(models.Model):
asset_id = models.CharField(
max_length=1024,
verbose_name=_('Asset ID'),
)
node_id = models.CharField(
max_length=1024,
verbose_name=_('Node ID'),
)
class Meta:
db_table = 'assets_asset_nodes_backup'
verbose_name = _('Asset Node Backup')

View File

@@ -65,7 +65,7 @@ class FamilyMixin:
def get_nodes_children_key_pattern(cls, nodes, with_self=True):
keys = [i.key for i in nodes]
keys = cls.clean_children_keys(keys)
patterns = [cls.get_node_all_children_key_pattern(key) for key in keys]
patterns = [cls.get_node_all_children_key_pattern(key, with_self=with_self) for key in keys]
patterns = '|'.join(patterns)
return patterns

View File

@@ -1,3 +1,4 @@
from .asset import *
from .node_assets_amount import *
from .node_assets_mapping import *
from .node_asset import *

View File

@@ -0,0 +1,24 @@
from django.db.models.signals import pre_save, m2m_changed, post_save
from django.dispatch import receiver
from assets.models import AssetNode, Asset, Node
@receiver(m2m_changed, sender=Asset.nodes.through)
def fill_node_key_on_m2m_change(sender, instance, action, pk_set, **kwargs):
if action == 'post_add':
if isinstance(instance, Asset):
asset_ids = [str(instance.id)]
node_ids = [str(pk) for pk in pk_set]
elif isinstance(instance, Node):
asset_ids = [str(pk) for pk in pk_set]
node_ids = [str(instance.id)]
else:
return
id_key_pairs = Node.objects.filter(id__in=node_ids).values_list('id', 'key')
id_key_map = {str(id_): key for id_, key in id_key_pairs}
rs = AssetNode.objects.filter(asset_id__in=asset_ids, node_id__in=node_ids)
for r in rs:
r.node_key = id_key_map.get(str(r.node_id), '')
AssetNode.objects.bulk_update(rs, ['node_key'], batch_size=5000)

View File

@@ -0,0 +1,477 @@
import re, time
from typing import Optional
from collections import defaultdict
from django.db.models import F
from users.models import User
from assets.models import Asset, Node
from perms.models import AssetPermission
from common.utils import lazyproperty, timeit
from orgs.utils import current_org
from django.core.cache import cache
from functools import wraps
def cache_tree(ttl=30):
"""权限树缓存装饰器"""
def decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
# 构建缓存 key
cache_key = f"perm_tree:{self._org_id}:{self._user_id}"
# 尝试从缓存获取
result = cache.get(cache_key)
if result is not None:
return result
# 执行原方法
result = func(self, *args, **kwargs)
# 缓存结果
cache.set(cache_key, result, ttl)
return result
return wrapper
return decorator
class TreeNode:
separator = ':'
class Type:
BRIDGE = 'bridge'
OWNER = 'owner'
DA = 'da'
def __init__(self, key, tp, assets=None):
self.key = key
self.type = tp
# 节点下的直接资产集合,不包含子孙节点的资产
self._assets = set() if assets is None else set(assets)
self._assets_amount = 0
def add_assets(self, asset_ids):
self._assets.update(asset_ids)
@property
def assets(self):
return self._assets
@property
def assets_amount(self):
return self._assets_amount
def assets_amount_increment(self, amount=1):
self._assets_amount += amount
def is_children(self, other: 'TreeNode'):
pattern = r'^{0}:[0-9]+$'.format(other.key)
return bool(re.match(pattern, self.key))
def is_all_children(self, other: 'TreeNode'):
pattern = r'^{0}:'.format(other.key)
return bool(re.match(pattern, self.key))
def can_be_overridden(self, other: 'TreeNode'):
"""
# 不可以
owner owner
owner bridge
owner da
# 可以
da da
# 可以
da owner
# 不可以
da bridge
# 可以
bridge da
bridge owner
bridge bridge
"""
if self.key != other.key:
return False
if self.type == self.Type.OWNER:
return False
if self.type == self.Type.DA and other.type == self.Type.BRIDGE:
return False
return True
class Tree:
def __init__(self, nodes: Optional[list[TreeNode]] = None, org_id=None):
# {node_key: TreeNode}
self._nodes = defaultdict(TreeNode)
self._org_id = org_id
self.init(nodes)
def init(self, nodes: Optional[list[TreeNode]]):
if nodes is None:
return
for node in nodes:
self.add_node(node)
self._reverse_generated()
def _reverse_generated(self):
""" 逆向生成树 """
for key in list(self._nodes.keys()):
ancestor_keys = Node.get_node_ancestor_keys(key)
for ancestor_key in ancestor_keys:
# 自动生成的祖先节点默认标记为 bridge, 添加时会判断是否要覆盖已经存在的节点
tree_node = TreeNode(key=ancestor_key, tp=TreeNode.Type.BRIDGE)
self.add_node(tree_node)
def merge(self, other: 'Tree') -> 'Tree':
merged_tree = Tree()
for node in self._nodes.values():
merged_tree.add_node(node)
for node in other._nodes.values():
merged_tree.add_node(node)
merged_tree._finalize()
return merged_tree
def _finalize(self):
self._prune()
self._init_owner_nodes_children()
self._compute_assets_amount()
self._sorted()
def _sorted(self):
self._nodes = defaultdict(
TreeNode,
sorted(self._nodes.items(), key=lambda item: [int(i) for i in item[0].split(':')])
)
@timeit
def _init_owner_nodes_children(self):
""" 初始化 Owner-Node 的所有子孙节点以及其下的直接资产 """
t1 = time.time()
owner_nodes = self._owner_nodes
if not owner_nodes:
return
nodes = Node.get_nodes_all_children(owner_nodes, with_self=True)
node_id_key_sets = nodes.annotate(char_id=F('id')).values_list('char_id', 'key')
node_id_key_mapper = dict(node_id_key_sets)
node_ids = node_id_key_mapper.keys()
t2 = time.time()
nid_aid_sets = Node.assets.through.objects.filter(node_id__in=node_ids).annotate(
char_nid=F('node_id'), char_aid=F('asset_id')).values_list('char_nid', 'char_aid')
nid_aid_sets = list(nid_aid_sets)
t3 = time.time()
print('Fetch node-assets sets time by node_id__in: {:.1f}ms'.format((t3 - t2) * 1000))
t4 = time.time()
print('Fetch all node-assets sets time: {:.1f}ms'.format((t4 - t3) * 1000))
for nid, aid in nid_aid_sets:
key = node_id_key_mapper.get(nid)
if not key:
continue
tree_node = self._nodes.get(key)
if tree_node:
tree_node.add_assets({aid})
else:
tree_node = self.wrap_as_tree_node(node_key=key, tp=TreeNode.Type.OWNER, assets={aid})
self.add_node(tree_node)
def _compute_assets_amount(self):
"""
生成数据结构:
{
"asset_id": set("node_key1", "node_key2" ...), # 资产所在的直接节点
}
迭代,对每个资产所在的节点的所有祖先节点取并集+去重, +1
"""
aid_node_keys_mapper = defaultdict(set)
for node in self._nodes.values():
for aid in node.assets:
aid_node_keys_mapper[aid].add(node.key)
for aid, node_keys in aid_node_keys_mapper.items():
ancestor_keys = set(self.get_ancestor_keys(node_keys)) # 必须去重
for ancestor_key in ancestor_keys:
tree_node = self._nodes.get(ancestor_key)
if not tree_node:
continue
tree_node.assets_amount_increment()
def get_ancestor_keys(self, keys, with_self=True):
ancestor_keys = set()
for k in keys:
_ancestor_keys = Node.get_node_ancestor_keys(k, with_self=with_self)
ancestor_keys.update(_ancestor_keys)
return ancestor_keys
def _prune(self):
self._prune_owner_nodes_branch()
def _prune_owner_nodes_branch(self):
# 修剪所有 owner nodes 节点的分枝(保留每条 owner 节点分枝的最上一层,删除其所有子孙节点)
owner_node_keys = [n.key for n in self._owner_nodes]
for node in list(self._nodes.values()):
ancestor_keys = Node.get_node_ancestor_keys(node.key)
if set(ancestor_keys) & set(owner_node_keys):
self.remove_node(node)
@property
def _owner_nodes(self):
return [node for node in self._nodes.values() if node.type == TreeNode.Type.OWNER]
def get_node_all_assets(self, node: TreeNode):
""" 获取节点下的所有资产 """
nodes = self.get_node_all_children(node)
assets = set()
for n in nodes:
assets.update(n.assets)
return assets
def get_node_all_children(self, node: TreeNode):
""" 获取节点的所有子孙节点 """
children = [
n for n in self._nodes.values() if n.is_all_children(node)
]
return children
def get_node_children(self, node: TreeNode):
""" 获取节点的直接子节点 """
children = [
n for n in self._nodes.values() if n.is_children(node)
]
return children
def get_node(self, key) -> Optional[TreeNode]:
return self._nodes.get(key)
def add_node(self, node: TreeNode):
_node = self._nodes.get(node.key)
if _node is None:
self._nodes[node.key] = node
return
if _node.can_be_overridden(node):
self._nodes[node.key] = node
return
def remove_node(self, node_or_key: 'TreeNode | str'):
if isinstance(node_or_key, TreeNode):
key = node_or_key.key
else:
key = node_or_key
self._nodes.pop(key, None)
@classmethod
def wrap_as_tree_node(cls, node_key, tp, assets=None):
return TreeNode(key=node_key, tp=tp, assets=assets)
@classmethod
def wrap_as_tree_nodes(cls, node_keys, tp):
return [cls.wrap_as_tree_node(nk, tp) for nk in node_keys]
def print_nodes(self):
print('--- Tree Nodes ---')
for n in self._nodes.values():
print(f'{n.key}({n.assets_amount}) - {n.type}')
class UserPermTreeEngine(object):
"""
DA: Directly Permed Asset
DN: Directly Permed Node
DA-Tree: 通过直接授权的资产生成的树
DN-Tree: 通过直接授权的节点生成的树
Perm-Tree: 最终的权限树,由 DA-Tree 和 DN-Tree 合并生成bridge 和 da 节点全部保留owner 节点只保留第一级
Tree-Node-Type:
bridge: 所有权桥梁节点,没有直接授权节点,也没有授权它下的资产
owner: 所有权节点,直接授权的节点
da: DA 节点,仅授权它下的资产
"""
def __init__(self, user, org_id=None):
self.user = user
self._user_id = str(user.id)
self._org_id = org_id or current_org.id
self._tree = self.tree()
@cache_tree(ttl=5)
def tree(self):
da_tree = self._generate_da_tree()
dn_tree = self._generate_dn_tree()
tree = self._merge_trees(da_tree, dn_tree)
return tree
def _generate_da_tree(self):
node_assets_mapper = self._get_da_node_key_asset_ids_mapper()
tree_nodes = [
TreeNode(key=key, tp=TreeNode.Type.DA, assets=asset_ids)
for key, asset_ids in node_assets_mapper.items()
]
tree = Tree(nodes=tree_nodes)
return tree
def _get_da_node_key_asset_ids_mapper(self):
direct_asset_ids = AssetPermission.assets.through.objects \
.filter(assetpermission_id__in=self._perm_ids) \
.annotate(char_id=F('asset_id')).values_list('char_id', flat=True)
nid_aid_set = Asset.nodes.through.objects.filter(asset_id__in=direct_asset_ids) \
.annotate(char_nid=F('node_id'), char_aid=F('asset_id')).values_list('char_nid', 'char_aid')
nid_aid_mapper = dict(nid_aid_set)
node_ids = list(nid_aid_mapper.keys())
node_id_key_set = Node.objects.filter(id__in=node_ids).annotate(char_id=F('id')).values_list('id', 'key')
node_id_key_mapper = dict(node_id_key_set)
mapper = defaultdict(set)
for nid, aid in nid_aid_set:
key = node_id_key_mapper.get(nid)
if key:
mapper[key].add(aid)
return mapper
def _generate_dn_tree(self):
node_keys = self._get_dn_node_keys()
nodes = Tree.wrap_as_tree_nodes(node_keys, TreeNode.Type.OWNER)
tree = Tree(nodes=nodes)
return tree
def _get_dn_node_keys(self):
node_ids = AssetPermission.nodes.through.objects.filter(assetpermission_id__in=self._perm_ids) \
.annotate(char_id=F('node_id')).values_list('char_id', flat=True)
node_keys = Node.objects.filter(id__in=node_ids).values_list('key', flat=True)
return list(set(node_keys))
def _merge_trees(self, da_tree: Tree, dn_tree: Tree) -> Tree:
tree = da_tree.merge(dn_tree)
return tree
@lazyproperty
def _perm_ids(self):
return self._get_permission_ids()
def _get_permission_ids(self):
user_perm_ids = AssetPermission.users.through.objects.filter(user_id=self._user_id).annotate(
char_id=F('assetpermission_id')).values_list('char_id', flat=True)
group_ids = User.groups.through.objects.filter(user_id=self._user_id).annotate(
char_id=F('usergroup_id')).values_list('char_id', flat=True)
group_perm_ids = AssetPermission.user_groups.through.objects.filter(usergroup_id__in=group_ids).annotate(
char_id=F('assetpermission_id')).values_list('char_id', flat=True)
perm_ids = set(user_perm_ids).union(set(group_perm_ids))
return perm_ids
def get_node_children(self, node_key, with_assets=False):
"Luna 页面会调用此方法"
tree_node = self.tree.get_node(node_key)
if not tree_node:
return None
children = self.tree.get_node_children(tree_node)
data = {"children": children}
if with_assets:
data.update({"assets": tree_node.assets})
return data
def get_node_all_assets(self, node_key):
" 用户详情页面会调用此方法 "
node = self.tree.get_node(node_key)
if not node:
return None
assets = self.tree.get_node_all_assets(node)
return {
"assets": assets
}
from common.utils import timeit
@timeit
def query_3_result(guessed_asset_amount=50000, view_old=False, sql2_raw=False, sql1_raw=False, sql3_raw=False):
from django.db import connection, connections
from django.conf import settings
from django.db.models import OuterRef, Subquery, Count
import time
settings.DEBUG = True
connections.close_all()
with connection.cursor() as cursor:
# 1 查询节点资产数量
# cursor.execute(sql1)
# node_key_asset_amount_tuple = cursor.fetchall()
t1 = time.time()
if sql1_raw:
# 不能用 raw 因为node下没有资产在表中不存在node 获取不完整
sql1 = """
SELECT node_id, COUNT(*) AS assets_count
FROM assets_asset_nodes
GROUP BY node_id
"""
cursor.execute(sql1)
node_id_asset_amount_rows = cursor.fetchall()
else:
count_sub = Node.assets.through.objects.filter(
node_id=OuterRef("id")
).values("node_id").annotate(c=Count("id")).values("c")
node_id_asset_amount_rows = Node.objects.annotate(
assets_count=Subquery(count_sub)
).values("id", "assets_count")
node_id_asset_amount_rows = list(node_id_asset_amount_rows)
t2 = time.time()
# 2 查询属于多个节点的资产ID
if sql2_raw:
sql2 = """
SELECT asset_id FROM assets_asset_nodes
GROUP BY asset_id HAVING COUNT(*) > 1
"""
cursor.execute(sql2)
rows = cursor.fetchall()
asset_ids = [row[0] for row in rows]
else:
count_sub = Asset.nodes.through.objects.filter(
asset_id=OuterRef("id")
).values("asset_id").annotate(c=Count("id")).values("c")
asset_id_node_amount_row = Asset.objects.annotate(
nodes_count=Subquery(count_sub)
).values_list('id', 'nodes_count')
print(asset_id_node_amount_row[0])
asset_ids = [str(row[0]) for row in asset_id_node_amount_row if row[1] and row[1] > 1]
print('Assets belong to multiple nodes:', len(asset_ids))
t3 = time.time()
# 3 查询资产ID和节点ID的对应关系 (只查 2 的资产)
# 假设 asset_ids 不多,只查前 guessed_asset_amount 个
guessed_asset_ids = asset_ids[:guessed_asset_amount]
if sql3_raw:
print('Guessed asset ids count:', len(guessed_asset_ids))
sql3 = """
SELECT asset_id, node_id FROM assets_asset_nodes
WHERE asset_id IN ({})
""".format(','.join(['%s'] * len(guessed_asset_ids)))
cursor.execute(sql3, guessed_asset_ids)
aid_nid_set = cursor.fetchall()
else:
aid_nid_set = Node.assets.through.objects.filter(asset_id__in=guessed_asset_ids).values_list('asset_id', 'node_id')
aid_nid_set = list(aid_nid_set)
t4 = time.time()
# aid_nid_set 获取每个 aid 的 parent_ids, 获取两两 parent 的祖先节点的交集 -1
print('Query times: sql1 {:.2f}s, sql2 {:.2f}s, sql3 {:.2f}s'.format(t2 - t1, t3 - t2, t4 - t3),
len(node_id_asset_amount_rows), len(asset_ids), len(aid_nid_set))
print('New ORM query time: {:.2f}s, total rows: {}'.format(t4 - t1, len(aid_nid_set)))
# old
if view_old:
t1 = time.time()
old = list(Node.assets.through.objects.all())
t2 = time.time()
print('Old ORM query time: {:.2f}s, total rows: {}'.format(t2 - t1, len(old)))
return node_id_asset_amount_rows, asset_ids, aid_nid_set

View File

@@ -0,0 +1,926 @@
#!/usr/bin/env python
"""
测试查询资产树和授权树下指定节点的资产总数
"""
import os, sys, django, json, time
from datetime import datetime
import copy
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'jumpserver.settings')
sys.path.insert(0, os.path.join(os.getcwd(), 'apps'))
django.setup()
from functools import reduce
from operator import or_
from operator import or_
from django.db.models import Q
from collections import defaultdict
from orgs.models import Organization
from users.models import User, UserGroup
from rbac.models import OrgRoleBinding
from assets.models import Node, Asset
from perms.models import AssetPermission
from django.core.cache import cache
AssetNodeThrough = Asset.nodes.through
BUILD_ASSET_TREE = True
# BUILD_ASSET_TREE = False
BUILD_USER_PERM_TREE = True
# BUILD_USER_PERM_TREE = False
TEST_USER_AMOUNT = 1
# 输出文件路径
OUTPUT_DIR = os.path.join(os.getcwd(), 'mapper_output')
if not os.path.exists(OUTPUT_DIR):
os.makedirs(OUTPUT_DIR)
#
# =========================== 辅助函数 ===============================
#
def log(msg=''):
"""打印带时间戳的日志"""
print(f"[{datetime.now().strftime('%H:%M:%S')}] {msg}")
def remove_asset_belong_many_node_record_from_asset_node_through():
"""删除 Asset.nodes.through 表中重复的记录:如果一个资产属于多个不同的 node_key只保留一个"""
log("开始清理 AssetNodeThrough 表中一个资产属于多个节点的重复记录...")
from django.db.models import Count
# 获取所有属于多个 node_key 的 asset_id
duplicates = AssetNodeThrough.objects.values('asset_id').annotate(
node_key_count=Count('node_key', distinct=True)
).filter(node_key_count__gt=1)
total_deleted = 0
for dup in duplicates:
asset_id = dup['asset_id']
# 获取该资产关联的所有 node_key
records = AssetNodeThrough.objects.filter(asset_id=asset_id).order_by('id')
# 保留第一条,删除其他的
records_to_delete = records[1:]
delete_count = records_to_delete.count()
for record in records_to_delete:
record.delete()
total_deleted += delete_count
if delete_count > 0:
keep_record = records[0]
log(f" 资产 {asset_id} 从多个节点删除到只属于一个: {keep_record.node_key}, 删除数量: {delete_count}")
log(f"✓ 清理完成,共删除 {total_deleted} 条重复记录")
return total_deleted
def get_node_level(node_key):
"""获取节点层级(根据冒号分割)"""
if not node_key:
return 0
return len(node_key.split(':'))
def write_asset_mapper_to_file(org_id, mapper, node_times=None, asset_count=0, through_count=0, total_time=0):
"""将 mapper 按树形结构写入文件,子节点放在父节点下,显示节点时间和下一级时间之和"""
output_file = os.path.join(OUTPUT_DIR, f'mapper_org_{org_id}.txt')
if node_times is None:
node_times = {}
# 计算树的深度和节点数最大宽度
max_count = 0
level_mapper = defaultdict(list)
for key, count in mapper.items():
level = get_node_level(key)
level_mapper[level].append((key, count))
max_count = max(max_count, count)
tree_depth = max(level_mapper.keys()) if level_mapper else 0
max_width = len(f"{max_count:,}")
# 构建树形结构(字典树)
tree = {}
for key, count in mapper.items():
parts = key.split(':')
current = tree
for i, part in enumerate(parts):
node_key = ':'.join(parts[:i+1])
if node_key not in current:
current[node_key] = {'count': count if i == len(parts) - 1 else None, 'children': {}}
current = current[node_key]['children']
# 递归写入树形结构
def write_tree(node_dict, f, prefix='', node_times_dict=None):
if node_times_dict is None:
node_times_dict = {}
for i, (key, node_data) in enumerate(sorted(node_dict.items())):
is_last = (i == len(node_dict) - 1)
# 获取当前节点的计算时间
current_time = node_times_dict.get(key, 0)
# 计算下一级时间之和(直接子节点的时间)
next_level_time = 0
if node_data['children']:
for child_key in node_data['children'].keys():
next_level_time += node_times_dict.get(child_key, 0)
# 树形符号
current_prefix = prefix + ('└── ' if is_last else '├── ')
# 获取直接下级节点数
child_count = len(node_data['children']) if node_data['children'] else 0
# 写入节点信息:节点 (资产数) [时间信息]
if node_data['count'] is not None:
time_info = f"当前: {current_time:.3f}s, 下级: {next_level_time:.3f}s, 下级节点: {child_count}"
f.write(f"{current_prefix}{key} ({node_data['count']:,}) [{time_info}]\n")
# 空行但保持竖线连续
blank_prefix = prefix + (' ' if is_last else '')
f.write(f"{blank_prefix}\n")
# 递归写入子节点
if node_data['children']:
next_prefix = prefix + (' ' if is_last else '')
write_tree(node_data['children'], f, next_prefix, node_times_dict)
# 写入文件
with open(output_file, 'w') as f:
f.write(f"{'='*80}\n")
f.write(f"组织 ID: {org_id}\n")
f.write(f"生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
f.write(f"总耗时: {total_time:.2f}s\n")
f.write(f"总节点数: {len(mapper):,}\n")
f.write(f"总资产数: {asset_count:,}\n")
f.write(f"总关系表数: {through_count:,}\n")
f.write(f"树深度: {tree_depth}\n")
f.write(f"节点数最大宽度: {max_width}\n")
f.write(f"{'='*80}\n\n")
# 写入树形结构
write_tree(tree, f, node_times_dict=node_times)
log(f"✓ mapper 已写入文件:")
log(f" {output_file}")
def write_user_perm_mapper_to_file(user_name, user_id, org_id, mapper, node_times, tree_nodes=None):
"""将 mapper 按树形结构写入文件,显示竖线和横线,每个节点后显示计算时间和节点信息"""
if tree_nodes is None:
tree_nodes = {}
output_file = os.path.join(OUTPUT_DIR, f'mapper_tree_{user_name}_{user_id}_{org_id}.txt')
# 构建树形结构 - 关键:每个节点只在其直接子节点中,不是所有节点都在顶层
tree = {}
for key, count in mapper.items():
parts = key.split(':')
current = tree
# 从根节点开始逐级构建
for i, part in enumerate(parts):
node_key = ':'.join(parts[:i+1])
# 如果这个节点还不存在,创建它
if node_key not in current:
# 只有最后一层才有真实的 count中间层是 None
is_leaf = (i == len(parts) - 1)
current[node_key] = {
'count': count if is_leaf else None,
'children': {}
}
# 移到这个节点的 children 字典,为下一层节点做准备
current = current[node_key]['children']
# 计算子节点的时间信息
def get_children_time_info(node_dict):
"""返回 (直接子节点时间和, 直接子节点数, 孙子及以后节点时间和)"""
direct_children_time = 0
direct_children_count = 0
deeper_descendants_time = 0
for child_key, child_data in node_dict.items():
# 直接子节点时间
direct_children_time += node_times.get(child_key, 0)
direct_children_count += 1
# 递归计算孙子及以后的节点时间
if child_data['children']:
deeper_direct_time, deeper_direct_count, deeper_deeper_time = get_children_time_info(child_data['children'])
deeper_descendants_time += deeper_direct_time + deeper_deeper_time
return direct_children_time, direct_children_count, deeper_descendants_time
# 递归输出树结构
def write_tree_recursive(node_dict, f, prefix=''):
items = sorted(node_dict.items())
# 在每一层级的开头加一个空行
if items:
extension = ''
f.write(f"{prefix}{extension}\n")
for i, (key, node_data) in enumerate(items):
is_last = (i == len(items) - 1)
# 从 mapper 中获取真实的 count 值
count_value = mapper.get(key, None)
if count_value is not None:
count_str = f"({count_value})"
else:
count_str = "(None)"
# 获取该节点的计算时间
node_time = node_times.get(key, 0)
# 计算直接子节点和更深层节点的时间
direct_children_time, direct_children_count, deeper_descendants_time = get_children_time_info(node_data['children'])
# 格式化时间信息
time_str = f" [当前: {node_time:.3f}s"
if direct_children_count > 0:
time_str += f", 直接子节点: {direct_children_time:.3f}s ({direct_children_count}个)"
time_str += "]"
# 获取节点的 da_asset_ids 和 dn_node_keys 信息
tree_node = tree_nodes.get(key)
node_info_str = ""
if tree_node:
da_count = len(tree_node.da_asset_ids) if hasattr(tree_node, 'da_asset_ids') else 0
dn_count = len(tree_node.dn_node_keys) if hasattr(tree_node, 'dn_node_keys') else 0
if da_count > 0 or dn_count > 0:
node_info_str = f" [da_assets: {da_count}, dn_nodes: {dn_count}]"
# 当前行的连接符
connector = '└── ' if is_last else '├── '
line = f"{prefix}{connector}{key} {count_str}{time_str}{node_info_str}\n"
f.write(line)
# 递归写入子节点
if node_data['children']:
# 为下一层计算新的前缀
extension = ' ' if is_last else ''
new_prefix = prefix + extension
write_tree_recursive(node_data['children'], f, new_prefix)
# 如果不是最后一个节点,在后面加空行
if not is_last:
extension = ''
f.write(f"{prefix}{extension}\n")
# 写入文件
with open(output_file, 'w') as f:
f.write(f"{'='*80}\n")
f.write(f"用户: {user_name}\n")
f.write(f"用户ID: {user_id}\n")
f.write(f"组织ID: {org_id}\n\n")
# 输出统计信息(用 ===== 包起来)
f.write(f"{'='*80}\n")
f.write(f"【统计信息】\n")
f.write(f"{'='*80}\n\n")
# 获取根节点 keys
root_keys = sorted(tree.keys())
# 逐个输出根节点及其子树
for key in root_keys:
node_data = tree[key]
count_value = mapper.get(key, None)
if count_value is not None:
count_str = f"({count_value})"
else:
count_str = "(None)"
# 获取根节点的计算时间
node_time = node_times.get(key, 0)
# 计算直接子节点和更深层节点的时间
direct_children_time, direct_children_count, deeper_descendants_time = get_children_time_info(node_data['children'])
# 格式化时间信息
time_str = f" [当前: {node_time:.3f}s"
if direct_children_count > 0:
time_str += f", 直接子节点: {direct_children_time:.3f}s ({direct_children_count}个)"
time_str += "]"
# 获取节点的 da_asset_ids 和 dn_node_keys 信息
tree_node = tree_nodes.get(key)
node_info_str = ""
if tree_node:
da_count = len(tree_node.da_asset_ids) if hasattr(tree_node, 'da_asset_ids') else 0
dn_count = len(tree_node.dn_node_keys) if hasattr(tree_node, 'dn_node_keys') else 0
if da_count > 0 or dn_count > 0:
node_info_str = f" [da_assets: {da_count}, dn_nodes: {dn_count}]"
f.write(f"{key} {count_str}{time_str}{node_info_str}\n")
# 递归输出子树
if node_data['children']:
write_tree_recursive(node_data['children'], f, prefix='')
log(f"✓ mapper_tree 已写入文件:")
log(f" {output_file}")
def query_node_assets_count(node_key):
count = AssetNodeThrough.objects.filter(
Q(node_key=node_key) | Q(node_key__startswith=node_key + ':')
).distinct('asset_id').count()
return count
def build_org_asset_tree(org_id):
"""构建单个组织的资产树"""
log(f"开始构建组织 {org_id} 的资产树...")
# 获取该组织的统计信息
nodes = Node.objects.filter(org_id=org_id).order_by('key')
node_count = nodes.count()
assets = Asset.objects.filter(org_id=org_id)
asset_count = assets.count()
through_count = AssetNodeThrough.objects.filter(asset__org_id=org_id).count()
log(f" 节点数: {node_count:,}")
log(f" 资产数: {asset_count:,}")
log(f" 关系表总数: {through_count:,}")
node_times = {} # 记录每个节点的计算时间
max_count = 0
level_mapper = defaultdict(list)
# =======================核心代码 ==========================
step_start = datetime.now()
mapper = defaultdict(int)
for i, node in enumerate(nodes):
node_start = datetime.now()
count = query_node_assets_count(node.key)
mapper[node.key] = count
max_count = max(max_count, count)
level = get_node_level(node.key)
level_mapper[level].append(count)
# 记录该节点的计算时间
node_times[node.key] = (datetime.now() - node_start).total_seconds()
# =======================核心代码 ==========================
if (i + 1) % 1000 == 0:
log(f" 已处理 {i+1:,}/{node_count:,} 个节点")
step_time = (datetime.now() - step_start).total_seconds()
log(f" 耗时: {step_time:.2f}s")
# 计算树深度和节点数最大宽度
tree_depth = max(level_mapper.keys()) if level_mapper else 0
max_width = len(f"{max_count:,}")
log(f" 树深度: {tree_depth}")
log(f" 节点数最大宽度: {max_width}")
# 写入文件
write_asset_mapper_to_file(org_id, mapper, node_times, asset_count, through_count, step_time)
return {
'org_id': org_id,
'node_count': node_count,
'asset_count': asset_count,
'through_count': through_count,
'time': step_time
}
#
# =========================== 构建资产树 ===============================
#
def build_orgs_asset_tree():
"""构建所有组织的资产树"""
log(f"\n{'='*60}")
log("【开始构建所有组织的资产树】")
log(f"{'='*60}\n")
total_start = datetime.now()
org_ids = list(Organization.objects.all().values_list('id', flat=True))[:1]
org_count = len(org_ids)
log(f"发现 {org_count:,} 个组织\n")
# 收集每个组织的统计信息
org_stats = []
total_nodes = 0
total_assets = 0
total_through = 0
# =======================核心代码 ==========================
for org_id in org_ids:
stats = build_org_asset_tree(str(org_id))
# =======================核心代码 ==========================
org_stats.append(stats)
total_nodes += stats['node_count']
total_assets += stats['asset_count']
total_through += stats['through_count']
log() # 空行分隔
total_time = (datetime.now() - total_start).total_seconds()
# 输出总体统计
log(f"{'='*60}")
log("【组织统计详情】")
log(f"{'='*60}")
log(f"{'组织ID':<12} {'节点数':<12} {'资产数':<12} {'关系总数':<12} {'耗时(s)':<10}")
log(f"{'-'*60}")
for stats in org_stats:
log(f"{str(stats['org_id']):<12} {stats['node_count']:<12,} {stats['asset_count']:<12,} {stats['through_count']:<12,} {stats['time']:<10.2f}")
log(f"{'-'*60}")
log(f"{'合计':<12} {total_nodes:<12,} {total_assets:<12,} {total_through:<12,}")
log(f"\n{'='*60}")
log("【全局统计】")
log(f"{'='*60}")
log(f"总组织数: {org_count:,}")
log(f"总节点数: {total_nodes:,}")
log(f"总资产数: {total_assets:,}")
log(f"关系表总数: {total_through:,}")
log(f"总耗时: {total_time:.2f}s")
if org_count > 0:
log(f"平均耗时/组织: {total_time/org_count:.2f}s")
log(f"输出目录: {OUTPUT_DIR}")
log(f"{'='*60}\n")
#
# =========================== 构建用户授权树 ===========================
#
class TreeNode:
class Type:
BRIDGE = 'bridge'
DA = 'da-node'
OWNER = 'owner-node'
def __init__(self, key, type=None):
self.key = key
self.type = type
self.assets_count = 0
self.da_asset_ids = set() # 当前节点下所有子孙的 da 资产 ID, 用于查询
self.dn_node_keys = set() # 当前节点下所有子孙的 dn 节点 key, 用于查询
def overridden_da_from_dn(self, dn):
"""da-node 被 dn 覆盖时,合并资产 ID"""
# bridge owner
# da owner
# bridge bridge
if dn.type == self.Type.OWNER:
return dn
if self.type in [self.Type.DA, self.Type.BRIDGE] and dn.type == self.Type.DA:
self.dn_node_keys.update(dn.dn_node_keys)
return self
def get_ancestor_keys(self, with_self=False):
parent_keys = []
key_list = self.key.split(":")
if not with_self:
key_list.pop()
for i in range(len(key_list)):
parent_keys.append(":".join(key_list))
key_list.pop()
return parent_keys
def get_all_children_keys(self, with_self=False):
children_keys = []
key_prefix = self.key + ":"
if with_self:
children_keys.append(self.key)
for node_key in Node.objects.filter(key__startswith=key_prefix).values_list('key', flat=True):
children_keys.append(node_key)
return children_keys
def reverse_build_da_perm_tree(da_nodes):
"""逆向构建完整树节点,补全 bridge-node, 同时把 da_asset_ids 挂上到祖先节点中"""
complete_tree = defaultdict(TreeNode)
# 创建所有缺失的 bridge 节点
for tn in da_nodes.values():
# 先把自己加进去
complete_tree[tn.key] = tn
ancestor_keys = tn.get_ancestor_keys(with_self=False) # 不包含自己
for key in ancestor_keys:
an = complete_tree.get(key)
if not an:
an = TreeNode(key, type=TreeNode.Type.BRIDGE)
an.da_asset_ids.update(tn.da_asset_ids)
complete_tree[key] = an
return complete_tree
def prune_owner_nodes(tree):
# prune_owner_nodes 只保留最上一层的 owner-node
owner_nodes = {k: v for k, v in tree.items() if v.type == TreeNode.Type.OWNER}
for tn in owner_nodes.values():
ancestor_keys = tn.get_ancestor_keys(with_self=False)
if set(ancestor_keys) & set(owner_nodes.keys()):
tree.pop(tn.key)
def reverse_build_dn_perm_tree(dn_nodes):
"""逆向构建完整树节点,补全 bridge-node同时把每一个 owner-node 挂上到祖先节点中"""
complete_tree = copy.deepcopy(dn_nodes)
# 第一步:创建所有缺失的 bridge 节点
for tn in dn_nodes.values():
ancestor_keys = tn.get_ancestor_keys(with_self=False) # 不包含自己
for key in ancestor_keys:
if key not in complete_tree:
complete_tree[key] = TreeNode(key, type=TreeNode.Type.BRIDGE)
# 第二步:给所有父节点挂上 owner-node
for tn in dn_nodes.values():
ancestor_keys = tn.get_ancestor_keys(with_self=False) # 不包含自己
for key in ancestor_keys:
n = complete_tree[key]
n.dn_node_keys.add(tn.key)
prune_owner_nodes(complete_tree)
return complete_tree
def merge_trees(da_tree, dn_tree):
tree = {k: v for k, v in da_tree.items()}
# 合并 dn_tree 到 da_tree 上
for dn in dn_tree.values():
da = tree.get(dn.key)
if dn.type == TreeNode.Type.OWNER:
tree[dn.key] = dn
continue
if dn.type == TreeNode.Type.BRIDGE:
da.dn_node_keys.update(dn.dn_node_keys)
continue
# 最后要删除 owner node 下的所有节点
owner_nodes = {v for v in tree.values() if v.type == TreeNode.Type.OWNER}
for on in owner_nodes:
ac_keys = [k for k in tree.keys() if k.startswith(on.key + ':')]
for k in ac_keys:
tree.pop(k)
return tree
def get_tree_node_all_da_asset_ids(tree, tree_node):
asset_ids = set()
all_da_nodes_children = [
n for n in tree.values()
if n.type == TreeNode.Type.DA and (n.key == tree_node.key or n.key.startswith(tree_node.key + ':'))
]
for node in all_da_nodes_children:
asset_ids.update(node.da_asset_ids)
return asset_ids
def get_tree_node_all_dn_nodes(tree, tree_node):
all_dn_nodes_children = [
n for n in tree.values()
if n.type == TreeNode.Type.OWNER and (n.key == tree_node.key or n.key.startswith(tree_node.key + ':'))
]
return set(all_dn_nodes_children)
def complete_all_dn_children_nodes(tree, org_node_keys):
"""补全所有 dn 子节点"""
complete_tree = copy.deepcopy(tree)
for node in tree.values():
if node.type != TreeNode.Type.OWNER:
continue
_keys = [ k for k in org_node_keys if k.startswith(node.key + ':') ]
for key in _keys:
complete_tree[key] = TreeNode(key, type=TreeNode.Type.OWNER)
return complete_tree
def query_user_perm_node_assets_count(da_ids, tree_node):
# 如果是 dn node, 查询方式和资产树节点的完全一样
# 构造当前节点下所有资产数量的 Q 对象
q_count = Q(node_key=tree_node.key) | Q(node_key__startswith=tree_node.key + ':')
if tree_node.type == TreeNode.Type.OWNER:
count = AssetNodeThrough.objects.filter(q_count).distinct('asset_id').count()
return count
# 如果是 da node 或 bridge node, 需要把 da_asset_ids 和 dn_node_keys 都加进去
# 构造查询用户当前节点下所有授权的全量资产的 Q 对象, 因为在 DB 中要基于这个子表查 (后面的逻辑和资产树节点的查询逻辑一样)
# 会在 sql 语句传输层面上让数据量小一些
q_assets = Q()
if tree_node.da_asset_ids:
_q = Q(asset_id__in=tree_node.da_asset_ids)
q_assets |= _q
if tree_node.dn_node_keys:
_q = [
Q(node_key=n_key) | Q(node_key__startswith=n_key + ':')
for n_key in tree_node.dn_node_keys
]
q_assets |= reduce(or_, _q)
# TODO: 待考虑 先 q_assets, 还是 q_count 效率更高?
count = AssetNodeThrough.objects.filter(q_assets).filter(q_count).distinct('asset_id').count()
return count
def build_user_org_perm_tree(user, org_id, org=None):
"""构建单个用户在组织下的授权树"""
org_name = org.name if org else "没有组织"
log(f" 构建用户 {user.name}({user.id}) 在组织 {org_name}({org_id}) 的授权树")
step_start = datetime.now()
node_times = {}
# =======================核心代码 - 数据构造阶段 ==========================
build_start = datetime.now()
# 查询所有该用户在该组织下的资产权限 perm_ids
group_ids = User.groups.through.objects.filter(user_id=user.id).values_list('usergroup_id', flat=True)
group_perm_ids = AssetPermission.user_groups.through.objects.filter(usergroup_id__in=group_ids).values_list('assetpermission_id', flat=True)
user_perm_ids = AssetPermission.users.through.objects.filter(user_id=user.id).values_list('assetpermission_id', flat=True)
orgs_perm_ids = set(list(group_perm_ids) + list(user_perm_ids))
org_perm_ids = AssetPermission.objects.filter(id__in=orgs_perm_ids, org_id=org_id).values_list('id', flat=True)
# 构建直接授权资产的节点树
da_asset_ids = AssetPermission.assets.through.objects.filter(assetpermission_id__in=org_perm_ids).distinct('asset_id').values_list('asset_id', flat=True)
da_asset_parent_node_keys = AssetNodeThrough.objects.filter(asset_id__in=da_asset_ids).distinct('asset_id').values_list('asset_id', 'node_key')
# 挂上直接资产 da_asset_ids
da_tree_nodes = defaultdict(TreeNode)
for asset_id, node_key in da_asset_parent_node_keys: # 元组 (asset_id, node_key) 不同得 asset_id 可能对应相同 node_key
tn = da_tree_nodes.get(node_key)
if not tn:
tn = TreeNode(node_key, type=TreeNode.Type.DA)
tn.da_asset_ids.add(asset_id)
da_tree_nodes[node_key] = tn
# 逆向生成 da-tree, 补全 bridge-node
da_tree = reverse_build_da_perm_tree(da_tree_nodes)
# 构建直接授权节点的节点树
dn_ids = AssetPermission.nodes.through.objects.filter(assetpermission_id__in=org_perm_ids).distinct().values_list('node_id', flat=True)
dn_node_keys = Node.objects.filter(id__in=dn_ids).values_list('key', flat=True)
dn_tree_nodes = {k: TreeNode(k, type=TreeNode.Type.OWNER) for k in dn_node_keys}
dn_tree = reverse_build_dn_perm_tree(dn_tree_nodes)
# 合并成一颗用户上半棵完整授权树
_tree = merge_trees(da_tree, dn_tree)
# 补全所有 dn 子节点
org_node_keys = set(Node.objects.filter(org_id=org_id).values_list('key', flat=True))
tree = complete_all_dn_children_nodes(_tree, org_node_keys)
build_time = (datetime.now() - build_start).total_seconds()
calc_start = datetime.now()
node_details = {} # 记录每个节点的详细信息
# 计算每个节点的资产数
# =======================核心代码 - Mapper计算阶段 ==========================
mapper = {}
for tn in tree.values():
node_start = datetime.now()
count = query_user_perm_node_assets_count(da_asset_ids, tn)
mapper[tn.key] = count
# =======================核心代码 =========================================
node_times[tn.key] = (datetime.now() - node_start).total_seconds()
# 记录节点的详细信息 - 用于查询时使用的 da_assets 和 dn_nodes
node_details[tn.key] = {
'query_da_assets': len(tn.da_asset_ids),
'query_dn_nodes': len(tn.dn_node_keys),
'node_type': tn.type
}
calc_time = (datetime.now() - calc_start).total_seconds()
step_time = (datetime.now() - step_start).total_seconds()
log(f" 节点数: {len(mapper):,}")
log(f" 数据构造时间: {build_time:.2f}s")
log(f" Mapper计算时间: {calc_time:.2f}s")
log(f" 总耗时: {step_time:.2f}s")
# 写入文件
write_user_perm_mapper_to_file(user.name, user.id, org_id, mapper, node_times, tree_nodes=tree)
return {
'user_id': user.id,
'org_id': org_id,
'node_count': len(mapper),
'time': step_time
}
def build_users_perm_tree():
"""构建所有用户的授权树"""
orgs_mapper = { str(org.id): org for org in Organization.objects.all() }
log(f"\n{'='*60}")
log("【开始构建所有用户的授权树】")
log(f"{'='*60}\n")
total_start = datetime.now()
users = User.objects.all()[:TEST_USER_AMOUNT]
user_count = users.count()
log(f"发现 {user_count:,} 个用户\n")
user_index = 1
# =======================核心代码 ==========================
for user in users:
org_ids = OrgRoleBinding.objects.filter(user=user).distinct('org_id').values_list('org_id', flat=True)
log(f"用户 {user_index}: {user.name}({user.id}) - {len(org_ids)} 个组织")
for org_id in org_ids:
org = orgs_mapper.get(str(org_id))
build_user_org_perm_tree(user, str(org_id), org)
# =======================核心代码 ==========================
log() # 用户之间的空行
user_index += 1
total_time = (datetime.now() - total_start).total_seconds()
log(f"\n{'='*60}")
log("【用户授权树构建完成】")
log(f"总耗时: {total_time:.2f}s")
log(f"{'='*60}\n")
def get_ancestor_keys(key):
parent_keys = []
key_list = key.split(":")
for i in range(len(key_list)):
parent_keys.append(":".join(key_list))
key_list.pop()
return parent_keys
#
# ============== X 方案 ==================
#
def x_build_user_org_perm_tree(user, org_id, use_cache=False):
print('构建用户授权树....')
t1 = time.time()
user_perm_tree = cache.get("user_perm_tree") if use_cache else None
if not user_perm_tree:
group_ids = User.groups.through.objects.filter(user_id=user.id).values_list('usergroup_id', flat=True)
group_perm_ids = AssetPermission.user_groups.through.objects.filter(usergroup_id__in=group_ids).values_list('assetpermission_id', flat=True)
user_perm_ids = AssetPermission.users.through.objects.filter(user_id=user.id).values_list('assetpermission_id', flat=True)
orgs_perm_ids = set(list(group_perm_ids) + list(user_perm_ids))
org_perm_ids = AssetPermission.objects.filter(id__in=orgs_perm_ids, org_id=org_id).values_list('id', flat=True)
da_ids = AssetPermission.assets.through.objects.filter(assetpermission_id__in=org_perm_ids).distinct('asset_id').values_list('asset_id', flat=True)
dn_ids = AssetPermission.nodes.through.objects.filter(assetpermission_id__in=org_perm_ids).distinct().values_list('node_id', flat=True)
dn_keys = Node.objects.filter(id__in=dn_ids).values_list('key', flat=True)
print("未命中缓存,查询所有资产节点关系对...")
# q 就是查询用户所有授权的资产
# 这是冗余 node_key 方案: 最最根本的原因
# 这一步是关键,根据直接授权节点,找到所有子孙节点
# 找出所有子孙节点 的 id
t11 = time.time()
# q = [Q(node_key=key) | Q(node_key__startswith=key + ':') for key in dn_keys] + [Q(asset__id__in=da_ids)]
# aid_nk_pairs = list(AssetNodeThrough.objects.filter(reduce(or_, q)).values_list('asset_id', 'node_key'))
t12 = time.time()
print(f"冗余 key: 查询资产节点关系对耗时: {t12 - t11:.2f}s, 关系对数量: {len(aid_nk_pairs):,}")
# 不冗余 node_key 方案: 直接计算所有 dn_ids 下的所有子孙节点
node_ids = set()
node_id_key_pairs = dict(Node.objects.filter(org_id=org_id).values_list('id', 'key'))
for nid in dn_ids:
node_ids.add(nid)
nk = node_id_key_pairs[nid]
children_ids = [ _id for _id, key in node_id_key_pairs.items() if key.startswith(nk + ':') ]
node_ids.update(children_ids)
aid_nid_pairs = list(AssetNodeThrough.objects.filter( Q(asset__id__in=da_ids) | Q(node_id__in=node_ids) ).values_list('asset_id', 'node_id'))
aid_nk_pairs = [ (aid, node_id_key_pairs[nid]) for aid, nid in aid_nid_pairs ]
t13 = time.time()
print(f"非冗余 key: 查询资产节点关系对耗时: {t13 - t12:.2f}s, 关系对数量: {len(aid_nk_pairs):,}")
mapper = defaultdict(set)
for aid, nk in aid_nk_pairs:
mapper[nk].add(aid)
an_ks = get_ancestor_keys(nk)
for ak in an_ks:
mapper[ak].add(aid)
user_perm_tree = {k: len(v) for k, v in mapper.items()}
cache.set("user_perm_tree", user_perm_tree, 3600)
else:
print("命中缓存,直接使用缓存的资产节点关系对...")
t2 = time.time()
# 格式化打印前10个 mapper key
# mapper_keys_sorted = sorted(mapper.keys())
# print("\n【前10个 mapper key (已排序)】")
# for i, key in enumerate(mapper_keys_sorted[:10], 1):
# print(f" {i}. {key} ({len(mapper[key]):,} 资产)")
print(
f"""
===============用户授权树构建 - 基本信息=============================
用户: {user.name}({user.id})
组织: {org_id}
资产总数: {Asset.objects.filter(org_id=org_id).count():,}
节点总数: {Node.objects.filter(org_id=org_id).count():,}
总耗时: {t2 - t1:.2f}s
测试数据: root-node - {user_perm_tree['1']}
===================================================================
"""
)
def x_build_org_asset_tree(org_id, use_cache=False):
print('构建资产树....')
t1 = time.time()
asset_tree = cache.get("asset_tree") if use_cache else None
if not asset_tree:
# node_keys = Node.objects.filter(org_id=org_id).values_list('key', flat=True)
print("未命中缓存,查询所有资产节点关系对...")
node_id_key_pairs = dict(Node.objects.filter(org_id=org_id).values_list('id', 'key'))
node_ids = list(node_id_key_pairs.keys())
aid_nid_pairs = list(AssetNodeThrough.objects.filter(node_id__in=node_ids).values_list('asset_id', 'node_id'))
mapper = defaultdict(set)
for aid, nid in aid_nid_pairs:
nk = node_id_key_pairs[nid]
mapper[nk].add(aid)
an_ks = get_ancestor_keys(nk)
for ak in an_ks:
mapper[ak].add(aid)
asset_tree = {k: len(v) for k, v in mapper.items()}
cache.set("asset_tree", asset_tree, 3600)
else:
print("命中缓存,直接使用缓存的资产树...")
t2 = time.time()
print('..........', asset_tree['1'])
print(
f"""
===============构建信息=============================
组织: {org_id}
资产总数: {Asset.objects.filter(org_id=org_id).count():,}
节点总数: {Node.objects.filter(org_id=org_id).count():,}
总耗时: {t2 - t1:.2f}s
测试数据: root-node - {asset_tree['1']}
===================================================================
"""
)
def main():
u = User.objects.get(username='admin')
org_id = Organization.default().org_id
try:
# =======================核心代码 ==========================
# if BUILD_ASSET_TREE:
# build_orgs_asset_tree()
# if BUILD_USER_PERM_TREE:
# build_users_perm_tree()
t1 = time.time()
x_build_user_org_perm_tree(u, org_id)
# =======================核心代码 ==========================
t2 = time.time()
x_build_org_asset_tree(org_id)
log(f"\n总耗时: {t2 - t1:.2f}s")
log(f"{'='*60}")
log("✓ 所有操作完成!")
log(f"{'='*60}\n")
except Exception as e:
log(f"\n✗ 发生错误: {str(e)}")
import traceback
traceback.print_exc()
if __name__ == '__main__':
# remove_asset_belong_many_node_record_from_asset_node_through()
main()

View File

@@ -0,0 +1,243 @@
#!/usr/bin/env python
"""
测试查询资产树和授权树下指定节点的资产总数
"""
import os, sys, django, json, time
from datetime import datetime
import copy
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'jumpserver.settings')
sys.path.insert(0, os.path.join(os.getcwd(), 'apps'))
django.setup()
from functools import reduce
from operator import or_
from operator import or_
from django.db.models import Q
from collections import defaultdict
from orgs.models import Organization
from users.models import User, UserGroup
from rbac.models import OrgRoleBinding
from assets.models import Node, Asset
from perms.models import AssetPermission
from django.core.cache import cache
#
# ============== X 方案 ==================
#
AssetNodeThrough = Node.assets.through
def get_ancestor_keys(key):
parent_keys = []
key_list = key.split(":")
for i in range(len(key_list)):
parent_keys.append(":".join(key_list))
key_list.pop()
return parent_keys
def x_build_user_org_perm_tree(user, org_id, use_cache=False):
print(f'构建用户授权树 ({user.username})....')
t1 = time.time()
user_perm_tree = cache.get("user_perm_tree") if use_cache else None
if not user_perm_tree:
group_ids = User.groups.through.objects.filter(user_id=user.id).values_list('usergroup_id', flat=True)
group_perm_ids = AssetPermission.user_groups.through.objects.filter(usergroup_id__in=group_ids).values_list('assetpermission_id', flat=True)
user_perm_ids = AssetPermission.users.through.objects.filter(user_id=user.id).values_list('assetpermission_id', flat=True)
orgs_perm_ids = set(list(group_perm_ids) + list(user_perm_ids))
org_perm_ids = AssetPermission.objects.filter(id__in=orgs_perm_ids, org_id=org_id).values_list('id', flat=True)
# 使用 set() 在 Python 中去重,而不是使用 distinct('asset_id')MySQL 不支持)
# da_ids_all = AssetPermission.assets.through.objects.filter(assetpermission_id__in=org_perm_ids).distinct('asset_id').values_list('asset_id', flat=True)
da_ids_all = AssetPermission.assets.through.objects.filter(assetpermission_id__in=org_perm_ids).values_list('asset_id', flat=True)
da_ids = set(da_ids_all)
# 获取所有直接授权的节点 ID
dn_ids_all = AssetPermission.nodes.through.objects.filter(assetpermission_id__in=org_perm_ids).values_list('node_id', flat=True)
dn_ids = set(dn_ids_all)
print("未命中缓存,查询所有资产节点关系对...")
# 这是冗余 node_key 方案: 最最根本的原因
# 这一步是关键,根据直接授权节点,找到所有子孙节点
# 找出所有子孙节点 的 id
t11 = time.time()
# q 就是查询用户所有授权的资产
# dn_keys = Node.objects.filter(id__in=dn_ids).values_list('key', flat=True)
# q = [Q(node_key=key) | Q(node_key__startswith=key + ':') for key in dn_keys] + [Q(asset__id__in=da_ids)]
# aid_nk_pairs = list(AssetNodeThrough.objects.filter(reduce(or_, q)).values_list('asset_id', 'node_key'))
# t12 = time.time()
# print(f"冗余 key: 查询资产节点关系对耗时: {t12 - t11:.2f}s, 关系对数量: {len(aid_nk_pairs):,}")
# 不冗余 node_key 方案: 直接计算所有 dn_ids 下的所有子孙节点
node_ids = set()
node_id_key_pairs = dict(Node.objects.filter(org_id=org_id).values_list('id', 'key'))
for nid in dn_ids:
node_ids.add(nid)
nk = node_id_key_pairs[nid]
children_ids = [ _id for _id, key in node_id_key_pairs.items() if key.startswith(nk + ':') ]
node_ids.update(children_ids)
aid_nid_pairs = list(AssetNodeThrough.objects.filter( Q(asset__id__in=da_ids) | Q(node_id__in=node_ids) ).values_list('asset_id', 'node_id'))
aid_nk_pairs = [ (aid, node_id_key_pairs[nid]) for aid, nid in aid_nid_pairs ]
t13 = time.time()
print(f"非冗余 key: 查询资产节点关系对耗时: {t13 - t11:.2f}s, 关系对数量: {len(aid_nk_pairs):,}")
mapper = defaultdict(set)
for aid, nk in aid_nk_pairs:
mapper[nk].add(aid)
an_ks = get_ancestor_keys(nk)
for ak in an_ks:
mapper[ak].add(aid)
user_perm_tree = {k: len(v) for k, v in mapper.items()}
cache.set("user_perm_tree", user_perm_tree, 3600)
else:
print("命中缓存,直接使用缓存的资产节点关系对...")
t2 = time.time()
# 格式化打印前10个 mapper key
# mapper_keys_sorted = sorted(mapper.keys())
# print("\n【前10个 mapper key (已排序)】")
# for i, key in enumerate(mapper_keys_sorted[:10], 1):
# print(f" {i}. {key} ({len(mapper[key]):,} 资产)")
ROOT_KEY = [k for k in user_perm_tree.keys() if k.isdigit() ]
if ROOT_KEY:
ROOT_KEY = ROOT_KEY[0]
else:
print("用户授权树中没有根节点,可能用户没有任何授权资产。")
return
print(
f"""
===============用户授权树构建 - 基本信息=============================
用户: {user.name}({user.id})
组织: {org_id}
资产总数: {Asset.objects.filter(org_id=org_id).count():,}
节点总数: {Node.objects.filter(org_id=org_id).count():,}
资产节点关系对总数: {AssetNodeThrough.objects.filter(node__org_id=org_id).count():,}
查出来的资产节点关系对总数: {len(aid_nk_pairs):,}
直接授权资产数(用于查询): {len(da_ids):,}
直接授权节点数: {len(dn_ids):,}
用于查询的节点数: {len(node_ids):,}
总耗时: {t2 - t1:.2f}s
ROOT_KEY: {ROOT_KEY},
用于验证数据: ROOT Node count- {user_perm_tree[ROOT_KEY]} 通过rebuild_user_org_perm_tree 函数构建的授权树资产数一致就对了。
===================================================================
"""
)
from perms.utils.user_perm_tree import UserPermTreeRefreshUtil
from perms.models import UserAssetGrantedTreeNodeRelation
UserPermTreeRefreshUtil(user=user)._rebuild_user_perm_tree_for_org(org_id)
count = UserAssetGrantedTreeNodeRelation.objects.filter(user=user).get(node__key=ROOT_KEY).node_assets_amount
print(f'''
使用原始方法获取授权树 ROOT Node count:
通过 UserAssetGrantedTreeNodeRelation 刷新得到的授权 ROOT Node count - {count}
''')
def x_build_org_asset_tree(org_id, use_cache=False):
print('构建资产树....')
t1 = time.time()
asset_tree = cache.get("asset_tree") if use_cache else None
if not asset_tree:
# node_keys = Node.objects.filter(org_id=org_id).values_list('key', flat=True)
print("未命中缓存,查询所有资产节点关系对...")
node_id_key_pairs = dict(Node.objects.filter(org_id=org_id).values_list('id', 'key'))
node_ids = list(node_id_key_pairs.keys())
aid_nid_pairs = list(AssetNodeThrough.objects.filter(node_id__in=node_ids).values_list('asset_id', 'node_id'))
mapper = defaultdict(set)
for aid, nid in aid_nid_pairs:
nk = node_id_key_pairs[nid]
mapper[nk].add(aid)
an_ks = get_ancestor_keys(nk)
for ak in an_ks:
mapper[ak].add(aid)
asset_tree = {k: len(v) for k, v in mapper.items()}
cache.set("asset_tree", asset_tree, 3600)
else:
print("命中缓存,直接使用缓存的资产树...")
t2 = time.time()
ROOT_KEY = [k for k in asset_tree.keys() if k.isdigit() ][0]
r_aids = list(dict(aid_nid_pairs).keys())
a_ids = list(Asset.objects.filter(org_id=org_id).values_list('id', flat=True))
_aids = Asset.objects.filter(nodes__isnull=True).count()
print(
f"""
===============构建信息=============================
组织: {org_id}
组织内资产总数: {Asset.objects.filter(org_id=org_id).count():,}
组织内节点总数: {Node.objects.filter(org_id=org_id).count():,}
资产节点关系对总数: {AssetNodeThrough.objects.count():,}
总耗时: {t2 - t1:.2f}s
用于验证数据: Root node - {asset_tree[ROOT_KEY]} 等于组织内资产总数就对了。
===================================================================
找出资产在关系表中 和 资产表中 不一致的数据:
关系表中资产数: {len(r_aids):,}
资产表中资产数: {len(a_ids):,}
在关系表中有,但资产表中没有的数量: {len(set(r_aids) - set(a_ids)):,} 说明资产被删除了,但是资产和节点的关系还在
在资产表中有,但关系表中没有的数量: {len(set(a_ids) - set(r_aids)):,}, 说明资产没有挂载节点上
游离资产总数: { _aids }
"""
)
def main():
users = User.objects.filter(username='admin')
# users = User.objects.all()
org = Organization.objects.filter(name__contains='性能').first()
# org = Organization.objects.filter().first()
x_build_org_asset_tree(org.id)
for u in users:
x_build_user_org_perm_tree(u, org.id)
# print(f"\n总耗时: {t2 - t1:.2f}s")
print(f"{'='*60}")
print("✓ 所有操作完成!")
print(f"{'='*60}\n")
def setup_test_data():
print("设置测试数据...")
org = Organization.objects.filter(name__contains='性能').first()
if not org:
print("未找到测试组织,请先创建一个名称包含 '性能' 的组织。")
return
user = User.objects.filter(username='admin').first()
a_ids = Asset.objects.filter(org_id=org.id).values_list('id', flat=True)[:30000]
n_ids = Node.objects.filter(org_id=org.id).values_list('id', flat=True)[:1000]
p1, p2 = AssetPermission.objects.filter(org_id=org.id)[:2]
p1.users.add(user)
p1.assets.add(*a_ids)
p2.users.add(user)
p2.nodes.add(*n_ids)
print(f'给用户 {user.username} 分配两个权限..., 一个直接授权 {len(a_ids)} 资产,一个直接授权 {len(n_ids)} 节点')
if __name__ == '__main__':
setup_test_data()
main()
#

View File

@@ -0,0 +1,350 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
生成随机的 Asset-Node 关系数据
用法: python generate_asset_node_through_data.py
"""
import os
import sys
import django
import random
from datetime import datetime
# 配置 Django 设置
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'jumpserver.settings')
sys.path.insert(0, '/Users/bryan/JumpServer/jumpserver/apps')
django.setup()
from django.db import connections, transaction
from assets.models import Asset, Node, AssetNode
from orgs.models import Organization
# ============== 配置 ==============
ORG_ID = str(Organization.default().id) # 指定组织 ID可根据需要修改
# ORG_ID = None
TARGET_COUNT = 1000 # 目标数据量50万条
BATCH_SIZE = 50000 # 批处理大小
INCLUDE_NODE_KEY = True # 是否生成 node_key 字段(如果 through 表有该字段)
def log(msg):
"""打印带时间戳的日志"""
print(f"[{datetime.now().strftime('%H:%M:%S')}] {msg}")
def get_asset_and_node_ids(org_id):
"""获取指定组织的 Asset 和 Node ID"""
log(f"获取组织 {org_id} 的 Asset 和 Node 数据...")
# 获取 Asset ID 列表
assets = Asset.objects.filter(org_id=org_id).values_list('id', flat=True)
asset_ids = list(assets)
# 获取 Node ID 列表
nodes = Node.objects.filter(org_id=org_id).values_list('id', flat=True)
node_ids = list(nodes)
if not asset_ids:
log(f"✗ 错误:组织 {org_id} 中没有 Asset 数据")
return None, None
if not node_ids:
log(f"✗ 错误:组织 {org_id} 中没有 Node 数据")
return None, None
log(f"✓ 获取到 {len(asset_ids):,} 个 Asset")
log(f"✓ 获取到 {len(node_ids):,} 个 Node\n")
return asset_ids, node_ids
def generate_and_insert_data_v2(asset_ids, node_ids, batch_size, include_node_key=False):
"""生成数据:每个 asset 随机关联 1-3 个节点
Args:
asset_ids: Asset ID 列表
node_ids: Node ID 列表
batch_size: 批处理大小
include_node_key: 是否生成 node_key 字段
"""
log(f"开始为所有 Asset 生成关联节点数据...")
log(f"每个 asset 随机关联 1-3 个节点")
log(f"总 asset 数: {len(asset_ids):,}")
log(f"总 node 数: {len(node_ids):,}")
log(f"批处理大小: {batch_size:,}")
log(f"是否生成 node_key: {'' if include_node_key else ''}\n")
start_time = datetime.now()
inserted = 0
# 获取 AssetNode through 表
asset_node_through = Asset.nodes.through
# 如果需要生成 node_key预加载 Node.key 映射
node_key_map = {}
if include_node_key:
log("加载 Node key 映射...")
node_key_map = {nid: key for nid, key in Node.objects.filter(id__in=node_ids).values_list('id', 'key')}
log(f"✓ 加载了 {len(node_key_map):,} 个 node_key\n")
# 检查 through 表是否有 node_key 字段
has_node_key_field = 'node_key' in [f.name for f in asset_node_through._meta.get_fields()]
if include_node_key and not has_node_key_field:
log("⚠ 警告:配置要求生成 node_key但 through 表没有该字段,将跳过 node_key 生成\n")
include_node_key = False
# 清空现有数据(如果需要)
current_count = asset_node_through.objects.count()
if current_count > 0:
log(f"⚠ 清空已有数据: {current_count:,}")
asset_node_through.objects.all().delete()
log(f"✓ 清空完成\n")
try:
# 为每个 asset 生成 1-3 个 node 关联
objs = []
for asset_idx, asset_id in enumerate(asset_ids):
# 为这个 asset 随机选择 1-3 个节点
num_nodes = random.randint(1, 3)
selected_nodes = random.sample(node_ids, min(num_nodes, len(node_ids)))
# 为每个选中的节点创建关联
for node_id in selected_nodes:
if include_node_key and has_node_key_field:
node_key = node_key_map.get(node_id, '')
objs.append(
asset_node_through(asset_id=asset_id, node_id=node_id, node_key=node_key)
)
else:
objs.append(
asset_node_through(asset_id=asset_id, node_id=node_id)
)
# 定期批量插入
if len(objs) >= batch_size:
with transaction.atomic():
asset_node_through.objects.bulk_create(objs, batch_size=batch_size, ignore_conflicts=True)
inserted += len(objs)
# 显示进度
progress_pct = (asset_idx + 1) / len(asset_ids) * 100
elapsed = (datetime.now() - start_time).total_seconds()
log(f" 已处理 {asset_idx + 1:,}/{len(asset_ids):,} ({progress_pct:.1f}%) asset - 已插入 {inserted:,} 条数据 - 耗时 {elapsed:.1f}s")
objs = []
# 插入剩余的数据
if objs:
with transaction.atomic():
asset_node_through.objects.bulk_create(objs, batch_size=batch_size, ignore_conflicts=True)
inserted += len(objs)
# 统计最终结果
total_now = asset_node_through.objects.count()
elapsed = (datetime.now() - start_time).total_seconds()
avg_nodes_per_asset = total_now / len(asset_ids) if asset_ids else 0
log(f"\n✓ 数据生成完成!")
log(f" 总 asset 数: {len(asset_ids):,}")
log(f" 生成关系数: {inserted:,}")
log(f" 平均每个 asset 关联节点数: {avg_nodes_per_asset:.2f}")
log(f" 表总数: {total_now:,}")
log(f" 耗时: {elapsed:.1f}s")
if elapsed > 0:
log(f" 平均速率: {inserted/elapsed:.0f} 条/秒\n")
return True
except Exception as e:
log(f"\n✗ 插入数据时出错: {str(e)}")
import traceback
traceback.print_exc()
return False
def generate_and_insert_data(asset_ids, node_ids, target_count, batch_size, include_node_key=False):
"""生成随机数据并插入数据库
Args:
asset_ids: Asset ID 列表
node_ids: Node ID 列表
target_count: 目标生成数量
batch_size: 批处理大小
include_node_key: 是否生成 node_key 字段
"""
log(f"开始生成 {target_count:,} 条 Asset-Node 关系数据...")
log(f"批处理大小: {batch_size:,}")
log(f"是否生成 node_key: {'' if include_node_key else ''}\n")
start_time = datetime.now()
inserted = 0
# 获取 AssetNode through 表
asset_node_through = Asset.nodes.through
# 如果需要生成 node_key预加载 Node.key 映射
node_key_map = {}
if include_node_key:
log("加载 Node key 映射...")
node_key_map = {nid: key for nid, key in Node.objects.filter(id__in=node_ids).values_list('id', 'key')}
log(f"✓ 加载了 {len(node_key_map):,} 个 node_key\n")
# 检查 through 表是否有 node_key 字段
has_node_key_field = 'node_key' in [f.name for f in asset_node_through._meta.get_fields()]
if include_node_key and not has_node_key_field:
log("⚠ 警告:配置要求生成 node_key但 through 表没有该字段,将跳过 node_key 生成\n")
include_node_key = False
# 清空现有数据(如果需要)
current_count = asset_node_through.objects.count()
if current_count > 0:
log(f"⚠ through 表中已有 {current_count:,} 条数据,继续追加...\n")
try:
# 分批生成和插入数据
# 为了确保每个 asset 只属于一个 node我们需要追踪已分配的 asset
asset_to_node_map = {}
for batch_start in range(0, target_count, batch_size):
batch_end = min(batch_start + batch_size, target_count)
batch_size_actual = batch_end - batch_start
# 生成随机数据
objs = []
for _ in range(batch_size_actual):
# 如果所有 asset 都已分配,重新开始循环
if len(asset_to_node_map) >= len(asset_ids):
asset_to_node_map = {}
# 找一个还没分配过的 asset
attempts = 0
while attempts < 100:
asset_id = random.choice(asset_ids)
if asset_id not in asset_to_node_map:
break
attempts += 1
# 如果找不到未分配的 asset就跳过这个
if asset_id in asset_to_node_map:
continue
# 为这个 asset 随机选择一个 node
node_id = random.choice(node_ids)
asset_to_node_map[asset_id] = node_id
# 构建对象
if include_node_key and has_node_key_field:
node_key = node_key_map.get(node_id, '')
objs.append(
asset_node_through(asset_id=asset_id, node_id=node_id, node_key=node_key)
)
else:
objs.append(
asset_node_through(asset_id=asset_id, node_id=node_id)
)
# 批量插入
if objs:
with transaction.atomic():
asset_node_through.objects.bulk_create(objs, batch_size=batch_size, ignore_conflicts=True)
inserted += len(objs)
# 显示进度
progress_pct = (batch_end / target_count) * 100 if target_count > 0 else 0
elapsed = (datetime.now() - start_time).total_seconds()
log(f" 已插入 {inserted:,}/{target_count:,} ({progress_pct:.1f}%) - 耗时 {elapsed:.1f}s")
# 统计最终结果
total_now = asset_node_through.objects.count()
elapsed = (datetime.now() - start_time).total_seconds()
log(f"\n✓ 数据生成完成!")
log(f" 插入: {inserted:,} 条新数据")
log(f" 表总数: {total_now:,}")
log(f" 耗时: {elapsed:.1f}s")
log(f" 平均速率: {inserted/elapsed:.0f} 条/秒\n")
return True
except Exception as e:
log(f"\n✗ 插入数据时出错: {str(e)}")
import traceback
traceback.print_exc()
return False
def main():
"""主函数"""
log(f"\n{'='*60}")
log("【Asset-Node 关系数据生成工具 (v2 - 每个asset关联1-3个节点)】")
log(f"{'='*60}\n")
log(f"配置:")
log(f" 组织 ID: {ORG_ID}")
log(f" 批处理大小: {BATCH_SIZE:,}")
log(f" 包含 node_key: {INCLUDE_NODE_KEY}\n")
if ORG_ID is None:
log("✗ 错误:请在脚本中设置 ORG_ID 变量")
return False
# 获取 Asset 和 Node ID
asset_ids, node_ids = get_asset_and_node_ids(ORG_ID)
if not asset_ids or not node_ids:
return False
# 生成并插入数据 (v2: 每个 asset 关联 1-3 个节点)
success = generate_and_insert_data_v2(asset_ids, node_ids, BATCH_SIZE, include_node_key=INCLUDE_NODE_KEY)
if success:
log("="*60)
log("✓ 所有操作完成!")
log("="*60)
return True
else:
log("="*60)
log("✗ 操作失败")
log("="*60)
return False
def check_node_key_is_empty():
if not INCLUDE_NODE_KEY:
return
# 校验是否有 node_key 为空的数据
has_node_key_field = 'node_key' in [f.name for f in asset_node_through._meta.get_fields()]
if has_node_key_field:
from django.db.models import Q
empty_node_key = asset_node_through.objects.filter(Q(node_key='') | Q(node_key__isnull=True)).count()
log(f" node_key 为空的数据: {empty_node_key:,}")
if empty_node_key > 0:
log(f" ⚠ 警告:存在 {empty_node_key:,} 条 node_key 为空的数据")
else:
log(f" ✓ 所有数据的 node_key 都已填充")
else:
log(f" ⚠ through 表没有 node_key 字段")
if __name__ == '__main__':
try:
success = main()
# 输出最终统计
asset_node_through = Asset.nodes.through
total_count = asset_node_through.objects.count()
log(f"{'='*60}")
log(f"最终统计:")
log(f" AssetNode through 表总数据量: {total_count:,}")
check_node_key_is_empty()
log(f"{'='*60}")
sys.exit(0 if success else 1)
except KeyboardInterrupt:
log("\n✗ 用户中断操作")
sys.exit(1)
except Exception as e:
log(f"\n✗ 发生错误: {str(e)}")
import traceback
traceback.print_exc()
sys.exit(1)

113
test_query_node_assets.py Normal file
View File

@@ -0,0 +1,113 @@
#!/usr/bin/env python
"""
测试查询资产树和授权树下指定节点的资产总数
迭代查询一个节点一个SQL
"""
import os, sys, django
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'jumpserver.settings')
sys.path.insert(0, '/Users/bryan/JumpServer/jumpserver/apps')
django.setup()
def test():
from django.db.models import Count, Q
from assets.models import Node
def main():
# 示例:查询这些节点的资产数
import os, time, json, random
from functools import reduce
from operator import or_
from django.db.models import Q
from collections import defaultdict
from assets.models import Node, Asset
# 资产树
nodes = Node.objects.all().order_by('key')[:1000]
node_keys = [n.key for n in nodes]
t1 = time.time()
t_mapper = defaultdict(int)
for key in node_keys[:1]:
count = Node.assets.through.objects.filter(
Q(node_key=key) | Q(node_key__startswith=key + ':')
).distinct('asset_id').count()
t_mapper[key] = count
t2 = time.time()
t_root = t_mapper.get('1', 0)
count_rs = Node.assets.through.objects.all().count()
count_asset = Asset.objects.all().count()
count_node = Node.objects.all().count()
# print('=' * 50, '查询资产树', '=' * 50)
# print(json.dumps(t_mapper, indent=4))
# print('总资产数:', count_asset, '总节点数:', count_node, '关联表总数:', count_rs)
# print(f"查询 {len(node_keys)} 个节点资产数耗时: {t2 - t1:.4f} 秒")
# 授权树
guessed_perm_das_count = 10000
guessed_perm_dn_count = 100
# guessed_perm_das_count = 5000
# guessed_perm_dn_count = 50
# guessed_perm_das_count = 1000
# guessed_perm_dn_count = 20
# guessed_perm_das_count = 100
# guessed_perm_dn_count = 10
# guessed_perm_das_count = 0
# guessed_perm_dn_count = 20
# guessed_perm_das_count = 1000
# guessed_perm_dn_count = 0
# guessed_perm_das_count = 200000
# guessed_perm_dn_count = 0
# guessed_perm_das_count = 0
# guessed_perm_dn_count = 1000
guessed_perm_das_count = 200000
guessed_perm_dn_count = 1000
perm_dns = Node.objects.all()[:guessed_perm_dn_count]
perm_dn_keys = [str(n.key) for n in perm_dns]
perm_das = Asset.objects.all()[:guessed_perm_das_count]
perm_da_ids= [str(a.id) for a in perm_das]
# qn_keys = Node.objects.all().values_list('key', flat=True)[:100]
qn_keys = ['1']
t3 = time.time()
p_mapper = defaultdict(int)
# 新方案: 比如查询 A 节点,
# 如果是 A 是 owner-node 那就跟资产一样 ms 级直接返回
# 如果是 A bridge 或 da-node:
# 只需要将A节点下的所有直接资产id和A节点下的所有 owner key 作第一个 filter 的过滤条件即可base 资产表)
# (in=A-da-ids | =A.key | startswith=A.key + ':')
q_user_assets = [
Q(node_key=key) | Q(node_key__startswith=key + ':') for key in perm_dn_keys
] + [ Q(asset_id__in=perm_da_ids) ]
qs_base = Node.assets.through.objects.filter(reduce(or_, q_user_assets))
for qn_key in qn_keys:
count = qs_base.filter(
Q(node_key=qn_key) | Q(node_key__startswith=qn_key + ':')
).distinct('asset_id').count()
p_mapper[qn_key] = count
t4 = time.time()
p_root = p_mapper.get('1', 0)
print(json.dumps(t_mapper, indent=4))
print(json.dumps(p_mapper, indent=4))
print('=' * 50, '查询资产树', '=' * 50)
print('总资产数:', count_asset, '总节点数:', count_node, '关联表总数:', count_rs)
print(f"查询 {len(node_keys)} 个节点资产数耗时: {t2 - t1:.4f}")
print('='*50, '查询用户权限下节点资产数', '=' * 50)
print('假设直接授权节点数: ', len(perm_dn_keys), '假设直接授权资产数:', len(perm_da_ids))
print(f"查询 {len(qn_keys)} 个节点资产数耗时: {t4 - t3:.4f}")
print(f"资产树根节点资产数: {t_root}, 授权树根节点资产数: {p_root}")
if __name__ == '__main__':
main()