mirror of
https://github.com/jumpserver/jumpserver.git
synced 2025-12-25 13:32:36 +00:00
Compare commits
23 Commits
pr@dev@per
...
v5_refacto
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a474d9be3e | ||
|
|
3a4e93af2f | ||
|
|
9c2ddbba7e | ||
|
|
39129cecbe | ||
|
|
88819bbf26 | ||
|
|
a88e35156a | ||
|
|
22a27946a7 | ||
|
|
4983465a23 | ||
|
|
4d9fc9dfd6 | ||
|
|
c7cb83fa1d | ||
|
|
ee92c72b50 | ||
|
|
6a05fbe0fe | ||
|
|
0284be169a | ||
|
|
a4e9d4f815 | ||
|
|
bbe549696a | ||
|
|
56f720271a | ||
|
|
9755076f7f | ||
|
|
8d7abef191 | ||
|
|
aaa40722c4 | ||
|
|
ca39344937 | ||
|
|
4b9a8227c9 | ||
|
|
f362163af1 | ||
|
|
5f1ba56e56 |
@@ -5,6 +5,7 @@ from rest_framework.request import Request
|
||||
from assets.models import Node, Platform, Protocol, MyAsset
|
||||
from assets.utils import get_node_from_request, is_query_node_all_assets
|
||||
from common.utils import lazyproperty, timeit
|
||||
from assets.tree.asset_tree import AssetTreeNode
|
||||
|
||||
|
||||
class SerializeToTreeNodeMixin:
|
||||
@@ -19,22 +20,22 @@ class SerializeToTreeNodeMixin:
|
||||
return False
|
||||
|
||||
@timeit
|
||||
def serialize_nodes(self, nodes: List[Node], with_asset_amount=False):
|
||||
if with_asset_amount:
|
||||
def _name(node: Node):
|
||||
return '{} ({})'.format(node.value, node.assets_amount)
|
||||
else:
|
||||
def _name(node: Node):
|
||||
return node.value
|
||||
def serialize_nodes(self, nodes: List[AssetTreeNode], with_asset_amount=False, expand_level=1, with_assets=False):
|
||||
if not nodes:
|
||||
return []
|
||||
|
||||
def _open(node):
|
||||
if not self.is_sync:
|
||||
# 异步加载资产树时,默认展开节点
|
||||
return True
|
||||
if not node.parent_key:
|
||||
return True
|
||||
def _name(node: AssetTreeNode):
|
||||
v = node.value
|
||||
if not with_asset_amount:
|
||||
return v
|
||||
v = f'{v} ({node.assets_amount_total})'
|
||||
return v
|
||||
|
||||
def is_parent(node: AssetTreeNode):
|
||||
if with_assets:
|
||||
return node.assets_amount > 0 or not node.is_leaf
|
||||
else:
|
||||
return False
|
||||
return not node.is_leaf
|
||||
|
||||
data = [
|
||||
{
|
||||
@@ -42,15 +43,17 @@ class SerializeToTreeNodeMixin:
|
||||
'name': _name(node),
|
||||
'title': _name(node),
|
||||
'pId': node.parent_key,
|
||||
'isParent': True,
|
||||
'open': _open(node),
|
||||
'isParent': is_parent(node),
|
||||
'open': node.level <= expand_level,
|
||||
'meta': {
|
||||
'type': 'node',
|
||||
'data': {
|
||||
"id": node.id,
|
||||
"key": node.key,
|
||||
"value": node.value,
|
||||
"assets_amount": node.assets_amount,
|
||||
"assets_amount_total": node.assets_amount_total,
|
||||
},
|
||||
'type': 'node'
|
||||
}
|
||||
}
|
||||
for node in nodes
|
||||
@@ -72,6 +75,9 @@ class SerializeToTreeNodeMixin:
|
||||
|
||||
@timeit
|
||||
def serialize_assets(self, assets, node_key=None, get_pid=None):
|
||||
if not assets:
|
||||
return []
|
||||
|
||||
if not get_pid and not node_key:
|
||||
get_pid = lambda asset, platform: getattr(asset, 'parent_key', '')
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# ~*~ coding: utf-8 ~*~
|
||||
|
||||
from django.db.models import Q
|
||||
from django.db.models import Q, Count
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework.generics import get_object_or_404
|
||||
from rest_framework.response import Response
|
||||
@@ -11,12 +11,16 @@ from common.tree import TreeNodeSerializer
|
||||
from common.utils import get_logger
|
||||
from orgs.mixins import generics
|
||||
from orgs.utils import current_org
|
||||
from orgs.models import Organization
|
||||
from .mixin import SerializeToTreeNodeMixin
|
||||
from .. import serializers
|
||||
from ..const import AllTypes
|
||||
from ..models import Node, Platform, Asset
|
||||
from assets.tree.asset_tree import AssetTree
|
||||
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
__all__ = [
|
||||
'NodeChildrenApi',
|
||||
'NodeChildrenAsTreeApi',
|
||||
@@ -25,14 +29,13 @@ __all__ = [
|
||||
|
||||
|
||||
class NodeChildrenApi(generics.ListCreateAPIView):
|
||||
"""
|
||||
节点的增删改查
|
||||
"""
|
||||
''' 节点的增删改查 '''
|
||||
serializer_class = serializers.NodeSerializer
|
||||
search_fields = ('value',)
|
||||
|
||||
instance = None
|
||||
is_initial = False
|
||||
perm_model = Node
|
||||
|
||||
def initial(self, request, *args, **kwargs):
|
||||
super().initial(request, *args, **kwargs)
|
||||
@@ -65,42 +68,16 @@ class NodeChildrenApi(generics.ListCreateAPIView):
|
||||
else:
|
||||
node = Node.org_root()
|
||||
return node
|
||||
|
||||
if pk:
|
||||
node = get_object_or_404(Node, pk=pk)
|
||||
else:
|
||||
node = get_object_or_404(Node, key=key)
|
||||
return node
|
||||
|
||||
def get_org_root_queryset(self, query_all):
|
||||
if query_all:
|
||||
return Node.objects.all()
|
||||
else:
|
||||
return Node.org_root_nodes()
|
||||
|
||||
def get_queryset(self):
|
||||
query_all = self.request.query_params.get("all", "0") == "all"
|
||||
|
||||
if self.is_initial and current_org.is_root():
|
||||
return self.get_org_root_queryset(query_all)
|
||||
|
||||
if self.is_initial:
|
||||
with_self = True
|
||||
else:
|
||||
with_self = False
|
||||
|
||||
if not self.instance:
|
||||
return Node.objects.none()
|
||||
|
||||
if query_all:
|
||||
queryset = self.instance.get_all_children(with_self=with_self)
|
||||
else:
|
||||
queryset = self.instance.get_children(with_self=with_self)
|
||||
return queryset
|
||||
|
||||
|
||||
class NodeChildrenAsTreeApi(SerializeToTreeNodeMixin, NodeChildrenApi):
|
||||
"""
|
||||
节点子节点作为树返回,
|
||||
''' 节点子节点作为树返回,
|
||||
[
|
||||
{
|
||||
"id": "",
|
||||
@@ -109,51 +86,96 @@ class NodeChildrenAsTreeApi(SerializeToTreeNodeMixin, NodeChildrenApi):
|
||||
"meta": ""
|
||||
}
|
||||
]
|
||||
'''
|
||||
|
||||
"""
|
||||
model = Node
|
||||
|
||||
def filter_queryset(self, queryset):
|
||||
""" queryset is Node queryset """
|
||||
if not self.request.GET.get('search'):
|
||||
return queryset
|
||||
queryset = super().filter_queryset(queryset)
|
||||
queryset = self.model.get_ancestor_queryset(queryset)
|
||||
return queryset
|
||||
|
||||
def get_queryset_for_assets(self):
|
||||
query_all = self.request.query_params.get("all", "0") == "all"
|
||||
include_assets = self.request.query_params.get('assets', '0') == '1'
|
||||
if not self.instance or not include_assets:
|
||||
return Asset.objects.none()
|
||||
if not self.request.GET.get('search') and self.instance.is_org_root():
|
||||
return Asset.objects.none()
|
||||
if query_all:
|
||||
assets = self.instance.get_all_assets()
|
||||
else:
|
||||
assets = self.instance.get_assets()
|
||||
return assets.only(
|
||||
"id", "name", "address", "platform_id",
|
||||
"org_id", "is_active", 'comment'
|
||||
).prefetch_related('platform')
|
||||
|
||||
def filter_queryset_for_assets(self, assets):
|
||||
search = self.request.query_params.get('search')
|
||||
if search:
|
||||
q = Q(name__icontains=search) | Q(address__icontains=search)
|
||||
assets = assets.filter(q)
|
||||
return assets
|
||||
|
||||
def list(self, request, *args, **kwargs):
|
||||
nodes = self.filter_queryset(self.get_queryset()).order_by('value')
|
||||
search = request.query_params.get('search')
|
||||
with_assets = request.query_params.get('assets', '0') == '1'
|
||||
with_asset_amount = request.query_params.get('asset_amount', '1') == '1'
|
||||
nodes = self.serialize_nodes(nodes, with_asset_amount=with_asset_amount)
|
||||
assets = self.filter_queryset_for_assets(self.get_queryset_for_assets())
|
||||
node_key = self.instance.key if self.instance else None
|
||||
assets = self.serialize_assets(assets, node_key=node_key)
|
||||
with_asset_amount = True
|
||||
nodes, assets, expand_level = self.get_nodes_assets(search, with_assets)
|
||||
nodes = self.serialize_nodes(nodes, with_asset_amount=with_asset_amount, expand_level=expand_level)
|
||||
assets = self.serialize_assets(assets)
|
||||
data = [*nodes, *assets]
|
||||
return Response(data=data)
|
||||
|
||||
def get_nodes_assets(self, search, with_assets):
|
||||
#
|
||||
# 资产管理-节点树
|
||||
#
|
||||
|
||||
# 全局组织: 初始化节点树, 返回所有节点, 不包含资产, 不展开节点
|
||||
# 实体组织: 初始化节点树, 返回所有节点, 不包含资产, 展开一级节点
|
||||
# 前端搜索
|
||||
if not with_assets:
|
||||
if current_org.is_root():
|
||||
orgs = Organization.objects.all()
|
||||
expand_level = 0
|
||||
else:
|
||||
orgs = [current_org]
|
||||
expand_level = 1
|
||||
|
||||
nodes = []
|
||||
assets = []
|
||||
for org in orgs:
|
||||
tree = AssetTree(org=org)
|
||||
org_nodes = tree.get_nodes()
|
||||
nodes.extend(org_nodes)
|
||||
return nodes, assets, expand_level
|
||||
|
||||
#
|
||||
# 权限管理、账号发现、风险检测 - 资产节点树
|
||||
#
|
||||
|
||||
# 全局组织: 搜索资产, 生成资产节点树, 过滤每个组织前 1000 个资产, 展开所有节点
|
||||
# 实体组织: 搜索资产, 生成资产节点树, 过滤前 1000 个资产, 展开所有节点
|
||||
if search:
|
||||
if current_org.is_root():
|
||||
orgs = list(Organization.objects.all())
|
||||
else:
|
||||
orgs = [current_org]
|
||||
nodes = []
|
||||
assets = []
|
||||
assets_q_object = Q(name__icontains=search) | Q(address__icontains=search)
|
||||
with_assets_limit = 1000 / len(orgs)
|
||||
for org in orgs:
|
||||
tree = AssetTree(
|
||||
assets_q_object=assets_q_object, org=org,
|
||||
with_assets=True, with_assets_limit=with_assets_limit, full_tree=False
|
||||
)
|
||||
nodes.extend(tree.get_nodes())
|
||||
assets.extend(tree.get_assets())
|
||||
expand_level = 10000 # search 时展开所有节点
|
||||
return nodes, assets, expand_level
|
||||
|
||||
# 全局组织: 展开某个节点及其资产
|
||||
# 实体组织: 展开某个节点及其资产
|
||||
# 实体组织: 初始化资产节点树, 自动展开根节点及其资产, 所以节点要包含自己 (特殊情况)
|
||||
if self.instance:
|
||||
nodes = []
|
||||
tree = AssetTree(with_assets_node_id=self.instance.id, org=self.instance.org)
|
||||
nodes_with_self = False
|
||||
if not current_org.is_root() and self.instance.is_org_root():
|
||||
nodes_with_self = True
|
||||
nodes = tree.get_node_children(key=self.instance.key, with_self=nodes_with_self)
|
||||
assets = tree.get_assets()
|
||||
expand_level = 1 # 默认只展开第一级
|
||||
return nodes, assets, expand_level
|
||||
|
||||
# 全局组织: 初始化资产节点树, 仅返回各组织根节点, 不展开
|
||||
orgs = Organization.objects.all()
|
||||
nodes = []
|
||||
assets = []
|
||||
for org in orgs:
|
||||
tree = AssetTree(org=org, with_assets=False)
|
||||
if not tree.root:
|
||||
continue
|
||||
nodes.append(tree.root)
|
||||
expand_level = 0 # 默认不展开节点
|
||||
return nodes, assets, expand_level
|
||||
|
||||
|
||||
class CategoryTreeApi(SerializeToTreeNodeMixin, generics.ListAPIView):
|
||||
serializer_class = TreeNodeSerializer
|
||||
|
||||
@@ -63,11 +63,11 @@ class NodeFilterBackend(filters.BaseFilterBackend):
|
||||
query_all = is_query_node_all_assets(request)
|
||||
if query_all:
|
||||
return queryset.filter(
|
||||
Q(nodes__key__startswith=f'{node.key}:') |
|
||||
Q(nodes__key=node.key)
|
||||
Q(node__key__startswith=f'{node.key}:') |
|
||||
Q(node__key=node.key)
|
||||
).distinct()
|
||||
else:
|
||||
return queryset.filter(nodes__key=node.key).distinct()
|
||||
return queryset.filter(node__key=node.key).distinct()
|
||||
|
||||
|
||||
class IpInFilterBackend(filters.BaseFilterBackend):
|
||||
|
||||
126
apps/assets/migrations/0020_asset_node.py
Normal file
126
apps/assets/migrations/0020_asset_node.py
Normal file
@@ -0,0 +1,126 @@
|
||||
# Generated by Django 4.1.13 on 2025-12-16 09:14
|
||||
|
||||
from django.db import migrations, models, transaction
|
||||
import django.db.models.deletion
|
||||
|
||||
|
||||
def log(msg=''):
|
||||
print(f' -> {msg}')
|
||||
|
||||
|
||||
def ensure_asset_single_node(apps, schema_editor):
|
||||
print('')
|
||||
log('Checking that all assets are linked to only one node...')
|
||||
Asset = apps.get_model('assets', 'Asset')
|
||||
Through = Asset.nodes.through
|
||||
|
||||
assets_count_multi_nodes = Through.objects.values('asset_id').annotate(
|
||||
node_count=models.Count('node_id')
|
||||
).filter(node_count__gt=1).count()
|
||||
|
||||
if assets_count_multi_nodes > 0:
|
||||
raise Exception(
|
||||
f'There are {assets_count_multi_nodes} assets associated with more than one node. '
|
||||
'Please ensure each asset is linked to only one node before applying this migration.'
|
||||
)
|
||||
else:
|
||||
log('All assets are linked to only one node. Proceeding with the migration.')
|
||||
|
||||
|
||||
def ensure_asset_has_node(apps, schema_editor):
|
||||
log('Checking that all assets are linked to at least one node...')
|
||||
Asset = apps.get_model('assets', 'Asset')
|
||||
Through = Asset.nodes.through
|
||||
|
||||
asset_count = Asset.objects.count()
|
||||
through_asset_count = Through.objects.values('asset_id').count()
|
||||
|
||||
assets_count_without_node = asset_count - through_asset_count
|
||||
|
||||
if assets_count_without_node > 0:
|
||||
raise Exception(
|
||||
f'Some assets ({assets_count_without_node}) are not associated with any node. '
|
||||
'Please ensure all assets are linked to a node before applying this migration.'
|
||||
)
|
||||
else:
|
||||
log('All assets are linked to a node. Proceeding with the migration.')
|
||||
|
||||
|
||||
def migrate_asset_node_id_field(apps, schema_editor):
|
||||
log('Migrating node_id field for all assets...')
|
||||
|
||||
Asset = apps.get_model('assets', 'Asset')
|
||||
Through = Asset.nodes.through
|
||||
|
||||
assets = Asset.objects.filter(node_id__isnull=True)
|
||||
log (f'Found {assets.count()} assets to migrate.')
|
||||
|
||||
asset_node_mapper = {
|
||||
str(asset_id): str(node_id)
|
||||
for asset_id, node_id in Through.objects.values_list('asset_id', 'node_id')
|
||||
}
|
||||
# 测试
|
||||
asset_node_mapper.pop(None, None) # Remove any entries with None keys
|
||||
|
||||
for asset in assets:
|
||||
node_id = asset_node_mapper.get(str(asset.id))
|
||||
if not node_id:
|
||||
raise Exception(
|
||||
f'Asset (ID: {asset.id}) is not associated with any node. '
|
||||
'Cannot migrate node_id field.'
|
||||
)
|
||||
asset.node_id = node_id
|
||||
|
||||
with transaction.atomic():
|
||||
total = len(assets)
|
||||
batch_size = 5000
|
||||
|
||||
for i in range(0, total, batch_size):
|
||||
batch = assets[i:i+batch_size]
|
||||
start = i + 1
|
||||
end = min(i + batch_size, total)
|
||||
|
||||
for asset in batch:
|
||||
asset.save(update_fields=['node_id'])
|
||||
|
||||
log(f"Migrated {start}-{end}/{total} assets")
|
||||
|
||||
count = Asset.objects.filter(node_id__isnull=True).count()
|
||||
if count > 0:
|
||||
log('Warning: Some assets still have null node_id after migration.')
|
||||
raise Exception('Migration failed: Some assets have null node_id.')
|
||||
|
||||
count = Asset.objects.filter(node_id__isnull=False).count()
|
||||
log(f'Successfully migrated node_id for {count} assets.')
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('assets', '0019_alter_asset_connectivity'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.RunPython(
|
||||
ensure_asset_single_node,
|
||||
reverse_code=migrations.RunPython.noop
|
||||
),
|
||||
migrations.RunPython(
|
||||
ensure_asset_has_node,
|
||||
reverse_code=migrations.RunPython.noop
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name='asset',
|
||||
name='node',
|
||||
field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.PROTECT, related_name='direct_assets', to='assets.node', verbose_name='Node'),
|
||||
),
|
||||
migrations.RunPython(
|
||||
migrate_asset_node_id_field,
|
||||
reverse_code=migrations.RunPython.noop
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='asset',
|
||||
name='node',
|
||||
field=models.ForeignKey(on_delete=django.db.models.deletion.PROTECT, related_name='direct_assets', to='assets.node', verbose_name='Node'),
|
||||
),
|
||||
]
|
||||
@@ -172,6 +172,11 @@ class Asset(NodesRelationMixin, LabeledMixin, AbsConnectivity, JSONFilterMixin,
|
||||
"assets.Zone", null=True, blank=True, related_name='assets',
|
||||
verbose_name=_("Zone"), on_delete=models.SET_NULL
|
||||
)
|
||||
node = models.ForeignKey(
|
||||
'assets.Node', null=False, blank=False, on_delete=models.PROTECT,
|
||||
related_name='direct_assets', verbose_name=_("Node")
|
||||
)
|
||||
# TODO: 删除完代码中所有使用的地方后,再删除 nodes 字段,并将 node 字段的 related_name 改为 'assets'
|
||||
nodes = models.ManyToManyField(
|
||||
'assets.Node', default=default_node, related_name='assets', verbose_name=_("Nodes")
|
||||
)
|
||||
|
||||
@@ -394,7 +394,7 @@ class NodeAssetsMixin(NodeAllAssetsMappingMixin):
|
||||
|
||||
def get_all_assets(self):
|
||||
from .asset import Asset
|
||||
q = Q(nodes__key__startswith=f'{self.key}:') | Q(nodes__key=self.key)
|
||||
q = Q(node__key__startswith=f'{self.key}:') | Q(node__key=self.key)
|
||||
return Asset.objects.filter(q).distinct()
|
||||
|
||||
def get_assets_amount(self):
|
||||
@@ -416,8 +416,8 @@ class NodeAssetsMixin(NodeAllAssetsMappingMixin):
|
||||
|
||||
def get_assets(self):
|
||||
from .asset import Asset
|
||||
assets = Asset.objects.filter(nodes=self)
|
||||
return assets.distinct()
|
||||
assets = Asset.objects.filter(node=self)
|
||||
return assets
|
||||
|
||||
def get_valid_assets(self):
|
||||
return self.get_assets().valid()
|
||||
@@ -531,6 +531,15 @@ class SomeNodesMixin:
|
||||
root_nodes = cls.objects.filter(parent_key='', key__regex=r'^[0-9]+$') \
|
||||
.exclude(key__startswith='-').order_by('key')
|
||||
return root_nodes
|
||||
|
||||
@classmethod
|
||||
def get_or_create_org_root(cls, org):
|
||||
org_root = cls.org_root_nodes().filter(org_id=org.id).first()
|
||||
if org_root:
|
||||
return org_root
|
||||
with tmp_to_org(org):
|
||||
org_root = cls.create_org_root_node()
|
||||
return org_root
|
||||
|
||||
|
||||
class Node(JMSOrgBaseModel, SomeNodesMixin, FamilyMixin, NodeAssetsMixin):
|
||||
|
||||
@@ -186,6 +186,7 @@ class AssetSerializer(BulkOrgResourceModelSerializer, ResourceLabelsMixin, Writa
|
||||
super().__init__(*args, **kwargs)
|
||||
self._init_field_choices()
|
||||
self._extract_accounts()
|
||||
self._set_platform()
|
||||
|
||||
def _extract_accounts(self):
|
||||
if not getattr(self, 'initial_data', None):
|
||||
@@ -217,6 +218,21 @@ class AssetSerializer(BulkOrgResourceModelSerializer, ResourceLabelsMixin, Writa
|
||||
protocols_data = [{'name': p.name, 'port': p.port} for p in protocols]
|
||||
self.initial_data['protocols'] = protocols_data
|
||||
|
||||
def _set_platform(self):
|
||||
if not hasattr(self, 'initial_data'):
|
||||
return
|
||||
platform_id = self.initial_data.get('platform')
|
||||
if not platform_id:
|
||||
return
|
||||
|
||||
if isinstance(platform_id, int) or str(platform_id).isdigit() or not isinstance(platform_id, str):
|
||||
return
|
||||
|
||||
platform = Platform.objects.filter(name=platform_id).first()
|
||||
if not platform:
|
||||
return
|
||||
self.initial_data['platform'] = platform.id
|
||||
|
||||
def _init_field_choices(self):
|
||||
request = self.context.get('request')
|
||||
if not request:
|
||||
@@ -265,8 +281,10 @@ class AssetSerializer(BulkOrgResourceModelSerializer, ResourceLabelsMixin, Writa
|
||||
|
||||
if not platform_id and self.instance:
|
||||
platform = self.instance.platform
|
||||
else:
|
||||
elif isinstance(platform_id, int):
|
||||
platform = Platform.objects.filter(id=platform_id).first()
|
||||
else:
|
||||
platform = Platform.objects.filter(name=platform_id).first()
|
||||
|
||||
if not platform:
|
||||
raise serializers.ValidationError({'platform': _("Platform not exist")})
|
||||
@@ -297,6 +315,7 @@ class AssetSerializer(BulkOrgResourceModelSerializer, ResourceLabelsMixin, Writa
|
||||
|
||||
def is_valid(self, raise_exception=False):
|
||||
self._set_protocols_default()
|
||||
self._set_platform()
|
||||
return super().is_valid(raise_exception=raise_exception)
|
||||
|
||||
def validate_protocols(self, protocols_data):
|
||||
|
||||
1
apps/assets/tree/__init__.py
Normal file
1
apps/assets/tree/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .asset_tree import *
|
||||
246
apps/assets/tree/asset_tree.py
Normal file
246
apps/assets/tree/asset_tree.py
Normal file
@@ -0,0 +1,246 @@
|
||||
from collections import defaultdict
|
||||
from django.db.models import Count, Q
|
||||
|
||||
from orgs.utils import current_org
|
||||
from orgs.models import Organization
|
||||
from assets.models import Asset, Node, Platform
|
||||
from assets.const.category import Category
|
||||
from common.utils import get_logger, timeit, lazyproperty
|
||||
|
||||
from .tree import TreeNode, Tree
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
__all__ = ['AssetTree', 'AssetTreeNode']
|
||||
|
||||
|
||||
class AssetTreeNodeAsset:
|
||||
|
||||
def __init__(self, id, node_id, parent_key, name, address,
|
||||
platform_id, is_active, comment, org_id):
|
||||
|
||||
self.id = id
|
||||
self.node_id = node_id
|
||||
self.parent_key = parent_key
|
||||
self.name = name
|
||||
self.address = address
|
||||
self.platform_id = platform_id
|
||||
self.is_active = is_active
|
||||
self.comment = comment
|
||||
self.org_id = org_id
|
||||
|
||||
@lazyproperty
|
||||
def org(self):
|
||||
return Organization.get_instance(self.org_id)
|
||||
|
||||
@property
|
||||
def org_name(self) -> str:
|
||||
return self.org.name
|
||||
|
||||
|
||||
class AssetTreeNode(TreeNode):
|
||||
|
||||
def __init__(self, _id, key, value, assets_amount=0, assets=None):
|
||||
super().__init__(_id, key, value)
|
||||
self.assets_amount = assets_amount
|
||||
self.assets_amount_total = 0
|
||||
self.assets: list[AssetTreeNodeAsset] = []
|
||||
self.init_assets(assets)
|
||||
|
||||
def init_assets(self, assets):
|
||||
if not assets:
|
||||
return
|
||||
for asset in assets:
|
||||
asset['parent_key'] = self.key
|
||||
self.assets.append(AssetTreeNodeAsset(**asset))
|
||||
|
||||
def get_assets(self):
|
||||
return self.assets
|
||||
|
||||
def as_dict(self, simple=True):
|
||||
data = super().as_dict(simple=simple)
|
||||
data.update({
|
||||
'assets_amount_total': self.assets_amount_total,
|
||||
'assets_amount': self.assets_amount,
|
||||
'assets': len(self.assets),
|
||||
})
|
||||
return data
|
||||
|
||||
|
||||
class AssetTree(Tree):
|
||||
|
||||
TreeNode = AssetTreeNode
|
||||
|
||||
def __init__(self, assets_q_object: Q = None, category=None, org=None,
|
||||
with_assets=False, with_assets_node_id=None, with_assets_limit=1000,
|
||||
full_tree=True):
|
||||
'''
|
||||
:param assets_q_object: 只生成这些资产所在的节点树
|
||||
:param category: 只生成该类别资产所在的节点树
|
||||
:param org: 只生成该组织下的资产节点树
|
||||
|
||||
:param with_assets_node_id: 仅指定节点下包含资产
|
||||
:param with_assets: 所有节点都包含资产
|
||||
:param with_assets_limit: 包含资产时, 所有资产的最大数量
|
||||
|
||||
:param full_tree: 完整树包含所有节点,否则只包含节点的资产总数不为0的节点
|
||||
'''
|
||||
|
||||
super().__init__()
|
||||
## 通过资产构建节点树, 支持 Q, category, org 等过滤条件 ##
|
||||
self._assets_q_object: Q = assets_q_object or Q()
|
||||
self._category = self._check_category(category)
|
||||
self._category_platform_ids = set()
|
||||
self._org: Organization = org or current_org
|
||||
|
||||
# org 下全量节点属性映射, 构建资产树时根据完整的节点进行构建
|
||||
self._nodes_attr_mapper = defaultdict(dict)
|
||||
# 节点直接资产数量映射, 用于计算节点下总资产数量
|
||||
self._nodes_assets_amount_mapper = defaultdict(int)
|
||||
# 节点下是否包含资产
|
||||
self._with_assets = with_assets # 所有节点都包含资产
|
||||
self._with_assets_node_id = with_assets_node_id # 仅指定节点下包含资产, 优先级高于 with_assets
|
||||
self._with_assets_limit = with_assets_limit
|
||||
self._node_assets_mapper = defaultdict(dict)
|
||||
|
||||
# 是否包含资产总数量为 0 的节点
|
||||
self._full_tree = full_tree
|
||||
|
||||
# 初始化时构建树
|
||||
self.build()
|
||||
|
||||
def _check_category(self, category):
|
||||
if category is None:
|
||||
return None
|
||||
if category in Category.values:
|
||||
return category
|
||||
logger.warning(f"Invalid category '{category}' for AssetSearchTree.")
|
||||
return None
|
||||
|
||||
@timeit
|
||||
def build(self):
|
||||
self._load_nodes_attr_mapper()
|
||||
self._load_category_platforms_if_needed()
|
||||
self._load_nodes_assets_amount()
|
||||
self._load_nodes_assets_if_needed()
|
||||
self._init_tree()
|
||||
self._compute_assets_amount_total()
|
||||
self._remove_nodes_with_zero_assets_if_needed()
|
||||
|
||||
@timeit
|
||||
def _load_category_platforms_if_needed(self):
|
||||
if self._category is None:
|
||||
return
|
||||
ids = Platform.objects.filter(category=self._category).values_list('id', flat=True)
|
||||
ids = self._uuids_to_string(ids)
|
||||
self._category_platform_ids = ids
|
||||
|
||||
@timeit
|
||||
def _load_nodes_attr_mapper(self):
|
||||
nodes = Node.objects.filter(org_id=self._org.id).values('id', 'key', 'value')
|
||||
# 保证节点按 key 顺序加载,以便后续构建树时父节点总在子节点前面
|
||||
nodes = sorted(nodes, key=lambda n: [int(i) for i in n['key'].split(':')])
|
||||
for node in list(nodes):
|
||||
node['id'] = str(node['id'])
|
||||
self._nodes_attr_mapper[node['id']] = node
|
||||
|
||||
@timeit
|
||||
def _load_nodes_assets_amount(self):
|
||||
q = self._make_assets_q_object()
|
||||
nodes_amount = Asset.objects.filter(q).values('node_id').annotate(
|
||||
amount=Count('id')
|
||||
).values('node_id', 'amount')
|
||||
for nc in list(nodes_amount):
|
||||
nid = str(nc['node_id'])
|
||||
self._nodes_assets_amount_mapper[nid] = nc['amount']
|
||||
|
||||
@timeit
|
||||
def _load_nodes_assets_if_needed(self):
|
||||
need_load = self._with_assets or self._with_assets_node_id
|
||||
if not need_load:
|
||||
return
|
||||
|
||||
q = self._make_assets_q_object()
|
||||
if self._with_assets_node_id:
|
||||
# 仅指定节点下包含资产,优先级高于 with_assets
|
||||
q &= Q(node_id=self._with_assets_node_id)
|
||||
|
||||
assets = Asset.objects.filter(q).values(
|
||||
'node_id', 'id', 'platform_id', 'name', 'address', 'is_active', 'comment', 'org_id'
|
||||
)
|
||||
# 按照 node_key 排序,尽可能保证前面节点的资产较多
|
||||
# 限制资产数量
|
||||
assets = assets.order_by('node__key')[:self._with_assets_limit]
|
||||
for asset in list(assets):
|
||||
nid = asset['node_id'] = str(asset['node_id'])
|
||||
aid = asset['id'] = str(asset['id'])
|
||||
self._node_assets_mapper[nid][aid] = asset
|
||||
|
||||
@timeit
|
||||
def _make_assets_q_object(self) -> Q:
|
||||
q = Q(org_id=self._org.id)
|
||||
if self._category_platform_ids:
|
||||
q &= Q(platform_id__in=self._category_platform_ids)
|
||||
if self._assets_q_object:
|
||||
q &= self._assets_q_object
|
||||
return q
|
||||
|
||||
@timeit
|
||||
def _init_tree(self):
|
||||
for nid in self._nodes_attr_mapper.keys():
|
||||
data = self._get_tree_node_data(nid)
|
||||
node = self.TreeNode(**data)
|
||||
self.add_node(node)
|
||||
|
||||
def _get_tree_node_data(self, node_id):
|
||||
attr = self._nodes_attr_mapper[node_id]
|
||||
assets_amount = self._nodes_assets_amount_mapper.get(node_id, 0)
|
||||
data = {
|
||||
'_id': node_id,
|
||||
'key': attr['key'],
|
||||
'value': attr['value'],
|
||||
'assets_amount': assets_amount,
|
||||
}
|
||||
|
||||
assets = self._node_assets_mapper[node_id].values()
|
||||
if assets:
|
||||
assets = list(assets)
|
||||
data.update({ 'assets': assets })
|
||||
return data
|
||||
|
||||
@timeit
|
||||
def _compute_assets_amount_total(self):
|
||||
for node in reversed(list(self.nodes.values())):
|
||||
total = node.assets_amount
|
||||
for child in node.children:
|
||||
child: AssetTreeNode
|
||||
total += child.assets_amount_total
|
||||
node: AssetTreeNode
|
||||
node.assets_amount_total = total
|
||||
@timeit
|
||||
def _remove_nodes_with_zero_assets_if_needed(self):
|
||||
if self._full_tree:
|
||||
return
|
||||
nodes: list[AssetTreeNode] = list(self.nodes.values())
|
||||
nodes_to_remove = [
|
||||
node for node in nodes if not node.is_root and node.assets_amount_total == 0
|
||||
]
|
||||
for node in nodes_to_remove:
|
||||
self.remove_node(node)
|
||||
|
||||
def get_assets(self):
|
||||
assets = []
|
||||
for node in self.nodes.values():
|
||||
node: AssetTreeNode
|
||||
_assets = node.get_assets()
|
||||
assets.extend(_assets)
|
||||
return assets
|
||||
|
||||
def _uuids_to_string(self, uuids):
|
||||
return [ str(u) for u in uuids ]
|
||||
|
||||
def print(self, count=20, simple=True):
|
||||
print('org_name: ', getattr(self._org, 'name', 'No-org'))
|
||||
print(f'asset_category: {self._category}')
|
||||
super().print(count=count, simple=simple)
|
||||
164
apps/assets/tree/tree.py
Normal file
164
apps/assets/tree/tree.py
Normal file
@@ -0,0 +1,164 @@
|
||||
from common.utils import get_logger, lazyproperty
|
||||
|
||||
|
||||
__all__ = ['TreeNode', 'Tree']
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class TreeNode(object):
|
||||
|
||||
def __init__(self, _id, key, value):
|
||||
self.id = _id
|
||||
self.key = key
|
||||
self.value = value
|
||||
self.children = []
|
||||
self.parent = None
|
||||
|
||||
@lazyproperty
|
||||
def parent_key(self):
|
||||
if self.is_root:
|
||||
return None
|
||||
return ':'.join(self.key.split(':')[:-1])
|
||||
|
||||
@property
|
||||
def is_root(self):
|
||||
return self.key.isdigit()
|
||||
|
||||
def add_child(self, child_node: 'TreeNode'):
|
||||
child_node.parent = self
|
||||
self.children.append(child_node)
|
||||
|
||||
def remove_child(self, child_node: 'TreeNode'):
|
||||
self.children.remove(child_node)
|
||||
child_node.parent = None
|
||||
|
||||
@property
|
||||
def is_leaf(self):
|
||||
return len(self.children) == 0
|
||||
|
||||
@lazyproperty
|
||||
def level(self):
|
||||
return self.key.count(':') + 1
|
||||
|
||||
@property
|
||||
def children_count(self):
|
||||
return len(self.children)
|
||||
|
||||
def as_dict(self, simple=True):
|
||||
data = {
|
||||
'key': self.key,
|
||||
}
|
||||
if simple:
|
||||
return data
|
||||
|
||||
data.update({
|
||||
'id': self.id,
|
||||
'value': self.value,
|
||||
'level': self.level,
|
||||
'children_count': self.children_count,
|
||||
'is_root': self.is_root,
|
||||
'is_leaf': self.is_leaf,
|
||||
})
|
||||
return data
|
||||
|
||||
def print(self, simple=True, is_print_keys=False):
|
||||
def info_as_string(_info):
|
||||
return ' | '.join(s.ljust(25) for s in _info)
|
||||
|
||||
if is_print_keys:
|
||||
info_keys = [k for k in self.as_dict(simple=simple).keys()]
|
||||
info_keys_string = info_as_string(info_keys)
|
||||
print('-' * len(info_keys_string))
|
||||
print(info_keys_string)
|
||||
print('-' * len(info_keys_string))
|
||||
|
||||
info_values = [str(v) for v in self.as_dict(simple=simple).values()]
|
||||
info_values_as_string = info_as_string(info_values)
|
||||
print(info_values_as_string)
|
||||
print('-' * len(info_values_as_string))
|
||||
|
||||
|
||||
class Tree(object):
|
||||
|
||||
def __init__(self):
|
||||
self.root = None
|
||||
# { key -> TreeNode }
|
||||
self.nodes: dict[TreeNode] = {}
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
return len(self.nodes)
|
||||
|
||||
@property
|
||||
def is_empty(self):
|
||||
return self.size == 0
|
||||
|
||||
@property
|
||||
def depth(self):
|
||||
" 返回树的最大深度,以及对应的节点key "
|
||||
if self.is_empty:
|
||||
return 0, 0
|
||||
node = max(self.nodes.values(), key=lambda node: node.level)
|
||||
node: TreeNode
|
||||
print(f"max_depth_node_key: {node.key}")
|
||||
return node.level
|
||||
|
||||
@property
|
||||
def width(self):
|
||||
" 返回树的最大宽度,以及对应的层级数 "
|
||||
if self.is_empty:
|
||||
return 0, 0
|
||||
node = max(self.nodes.values(), key=lambda node: node.children_count)
|
||||
node: TreeNode
|
||||
print(f"max_width_level: {node.level + 1}")
|
||||
return node.children_count
|
||||
|
||||
def add_node(self, node: TreeNode):
|
||||
if node.is_root:
|
||||
self.root = node
|
||||
self.nodes[node.key] = node
|
||||
return
|
||||
parent = self.get_node(node.parent_key)
|
||||
if not parent:
|
||||
error = f""" Cannot add node {node.key}: parent key {node.parent_key} not found.
|
||||
Please ensure parent nodes are added before child nodes."""
|
||||
raise ValueError(error)
|
||||
parent.add_child(node)
|
||||
self.nodes[node.key] = node
|
||||
|
||||
def get_node(self, key: str) -> TreeNode:
|
||||
return self.nodes.get(key)
|
||||
|
||||
def remove_node(self, node: TreeNode):
|
||||
if node.is_root:
|
||||
self.root = None
|
||||
else:
|
||||
parent: TreeNode = node.parent
|
||||
parent.remove_child(node)
|
||||
self.nodes.pop(node.key, None)
|
||||
|
||||
def get_nodes(self):
|
||||
return list(self.nodes.values())
|
||||
|
||||
def get_node_children(self, key, with_self=False):
|
||||
node = self.get_node(key)
|
||||
if not node:
|
||||
return []
|
||||
nodes = []
|
||||
if with_self:
|
||||
nodes.append(node)
|
||||
nodes.extend(node.children)
|
||||
return nodes
|
||||
|
||||
def print(self, count=10, simple=True):
|
||||
print('tree_root_key: ', getattr(self.root, 'key', 'No-root'))
|
||||
print('tree_size: ', self.size)
|
||||
print('tree_depth: ', self.depth)
|
||||
print('tree_width: ', self.width)
|
||||
|
||||
is_print_key = True
|
||||
for n in list(self.nodes.values())[:count]:
|
||||
n: TreeNode
|
||||
n.print(simple=simple, is_print_keys=is_print_key)
|
||||
is_print_key = False
|
||||
@@ -67,6 +67,7 @@ class UserLoginMFAView(mixins.AuthMixin, FormView):
|
||||
def get_context_data(self, **kwargs):
|
||||
user = self.get_user_from_session()
|
||||
mfa_context = self.get_user_mfa_context(user)
|
||||
print(mfa_context)
|
||||
kwargs.update(mfa_context)
|
||||
return kwargs
|
||||
|
||||
|
||||
@@ -701,15 +701,7 @@ class Config(dict):
|
||||
'CHAT_AI_ENABLED': False,
|
||||
'CHAT_AI_METHOD': 'api',
|
||||
'CHAT_AI_EMBED_URL': '',
|
||||
'CHAT_AI_TYPE': 'gpt',
|
||||
'GPT_BASE_URL': '',
|
||||
'GPT_API_KEY': '',
|
||||
'GPT_PROXY': '',
|
||||
'GPT_MODEL': 'gpt-4o-mini',
|
||||
'DEEPSEEK_BASE_URL': '',
|
||||
'DEEPSEEK_API_KEY': '',
|
||||
'DEEPSEEK_PROXY': '',
|
||||
'DEEPSEEK_MODEL': 'deepseek-chat',
|
||||
'CHAT_AI_PROVIDERS': [],
|
||||
'VIRTUAL_APP_ENABLED': False,
|
||||
|
||||
'FILE_UPLOAD_SIZE_LIMIT_MB': 200,
|
||||
|
||||
@@ -241,15 +241,7 @@ ASSET_SIZE = 'small'
|
||||
CHAT_AI_ENABLED = CONFIG.CHAT_AI_ENABLED
|
||||
CHAT_AI_METHOD = CONFIG.CHAT_AI_METHOD
|
||||
CHAT_AI_EMBED_URL = CONFIG.CHAT_AI_EMBED_URL
|
||||
CHAT_AI_TYPE = CONFIG.CHAT_AI_TYPE
|
||||
GPT_BASE_URL = CONFIG.GPT_BASE_URL
|
||||
GPT_API_KEY = CONFIG.GPT_API_KEY
|
||||
GPT_PROXY = CONFIG.GPT_PROXY
|
||||
GPT_MODEL = CONFIG.GPT_MODEL
|
||||
DEEPSEEK_BASE_URL = CONFIG.DEEPSEEK_BASE_URL
|
||||
DEEPSEEK_API_KEY = CONFIG.DEEPSEEK_API_KEY
|
||||
DEEPSEEK_PROXY = CONFIG.DEEPSEEK_PROXY
|
||||
DEEPSEEK_MODEL = CONFIG.DEEPSEEK_MODEL
|
||||
CHAT_AI_DEFAULT_PROVIDER = CONFIG.CHAT_AI_DEFAULT_PROVIDER
|
||||
|
||||
VIRTUAL_APP_ENABLED = CONFIG.VIRTUAL_APP_ENABLED
|
||||
|
||||
@@ -268,4 +260,6 @@ LOKI_BASE_URL = CONFIG.LOKI_BASE_URL
|
||||
TOOL_USER_ENABLED = CONFIG.TOOL_USER_ENABLED
|
||||
|
||||
SUGGESTION_LIMIT = CONFIG.SUGGESTION_LIMIT
|
||||
MCP_ENABLED = CONFIG.MCP_ENABLED
|
||||
MCP_ENABLED = CONFIG.MCP_ENABLED
|
||||
CHAT_AI_PROVIDERS = CONFIG.CHAT_AI_PROVIDERS
|
||||
|
||||
|
||||
73
apps/perms/tree.py
Normal file
73
apps/perms/tree.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from collections import defaultdict
|
||||
from django.db.models import Q, Count
|
||||
|
||||
from common.utils import get_logger
|
||||
from users.models import User
|
||||
from assets.tree.asset_tree import AssetTree, AssetTreeNode
|
||||
from perms.utils.utils import UserPermUtil
|
||||
|
||||
|
||||
__all__ = ['UserPermTree']
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PermTreeNode(AssetTreeNode):
|
||||
|
||||
class Type:
|
||||
# Neither a permission node nor a node with direct permission assets
|
||||
BRIDGE = 'bridge'
|
||||
# Node with direct permission
|
||||
DN = 'dn'
|
||||
# Node with only direct permission assets
|
||||
DA = 'da'
|
||||
|
||||
def __init__(self, tp, _id, key, value, assets_count=0, assets=None):
|
||||
super().__init__(_id, key, value, assets_count)
|
||||
self.type = tp or self.Type.BRIDGE
|
||||
|
||||
def as_dict(self, simple=True):
|
||||
data = super().as_dict(simple=simple)
|
||||
data.update({
|
||||
'type': self.type,
|
||||
})
|
||||
return data
|
||||
|
||||
|
||||
class UserPermTree(AssetTree):
|
||||
|
||||
TreeNode = PermTreeNode
|
||||
|
||||
def __init__(self, user=None, assets_q_object=None, category=None, org=None, with_assets=False):
|
||||
super().__init__(
|
||||
assets_q_object=assets_q_object,
|
||||
category=category,
|
||||
org=org,
|
||||
with_assets=with_assets,
|
||||
full_tree=False
|
||||
)
|
||||
self._user: User = user
|
||||
self._util = UserPermUtil(user, org=self._org)
|
||||
|
||||
def _make_assets_q_object(self):
|
||||
q = super()._make_assets_q_object()
|
||||
q_perm_assets = Q(id__in=self._util._user_direct_asset_ids)
|
||||
q_perm_nodes = Q(node_id__in=self._util._user_direct_node_all_children_ids)
|
||||
q = q & (q_perm_assets | q_perm_nodes)
|
||||
return q
|
||||
|
||||
def _get_tree_node_data(self, node_id):
|
||||
data = super()._get_tree_node_data(node_id)
|
||||
if node_id in self._util._user_direct_node_all_children_ids:
|
||||
tp = PermTreeNode.Type.DN
|
||||
elif self._nodes_assets_count_mapper.get(node_id, 0) > 0:
|
||||
tp = PermTreeNode.Type.DA
|
||||
else:
|
||||
tp = PermTreeNode.Type.BRIDGE
|
||||
data.update({ 'tp': tp })
|
||||
return data
|
||||
|
||||
def print(self, simple=True, count=10):
|
||||
self._util.print()
|
||||
super().print(simple=simple, count=count)
|
||||
138
apps/perms/utils/utils.py
Normal file
138
apps/perms/utils/utils.py
Normal file
@@ -0,0 +1,138 @@
|
||||
|
||||
from django.db.models import Q
|
||||
|
||||
from common.utils import timeit, lazyproperty, get_logger, is_uuid
|
||||
from orgs.utils import current_org
|
||||
from users.models import User
|
||||
from assets.models import Node, Asset
|
||||
from perms.models import AssetPermission
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
__all__ = ['UserPermUtil']
|
||||
|
||||
|
||||
class UserPermUtil(object):
|
||||
|
||||
UserGroupThrough = User.groups.through
|
||||
PermUserThrough = AssetPermission.users.through
|
||||
PermUserGroupThrough = AssetPermission.user_groups.through
|
||||
PermAssetThrough = AssetPermission.assets.through
|
||||
PermNodeThrough = AssetPermission.nodes.through
|
||||
|
||||
def __init__(self, user, org=None):
|
||||
self._user: User = user
|
||||
self._org = org or current_org
|
||||
self._user_permission_ids = set()
|
||||
self._user_group_ids = set()
|
||||
self._user_group_permission_ids = set()
|
||||
self._user_all_permission_ids = set()
|
||||
self._user_direct_asset_ids = set()
|
||||
self._user_direct_node_ids = set()
|
||||
self._user_direct_node_all_children_ids = set()
|
||||
self._init()
|
||||
|
||||
def _init(self):
|
||||
self._load_user_permission_ids()
|
||||
self._load_user_group_ids()
|
||||
self._load_user_group_permission_ids()
|
||||
self._load_user_direct_asset_ids()
|
||||
self._load_user_direct_node_ids()
|
||||
self._load_user_direct_node_all_children_ids()
|
||||
|
||||
@timeit
|
||||
def _load_user_permission_ids(self):
|
||||
perm_ids = self.PermUserThrough.objects.filter(
|
||||
user_id=self._user.id
|
||||
).distinct('assetpermission_id').values_list('assetpermission_id', flat=True)
|
||||
perm_ids = self._uuids_to_string(perm_ids)
|
||||
self._user_permission_ids.update(perm_ids)
|
||||
self._user_all_permission_ids.update(perm_ids)
|
||||
|
||||
@timeit
|
||||
def _load_user_group_ids(self):
|
||||
group_ids = self.UserGroupThrough.objects.filter(
|
||||
user_id=self._user.id
|
||||
).distinct('usergroup_id').values_list('usergroup_id', flat=True)
|
||||
group_ids = self._uuids_to_string(group_ids)
|
||||
self._user_group_ids.update(group_ids)
|
||||
|
||||
@timeit
|
||||
def _load_user_group_permission_ids(self):
|
||||
perm_ids = self.PermUserGroupThrough.objects.filter(
|
||||
usergroup_id__in=self._user_group_ids
|
||||
).distinct('assetpermission_id').values_list('assetpermission_id', flat=True)
|
||||
perm_ids = self._uuids_to_string(perm_ids)
|
||||
self._user_group_permission_ids.update(perm_ids)
|
||||
self._user_all_permission_ids.update(perm_ids)
|
||||
|
||||
@timeit
|
||||
def _load_user_direct_asset_ids(self):
|
||||
asset_ids = self.PermAssetThrough.objects.filter(
|
||||
assetpermission_id__in=self._user_all_permission_ids
|
||||
).distinct('asset_id').values_list('asset_id', flat=True)
|
||||
asset_ids = self._uuids_to_string(asset_ids)
|
||||
self._user_direct_asset_ids.update(asset_ids)
|
||||
|
||||
@timeit
|
||||
def _load_user_direct_node_ids(self):
|
||||
node_ids = self.PermNodeThrough.objects.filter(
|
||||
assetpermission_id__in=self._user_all_permission_ids
|
||||
).distinct('node_id').values_list('node_id', flat=True)
|
||||
node_ids = self._uuids_to_string(node_ids)
|
||||
self._user_direct_node_ids.update(node_ids)
|
||||
|
||||
@timeit
|
||||
def _load_user_direct_node_all_children_ids(self):
|
||||
nid_key_pairs = Node.objects.filter(org_id=self._org.id).values_list('id', 'key')
|
||||
nid_key_mapper = { str(nid): key for nid, key in nid_key_pairs }
|
||||
|
||||
dn_keys = [ nid_key_mapper[nid] for nid in self._user_direct_node_ids ]
|
||||
|
||||
def has_ancestor_in_direct_nodes(key: str) -> bool:
|
||||
ancestor_keys = [ ':'.join(key.split(':')[:i]) for i in range(1, key.count(':') + 1) ]
|
||||
return bool(set(ancestor_keys) & set(dn_keys))
|
||||
|
||||
dn_children_ids = [ nid for nid, key in nid_key_mapper.items() if has_ancestor_in_direct_nodes(key) ]
|
||||
|
||||
self._user_direct_node_all_children_ids.update(self._user_direct_node_ids)
|
||||
self._user_direct_node_all_children_ids.update(dn_children_ids)
|
||||
|
||||
def get_node_assets(self, node: Node):
|
||||
''' 获取节点下授权的直接资产, Luna 页面展开时需要 '''
|
||||
q = Q(node_id=node.id)
|
||||
if str(node.id) not in self._user_direct_node_all_children_ids:
|
||||
q &= Q(id__in=self._user_direct_asset_ids)
|
||||
assets = Asset.objects.filter(q)
|
||||
return assets
|
||||
|
||||
def get_node_all_assets(self, node: Node):
|
||||
''' 获取节点及其子节点下所有授权资产, 测试时需要 '''
|
||||
if str(node.id) in self._user_direct_node_all_children_ids:
|
||||
assets = node.get_all_assets()
|
||||
return assets
|
||||
|
||||
children_ids = node.get_all_children(with_self=True).values_list('id', flat=True)
|
||||
children_ids = self._uuids_to_string(children_ids)
|
||||
dn_all_nodes_ids = set(children_ids) & self._user_direct_node_all_children_ids
|
||||
other_nodes_ids = set(children_ids) - dn_all_nodes_ids
|
||||
|
||||
q = Q(node_id__in=dn_all_nodes_ids)
|
||||
q |= Q(node_id__in=other_nodes_ids) & Q(id__in=self._user_direct_asset_ids)
|
||||
assets = Asset.objects.filter(q)
|
||||
return assets
|
||||
|
||||
def _uuids_to_string(self, uuids):
|
||||
return [ str(u) for u in uuids ]
|
||||
|
||||
def print(self):
|
||||
print('user_perm_tree:', self._user.username)
|
||||
print('user_permission_ids_count:', len(self._user_permission_ids))
|
||||
print('user_group_ids_count:', len(self._user_group_ids))
|
||||
print('user_group_permission_ids_count:', len(self._user_permission_ids) - len(self._user_group_ids))
|
||||
print('user_all_permission_ids_count:', len(self._user_all_permission_ids))
|
||||
print('user_direct_asset_ids_count:', len(self._user_direct_asset_ids))
|
||||
print('user_direct_node_ids_count:', len(self._user_direct_node_ids))
|
||||
print('user_direct_node_all_children_ids_count:', len(self._user_direct_node_all_children_ids))
|
||||
@@ -1,98 +1,10 @@
|
||||
import httpx
|
||||
import openai
|
||||
from django.conf import settings
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from rest_framework import status
|
||||
from rest_framework.generics import GenericAPIView
|
||||
from rest_framework.response import Response
|
||||
|
||||
from common.api import JMSModelViewSet
|
||||
from common.permissions import IsValidUser, OnlySuperUser
|
||||
from .. import serializers
|
||||
from ..const import ChatAITypeChoices
|
||||
from ..models import ChatPrompt
|
||||
from ..prompt import DefaultChatPrompt
|
||||
|
||||
|
||||
class ChatAITestingAPI(GenericAPIView):
|
||||
serializer_class = serializers.ChatAISettingSerializer
|
||||
rbac_perms = {
|
||||
'POST': 'settings.change_chatai'
|
||||
}
|
||||
|
||||
def get_config(self, request):
|
||||
serializer = self.serializer_class(data=request.data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
data = self.serializer_class().data
|
||||
data.update(serializer.validated_data)
|
||||
for k, v in data.items():
|
||||
if v:
|
||||
continue
|
||||
# 页面没有传递值, 从 settings 中获取
|
||||
data[k] = getattr(settings, k, None)
|
||||
return data
|
||||
|
||||
def post(self, request):
|
||||
config = self.get_config(request)
|
||||
chat_ai_enabled = config['CHAT_AI_ENABLED']
|
||||
if not chat_ai_enabled:
|
||||
return Response(
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
data={'msg': _('Chat AI is not enabled')}
|
||||
)
|
||||
|
||||
tp = config['CHAT_AI_TYPE']
|
||||
if tp == ChatAITypeChoices.gpt:
|
||||
url = config['GPT_BASE_URL']
|
||||
api_key = config['GPT_API_KEY']
|
||||
proxy = config['GPT_PROXY']
|
||||
model = config['GPT_MODEL']
|
||||
else:
|
||||
url = config['DEEPSEEK_BASE_URL']
|
||||
api_key = config['DEEPSEEK_API_KEY']
|
||||
proxy = config['DEEPSEEK_PROXY']
|
||||
model = config['DEEPSEEK_MODEL']
|
||||
|
||||
kwargs = {
|
||||
'base_url': url or None,
|
||||
'api_key': api_key,
|
||||
}
|
||||
try:
|
||||
if proxy:
|
||||
kwargs['http_client'] = httpx.Client(
|
||||
proxies=proxy,
|
||||
transport=httpx.HTTPTransport(local_address='0.0.0.0')
|
||||
)
|
||||
client = openai.OpenAI(**kwargs)
|
||||
|
||||
ok = False
|
||||
error = ''
|
||||
|
||||
client.chat.completions.create(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Say this is a test",
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
)
|
||||
ok = True
|
||||
except openai.APIConnectionError as e:
|
||||
error = str(e.__cause__) # an underlying Exception, likely raised within httpx.
|
||||
except openai.APIStatusError as e:
|
||||
error = str(e.message)
|
||||
except Exception as e:
|
||||
ok, error = False, str(e)
|
||||
|
||||
if ok:
|
||||
_status, msg = status.HTTP_200_OK, _('Test success')
|
||||
else:
|
||||
_status, msg = status.HTTP_400_BAD_REQUEST, error
|
||||
|
||||
return Response(status=_status, data={'msg': msg})
|
||||
|
||||
|
||||
class ChatPromptViewSet(JMSModelViewSet):
|
||||
serializer_classes = {
|
||||
'default': serializers.ChatPromptSerializer,
|
||||
|
||||
@@ -154,7 +154,10 @@ class SettingsApi(generics.RetrieveUpdateAPIView):
|
||||
def parse_serializer_data(self, serializer):
|
||||
data = []
|
||||
fields = self.get_fields()
|
||||
encrypted_items = [name for name, field in fields.items() if field.write_only]
|
||||
encrypted_items = [
|
||||
name for name, field in fields.items()
|
||||
if field.write_only or getattr(field, 'encrypted', False)
|
||||
]
|
||||
category = self.request.query_params.get('category', '')
|
||||
for name, value in serializer.validated_data.items():
|
||||
encrypted = name in encrypted_items
|
||||
|
||||
@@ -14,18 +14,5 @@ class ChatAIMethodChoices(TextChoices):
|
||||
|
||||
|
||||
class ChatAITypeChoices(TextChoices):
|
||||
gpt = 'gpt', 'GPT'
|
||||
deep_seek = 'deep-seek', 'DeepSeek'
|
||||
|
||||
|
||||
class GPTModelChoices(TextChoices):
|
||||
gpt_4o_mini = 'gpt-4o-mini', 'gpt-4o-mini'
|
||||
gpt_4o = 'gpt-4o', 'gpt-4o'
|
||||
o3_mini = 'o3-mini', 'o3-mini'
|
||||
o1_mini = 'o1-mini', 'o1-mini'
|
||||
o1 = 'o1', 'o1'
|
||||
|
||||
|
||||
class DeepSeekModelChoices(TextChoices):
|
||||
deepseek_chat = 'deepseek-chat', 'DeepSeek-V3'
|
||||
deepseek_reasoner = 'deepseek-reasoner', 'DeepSeek-R1'
|
||||
openai = 'openai', 'Openai'
|
||||
ollama = 'ollama', 'Ollama'
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.files.base import ContentFile
|
||||
@@ -14,7 +15,6 @@ from rest_framework.utils.encoders import JSONEncoder
|
||||
from common.db.models import JMSBaseModel
|
||||
from common.db.utils import Encryptor
|
||||
from common.utils import get_logger
|
||||
from .const import ChatAITypeChoices
|
||||
from .signals import setting_changed
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -196,20 +196,25 @@ class ChatPrompt(JMSBaseModel):
|
||||
return self.name
|
||||
|
||||
|
||||
def get_chatai_data():
|
||||
data = {
|
||||
'url': settings.GPT_BASE_URL,
|
||||
'api_key': settings.GPT_API_KEY,
|
||||
'proxy': settings.GPT_PROXY,
|
||||
'model': settings.GPT_MODEL,
|
||||
}
|
||||
if settings.CHAT_AI_TYPE != ChatAITypeChoices.gpt:
|
||||
data['url'] = settings.DEEPSEEK_BASE_URL
|
||||
data['api_key'] = settings.DEEPSEEK_API_KEY
|
||||
data['proxy'] = settings.DEEPSEEK_PROXY
|
||||
data['model'] = settings.DEEPSEEK_MODEL
|
||||
def get_chatai_data() -> Dict[str, Any]:
|
||||
raw_providers = settings.CHAT_AI_PROVIDERS
|
||||
providers: List[dict] = [p for p in raw_providers if isinstance(p, dict)]
|
||||
|
||||
return data
|
||||
if not providers:
|
||||
return {}
|
||||
|
||||
selected = next(
|
||||
(p for p in providers if p.get('is_assistant')),
|
||||
providers[0],
|
||||
)
|
||||
|
||||
return {
|
||||
'url': selected.get('base_url'),
|
||||
'api_key': selected.get('api_key'),
|
||||
'proxy': selected.get('proxy'),
|
||||
'model': selected.get('model'),
|
||||
'name': selected.get('name'),
|
||||
}
|
||||
|
||||
|
||||
def init_sqlite_db():
|
||||
|
||||
@@ -10,11 +10,12 @@ from common.utils import date_expired_default
|
||||
__all__ = [
|
||||
'AnnouncementSettingSerializer', 'OpsSettingSerializer', 'VaultSettingSerializer',
|
||||
'HashicorpKVSerializer', 'AzureKVSerializer', 'TicketSettingSerializer',
|
||||
'ChatAISettingSerializer', 'VirtualAppSerializer', 'AmazonSMSerializer',
|
||||
'ChatAIProviderSerializer', 'ChatAISettingSerializer',
|
||||
'VirtualAppSerializer', 'AmazonSMSerializer',
|
||||
]
|
||||
|
||||
from settings.const import (
|
||||
ChatAITypeChoices, GPTModelChoices, DeepSeekModelChoices, ChatAIMethodChoices
|
||||
ChatAITypeChoices, ChatAIMethodChoices
|
||||
)
|
||||
|
||||
|
||||
@@ -120,6 +121,29 @@ class AmazonSMSerializer(serializers.Serializer):
|
||||
)
|
||||
|
||||
|
||||
class ChatAIProviderListSerializer(serializers.ListSerializer):
|
||||
# 标记整个列表需要加密存储,避免明文保存 API Key
|
||||
encrypted = True
|
||||
|
||||
|
||||
class ChatAIProviderSerializer(serializers.Serializer):
|
||||
type = serializers.ChoiceField(
|
||||
default=ChatAITypeChoices.openai, choices=ChatAITypeChoices.choices,
|
||||
label=_("Types"), required=False,
|
||||
)
|
||||
base_url = serializers.CharField(
|
||||
allow_blank=True, required=False, label=_('Base URL'),
|
||||
help_text=_('The base URL of the Chat service.')
|
||||
)
|
||||
api_key = EncryptedField(
|
||||
allow_blank=True, required=False, label=_('API Key'),
|
||||
)
|
||||
proxy = serializers.CharField(
|
||||
allow_blank=True, required=False, label=_('Proxy'),
|
||||
help_text=_('The proxy server address of the GPT service. For example: http://ip:port')
|
||||
)
|
||||
|
||||
|
||||
class ChatAISettingSerializer(serializers.Serializer):
|
||||
PREFIX_TITLE = _('Chat AI')
|
||||
|
||||
@@ -130,44 +154,14 @@ class ChatAISettingSerializer(serializers.Serializer):
|
||||
default=ChatAIMethodChoices.api, choices=ChatAIMethodChoices.choices,
|
||||
label=_("Method"), required=False,
|
||||
)
|
||||
CHAT_AI_PROVIDERS = ChatAIProviderListSerializer(
|
||||
child=ChatAIProviderSerializer(),
|
||||
allow_empty=True, required=False, default=list, label=_('Providers')
|
||||
)
|
||||
CHAT_AI_EMBED_URL = serializers.CharField(
|
||||
allow_blank=True, required=False, label=_('Base URL'),
|
||||
help_text=_('The base URL of the Chat service.')
|
||||
)
|
||||
CHAT_AI_TYPE = serializers.ChoiceField(
|
||||
default=ChatAITypeChoices.gpt, choices=ChatAITypeChoices.choices,
|
||||
label=_("Types"), required=False,
|
||||
)
|
||||
GPT_BASE_URL = serializers.CharField(
|
||||
allow_blank=True, required=False, label=_('Base URL'),
|
||||
help_text=_('The base URL of the Chat service.')
|
||||
)
|
||||
GPT_API_KEY = EncryptedField(
|
||||
allow_blank=True, required=False, label=_('API Key'),
|
||||
)
|
||||
GPT_PROXY = serializers.CharField(
|
||||
allow_blank=True, required=False, label=_('Proxy'),
|
||||
help_text=_('The proxy server address of the GPT service. For example: http://ip:port')
|
||||
)
|
||||
GPT_MODEL = serializers.ChoiceField(
|
||||
default=GPTModelChoices.gpt_4o_mini, choices=GPTModelChoices.choices,
|
||||
label=_("GPT Model"), required=False,
|
||||
)
|
||||
DEEPSEEK_BASE_URL = serializers.CharField(
|
||||
allow_blank=True, required=False, label=_('Base URL'),
|
||||
help_text=_('The base URL of the Chat service.')
|
||||
)
|
||||
DEEPSEEK_API_KEY = EncryptedField(
|
||||
allow_blank=True, required=False, label=_('API Key'),
|
||||
)
|
||||
DEEPSEEK_PROXY = serializers.CharField(
|
||||
allow_blank=True, required=False, label=_('Proxy'),
|
||||
help_text=_('The proxy server address of the GPT service. For example: http://ip:port')
|
||||
)
|
||||
DEEPSEEK_MODEL = serializers.ChoiceField(
|
||||
default=DeepSeekModelChoices.deepseek_chat, choices=DeepSeekModelChoices.choices,
|
||||
label=_("DeepSeek Model"), required=False,
|
||||
)
|
||||
|
||||
|
||||
class TicketSettingSerializer(serializers.Serializer):
|
||||
|
||||
@@ -73,8 +73,6 @@ class PrivateSettingSerializer(PublicSettingSerializer):
|
||||
CHAT_AI_ENABLED = serializers.BooleanField()
|
||||
CHAT_AI_METHOD = serializers.CharField()
|
||||
CHAT_AI_EMBED_URL = serializers.CharField()
|
||||
CHAT_AI_TYPE = serializers.CharField()
|
||||
GPT_MODEL = serializers.CharField()
|
||||
FILE_UPLOAD_SIZE_LIMIT_MB = serializers.IntegerField()
|
||||
FTP_FILE_MAX_STORE = serializers.IntegerField()
|
||||
LOKI_LOG_ENABLED = serializers.BooleanField()
|
||||
|
||||
@@ -21,7 +21,6 @@ urlpatterns = [
|
||||
path('sms/<str:backend>/testing/', api.SMSTestingAPI.as_view(), name='sms-testing'),
|
||||
path('sms/backend/', api.SMSBackendAPI.as_view(), name='sms-backend'),
|
||||
path('vault/<str:backend>/testing/', api.VaultTestingAPI.as_view(), name='vault-testing'),
|
||||
path('chatai/testing/', api.ChatAITestingAPI.as_view(), name='chatai-testing'),
|
||||
path('vault/sync/', api.VaultSyncDataAPI.as_view(), name='vault-sync'),
|
||||
path('security/block-ip/', api.BlockIPSecurityAPI.as_view(), name='block-ip'),
|
||||
path('security/unlock-ip/', api.UnlockIPSecurityAPI.as_view(), name='unlock-ip'),
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
from .applet import *
|
||||
from .chat import *
|
||||
from .component import *
|
||||
from .session import *
|
||||
from .virtualapp import *
|
||||
|
||||
1
apps/terminal/api/chat/__init__.py
Normal file
1
apps/terminal/api/chat/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .chat import *
|
||||
15
apps/terminal/api/chat/chat.py
Normal file
15
apps/terminal/api/chat/chat.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from common.api import JMSBulkModelViewSet
|
||||
from terminal import serializers
|
||||
from terminal.filters import ChatFilter
|
||||
from terminal.models import Chat
|
||||
|
||||
__all__ = ['ChatViewSet']
|
||||
|
||||
|
||||
class ChatViewSet(JMSBulkModelViewSet):
|
||||
queryset = Chat.objects.all()
|
||||
serializer_class = serializers.ChatSerializer
|
||||
filterset_class = ChatFilter
|
||||
search_fields = ['title']
|
||||
ordering_fields = ['date_updated']
|
||||
ordering = ['-date_updated']
|
||||
@@ -2,7 +2,7 @@ from django.db.models import QuerySet
|
||||
from django_filters import rest_framework as filters
|
||||
|
||||
from orgs.utils import filter_org_queryset
|
||||
from terminal.models import Command, CommandStorage, Session
|
||||
from terminal.models import Command, CommandStorage, Session, Chat
|
||||
|
||||
|
||||
class CommandFilter(filters.FilterSet):
|
||||
@@ -79,7 +79,34 @@ class CommandStorageFilter(filters.FilterSet):
|
||||
model = CommandStorage
|
||||
fields = ['real', 'name', 'type', 'is_default']
|
||||
|
||||
def filter_real(self, queryset, name, value):
|
||||
@staticmethod
|
||||
def filter_real(queryset, name, value):
|
||||
if value:
|
||||
queryset = queryset.exclude(name='null')
|
||||
return queryset
|
||||
|
||||
|
||||
class ChatFilter(filters.FilterSet):
|
||||
ids = filters.BooleanFilter(method='filter_ids')
|
||||
folder_ids = filters.BooleanFilter(method='filter_folder_ids')
|
||||
|
||||
|
||||
class Meta:
|
||||
model = Chat
|
||||
fields = [
|
||||
'title', 'user_id', 'pinned', 'folder_id',
|
||||
'archived', 'socket_id', 'share_id'
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def filter_ids(queryset, name, value):
|
||||
ids = value.split(',')
|
||||
queryset = queryset.filter(id__in=ids)
|
||||
return queryset
|
||||
|
||||
|
||||
@staticmethod
|
||||
def filter_folder_ids(queryset, name, value):
|
||||
ids = value.split(',')
|
||||
queryset = queryset.filter(folder_id__in=ids)
|
||||
return queryset
|
||||
|
||||
38
apps/terminal/migrations/0011_chat.py
Normal file
38
apps/terminal/migrations/0011_chat.py
Normal file
@@ -0,0 +1,38 @@
|
||||
# Generated by Django 4.1.13 on 2025-09-30 06:57
|
||||
|
||||
from django.db import migrations, models
|
||||
import uuid
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('terminal', '0010_alter_command_risk_level_alter_session_login_from_and_more'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name='Chat',
|
||||
fields=[
|
||||
('created_by', models.CharField(blank=True, max_length=128, null=True, verbose_name='Created by')),
|
||||
('updated_by', models.CharField(blank=True, max_length=128, null=True, verbose_name='Updated by')),
|
||||
('date_created', models.DateTimeField(auto_now_add=True, null=True, verbose_name='Date created')),
|
||||
('date_updated', models.DateTimeField(auto_now=True, verbose_name='Date updated')),
|
||||
('comment', models.TextField(blank=True, default='', verbose_name='Comment')),
|
||||
('id', models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False)),
|
||||
('title', models.CharField(max_length=256, verbose_name='Title')),
|
||||
('chat', models.JSONField(default=dict, verbose_name='Chat')),
|
||||
('meta', models.JSONField(default=dict, verbose_name='Meta')),
|
||||
('pinned', models.BooleanField(default=False, verbose_name='Pinned')),
|
||||
('archived', models.BooleanField(default=False, verbose_name='Archived')),
|
||||
('share_id', models.CharField(blank=True, default='', max_length=36)),
|
||||
('folder_id', models.CharField(blank=True, default='', max_length=36)),
|
||||
('socket_id', models.CharField(blank=True, default='', max_length=36)),
|
||||
('user_id', models.CharField(blank=True, db_index=True, default='', max_length=36)),
|
||||
('session_info', models.JSONField(default=dict, verbose_name='Session Info')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': 'Chat',
|
||||
},
|
||||
),
|
||||
]
|
||||
@@ -1,4 +1,5 @@
|
||||
from .session import *
|
||||
from .component import *
|
||||
from .applet import *
|
||||
from .chat import *
|
||||
from .component import *
|
||||
from .session import *
|
||||
from .virtualapp import *
|
||||
|
||||
1
apps/terminal/models/chat/__init__.py
Normal file
1
apps/terminal/models/chat/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .chat import *
|
||||
30
apps/terminal/models/chat/chat.py
Normal file
30
apps/terminal/models/chat/chat.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from django.db import models
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from common.db.models import JMSBaseModel
|
||||
from common.utils import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
__all__ = ['Chat']
|
||||
|
||||
|
||||
class Chat(JMSBaseModel):
|
||||
# id == session_id # 36 chars
|
||||
title = models.CharField(max_length=256, verbose_name=_('Title'))
|
||||
chat = models.JSONField(default=dict, verbose_name=_('Chat'))
|
||||
meta = models.JSONField(default=dict, verbose_name=_('Meta'))
|
||||
pinned = models.BooleanField(default=False, verbose_name=_('Pinned'))
|
||||
archived = models.BooleanField(default=False, verbose_name=_('Archived'))
|
||||
share_id = models.CharField(blank=True, default='', max_length=36)
|
||||
folder_id = models.CharField(blank=True, default='', max_length=36)
|
||||
socket_id = models.CharField(blank=True, default='', max_length=36)
|
||||
user_id = models.CharField(blank=True, default='', max_length=36, db_index=True)
|
||||
|
||||
session_info = models.JSONField(default=dict, verbose_name=_('Session Info'))
|
||||
|
||||
class Meta:
|
||||
verbose_name = _('Chat')
|
||||
|
||||
def __str__(self):
|
||||
return self.title
|
||||
@@ -123,11 +123,10 @@ class Terminal(StorageMixin, TerminalStatusMixin, JMSBaseModel):
|
||||
def get_chat_ai_setting():
|
||||
data = get_chatai_data()
|
||||
return {
|
||||
'GPT_BASE_URL': data['url'],
|
||||
'GPT_API_KEY': data['api_key'],
|
||||
'GPT_PROXY': data['proxy'],
|
||||
'GPT_MODEL': data['model'],
|
||||
'CHAT_AI_TYPE': settings.CHAT_AI_TYPE,
|
||||
'GPT_BASE_URL': data.get('url'),
|
||||
'GPT_API_KEY': data.get('api_key'),
|
||||
'GPT_PROXY': data.get('proxy'),
|
||||
'CHAT_AI_PROVIDERS': settings.CHAT_AI_PROVIDERS,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -2,8 +2,10 @@
|
||||
#
|
||||
from .applet import *
|
||||
from .applet_host import *
|
||||
from .chat import *
|
||||
from .command import *
|
||||
from .endpoint import *
|
||||
from .loki import *
|
||||
from .session import *
|
||||
from .sharing import *
|
||||
from .storage import *
|
||||
@@ -11,4 +13,3 @@ from .task import *
|
||||
from .terminal import *
|
||||
from .virtualapp import *
|
||||
from .virtualapp_provider import *
|
||||
from .loki import *
|
||||
|
||||
28
apps/terminal/serializers/chat.py
Normal file
28
apps/terminal/serializers/chat.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from rest_framework import serializers
|
||||
|
||||
from common.serializers import CommonBulkModelSerializer
|
||||
from terminal.models import Chat
|
||||
|
||||
__all__ = ['ChatSerializer']
|
||||
|
||||
|
||||
class ChatSerializer(CommonBulkModelSerializer):
|
||||
created_at = serializers.SerializerMethodField()
|
||||
updated_at = serializers.SerializerMethodField()
|
||||
|
||||
class Meta:
|
||||
model = Chat
|
||||
fields_mini = ['id', 'title', 'created_at', 'updated_at']
|
||||
fields = fields_mini + [
|
||||
'chat', 'meta', 'pinned', 'archived',
|
||||
'share_id', 'folder_id',
|
||||
'user_id', 'session_info'
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_created_at(obj):
|
||||
return int(obj.date_created.timestamp())
|
||||
|
||||
@staticmethod
|
||||
def get_updated_at(obj):
|
||||
return int(obj.date_updated.timestamp())
|
||||
@@ -32,6 +32,7 @@ router.register(r'virtual-apps', api.VirtualAppViewSet, 'virtual-app')
|
||||
router.register(r'app-providers', api.AppProviderViewSet, 'app-provider')
|
||||
router.register(r'app-providers/((?P<provider>[^/.]+)/)?apps', api.AppProviderAppViewSet, 'app-provider-app')
|
||||
router.register(r'virtual-app-publications', api.VirtualAppPublicationViewSet, 'virtual-app-publication')
|
||||
router.register(r'chats', api.ChatViewSet, 'chat')
|
||||
|
||||
urlpatterns = [
|
||||
path('my-sessions/', api.MySessionAPIView.as_view(), name='my-session'),
|
||||
|
||||
@@ -199,11 +199,19 @@ class UserChangePasswordApi(UserQuerysetMixin, generics.UpdateAPIView):
|
||||
class UserUnblockPKApi(UserQuerysetMixin, generics.UpdateAPIView):
|
||||
serializer_class = serializers.UserSerializer
|
||||
|
||||
def get_object(self):
|
||||
pk = self.kwargs.get('pk')
|
||||
if is_uuid(pk):
|
||||
return super().get_object()
|
||||
else:
|
||||
return self.get_queryset().filter(username=pk).first()
|
||||
|
||||
def perform_update(self, serializer):
|
||||
user = self.get_object()
|
||||
username = user.username if user else ''
|
||||
LoginBlockUtil.unblock_user(username)
|
||||
MFABlockUtils.unblock_user(username)
|
||||
if not user:
|
||||
return Response({"error": _("User not found")}, status=404)
|
||||
|
||||
user.unblock_login()
|
||||
|
||||
|
||||
class UserResetMFAApi(UserQuerysetMixin, generics.RetrieveAPIView):
|
||||
|
||||
@@ -274,8 +274,8 @@ class User(
|
||||
LoginBlockUtil.unblock_user(self.username)
|
||||
MFABlockUtils.unblock_user(self.username)
|
||||
|
||||
@lazyproperty
|
||||
def login_blocked(self):
|
||||
@property
|
||||
def is_login_blocked(self):
|
||||
from users.utils import LoginBlockUtil, MFABlockUtils
|
||||
|
||||
if LoginBlockUtil.is_user_block(self.username):
|
||||
@@ -284,6 +284,13 @@ class User(
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def block_login(cls, username):
|
||||
from users.utils import LoginBlockUtil, MFABlockUtils
|
||||
|
||||
LoginBlockUtil.block_user(username)
|
||||
MFABlockUtils.block_user(username)
|
||||
|
||||
def delete(self, using=None, keep_parents=False):
|
||||
if self.pk == 1 or self.username == "admin":
|
||||
raise PermissionDenied(_("Can not delete admin user"))
|
||||
|
||||
@@ -123,7 +123,7 @@ class UserSerializer(
|
||||
mfa_force_enabled = serializers.BooleanField(
|
||||
read_only=True, label=_("MFA force enabled")
|
||||
)
|
||||
login_blocked = serializers.BooleanField(read_only=True, label=_("Login blocked"))
|
||||
is_login_blocked = serializers.BooleanField(read_only=True, label=_("Login blocked"))
|
||||
is_expired = serializers.BooleanField(read_only=True, label=_("Is expired"))
|
||||
is_valid = serializers.BooleanField(read_only=True, label=_("Is valid"))
|
||||
is_otp_secret_key_bound = serializers.BooleanField(
|
||||
@@ -193,6 +193,7 @@ class UserSerializer(
|
||||
"is_valid", "is_expired", "is_active", # 布尔字段
|
||||
"is_otp_secret_key_bound", "can_public_key_auth",
|
||||
"mfa_enabled", "need_update_password", "is_face_code_set",
|
||||
"is_login_blocked",
|
||||
]
|
||||
# 包含不太常用的字段,可以没有
|
||||
fields_verbose = (
|
||||
@@ -211,7 +212,7 @@ class UserSerializer(
|
||||
# 多对多字段
|
||||
fields_m2m = ["groups", "system_roles", "org_roles", "orgs_roles", "labels"]
|
||||
# 在serializer 上定义的字段
|
||||
fields_custom = ["login_blocked", "password_strategy"]
|
||||
fields_custom = ["is_login_blocked", "password_strategy"]
|
||||
fields = fields_verbose + fields_fk + fields_m2m + fields_custom
|
||||
fields_unexport = ["avatar_url", "is_service_account"]
|
||||
|
||||
|
||||
@@ -28,6 +28,6 @@ urlpatterns = [
|
||||
path('users/<uuid:pk>/password/', api.UserChangePasswordApi.as_view(), name='change-user-password'),
|
||||
path('users/<uuid:pk>/password/reset/', api.UserResetPasswordApi.as_view(), name='user-reset-password'),
|
||||
path('users/<uuid:pk>/pubkey/reset/', api.UserResetPKApi.as_view(), name='user-public-key-reset'),
|
||||
path('users/<uuid:pk>/unblock/', api.UserUnblockPKApi.as_view(), name='user-unblock'),
|
||||
path('users/<str:pk>/unblock/', api.UserUnblockPKApi.as_view(), name='user-unblock'),
|
||||
]
|
||||
urlpatterns += router.urls
|
||||
|
||||
@@ -186,6 +186,13 @@ class BlockUtilBase:
|
||||
def is_block(self):
|
||||
return bool(cache.get(self.block_key))
|
||||
|
||||
@classmethod
|
||||
def block_user(cls, username):
|
||||
username = username.lower()
|
||||
block_key = cls.BLOCK_KEY_TMPL.format(username)
|
||||
key_ttl = int(settings.SECURITY_LOGIN_LIMIT_TIME) * 60
|
||||
cache.set(block_key, True, key_ttl)
|
||||
|
||||
@classmethod
|
||||
def get_blocked_usernames(cls):
|
||||
key = cls.BLOCK_KEY_TMPL.format('*')
|
||||
|
||||
@@ -0,0 +1,358 @@
|
||||
import os
|
||||
import sys
|
||||
import django
|
||||
import random
|
||||
from datetime import datetime
|
||||
|
||||
if os.path.exists('../../apps'):
|
||||
sys.path.insert(0, '../../apps')
|
||||
if os.path.exists('../apps'):
|
||||
sys.path.insert(0, '../apps')
|
||||
elif os.path.exists('./apps'):
|
||||
sys.path.insert(0, './apps')
|
||||
|
||||
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "jumpserver.settings")
|
||||
django.setup()
|
||||
|
||||
|
||||
from assets.models import Asset, Node
|
||||
from orgs.models import Organization
|
||||
from django.db.models import Count
|
||||
|
||||
OUTPUT_FILE = 'report_cleanup_and_keep_one_node_for_multi_parent_nodes_assets.txt'
|
||||
|
||||
# Special organization IDs and names
|
||||
SPECIAL_ORGS = {
|
||||
'00000000-0000-0000-0000-000000000000': 'GLOBAL',
|
||||
'00000000-0000-0000-0000-000000000002': 'DEFAULT',
|
||||
'00000000-0000-0000-0000-000000000004': 'SYSTEM',
|
||||
}
|
||||
|
||||
try:
|
||||
AssetNodeThrough = Asset.nodes.through
|
||||
except Exception as e:
|
||||
print("Failed to get AssetNodeThrough model. Check Asset.nodes field definition.")
|
||||
raise e
|
||||
|
||||
|
||||
def log(msg=''):
|
||||
"""Print log with timestamp to console"""
|
||||
print(f"[{datetime.now().strftime('%H:%M:%S')}] {msg}")
|
||||
|
||||
|
||||
def write_report(content):
|
||||
"""Write content to report file"""
|
||||
with open(OUTPUT_FILE, 'a', encoding='utf-8') as f:
|
||||
f.write(content)
|
||||
|
||||
|
||||
def get_org_name(org_id, orgs_map):
|
||||
"""Get organization name, check special orgs first, then orgs_map"""
|
||||
# Check if it's a special organization
|
||||
org_id_str = str(org_id)
|
||||
if org_id_str in SPECIAL_ORGS:
|
||||
return SPECIAL_ORGS[org_id_str]
|
||||
|
||||
# Try to get from orgs_map
|
||||
org = orgs_map.get(org_id)
|
||||
if org:
|
||||
return org.name
|
||||
|
||||
return 'Unknown'
|
||||
|
||||
|
||||
def find_and_cleanup_multi_parent_assets():
|
||||
"""Find and cleanup assets with multiple parent nodes"""
|
||||
|
||||
log("Searching for assets with multiple parent nodes...")
|
||||
|
||||
# Find all asset_ids that belong to multiple node_ids
|
||||
multi_parent_assets = AssetNodeThrough.objects.values('asset_id').annotate(
|
||||
node_count=Count('node_id', distinct=True)
|
||||
).filter(node_count__gt=1).order_by('-node_count')
|
||||
|
||||
total_count = multi_parent_assets.count()
|
||||
log(f"Found {total_count:,} assets with multiple parent nodes\n")
|
||||
|
||||
if total_count == 0:
|
||||
log("✓ All assets already have single parent node")
|
||||
return {}
|
||||
|
||||
# Collect all asset_ids and node_ids
|
||||
asset_ids = [item['asset_id'] for item in multi_parent_assets]
|
||||
|
||||
# Get all through records
|
||||
all_through_records = AssetNodeThrough.objects.filter(asset_id__in=asset_ids)
|
||||
node_ids = list(set(through.node_id for through in all_through_records))
|
||||
|
||||
# Batch fetch all objects
|
||||
log("Batch loading Asset objects...")
|
||||
assets_map = {asset.id: asset for asset in Asset.objects.filter(id__in=asset_ids)}
|
||||
|
||||
log("Batch loading Node objects...")
|
||||
nodes_map = {node.id: node for node in Node.objects.filter(id__in=node_ids)}
|
||||
|
||||
# Batch fetch all Organization objects
|
||||
org_ids = list(set(asset.org_id for asset in assets_map.values())) + \
|
||||
list(set(node.org_id for node in nodes_map.values()))
|
||||
org_ids = list(set(org_ids))
|
||||
|
||||
log("Batch loading Organization objects...")
|
||||
orgs_map = {org.id: org for org in Organization.objects.filter(id__in=org_ids)}
|
||||
|
||||
# Build mapping of asset_id -> list of through_records
|
||||
asset_nodes_map = {}
|
||||
for through in all_through_records:
|
||||
if through.asset_id not in asset_nodes_map:
|
||||
asset_nodes_map[through.asset_id] = []
|
||||
asset_nodes_map[through.asset_id].append(through)
|
||||
|
||||
# Organize by organization
|
||||
org_cleanup_data = {} # org_id -> { asset_id -> { keep_node_id, remove_node_ids } }
|
||||
|
||||
for item in multi_parent_assets:
|
||||
asset_id = item['asset_id']
|
||||
|
||||
# Get Asset object
|
||||
asset = assets_map.get(asset_id)
|
||||
if not asset:
|
||||
log(f"⚠ Asset {asset_id} not found in map, skipping")
|
||||
continue
|
||||
|
||||
org_id = asset.org_id
|
||||
|
||||
# Initialize org data if not exists
|
||||
if org_id not in org_cleanup_data:
|
||||
org_cleanup_data[org_id] = {}
|
||||
|
||||
# Get all nodes for this asset
|
||||
through_records = asset_nodes_map.get(asset_id, [])
|
||||
|
||||
if len(through_records) < 2:
|
||||
continue
|
||||
|
||||
# Randomly select one node to keep
|
||||
keep_through = random.choice(through_records)
|
||||
remove_throughs = [t for t in through_records if t.id != keep_through.id]
|
||||
|
||||
org_cleanup_data[org_id][asset_id] = {
|
||||
'asset_name': asset.name,
|
||||
'keep_node_id': keep_through.node_id,
|
||||
'keep_node': nodes_map.get(keep_through.node_id),
|
||||
'remove_records': remove_throughs,
|
||||
'remove_nodes': [nodes_map.get(t.node_id) for t in remove_throughs]
|
||||
}
|
||||
|
||||
return org_cleanup_data
|
||||
|
||||
|
||||
def perform_cleanup(org_cleanup_data, dry_run=False):
|
||||
"""Perform the actual cleanup - delete extra node relationships"""
|
||||
|
||||
if dry_run:
|
||||
log("DRY RUN: Simulating cleanup process (no data will be deleted)...")
|
||||
else:
|
||||
log("\nStarting cleanup process...")
|
||||
|
||||
total_deleted = 0
|
||||
|
||||
for org_id in org_cleanup_data.keys():
|
||||
for asset_id, cleanup_info in org_cleanup_data[org_id].items():
|
||||
# Delete the extra relationships
|
||||
for through_record in cleanup_info['remove_records']:
|
||||
if not dry_run:
|
||||
through_record.delete()
|
||||
total_deleted += 1
|
||||
|
||||
return total_deleted
|
||||
|
||||
|
||||
def verify_cleanup():
|
||||
"""Verify that there are no more assets with multiple parent nodes"""
|
||||
log("\n" + "="*80)
|
||||
log("VERIFICATION: Checking for remaining assets with multiple parent nodes...")
|
||||
log("="*80)
|
||||
|
||||
# Find all asset_ids that belong to multiple node_ids
|
||||
multi_parent_assets = AssetNodeThrough.objects.values('asset_id').annotate(
|
||||
node_count=Count('node_id', distinct=True)
|
||||
).filter(node_count__gt=1).order_by('-node_count')
|
||||
|
||||
remaining_count = multi_parent_assets.count()
|
||||
|
||||
if remaining_count == 0:
|
||||
log(f"✓ Verification successful: No assets with multiple parent nodes remaining\n")
|
||||
return True
|
||||
else:
|
||||
log(f"✗ Verification failed: Found {remaining_count:,} assets still with multiple parent nodes\n")
|
||||
# Show some details
|
||||
for item in multi_parent_assets[:10]:
|
||||
asset_id = item['asset_id']
|
||||
node_count = item['node_count']
|
||||
try:
|
||||
asset = Asset.objects.get(id=asset_id)
|
||||
log(f" - Asset: {asset.name} ({asset_id}) has {node_count} parent nodes")
|
||||
except:
|
||||
log(f" - Asset ID: {asset_id} has {node_count} parent nodes")
|
||||
|
||||
if remaining_count > 10:
|
||||
log(f" ... and {remaining_count - 10} more")
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def generate_report(org_cleanup_data, total_deleted):
|
||||
"""Generate and write report to file"""
|
||||
# Clear previous report
|
||||
if os.path.exists(OUTPUT_FILE):
|
||||
os.remove(OUTPUT_FILE)
|
||||
|
||||
# Write header
|
||||
write_report(f"Multi-Parent Assets Cleanup Report\n")
|
||||
write_report(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
|
||||
write_report(f"{'='*80}\n\n")
|
||||
|
||||
# Get all organizations
|
||||
all_org_ids = list(set(org_id for org_id in org_cleanup_data.keys()))
|
||||
all_orgs = {org.id: org for org in Organization.objects.filter(id__in=all_org_ids)}
|
||||
|
||||
# Calculate statistics
|
||||
total_orgs = Organization.objects.count()
|
||||
orgs_processed = len(org_cleanup_data)
|
||||
orgs_no_issues = total_orgs - orgs_processed
|
||||
total_assets_cleaned = sum(len(assets) for assets in org_cleanup_data.values())
|
||||
|
||||
# Overview
|
||||
write_report("OVERVIEW\n")
|
||||
write_report(f"{'-'*80}\n")
|
||||
write_report(f"Total organizations: {total_orgs:,}\n")
|
||||
write_report(f"Organizations processed: {orgs_processed:,}\n")
|
||||
write_report(f"Organizations without issues: {orgs_no_issues:,}\n")
|
||||
write_report(f"Total assets cleaned: {total_assets_cleaned:,}\n")
|
||||
total_relationships = AssetNodeThrough.objects.count()
|
||||
write_report(f"Total relationships (through records): {total_relationships:,}\n")
|
||||
write_report(f"Total relationships deleted: {total_deleted:,}\n\n")
|
||||
|
||||
# Summary by organization
|
||||
write_report("Summary by Organization:\n")
|
||||
for org_id in sorted(org_cleanup_data.keys()):
|
||||
org_name = get_org_name(org_id, all_orgs)
|
||||
asset_count = len(org_cleanup_data[org_id])
|
||||
write_report(f" - {org_name} ({org_id}): {asset_count:,} assets cleaned\n")
|
||||
|
||||
write_report(f"\n{'='*80}\n\n")
|
||||
|
||||
# Detailed cleanup information grouped by organization
|
||||
for org_id in sorted(org_cleanup_data.keys()):
|
||||
org_name = get_org_name(org_id, all_orgs)
|
||||
asset_count = len(org_cleanup_data[org_id])
|
||||
|
||||
write_report(f"ORGANIZATION: {org_name} ({org_id})\n")
|
||||
write_report(f"Total assets cleaned: {asset_count:,}\n")
|
||||
write_report(f"{'-'*80}\n\n")
|
||||
|
||||
for asset_id, cleanup_info in org_cleanup_data[org_id].items():
|
||||
write_report(f"Asset: {cleanup_info['asset_name']} ({asset_id})\n")
|
||||
|
||||
# Kept node
|
||||
keep_node = cleanup_info['keep_node']
|
||||
if keep_node:
|
||||
write_report(f" ✓ Kept: {keep_node.name} (key: {keep_node.key}) (id: {keep_node.id})\n")
|
||||
else:
|
||||
write_report(f" ✓ Kept: Unknown (id: {cleanup_info['keep_node_id']})\n")
|
||||
|
||||
# Removed nodes
|
||||
write_report(f" ✗ Removed: {len(cleanup_info['remove_nodes'])} node(s)\n")
|
||||
for node in cleanup_info['remove_nodes']:
|
||||
if node:
|
||||
write_report(f" - {node.name} (key: {node.key}) (id: {node.id})\n")
|
||||
else:
|
||||
write_report(f" - Unknown\n")
|
||||
|
||||
write_report(f"\n")
|
||||
|
||||
write_report(f"{'='*80}\n\n")
|
||||
|
||||
log(f"✓ Report written to {OUTPUT_FILE}")
|
||||
|
||||
|
||||
def main():
|
||||
try:
|
||||
# Display warning banner
|
||||
warning_message = """
|
||||
╔══════════════════════════════════════════════════════════════════════════════╗
|
||||
║ ⚠️ WARNING ⚠️ ║
|
||||
║ ║
|
||||
║ This script is designed for TEST/FAKE DATA ONLY! ║
|
||||
║ DO NOT run this script in PRODUCTION environment! ║
|
||||
║ ║
|
||||
║ This script will DELETE asset-node relationships from the database. ║
|
||||
║ Use only for data cleanup in development/testing environments. ║
|
||||
║ ║
|
||||
╚══════════════════════════════════════════════════════════════════════════════╝
|
||||
"""
|
||||
print(warning_message)
|
||||
|
||||
# Ask user to confirm before proceeding
|
||||
confirm = input("Do you understand the warning and want to continue? (yes/no): ").strip().lower()
|
||||
if confirm not in ['yes', 'y']:
|
||||
log("✗ Operation cancelled by user")
|
||||
sys.exit(0)
|
||||
|
||||
log("✓ Proceeding with operation\n")
|
||||
|
||||
org_cleanup_data = find_and_cleanup_multi_parent_assets()
|
||||
|
||||
if not org_cleanup_data:
|
||||
log("✓ Cleanup complete, no assets to process")
|
||||
sys.exit(0)
|
||||
|
||||
total_assets = sum(len(assets) for assets in org_cleanup_data.values())
|
||||
log(f"\nProcessing {total_assets:,} assets across {len(org_cleanup_data):,} organizations...")
|
||||
|
||||
# First, do a dry-run to show what will be deleted
|
||||
log("\n" + "="*80)
|
||||
log("PREVIEW: Simulating cleanup process...")
|
||||
log("="*80)
|
||||
total_deleted_preview = perform_cleanup(org_cleanup_data, dry_run=True)
|
||||
log(f"✓ Dry-run complete: {total_deleted_preview:,} relationships would be deleted\n")
|
||||
|
||||
# Generate preview report
|
||||
generate_report(org_cleanup_data, total_deleted_preview)
|
||||
log(f"✓ Preview report written to {OUTPUT_FILE}\n")
|
||||
|
||||
# Ask for confirmation 3 times before actual deletion
|
||||
log("="*80)
|
||||
log("FINAL CONFIRMATION: Do you want to proceed with actual cleanup?")
|
||||
log("="*80)
|
||||
confirmation_count = 3
|
||||
for attempt in range(1, confirmation_count + 1):
|
||||
response = input(f"Confirm cleanup (attempt {attempt}/{confirmation_count})? (yes/no): ").strip().lower()
|
||||
if response not in ['yes', 'y']:
|
||||
log(f"✗ Cleanup cancelled by user at attempt {attempt}")
|
||||
sys.exit(1)
|
||||
|
||||
log("✓ All confirmations received, proceeding with actual cleanup")
|
||||
|
||||
# Perform cleanup
|
||||
total_deleted = perform_cleanup(org_cleanup_data)
|
||||
log(f"✓ Deleted {total_deleted:,} relationships")
|
||||
|
||||
# Generate final report
|
||||
generate_report(org_cleanup_data, total_deleted)
|
||||
|
||||
# Verify cleanup by checking for remaining multi-parent assets
|
||||
verify_cleanup()
|
||||
|
||||
log(f"✓ Cleanup complete: processed {total_assets:,} assets")
|
||||
sys.exit(0)
|
||||
|
||||
except Exception as e:
|
||||
log(f"✗ Error occurred: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,270 @@
|
||||
import os
|
||||
import sys
|
||||
import django
|
||||
from datetime import datetime
|
||||
|
||||
if os.path.exists('../../apps'):
|
||||
sys.path.insert(0, '../../apps')
|
||||
if os.path.exists('../apps'):
|
||||
sys.path.insert(0, '../apps')
|
||||
elif os.path.exists('./apps'):
|
||||
sys.path.insert(0, './apps')
|
||||
|
||||
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "jumpserver.settings")
|
||||
django.setup()
|
||||
|
||||
|
||||
from assets.models import Asset, Node
|
||||
from orgs.models import Organization
|
||||
from django.db.models import Count
|
||||
|
||||
OUTPUT_FILE = 'report_find_multi_parent_nodes_assets.txt'
|
||||
|
||||
# Special organization IDs and names
|
||||
SPECIAL_ORGS = {
|
||||
'00000000-0000-0000-0000-000000000000': 'GLOBAL',
|
||||
'00000000-0000-0000-0000-000000000002': 'DEFAULT',
|
||||
'00000000-0000-0000-0000-000000000004': 'SYSTEM',
|
||||
}
|
||||
|
||||
try:
|
||||
AssetNodeThrough = Asset.nodes.through
|
||||
except Exception as e:
|
||||
print("Failed to get AssetNodeThrough model. Check Asset.nodes field definition.")
|
||||
raise e
|
||||
|
||||
|
||||
def log(msg=''):
|
||||
"""Print log with timestamp"""
|
||||
print(f"[{datetime.now().strftime('%H:%M:%S')}] {msg}")
|
||||
|
||||
|
||||
def get_org_name(org_id, orgs_map):
|
||||
"""Get organization name, check special orgs first, then orgs_map"""
|
||||
# Check if it's a special organization
|
||||
org_id_str = str(org_id)
|
||||
if org_id_str in SPECIAL_ORGS:
|
||||
return SPECIAL_ORGS[org_id_str]
|
||||
|
||||
# Try to get from orgs_map
|
||||
org = orgs_map.get(org_id)
|
||||
if org:
|
||||
return org.name
|
||||
|
||||
return 'Unknown'
|
||||
|
||||
|
||||
def write_report(content):
|
||||
"""Write content to report file"""
|
||||
with open(OUTPUT_FILE, 'a', encoding='utf-8') as f:
|
||||
f.write(content)
|
||||
|
||||
|
||||
def find_assets_multiple_parents():
|
||||
"""Find assets belonging to multiple node_ids organized by organization"""
|
||||
|
||||
log("Searching for assets with multiple parent nodes...")
|
||||
|
||||
# Find all asset_ids that belong to multiple node_ids
|
||||
multi_parent_assets = AssetNodeThrough.objects.values('asset_id').annotate(
|
||||
node_count=Count('node_id', distinct=True)
|
||||
).filter(node_count__gt=1).order_by('-node_count')
|
||||
|
||||
total_count = multi_parent_assets.count()
|
||||
log(f"Found {total_count:,} assets with multiple parent nodes\n")
|
||||
|
||||
if total_count == 0:
|
||||
log("✓ All assets belong to only one node")
|
||||
return {}
|
||||
|
||||
# Collect all asset_ids and node_ids that need to be fetched
|
||||
asset_ids = [item['asset_id'] for item in multi_parent_assets]
|
||||
|
||||
# Get all through records for these assets
|
||||
all_through_records = AssetNodeThrough.objects.filter(asset_id__in=asset_ids)
|
||||
node_ids = list(set(through.node_id for through in all_through_records))
|
||||
|
||||
# Batch fetch all Asset and Node objects
|
||||
log("Batch loading Asset objects...")
|
||||
assets_map = {asset.id: asset for asset in Asset.objects.filter(id__in=asset_ids)}
|
||||
|
||||
log("Batch loading Node objects...")
|
||||
nodes_map = {node.id: node for node in Node.objects.filter(id__in=node_ids)}
|
||||
|
||||
# Batch fetch all Organization objects
|
||||
org_ids = list(set(asset.org_id for asset in assets_map.values())) + \
|
||||
list(set(node.org_id for node in nodes_map.values()))
|
||||
org_ids = list(set(org_ids)) # Remove duplicates
|
||||
|
||||
log("Batch loading Organization objects...")
|
||||
orgs_map = {org.id: org for org in Organization.objects.filter(id__in=org_ids)}
|
||||
|
||||
# Build mapping of asset_id -> list of through_records
|
||||
asset_nodes_map = {}
|
||||
for through in all_through_records:
|
||||
if through.asset_id not in asset_nodes_map:
|
||||
asset_nodes_map[through.asset_id] = []
|
||||
asset_nodes_map[through.asset_id].append(through)
|
||||
|
||||
# Organize by organization first, then by node count, then by asset
|
||||
org_assets_data = {} # org_id -> { node_count -> [asset_data] }
|
||||
|
||||
for item in multi_parent_assets:
|
||||
asset_id = item['asset_id']
|
||||
node_count = item['node_count']
|
||||
|
||||
# Get Asset object from map
|
||||
asset = assets_map.get(asset_id)
|
||||
if not asset:
|
||||
log(f"⚠ Asset {asset_id} not found in map, skipping")
|
||||
continue
|
||||
|
||||
org_id = asset.org_id
|
||||
|
||||
# Initialize org data if not exists
|
||||
if org_id not in org_assets_data:
|
||||
org_assets_data[org_id] = {}
|
||||
|
||||
# Get all nodes for this asset
|
||||
through_records = asset_nodes_map.get(asset_id, [])
|
||||
|
||||
node_details = []
|
||||
for through in through_records:
|
||||
# Get Node object from map
|
||||
node = nodes_map.get(through.node_id)
|
||||
if not node:
|
||||
log(f"⚠ Node {through.node_id} not found in map, skipping")
|
||||
continue
|
||||
|
||||
node_details.append({
|
||||
'id': node.id,
|
||||
'name': node.name,
|
||||
'key': node.key,
|
||||
'path': node.full_value if hasattr(node, 'full_value') else ''
|
||||
})
|
||||
|
||||
if not node_details:
|
||||
continue
|
||||
|
||||
if node_count not in org_assets_data[org_id]:
|
||||
org_assets_data[org_id][node_count] = []
|
||||
|
||||
org_assets_data[org_id][node_count].append({
|
||||
'asset_id': asset.id,
|
||||
'asset_name': asset.name,
|
||||
'nodes': node_details
|
||||
})
|
||||
|
||||
return org_assets_data
|
||||
|
||||
|
||||
def generate_report(org_assets_data):
|
||||
"""Generate and write report to file organized by organization"""
|
||||
# Clear previous report
|
||||
if os.path.exists(OUTPUT_FILE):
|
||||
os.remove(OUTPUT_FILE)
|
||||
|
||||
# Write header
|
||||
write_report(f"Multi-Parent Assets Report\n")
|
||||
write_report(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
|
||||
write_report(f"{'='*80}\n\n")
|
||||
|
||||
# Get all organizations
|
||||
all_org_ids = list(set(org_id for org_id in org_assets_data.keys()))
|
||||
all_orgs = {org.id: org for org in Organization.objects.filter(id__in=all_org_ids)}
|
||||
|
||||
# Calculate statistics
|
||||
total_orgs = Organization.objects.count()
|
||||
orgs_with_issues = len(org_assets_data)
|
||||
orgs_without_issues = total_orgs - orgs_with_issues
|
||||
total_assets_with_issues = sum(
|
||||
len(assets)
|
||||
for org_id in org_assets_data
|
||||
for assets in org_assets_data[org_id].values()
|
||||
)
|
||||
|
||||
# Overview
|
||||
write_report("OVERVIEW\n")
|
||||
write_report(f"{'-'*80}\n")
|
||||
write_report(f"Total organizations: {total_orgs:,}\n")
|
||||
write_report(f"Organizations with multiple-parent assets: {orgs_with_issues:,}\n")
|
||||
write_report(f"Organizations without issues: {orgs_without_issues:,}\n")
|
||||
write_report(f"Total assets with multiple parent nodes: {total_assets_with_issues:,}\n\n")
|
||||
|
||||
# Summary by organization
|
||||
write_report("Summary by Organization:\n")
|
||||
for org_id in sorted(org_assets_data.keys()):
|
||||
org_name = get_org_name(org_id, all_orgs)
|
||||
|
||||
org_asset_count = sum(
|
||||
len(assets)
|
||||
for assets in org_assets_data[org_id].values()
|
||||
)
|
||||
write_report(f" - {org_name} ({org_id}): {org_asset_count:,} assets\n")
|
||||
|
||||
write_report(f"\n{'='*80}\n\n")
|
||||
|
||||
# Detailed sections grouped by organization, then node count
|
||||
for org_id in sorted(org_assets_data.keys()):
|
||||
org_name = get_org_name(org_id, all_orgs)
|
||||
|
||||
org_asset_count = sum(
|
||||
len(assets)
|
||||
for assets in org_assets_data[org_id].values()
|
||||
)
|
||||
|
||||
write_report(f"ORGANIZATION: {org_name} ({org_id})\n")
|
||||
write_report(f"Total assets with issues: {org_asset_count:,}\n")
|
||||
write_report(f"{'-'*80}\n\n")
|
||||
|
||||
# Group by node count within this organization
|
||||
for node_count in sorted(org_assets_data[org_id].keys(), reverse=True):
|
||||
assets = org_assets_data[org_id][node_count]
|
||||
|
||||
write_report(f" Section: {node_count} Parent Nodes ({len(assets):,} assets)\n")
|
||||
write_report(f" {'-'*76}\n\n")
|
||||
|
||||
for asset in assets:
|
||||
write_report(f" {asset['asset_name']} ({asset['asset_id']})\n")
|
||||
|
||||
for node in asset['nodes']:
|
||||
write_report(f" {node['name']} ({node['key']}) ({node['path']}) ({node['id']})\n")
|
||||
|
||||
write_report(f"\n")
|
||||
|
||||
write_report(f"\n")
|
||||
|
||||
write_report(f"{'='*80}\n\n")
|
||||
|
||||
log(f"✓ Report written to {OUTPUT_FILE}")
|
||||
|
||||
|
||||
def main():
|
||||
try:
|
||||
org_assets_data = find_assets_multiple_parents()
|
||||
|
||||
if not org_assets_data:
|
||||
log("✓ Detection complete, no issues found")
|
||||
sys.exit(0)
|
||||
|
||||
total_assets = sum(
|
||||
len(assets)
|
||||
for org_id in org_assets_data
|
||||
for assets in org_assets_data[org_id].values()
|
||||
)
|
||||
log(f"Generating report for {total_assets:,} assets across {len(org_assets_data):,} organizations...")
|
||||
|
||||
generate_report(org_assets_data)
|
||||
|
||||
log(f"✗ Detected {total_assets:,} assets with multiple parent nodes")
|
||||
sys.exit(1)
|
||||
|
||||
except Exception as e:
|
||||
log(f"✗ Error occurred: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -3,7 +3,7 @@ from random import choice
|
||||
|
||||
import forgery_py
|
||||
|
||||
from assets.const import AllTypes
|
||||
from assets.const import AllTypes, Category
|
||||
from assets.models import *
|
||||
from .base import FakeDataGenerator
|
||||
|
||||
@@ -48,12 +48,12 @@ class AssetsGenerator(FakeDataGenerator):
|
||||
|
||||
def pre_generate(self):
|
||||
self.node_ids = list(Node.objects.all().values_list('id', flat=True))
|
||||
self.platform_ids = list(Platform.objects.filter(category='host').values_list('id', flat=True))
|
||||
self.platform_ids = list(Platform.objects.filter(category=Category.DATABASE).values_list('id', flat=True))
|
||||
|
||||
def set_assets_nodes(self, assets):
|
||||
for asset in assets:
|
||||
nodes_id_add_to = random.sample(self.node_ids, 3)
|
||||
asset.nodes.add(*nodes_id_add_to)
|
||||
nodes_id_add_to = random.choice(self.node_ids)
|
||||
asset.node_id = nodes_id_add_to
|
||||
|
||||
def do_generate(self, batch, batch_size):
|
||||
assets = []
|
||||
|
||||
Reference in New Issue
Block a user