Compare commits

..

2 Commits

Author SHA1 Message Date
ibuler
e41f6e27e2 perf: update schema 2025-12-12 15:40:40 +08:00
ibuler
d2386fb56c perf: update swagger for mcp 2025-12-10 18:17:10 +08:00
23 changed files with 494 additions and 1667 deletions

View File

@@ -1,126 +0,0 @@
# Generated by Django 4.1.13 on 2025-12-16 09:14
from django.db import migrations, models, transaction
import django.db.models.deletion
def log(msg=''):
print(f' -> {msg}')
def ensure_asset_single_node(apps, schema_editor):
print('')
log('Checking that all assets are linked to only one node...')
Asset = apps.get_model('assets', 'Asset')
Through = Asset.nodes.through
assets_count_multi_nodes = Through.objects.values('asset_id').annotate(
node_count=models.Count('node_id')
).filter(node_count__gt=1).count()
if assets_count_multi_nodes > 0:
raise Exception(
f'There are {assets_count_multi_nodes} assets associated with more than one node. '
'Please ensure each asset is linked to only one node before applying this migration.'
)
else:
log('All assets are linked to only one node. Proceeding with the migration.')
def ensure_asset_has_node(apps, schema_editor):
log('Checking that all assets are linked to at least one node...')
Asset = apps.get_model('assets', 'Asset')
Through = Asset.nodes.through
asset_count = Asset.objects.count()
through_asset_count = Through.objects.values('asset_id').count()
assets_count_without_node = asset_count - through_asset_count
if assets_count_without_node > 0:
raise Exception(
f'Some assets ({assets_count_without_node}) are not associated with any node. '
'Please ensure all assets are linked to a node before applying this migration.'
)
else:
log('All assets are linked to a node. Proceeding with the migration.')
def migrate_asset_node_id_field(apps, schema_editor):
log('Migrating node_id field for all assets...')
Asset = apps.get_model('assets', 'Asset')
Through = Asset.nodes.through
assets = Asset.objects.filter(node_id__isnull=True)
log (f'Found {assets.count()} assets to migrate.')
asset_node_mapper = {
str(asset_id): str(node_id)
for asset_id, node_id in Through.objects.values_list('asset_id', 'node_id')
}
# 测试
asset_node_mapper.pop(None, None) # Remove any entries with None keys
for asset in assets:
node_id = asset_node_mapper.get(str(asset.id))
if not node_id:
raise Exception(
f'Asset (ID: {asset.id}) is not associated with any node. '
'Cannot migrate node_id field.'
)
asset.node_id = node_id
with transaction.atomic():
total = len(assets)
batch_size = 5000
for i in range(0, total, batch_size):
batch = assets[i:i+batch_size]
start = i + 1
end = min(i + batch_size, total)
for asset in batch:
asset.save(update_fields=['node_id'])
log(f"Migrated {start}-{end}/{total} assets")
count = Asset.objects.filter(node_id__isnull=True).count()
if count > 0:
log('Warning: Some assets still have null node_id after migration.')
raise Exception('Migration failed: Some assets have null node_id.')
count = Asset.objects.filter(node_id__isnull=False).count()
log(f'Successfully migrated node_id for {count} assets.')
class Migration(migrations.Migration):
dependencies = [
('assets', '0019_alter_asset_connectivity'),
]
operations = [
migrations.RunPython(
ensure_asset_single_node,
reverse_code=migrations.RunPython.noop
),
migrations.RunPython(
ensure_asset_has_node,
reverse_code=migrations.RunPython.noop
),
migrations.AddField(
model_name='asset',
name='node',
field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.PROTECT, related_name='direct_assets', to='assets.node', verbose_name='Node'),
),
migrations.RunPython(
migrate_asset_node_id_field,
reverse_code=migrations.RunPython.noop
),
migrations.AlterField(
model_name='asset',
name='node',
field=models.ForeignKey(on_delete=django.db.models.deletion.PROTECT, related_name='direct_assets', to='assets.node', verbose_name='Node'),
),
]

View File

@@ -172,11 +172,6 @@ class Asset(NodesRelationMixin, LabeledMixin, AbsConnectivity, JSONFilterMixin,
"assets.Zone", null=True, blank=True, related_name='assets',
verbose_name=_("Zone"), on_delete=models.SET_NULL
)
node = models.ForeignKey(
'assets.Node', null=False, blank=False, on_delete=models.PROTECT,
related_name='direct_assets', verbose_name=_("Node")
)
# TODO: 删除完代码中所有使用的地方后,再删除 nodes 字段,并将 node 字段的 related_name 改为 'assets'
nodes = models.ManyToManyField(
'assets.Node', default=default_node, related_name='assets', verbose_name=_("Nodes")
)

View File

@@ -150,6 +150,7 @@ class AssetSerializer(BulkOrgResourceModelSerializer, ResourceLabelsMixin, Writa
auto_config = serializers.DictField(read_only=True, label=_('Auto info'))
platform = ObjectRelatedField(queryset=Platform.objects, required=True, label=_('Platform'),
attrs=('id', 'name', 'type'))
spec_info = serializers.DictField(read_only=True, label=_('Spec info'))
accounts_amount = serializers.IntegerField(read_only=True, label=_('Accounts amount'))
_accounts = None
@@ -164,7 +165,8 @@ class AssetSerializer(BulkOrgResourceModelSerializer, ResourceLabelsMixin, Writa
'directory_services',
]
read_only_fields = [
'accounts_amount', 'category', 'type', 'connectivity', 'auto_config',
'accounts_amount', 'category', 'type', 'connectivity',
'auto_config', 'spec_info',
'date_verified', 'created_by', 'date_created', 'date_updated',
]
fields = fields_small + fields_fk + fields_m2m + read_only_fields
@@ -247,6 +249,19 @@ class AssetSerializer(BulkOrgResourceModelSerializer, ResourceLabelsMixin, Writa
return
field_type.choices = AllTypes.filter_choices(category)
@staticmethod
def get_spec_info(obj):
return {}
def get_auto_config(self, obj):
return obj.auto_config()
def get_gathered_info(self, obj):
return obj.gathered_info()
def get_accounts_amount(self, obj):
return obj.accounts_amount()
@classmethod
def setup_eager_loading(cls, queryset):
""" Perform necessary eager loading of data. """
@@ -441,7 +456,7 @@ class AssetSerializer(BulkOrgResourceModelSerializer, ResourceLabelsMixin, Writa
class DetailMixin(serializers.Serializer):
accounts = AssetAccountSerializer(many=True, required=False, label=_('Accounts'))
spec_info = MethodSerializer(label=_('Spec info'), read_only=True)
spec_info = MethodSerializer(label=_('Spec info'), read_only=True, required=False)
gathered_info = MethodSerializer(label=_('Gathered info'), read_only=True)
auto_config = serializers.DictField(read_only=True, label=_('Auto info'))

View File

@@ -1,4 +1,3 @@
from .aggregate import *
from .dashboard import IndexApi
from .health import PrometheusMetricsApi, HealthCheckView
from .search import GlobalSearchView

View File

@@ -1,9 +0,0 @@
from .detail import ResourceDetailApi
from .list import ResourceListApi
from .supported import ResourceTypeListApi
__all__ = [
'ResourceListApi',
'ResourceDetailApi',
'ResourceTypeListApi',
]

View File

@@ -1,57 +0,0 @@
list_params = [
{
"name": "search",
"in": "query",
"description": "A search term.",
"required": False,
"type": "string"
},
{
"name": "order",
"in": "query",
"description": "Which field to use when ordering the results.",
"required": False,
"type": "string"
},
{
"name": "limit",
"in": "query",
"description": "Number of results to return per page. Default is 10.",
"required": False,
"type": "integer"
},
{
"name": "offset",
"in": "query",
"description": "The initial index from which to return the results.",
"required": False,
"type": "integer"
},
]
common_params = [
{
"name": "resource",
"in": "path",
"description": """Resource to query, e.g. users, assets, permissions, acls, user-groups, policies, nodes, hosts,
devices, clouds, webs, databases,
gpts, ds, customs, platforms, zones, gateways, protocol-settings, labels, virtual-accounts,
gathered-accounts, account-templates, account-template-secrets, account-backups, account-backup-executions,
change-secret-automations, change-secret-executions, change-secret-records, gather-account-automations,
gather-account-executions, push-account-automations, push-account-executions, push-account-records,
check-account-automations, check-account-executions, account-risks, integration-apps, asset-permissions,
zones, gateways, virtual-accounts, gathered-accounts, account-templates, account-template-secrets,,
GET /api/v1/resources/ to get full supported resource.
""",
"required": True,
"type": "string"
},
{
"name": "X-JMS-ORG",
"in": "header",
"description": "The organization ID to use for the request. Organization is the namespace for resources, if not set, use default org",
"required": False,
"type": "string"
}
]

View File

@@ -1,75 +0,0 @@
# views.py
from drf_spectacular.utils import extend_schema, OpenApiParameter
from rest_framework.permissions import IsAuthenticated
from rest_framework.views import APIView
from .const import common_params
from .proxy import ProxyMixin
from .utils import param_dic_to_param
one_param = [
{
'name': 'id',
'in': 'path',
'required': True,
'description': 'Resource ID',
'type': 'string',
}
]
object_params = [
param_dic_to_param(d)
for d in common_params + one_param
]
class ResourceDetailApi(ProxyMixin, APIView):
permission_classes = [IsAuthenticated]
@extend_schema(
operation_id="get_resource_detail",
summary="Get resource detail",
parameters=object_params,
description="""
Get resource detail.
{resource} is the resource name, GET /api/v1/resources/ to get full supported resource.
""",
)
def get(self, request, resource, pk=None):
return self._proxy(request, resource, pk=pk, action='retrieve')
@extend_schema(
operation_id="delete_resource",
summary="Delete the resource",
parameters=object_params,
description="Delete the resource, and can not be restored",
)
def delete(self, request, resource, pk=None):
return self._proxy(request, resource, pk, action='destroy')
@extend_schema(
operation_id="update_resource",
summary="Update the resource property",
parameters=object_params,
description="""
Update the resource property, all property will be update,
{resource} is the resource name, GET /api/v1/resources/ to get full supported resource.
OPTION /api/v1/resources/{resource}/{id}/?action=put to get field type and helptext.
""",
)
def put(self, request, resource, pk=None):
return self._proxy(request, resource, pk, action='update')
@extend_schema(
operation_id="partial_update_resource",
summary="Update the resource property",
parameters=object_params,
description="""
Partial update the resource property, only request property will be update,
OPTION /api/v1/resources/{resource}/{id}/?action=patch to get field type and helptext.
""",
)
def patch(self, request, resource, pk=None):
return self._proxy(request, resource, pk, action='partial_update')

View File

@@ -1,87 +0,0 @@
# views.py
from drf_spectacular.utils import extend_schema
from rest_framework.routers import DefaultRouter
from rest_framework.views import APIView
from .const import list_params, common_params
from .proxy import ProxyMixin
from .utils import param_dic_to_param
router = DefaultRouter()
BASE_URL = "http://localhost:8080"
list_params = [
param_dic_to_param(d)
for d in list_params + common_params
]
create_params = [
param_dic_to_param(d)
for d in common_params
]
list_schema = {
"required": [
"count",
"results"
],
"type": "object",
"properties": {
"count": {
"type": "integer"
},
"next": {
"type": "string",
"format": "uri",
"x-nullable": True
},
"previous": {
"type": "string",
"format": "uri",
"x-nullable": True
},
"results": {
"type": "array",
"items": {
}
}
}
}
from drf_spectacular.openapi import OpenApiResponse, OpenApiExample
class ResourceListApi(ProxyMixin, APIView):
@extend_schema(
operation_id="get_resource_list",
summary="Get resource list",
parameters=list_params,
responses={200: OpenApiResponse(description="Resource list response")},
description="""
Get resource list, you should set the resource name in the url.
OPTIONS /api/v1/resources/{resource}/?action=get to get every type resource's field type and help text.
""",
)
# ↓↓↓ Swagger 自动文档 ↓↓↓
def get(self, request, resource):
return self._proxy(request, resource)
@extend_schema(
operation_id="create_resource_by_type",
summary="Create resource",
parameters=create_params,
description="""
Create resource,
OPTIONS /api/v1/resources/{resource}/?action=post to get every resource type field type and helptext, and
you will know how to create it.
""",
)
def post(self, request, resource, pk=None):
if not resource:
resource = request.data.pop('resource', '')
return self._proxy(request, resource, pk, action='create')
def options(self, request, resource, pk=None):
return self._proxy(request, resource, pk, action='metadata')

View File

@@ -1,75 +0,0 @@
# views.py
from urllib.parse import urlencode
import requests
from rest_framework.exceptions import NotFound, APIException
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
from rest_framework.routers import DefaultRouter
from rest_framework.views import APIView
from .utils import get_full_resource_map
router = DefaultRouter()
BASE_URL = "http://localhost:8080"
class ProxyMixin(APIView):
"""
通用资源代理 API支持动态路径、自动文档生成
"""
permission_classes = [IsAuthenticated]
def _build_url(self, resource_name: str, pk: str = None, query_params=None):
resource_map = get_full_resource_map()
resource = resource_map.get(resource_name)
if not resource:
raise NotFound(f"Unknown resource: {resource_name}")
base_path = resource['path']
if pk:
base_path += f"{pk}/"
if query_params:
base_path += f"?{urlencode(query_params)}"
return f"{BASE_URL}{base_path}"
def _proxy(self, request, resource: str, pk: str = None, action='list'):
method = request.method.lower()
if method not in ['get', 'post', 'put', 'patch', 'delete', 'options']:
raise APIException("Unsupported method")
if not resource or resource == '{resource}':
if request.data:
resource = request.data.get('resource')
query_params = request.query_params.dict()
if action == 'list':
query_params['limit'] = 10
url = self._build_url(resource, pk, query_params)
headers = {k: v for k, v in request.headers.items() if k.lower() != 'host'}
cookies = request.COOKIES
body = request.body if method in ['post', 'put', 'patch'] else None
try:
resp = requests.request(
method=method,
url=url,
headers=headers,
cookies=cookies,
data=body,
timeout=10,
)
content_type = resp.headers.get('Content-Type', '')
if 'application/json' in content_type:
data = resp.json()
else:
data = resp.text # 或者 bytesresp.content
return Response(data=data, status=resp.status_code)
except requests.RequestException as e:
raise APIException(f"Proxy request failed: {str(e)}")

View File

@@ -1,45 +0,0 @@
# views.py
from drf_spectacular.utils import extend_schema
from rest_framework import serializers
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
from rest_framework.routers import DefaultRouter
from rest_framework.views import APIView
from .utils import get_full_resource_map
router = DefaultRouter()
BASE_URL = "http://localhost:8080"
class ResourceTypeResourceSerializer(serializers.Serializer):
name = serializers.CharField()
path = serializers.CharField()
app = serializers.CharField()
verbose_name = serializers.CharField()
description = serializers.CharField()
class ResourceTypeListApi(APIView):
permission_classes = [IsAuthenticated]
@extend_schema(
operation_id="get_supported_resources",
summary="Get-all-support-resources",
description="Get all support resources, name, path, verbose_name description",
responses={200: ResourceTypeResourceSerializer(many=True)}, # Specify the response serializer
)
def get(self, request):
result = []
resource_map = get_full_resource_map()
for name, desc in resource_map.items():
desc = resource_map.get(name, {})
resource = {
"name": name,
**desc,
"path": f'/api/v1/resources/{name}/',
}
result.append(resource)
return Response(result)

View File

@@ -1,128 +0,0 @@
# views.py
import re
from functools import lru_cache
from typing import Dict
from django.urls import URLPattern
from django.urls import URLResolver
from drf_spectacular.utils import OpenApiParameter
from rest_framework.routers import DefaultRouter
router = DefaultRouter()
BASE_URL = "http://localhost:8080"
def clean_path(path: str) -> str:
"""
清理掉 DRF 自动生成的正则格式内容,让其变成普通 RESTful URL path。
"""
# 去掉格式后缀匹配: \.(?P<format>xxx)
path = re.sub(r'\\\.\(\?P<format>[^)]+\)', '', path)
# 去掉括号式格式匹配
path = re.sub(r'\(\?P<format>[^)]+\)', '', path)
# 移除 DRF 中正则参数的部分 (?P<param>pattern)
path = re.sub(r'\(\?P<\w+>[^)]+\)', '{param}', path)
# 如果有多个括号包裹的正则(比如前缀路径),去掉可选部分包装
path = re.sub(r'\(\(([^)]+)\)\?\)', r'\1', path) # ((...))? => ...
# 去掉中间和两边的 ^ 和 $
path = path.replace('^', '').replace('$', '')
# 去掉尾部 ?/
path = re.sub(r'\?/?$', '', path)
# 去掉反斜杠
path = path.replace('\\', '')
# 替换多重斜杠
path = re.sub(r'/+', '/', path)
# 添加开头斜杠,移除多余空格
path = path.strip()
if not path.startswith('/'):
path = '/' + path
if not path.endswith('/'):
path += '/'
return path
def extract_resource_paths(urlpatterns, prefix='/api/v1/') -> Dict[str, Dict[str, str]]:
resource_map = {}
for pattern in urlpatterns:
if isinstance(pattern, URLResolver):
nested_prefix = prefix + str(pattern.pattern)
resource_map.update(extract_resource_paths(pattern.url_patterns, nested_prefix))
elif isinstance(pattern, URLPattern):
callback = pattern.callback
actions = getattr(callback, 'actions', {})
if not actions:
continue
if 'get' in actions and actions['get'] == 'list':
path = clean_path(prefix + str(pattern.pattern))
# 尝试获取资源名称
name = pattern.name
if name and name.endswith('-list'):
resource = name[:-5]
else:
resource = path.strip('/').split('/')[-1]
# 不强行加 s资源名保持原状即可
resource = resource if resource.endswith('s') else resource + 's'
# 获取 View 类和 model 的 verbose_name
view_cls = getattr(callback, 'cls', None)
model = None
if view_cls:
queryset = getattr(view_cls, 'queryset', None)
if queryset is not None:
model = getattr(queryset, 'model', None)
else:
# 有些 View 用 get_queryset()
try:
instance = view_cls()
qs = instance.get_queryset()
model = getattr(qs, 'model', None)
except Exception:
pass
if not model:
continue
app = str(getattr(model._meta, 'app_label', ''))
verbose_name = str(getattr(model._meta, 'verbose_name', ''))
resource_map[resource] = {
'path': path,
'app': app,
'verbose_name': verbose_name,
'description': model.__doc__.__str__()
}
print("Extracted resource paths:", list(resource_map.keys()))
return resource_map
def param_dic_to_param(d):
return OpenApiParameter(
name=d['name'], location=d['in'],
description=d['description'], type=d['type'], required=d.get('required', False)
)
@lru_cache()
def get_full_resource_map():
from apps.jumpserver.urls import resource_api
resource_map = extract_resource_paths(resource_api)
print("Building URL for resource:", resource_map)
return resource_map

View File

@@ -5,6 +5,7 @@ from rest_framework.pagination import LimitOffsetPagination
class MaxLimitOffsetPagination(LimitOffsetPagination):
max_limit = settings.MAX_PAGE_SIZE
default_limit = settings.DEFAULT_PAGE_SIZE
def get_count(self, queryset):
try:

View File

@@ -85,6 +85,7 @@ SPECTACULAR_SETTINGS = {
'jumpserver.views.schema.LabeledChoiceFieldExtension',
'jumpserver.views.schema.BitChoicesFieldExtension',
'jumpserver.views.schema.LabelRelatedFieldExtension',
'jumpserver.views.schema.DateTimeFieldExtension',
],
'SECURITY': [{'Bearer': []}],
}

View File

@@ -37,12 +37,6 @@ api_v1 = resource_api + [
path('prometheus/metrics/', api.PrometheusMetricsApi.as_view()),
path('search/', api.GlobalSearchView.as_view()),
]
if settings.MCP_ENABLED:
api_v1.extend([
path('resources/', api.ResourceTypeListApi.as_view(), name='resource-list'),
path('resources/<str:resource>/', api.ResourceListApi.as_view()),
path('resources/<str:resource>/<str:pk>/', api.ResourceDetailApi.as_view()),
])
app_view_patterns = [
path('auth/', include('authentication.urls.view_urls'), name='auth'),

View File

@@ -1,421 +0,0 @@
import re
from drf_spectacular.openapi import AutoSchema
from drf_spectacular.generators import SchemaGenerator
class CustomSchemaGenerator(SchemaGenerator):
from_mcp = False
def get_schema(self, request=None, public=False):
self.from_mcp = request.query_params.get('mcp') or request.path.endswith('swagger.json')
return super().get_schema(request, public)
class CustomAutoSchema(AutoSchema):
def __init__(self, *args, **kwargs):
self.from_mcp = kwargs.get('from_mcp', False)
super().__init__(*args, **kwargs)
def map_parsers(self):
return ['application/json']
def map_renderers(self, *args, **kwargs):
return ['application/json']
def get_tags(self):
operation_keys = self._tokenize_path()
if len(operation_keys) == 1:
return []
tags = ['_'.join(operation_keys[:2])]
return tags
def get_operation(self, path, *args, **kwargs):
if path.endswith('render-to-json/'):
return None
# if not path.startswith('/api/v1/users'):
# return None
operation = super().get_operation(path, *args, **kwargs)
if not operation:
return operation
if not operation.get('summary', ''):
operation['summary'] = operation.get('operationId')
return operation
def get_operation_id(self):
tokenized_path = self._tokenize_path()
# replace dashes as they can be problematic later in code generation
tokenized_path = [t.replace('-', '_') for t in tokenized_path]
action = ''
if hasattr(self.view, 'action'):
action = self.view.action
if not action:
if self.method == 'GET' and self._is_list_view():
action = 'list'
else:
action = self.method_mapping[self.method.lower()]
if action == "bulk_destroy":
action = "bulk_delete"
if not tokenized_path:
tokenized_path.append('root')
if re.search(r'<drf_format_suffix\w*:\w+>', self.path_regex):
tokenized_path.append('formatted')
return '_'.join(tokenized_path + [action])
def get_filter_parameters(self):
if not self.should_filter():
return []
fields = []
if hasattr(self.view, 'get_filter_backends'):
backends = self.view.get_filter_backends()
elif hasattr(self.view, 'filter_backends'):
backends = self.view.filter_backends
else:
backends = []
for filter_backend in backends:
fields += self.probe_inspectors(
self.filter_inspectors, 'get_filter_parameters', filter_backend()
) or []
return fields
def get_auth(self):
return [{'Bearer': []}]
def get_operation_security(self):
"""
重写操作安全配置,统一使用 Bearer token
"""
return [{'Bearer': []}]
def get_components_security_schemes(self):
"""
重写安全方案定义,避免认证类解析错误
"""
return {
'Bearer': {
'type': 'http',
'scheme': 'bearer',
'bearerFormat': 'JWT',
'description': 'JWT token for API authentication'
}
}
@staticmethod
def exclude_some_paths(path):
# 这里可以对 paths 进行处理
excludes = [
'/report/', '/render-to-json/', '/suggestions/',
'executions', 'automations', 'change-secret-records',
'change-secret-dashboard', '/copy-to-assets/',
'/move-to-assets/', 'dashboard', 'index', 'countries',
'/resources/cache/', 'profile/mfa', 'profile/password',
'profile/permissions', 'prometheus', 'constraints'
]
for p in excludes:
if path.find(p) >= 0:
return True
return False
def exclude_some_app_model(self, path):
parts = path.split('/')
if len(parts) < 5:
return False
apps = []
if self.from_mcp:
apps = [
'ops', 'tickets', 'authentication',
'settings', 'xpack', 'terminal', 'rbac',
'notifications', 'promethues', 'acls'
]
app_name = parts[3]
if app_name in apps:
return True
models = []
model = parts[4]
if self.from_mcp:
models = [
'users', 'user-groups', 'users-groups-relations', 'assets', 'hosts', 'devices', 'databases',
'webs', 'clouds', 'gpts', 'ds', 'customs', 'platforms', 'nodes', 'zones', 'gateways',
'protocol-settings', 'labels', 'virtual-accounts', 'gathered-accounts', 'account-templates',
'account-template-secrets', 'account-backups', 'account-backup-executions',
'change-secret-automations', 'change-secret-executions', 'change-secret-records',
'gather-account-automations', 'gather-account-executions', 'push-account-automations',
'push-account-executions', 'push-account-records', 'check-account-automations',
'check-account-executions', 'account-risks', 'integration-apps', 'asset-permissions',
'asset-permissions-users-relations', 'asset-permissions-user-groups-relations',
'asset-permissions-assets-relations', 'asset-permissions-nodes-relations', 'terminal-status',
'terminals', 'tasks', 'status', 'replay-storages', 'command-storages', 'session-sharing-records',
'endpoints', 'endpoint-rules', 'applets', 'applet-hosts', 'applet-publications',
'applet-host-deployments', 'virtual-apps', 'app-providers', 'virtual-app-publications',
'celery-period-tasks', 'task-executions', 'adhocs', 'playbooks', 'variables', 'ftp-logs',
'login-logs', 'operate-logs', 'password-change-logs', 'job-logs', 'jobs', 'user-sessions',
'service-access-logs', 'chatai-prompts', 'super-connection-tokens', 'flows',
'apply-assets', 'apply-nodes', 'login-acls', 'login-asset-acls', 'command-filter-acls',
'command-groups', 'connect-method-acls', 'system-msg-subscriptions', 'roles', 'role-bindings',
'system-roles', 'system-role-bindings', 'org-roles', 'org-role-bindings', 'content-types',
'labeled-resources', 'account-backup-plans', 'account-check-engines', 'account-secrets',
'change-secret', 'integration-applications', 'push-account', 'directories', 'connection-token',
'groups', 'accounts', 'resource-types', 'favorite-assets', 'activities', 'platform-automation-methods',
]
if model in models:
return True
return False
def is_excluded(self):
if self.exclude_some_paths(self.path):
return True
if self.exclude_some_app_model(self.path):
return True
return False
def get_operation(self, path, *args, **kwargs):
operation = super().get_operation(path, *args, **kwargs)
if not operation:
return operation
operation_id = operation.get('operationId')
if 'bulk' in operation_id:
return None
if not operation.get('summary', ''):
operation['summary'] = operation.get('operationId')
exclude_operations = [
'orgs_orgs_read', 'orgs_orgs_update', 'orgs_orgs_delete',
'orgs_orgs_create', 'orgs_orgs_partial_update',
]
if operation_id in exclude_operations:
return None
return operation
# 添加自定义字段的 OpenAPI 扩展
from drf_spectacular.extensions import OpenApiSerializerFieldExtension
from drf_spectacular.openapi import AutoSchema
from drf_spectacular.plumbing import build_basic_type
from common.serializers.fields import ObjectRelatedField, LabeledChoiceField, BitChoicesField
class ObjectRelatedFieldExtension(OpenApiSerializerFieldExtension):
"""
为 ObjectRelatedField 提供 OpenAPI schema
"""
target_class = ObjectRelatedField
def map_serializer_field(self, auto_schema, direction):
field = self.target
# 获取字段的基本信息
field_type = 'array' if field.many else 'object'
if field_type == 'array':
# 如果是多对多关系
return {
'type': 'array',
'items': self._get_openapi_item_schema(field),
'description': getattr(field, 'help_text', ''),
'title': getattr(field, 'label', ''),
}
else:
# 如果是一对一关系
return {
'type': 'object',
'properties': self._get_openapi_properties_schema(field),
'description': getattr(field, 'help_text', ''),
'title': getattr(field, 'label', ''),
}
def _get_openapi_item_schema(self, field):
"""
获取数组项的 OpenAPI schema
"""
return self._get_openapi_object_schema(field)
def _get_openapi_object_schema(self, field):
"""
获取对象的 OpenAPI schema
"""
properties = {}
# 动态分析 attrs 中的属性类型
for attr in field.attrs:
# 尝试从 queryset 的 model 中获取字段信息
field_type = self._infer_field_type(field, attr)
properties[attr] = {
'type': field_type,
'description': f'{attr} field'
}
return {
'type': 'object',
'properties': properties,
'required': ['id'] if 'id' in field.attrs else []
}
def _infer_field_type(self, field, attr_name):
"""
智能推断字段类型
"""
try:
# 如果有 queryset尝试从 model 中获取字段信息
if hasattr(field, 'queryset') and field.queryset is not None:
model = field.queryset.model
if hasattr(model, '_meta') and hasattr(model._meta, 'fields'):
model_field = model._meta.get_field(attr_name)
if model_field:
return self._map_django_field_type(model_field)
except Exception:
pass
# 如果没有 queryset 或无法获取字段信息,使用启发式规则
return self._heuristic_field_type(attr_name)
def _map_django_field_type(self, model_field):
"""
将 Django 字段类型映射到 OpenAPI 类型
"""
field_type = type(model_field).__name__
# 整数类型
if 'Integer' in field_type or 'BigInteger' in field_type or 'SmallInteger' in field_type:
return 'integer'
# 浮点数类型
elif 'Float' in field_type or 'Decimal' in field_type:
return 'number'
# 布尔类型
elif 'Boolean' in field_type:
return 'boolean'
# 日期时间类型
elif 'DateTime' in field_type or 'Date' in field_type or 'Time' in field_type:
return 'string'
# 文件类型
elif 'File' in field_type or 'Image' in field_type:
return 'string'
# 其他类型默认为字符串
else:
return 'string'
def _heuristic_field_type(self, attr_name):
"""
启发式推断字段类型
"""
# 基于属性名的启发式规则
if attr_name in ['is_active', 'enabled', 'visible'] or attr_name.startswith('is_'):
return 'boolean'
elif attr_name in ['count', 'number', 'size', 'amount']:
return 'integer'
elif attr_name in ['price', 'rate', 'percentage']:
return 'number'
else:
# 默认返回字符串类型
return 'string'
def _get_openapi_properties_schema(self, field):
"""
获取对象属性的 OpenAPI schema
"""
return self._get_openapi_object_schema(field)['properties']
class LabeledChoiceFieldExtension(OpenApiSerializerFieldExtension):
"""
为 LabeledChoiceField 提供 OpenAPI schema
"""
target_class = LabeledChoiceField
def map_serializer_field(self, auto_schema, direction):
field = self.target
if getattr(field, 'many', False):
return {
'type': 'array',
'items': {
'type': 'object',
'properties': {
'value': {'type': 'string'},
'label': {'type': 'string'}
}
},
'description': getattr(field, 'help_text', ''),
'title': getattr(field, 'label', ''),
}
else:
return {
'type': 'object',
'properties': {
'value': {'type': 'string'},
'label': {'type': 'string'}
},
'description': getattr(field, 'help_text', ''),
'title': getattr(field, 'label', ''),
}
class BitChoicesFieldExtension(OpenApiSerializerFieldExtension):
"""
为 BitChoicesField 提供 OpenAPI schema
"""
target_class = BitChoicesField
def map_serializer_field(self, auto_schema, direction):
field = self.target
return {
'type': 'array',
'items': {
'type': 'object',
'properties': {
'value': {'type': 'string'},
'label': {'type': 'string'}
}
},
'description': getattr(field, 'help_text', ''),
'title': getattr(field, 'label', ''),
}
class LabelRelatedFieldExtension(OpenApiSerializerFieldExtension):
"""
为 LabelRelatedField 提供 OpenAPI schema
"""
target_class = 'common.serializers.fields.LabelRelatedField'
def map_serializer_field(self, auto_schema, direction):
field = self.target
# LabelRelatedField 返回一个包含 id, name, value, color 的对象
return {
'type': 'object',
'properties': {
'id': {
'type': 'string',
'description': 'Label ID'
},
'name': {
'type': 'string',
'description': 'Label name'
},
'value': {
'type': 'string',
'description': 'Label value'
},
'color': {
'type': 'string',
'description': 'Label color'
}
},
'required': ['id', 'name', 'value'],
'description': getattr(field, 'help_text', 'Label information'),
'title': getattr(field, 'label', 'Label'),
}

View File

@@ -0,0 +1,2 @@
from .extension import *
from .schema import *

View File

@@ -0,0 +1,263 @@
# 添加自定义字段的 OpenAPI 扩展
from drf_spectacular.extensions import OpenApiSerializerFieldExtension
from drf_spectacular.openapi import AutoSchema
from drf_spectacular.plumbing import build_basic_type
from rest_framework import serializers
from common.serializers.fields import ObjectRelatedField, LabeledChoiceField, BitChoicesField
__all__ = [
'ObjectRelatedFieldExtension', 'LabeledChoiceFieldExtension',
'BitChoicesFieldExtension', 'LabelRelatedFieldExtension',
'DateTimeFieldExtension'
]
class ObjectRelatedFieldExtension(OpenApiSerializerFieldExtension):
"""
为 ObjectRelatedField 提供 OpenAPI schema
"""
target_class = ObjectRelatedField
def map_serializer_field(self, auto_schema, direction):
field = self.target
# 获取字段的基本信息
field_type = 'array' if field.many else 'object'
if field_type == 'array':
# 如果是多对多关系
return {
'type': 'array',
'items': self._get_openapi_item_schema(field),
'description': getattr(field, 'help_text', ''),
'title': getattr(field, 'label', ''),
}
else:
# 如果是一对一关系
return {
'type': 'object',
'properties': self._get_openapi_properties_schema(field),
'description': getattr(field, 'help_text', ''),
'title': getattr(field, 'label', ''),
}
def _get_openapi_item_schema(self, field):
"""
获取数组项的 OpenAPI schema
"""
return self._get_openapi_object_schema(field)
def _get_openapi_object_schema(self, field):
"""
获取对象的 OpenAPI schema
"""
properties = {}
# 动态分析 attrs 中的属性类型
for attr in field.attrs:
# 尝试从 queryset 的 model 中获取字段信息
field_type = self._infer_field_type(field, attr)
properties[attr] = {
'type': field_type,
'description': f'{attr} field'
}
return {
'type': 'object',
'properties': properties,
'required': ['id'] if 'id' in field.attrs else []
}
def _infer_field_type(self, field, attr_name):
"""
智能推断字段类型
"""
try:
# 如果有 queryset尝试从 model 中获取字段信息
if hasattr(field, 'queryset') and field.queryset is not None:
model = field.queryset.model
if hasattr(model, '_meta') and hasattr(model._meta, 'fields'):
model_field = model._meta.get_field(attr_name)
if model_field:
return self._map_django_field_type(model_field)
except Exception:
pass
# 如果没有 queryset 或无法获取字段信息,使用启发式规则
return self._heuristic_field_type(attr_name)
def _map_django_field_type(self, model_field):
"""
将 Django 字段类型映射到 OpenAPI 类型
"""
field_type = type(model_field).__name__
# 整数类型
if 'Integer' in field_type or 'BigInteger' in field_type or 'SmallInteger' in field_type or 'AutoField' in field_type:
return 'integer'
# 浮点数类型
elif 'Float' in field_type or 'Decimal' in field_type:
return 'number'
# 布尔类型
elif 'Boolean' in field_type:
return 'boolean'
# 日期时间类型
elif 'DateTime' in field_type or 'Date' in field_type or 'Time' in field_type:
return 'string'
# 文件类型
elif 'File' in field_type or 'Image' in field_type:
return 'string'
# 其他类型默认为字符串
else:
return 'string'
def _heuristic_field_type(self, attr_name):
"""
启发式推断字段类型
"""
# 基于属性名的启发式规则
if attr_name in ['is_active', 'enabled', 'visible'] or attr_name.startswith('is_'):
return 'boolean'
elif attr_name in ['count', 'number', 'size', 'amount']:
return 'integer'
elif attr_name in ['price', 'rate', 'percentage']:
return 'number'
else:
# 默认返回字符串类型
return 'string'
def _get_openapi_properties_schema(self, field):
"""
获取对象属性的 OpenAPI schema
"""
return self._get_openapi_object_schema(field)['properties']
class LabeledChoiceFieldExtension(OpenApiSerializerFieldExtension):
"""
为 LabeledChoiceField 提供 OpenAPI schema
"""
target_class = LabeledChoiceField
def map_serializer_field(self, auto_schema, direction):
field = self.target
if getattr(field, 'many', False):
return {
'type': 'array',
'items': {
'type': 'object',
'properties': {
'value': {'type': 'string'},
'label': {'type': 'string'}
}
},
'description': getattr(field, 'help_text', ''),
'title': getattr(field, 'label', ''),
}
else:
return {
'type': 'object',
'properties': {
'value': {'type': 'string'},
'label': {'type': 'string'}
},
'description': getattr(field, 'help_text', ''),
'title': getattr(field, 'label', ''),
}
class BitChoicesFieldExtension(OpenApiSerializerFieldExtension):
"""
为 BitChoicesField 提供 OpenAPI schema
"""
target_class = BitChoicesField
def map_serializer_field(self, auto_schema, direction):
field = self.target
return {
'type': 'array',
'items': {
'type': 'object',
'properties': {
'value': {'type': 'string'},
'label': {'type': 'string'}
}
},
'description': getattr(field, 'help_text', ''),
'title': getattr(field, 'label', ''),
}
class LabelRelatedFieldExtension(OpenApiSerializerFieldExtension):
"""
为 LabelRelatedField 提供 OpenAPI schema
"""
target_class = 'common.serializers.fields.LabelRelatedField'
def map_serializer_field(self, auto_schema, direction):
field = self.target
# LabelRelatedField 返回一个包含 id, name, value, color 的对象
return {
'type': 'object',
'properties': {
'id': {
'type': 'string',
'description': 'Label ID'
},
'name': {
'type': 'string',
'description': 'Label name'
},
'value': {
'type': 'string',
'description': 'Label value'
},
'color': {
'type': 'string',
'description': 'Label color'
}
},
'required': ['id', 'name', 'value'],
'description': getattr(field, 'help_text', 'Label information'),
'title': getattr(field, 'label', 'Label'),
}
class DateTimeFieldExtension(OpenApiSerializerFieldExtension):
"""
为 DateTimeField 提供自定义 OpenAPI schema
修正 datetime 字段格式,使其符合实际返回格式 '%Y/%m/%d %H:%M:%S %z'
而不是标准的 ISO 8601 格式 (date-time)
"""
target_class = serializers.DateTimeField
def map_serializer_field(self, auto_schema, direction):
field = self.target
# 获取字段的描述信息,确保始终是字符串类型
help_text = getattr(field, 'help_text', None) or ''
description = help_text if isinstance(help_text, str) else ''
# 添加格式说明
format_desc = 'Format: YYYY/MM/DD HH:MM:SS +TZ (e.g., 2023/10/01 12:00:00 +0800)'
if description:
description = f'{description} {format_desc}'
else:
description = format_desc
# 返回字符串类型,不包含 format: date-time
# 因为实际返回格式是 '%Y/%m/%d %H:%M:%S %z',不是标准的 ISO 8601
schema = {
'type': 'string',
'description': description,
'title': getattr(field, 'label', '') or '',
'example': '2023/10/01 12:00:00 +0800',
}
return schema

View File

@@ -0,0 +1,206 @@
import re
from drf_spectacular.openapi import AutoSchema
from drf_spectacular.generators import SchemaGenerator
__all__ = [
'CustomSchemaGenerator', 'CustomAutoSchema'
]
class CustomSchemaGenerator(SchemaGenerator):
from_mcp = False
def get_schema(self, request=None, public=False):
self.from_mcp = request.query_params.get('mcp') or request.path.endswith('swagger.json')
return super().get_schema(request, public)
class CustomAutoSchema(AutoSchema):
def __init__(self, *args, **kwargs):
self.from_mcp = True
super().__init__(*args, **kwargs)
def map_parsers(self):
return ['application/json']
def map_renderers(self, *args, **kwargs):
return ['application/json']
def get_tags(self):
operation_keys = self._tokenize_path()
if len(operation_keys) == 1:
return []
tags = ['_'.join(operation_keys[:2])]
return tags
def get_operation_id(self):
tokenized_path = self._tokenize_path()
# replace dashes as they can be problematic later in code generation
tokenized_path = [t.replace('-', '_') for t in tokenized_path]
action = ''
if hasattr(self.view, 'action'):
action = self.view.action
if not action:
if self.method == 'GET' and self._is_list_view():
action = 'list'
else:
action = self.method_mapping[self.method.lower()]
if action == "bulk_destroy":
action = "bulk_delete"
if not tokenized_path:
tokenized_path.append('root')
if re.search(r'<drf_format_suffix\w*:\w+>', self.path_regex):
tokenized_path.append('formatted')
return '_'.join(tokenized_path + [action])
def get_filter_parameters(self):
if not self.should_filter():
return []
fields = []
if hasattr(self.view, 'get_filter_backends'):
backends = self.view.get_filter_backends()
elif hasattr(self.view, 'filter_backends'):
backends = self.view.filter_backends
else:
backends = []
for filter_backend in backends:
fields += self.probe_inspectors(
self.filter_inspectors, 'get_filter_parameters', filter_backend()
) or []
return fields
def get_auth(self):
return [{'Bearer': []}]
def get_operation_security(self):
"""
重写操作安全配置,统一使用 Bearer token
"""
return [{'Bearer': []}]
def get_components_security_schemes(self):
"""
重写安全方案定义,避免认证类解析错误
"""
return {
'Bearer': {
'type': 'http',
'scheme': 'bearer',
'bearerFormat': 'JWT',
'description': 'JWT token for API authentication'
}
}
@staticmethod
def exclude_some_paths(path):
# 这里可以对 paths 进行处理
excludes = [
'/report/', '/render-to-json/', '/suggestions/',
'executions', 'automations', 'change-secret-records',
'change-secret-dashboard', '/copy-to-assets/',
'/move-to-assets/', 'dashboard', 'index', 'countries',
'/resources/cache/', 'profile/mfa', 'profile/password',
'profile/permissions', 'prometheus', 'constraints',
'/api/swagger.json', '/api/swagger.yaml',
]
for p in excludes:
if path.find(p) >= 0:
return True
return False
def exclude_some_models(self, model):
models = []
if self.from_mcp:
models = [
'users', 'user-groups',
'assets', 'hosts', 'devices', 'databases',
'webs', 'clouds', 'ds', 'platforms',
'nodes', 'zones', 'labels',
'accounts', 'account-templates',
'asset-permissions',
]
if models and model in models:
return False
return True
def exclude_some_apps(self, app):
apps = []
if self.from_mcp:
apps = [
'users', 'assets', 'accounts',
'perms', 'labels',
]
if apps and app in apps:
return False
return True
def exclude_some_app_model(self, path):
parts = path.split('/')
if len(parts) < 5 :
return True
if len(parts) == 7 and parts[5] != "{id}":
return True
if len(parts) > 7:
return True
app_name = parts[3]
if self.exclude_some_apps(app_name):
return True
if self.exclude_some_models(parts[4]):
return True
return False
def is_excluded(self):
if self.exclude_some_paths(self.path):
return True
if self.exclude_some_app_model(self.path):
return True
return False
def exclude_some_operations(self, operation_id):
exclude_operations = [
'orgs_orgs_read', 'orgs_orgs_update', 'orgs_orgs_delete',
'orgs_orgs_create', 'orgs_orgs_partial_update',
]
if operation_id in exclude_operations:
return True
if 'bulk' in operation_id:
return True
if 'destroy' in operation_id:
return True
if 'update' in operation_id and 'partial' not in operation_id:
return True
return False
def get_operation(self, path, *args, **kwargs):
operation = super().get_operation(path, *args, **kwargs)
if not operation:
return operation
operation_id = operation.get('operationId')
if self.exclude_some_operations(operation_id):
return None
if not operation.get('summary', ''):
operation['summary'] = operation.get('operationId')
# if self.is_excluded():
# return None
return operation

View File

@@ -37,11 +37,11 @@ class SchemeMixin:
}
return Response(schema)
@method_decorator(cache_page(60 * 5,), name="dispatch")
# @method_decorator(cache_page(60 * 5,), name="dispatch")
class JsonApi(SchemeMixin, SpectacularJSONAPIView):
pass
@method_decorator(cache_page(60 * 5,), name="dispatch")
# @method_decorator(cache_page(60 * 5,), name="dispatch")
class YamlApi(SchemeMixin, SpectacularYAMLAPIView):
pass

View File

@@ -73,6 +73,7 @@ class PrivateSettingSerializer(PublicSettingSerializer):
CHAT_AI_ENABLED = serializers.BooleanField()
CHAT_AI_METHOD = serializers.CharField()
CHAT_AI_EMBED_URL = serializers.CharField()
GPT_MODEL = serializers.CharField()
FILE_UPLOAD_SIZE_LIMIT_MB = serializers.IntegerField()
FTP_FILE_MAX_STORE = serializers.IntegerField()
LOKI_LOG_ENABLED = serializers.BooleanField()

View File

@@ -126,6 +126,7 @@ class Terminal(StorageMixin, TerminalStatusMixin, JMSBaseModel):
'GPT_BASE_URL': data.get('url'),
'GPT_API_KEY': data.get('api_key'),
'GPT_PROXY': data.get('proxy'),
'GPT_MODEL': data.get('model'),
'CHAT_AI_PROVIDERS': settings.CHAT_AI_PROVIDERS,
}

View File

@@ -1,358 +0,0 @@
import os
import sys
import django
import random
from datetime import datetime
if os.path.exists('../../apps'):
sys.path.insert(0, '../../apps')
if os.path.exists('../apps'):
sys.path.insert(0, '../apps')
elif os.path.exists('./apps'):
sys.path.insert(0, './apps')
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "jumpserver.settings")
django.setup()
from assets.models import Asset, Node
from orgs.models import Organization
from django.db.models import Count
OUTPUT_FILE = 'report_cleanup_and_keep_one_node_for_multi_parent_nodes_assets.txt'
# Special organization IDs and names
SPECIAL_ORGS = {
'00000000-0000-0000-0000-000000000000': 'GLOBAL',
'00000000-0000-0000-0000-000000000002': 'DEFAULT',
'00000000-0000-0000-0000-000000000004': 'SYSTEM',
}
try:
AssetNodeThrough = Asset.nodes.through
except Exception as e:
print("Failed to get AssetNodeThrough model. Check Asset.nodes field definition.")
raise e
def log(msg=''):
"""Print log with timestamp to console"""
print(f"[{datetime.now().strftime('%H:%M:%S')}] {msg}")
def write_report(content):
"""Write content to report file"""
with open(OUTPUT_FILE, 'a', encoding='utf-8') as f:
f.write(content)
def get_org_name(org_id, orgs_map):
"""Get organization name, check special orgs first, then orgs_map"""
# Check if it's a special organization
org_id_str = str(org_id)
if org_id_str in SPECIAL_ORGS:
return SPECIAL_ORGS[org_id_str]
# Try to get from orgs_map
org = orgs_map.get(org_id)
if org:
return org.name
return 'Unknown'
def find_and_cleanup_multi_parent_assets():
"""Find and cleanup assets with multiple parent nodes"""
log("Searching for assets with multiple parent nodes...")
# Find all asset_ids that belong to multiple node_ids
multi_parent_assets = AssetNodeThrough.objects.values('asset_id').annotate(
node_count=Count('node_id', distinct=True)
).filter(node_count__gt=1).order_by('-node_count')
total_count = multi_parent_assets.count()
log(f"Found {total_count:,} assets with multiple parent nodes\n")
if total_count == 0:
log("✓ All assets already have single parent node")
return {}
# Collect all asset_ids and node_ids
asset_ids = [item['asset_id'] for item in multi_parent_assets]
# Get all through records
all_through_records = AssetNodeThrough.objects.filter(asset_id__in=asset_ids)
node_ids = list(set(through.node_id for through in all_through_records))
# Batch fetch all objects
log("Batch loading Asset objects...")
assets_map = {asset.id: asset for asset in Asset.objects.filter(id__in=asset_ids)}
log("Batch loading Node objects...")
nodes_map = {node.id: node for node in Node.objects.filter(id__in=node_ids)}
# Batch fetch all Organization objects
org_ids = list(set(asset.org_id for asset in assets_map.values())) + \
list(set(node.org_id for node in nodes_map.values()))
org_ids = list(set(org_ids))
log("Batch loading Organization objects...")
orgs_map = {org.id: org for org in Organization.objects.filter(id__in=org_ids)}
# Build mapping of asset_id -> list of through_records
asset_nodes_map = {}
for through in all_through_records:
if through.asset_id not in asset_nodes_map:
asset_nodes_map[through.asset_id] = []
asset_nodes_map[through.asset_id].append(through)
# Organize by organization
org_cleanup_data = {} # org_id -> { asset_id -> { keep_node_id, remove_node_ids } }
for item in multi_parent_assets:
asset_id = item['asset_id']
# Get Asset object
asset = assets_map.get(asset_id)
if not asset:
log(f"⚠ Asset {asset_id} not found in map, skipping")
continue
org_id = asset.org_id
# Initialize org data if not exists
if org_id not in org_cleanup_data:
org_cleanup_data[org_id] = {}
# Get all nodes for this asset
through_records = asset_nodes_map.get(asset_id, [])
if len(through_records) < 2:
continue
# Randomly select one node to keep
keep_through = random.choice(through_records)
remove_throughs = [t for t in through_records if t.id != keep_through.id]
org_cleanup_data[org_id][asset_id] = {
'asset_name': asset.name,
'keep_node_id': keep_through.node_id,
'keep_node': nodes_map.get(keep_through.node_id),
'remove_records': remove_throughs,
'remove_nodes': [nodes_map.get(t.node_id) for t in remove_throughs]
}
return org_cleanup_data
def perform_cleanup(org_cleanup_data, dry_run=False):
"""Perform the actual cleanup - delete extra node relationships"""
if dry_run:
log("DRY RUN: Simulating cleanup process (no data will be deleted)...")
else:
log("\nStarting cleanup process...")
total_deleted = 0
for org_id in org_cleanup_data.keys():
for asset_id, cleanup_info in org_cleanup_data[org_id].items():
# Delete the extra relationships
for through_record in cleanup_info['remove_records']:
if not dry_run:
through_record.delete()
total_deleted += 1
return total_deleted
def verify_cleanup():
"""Verify that there are no more assets with multiple parent nodes"""
log("\n" + "="*80)
log("VERIFICATION: Checking for remaining assets with multiple parent nodes...")
log("="*80)
# Find all asset_ids that belong to multiple node_ids
multi_parent_assets = AssetNodeThrough.objects.values('asset_id').annotate(
node_count=Count('node_id', distinct=True)
).filter(node_count__gt=1).order_by('-node_count')
remaining_count = multi_parent_assets.count()
if remaining_count == 0:
log(f"✓ Verification successful: No assets with multiple parent nodes remaining\n")
return True
else:
log(f"✗ Verification failed: Found {remaining_count:,} assets still with multiple parent nodes\n")
# Show some details
for item in multi_parent_assets[:10]:
asset_id = item['asset_id']
node_count = item['node_count']
try:
asset = Asset.objects.get(id=asset_id)
log(f" - Asset: {asset.name} ({asset_id}) has {node_count} parent nodes")
except:
log(f" - Asset ID: {asset_id} has {node_count} parent nodes")
if remaining_count > 10:
log(f" ... and {remaining_count - 10} more")
return False
def generate_report(org_cleanup_data, total_deleted):
"""Generate and write report to file"""
# Clear previous report
if os.path.exists(OUTPUT_FILE):
os.remove(OUTPUT_FILE)
# Write header
write_report(f"Multi-Parent Assets Cleanup Report\n")
write_report(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
write_report(f"{'='*80}\n\n")
# Get all organizations
all_org_ids = list(set(org_id for org_id in org_cleanup_data.keys()))
all_orgs = {org.id: org for org in Organization.objects.filter(id__in=all_org_ids)}
# Calculate statistics
total_orgs = Organization.objects.count()
orgs_processed = len(org_cleanup_data)
orgs_no_issues = total_orgs - orgs_processed
total_assets_cleaned = sum(len(assets) for assets in org_cleanup_data.values())
# Overview
write_report("OVERVIEW\n")
write_report(f"{'-'*80}\n")
write_report(f"Total organizations: {total_orgs:,}\n")
write_report(f"Organizations processed: {orgs_processed:,}\n")
write_report(f"Organizations without issues: {orgs_no_issues:,}\n")
write_report(f"Total assets cleaned: {total_assets_cleaned:,}\n")
total_relationships = AssetNodeThrough.objects.count()
write_report(f"Total relationships (through records): {total_relationships:,}\n")
write_report(f"Total relationships deleted: {total_deleted:,}\n\n")
# Summary by organization
write_report("Summary by Organization:\n")
for org_id in sorted(org_cleanup_data.keys()):
org_name = get_org_name(org_id, all_orgs)
asset_count = len(org_cleanup_data[org_id])
write_report(f" - {org_name} ({org_id}): {asset_count:,} assets cleaned\n")
write_report(f"\n{'='*80}\n\n")
# Detailed cleanup information grouped by organization
for org_id in sorted(org_cleanup_data.keys()):
org_name = get_org_name(org_id, all_orgs)
asset_count = len(org_cleanup_data[org_id])
write_report(f"ORGANIZATION: {org_name} ({org_id})\n")
write_report(f"Total assets cleaned: {asset_count:,}\n")
write_report(f"{'-'*80}\n\n")
for asset_id, cleanup_info in org_cleanup_data[org_id].items():
write_report(f"Asset: {cleanup_info['asset_name']} ({asset_id})\n")
# Kept node
keep_node = cleanup_info['keep_node']
if keep_node:
write_report(f" ✓ Kept: {keep_node.name} (key: {keep_node.key}) (id: {keep_node.id})\n")
else:
write_report(f" ✓ Kept: Unknown (id: {cleanup_info['keep_node_id']})\n")
# Removed nodes
write_report(f" ✗ Removed: {len(cleanup_info['remove_nodes'])} node(s)\n")
for node in cleanup_info['remove_nodes']:
if node:
write_report(f" - {node.name} (key: {node.key}) (id: {node.id})\n")
else:
write_report(f" - Unknown\n")
write_report(f"\n")
write_report(f"{'='*80}\n\n")
log(f"✓ Report written to {OUTPUT_FILE}")
def main():
try:
# Display warning banner
warning_message = """
╔══════════════════════════════════════════════════════════════════════════════╗
║ ⚠️ WARNING ⚠️ ║
║ ║
║ This script is designed for TEST/FAKE DATA ONLY! ║
║ DO NOT run this script in PRODUCTION environment! ║
║ ║
║ This script will DELETE asset-node relationships from the database. ║
║ Use only for data cleanup in development/testing environments. ║
║ ║
╚══════════════════════════════════════════════════════════════════════════════╝
"""
print(warning_message)
# Ask user to confirm before proceeding
confirm = input("Do you understand the warning and want to continue? (yes/no): ").strip().lower()
if confirm not in ['yes', 'y']:
log("✗ Operation cancelled by user")
sys.exit(0)
log("✓ Proceeding with operation\n")
org_cleanup_data = find_and_cleanup_multi_parent_assets()
if not org_cleanup_data:
log("✓ Cleanup complete, no assets to process")
sys.exit(0)
total_assets = sum(len(assets) for assets in org_cleanup_data.values())
log(f"\nProcessing {total_assets:,} assets across {len(org_cleanup_data):,} organizations...")
# First, do a dry-run to show what will be deleted
log("\n" + "="*80)
log("PREVIEW: Simulating cleanup process...")
log("="*80)
total_deleted_preview = perform_cleanup(org_cleanup_data, dry_run=True)
log(f"✓ Dry-run complete: {total_deleted_preview:,} relationships would be deleted\n")
# Generate preview report
generate_report(org_cleanup_data, total_deleted_preview)
log(f"✓ Preview report written to {OUTPUT_FILE}\n")
# Ask for confirmation 3 times before actual deletion
log("="*80)
log("FINAL CONFIRMATION: Do you want to proceed with actual cleanup?")
log("="*80)
confirmation_count = 3
for attempt in range(1, confirmation_count + 1):
response = input(f"Confirm cleanup (attempt {attempt}/{confirmation_count})? (yes/no): ").strip().lower()
if response not in ['yes', 'y']:
log(f"✗ Cleanup cancelled by user at attempt {attempt}")
sys.exit(1)
log("✓ All confirmations received, proceeding with actual cleanup")
# Perform cleanup
total_deleted = perform_cleanup(org_cleanup_data)
log(f"✓ Deleted {total_deleted:,} relationships")
# Generate final report
generate_report(org_cleanup_data, total_deleted)
# Verify cleanup by checking for remaining multi-parent assets
verify_cleanup()
log(f"✓ Cleanup complete: processed {total_assets:,} assets")
sys.exit(0)
except Exception as e:
log(f"✗ Error occurred: {str(e)}")
import traceback
traceback.print_exc()
sys.exit(2)
if __name__ == "__main__":
main()

View File

@@ -1,270 +0,0 @@
import os
import sys
import django
from datetime import datetime
if os.path.exists('../../apps'):
sys.path.insert(0, '../../apps')
if os.path.exists('../apps'):
sys.path.insert(0, '../apps')
elif os.path.exists('./apps'):
sys.path.insert(0, './apps')
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "jumpserver.settings")
django.setup()
from assets.models import Asset, Node
from orgs.models import Organization
from django.db.models import Count
OUTPUT_FILE = 'report_find_multi_parent_nodes_assets.txt'
# Special organization IDs and names
SPECIAL_ORGS = {
'00000000-0000-0000-0000-000000000000': 'GLOBAL',
'00000000-0000-0000-0000-000000000002': 'DEFAULT',
'00000000-0000-0000-0000-000000000004': 'SYSTEM',
}
try:
AssetNodeThrough = Asset.nodes.through
except Exception as e:
print("Failed to get AssetNodeThrough model. Check Asset.nodes field definition.")
raise e
def log(msg=''):
"""Print log with timestamp"""
print(f"[{datetime.now().strftime('%H:%M:%S')}] {msg}")
def get_org_name(org_id, orgs_map):
"""Get organization name, check special orgs first, then orgs_map"""
# Check if it's a special organization
org_id_str = str(org_id)
if org_id_str in SPECIAL_ORGS:
return SPECIAL_ORGS[org_id_str]
# Try to get from orgs_map
org = orgs_map.get(org_id)
if org:
return org.name
return 'Unknown'
def write_report(content):
"""Write content to report file"""
with open(OUTPUT_FILE, 'a', encoding='utf-8') as f:
f.write(content)
def find_assets_multiple_parents():
"""Find assets belonging to multiple node_ids organized by organization"""
log("Searching for assets with multiple parent nodes...")
# Find all asset_ids that belong to multiple node_ids
multi_parent_assets = AssetNodeThrough.objects.values('asset_id').annotate(
node_count=Count('node_id', distinct=True)
).filter(node_count__gt=1).order_by('-node_count')
total_count = multi_parent_assets.count()
log(f"Found {total_count:,} assets with multiple parent nodes\n")
if total_count == 0:
log("✓ All assets belong to only one node")
return {}
# Collect all asset_ids and node_ids that need to be fetched
asset_ids = [item['asset_id'] for item in multi_parent_assets]
# Get all through records for these assets
all_through_records = AssetNodeThrough.objects.filter(asset_id__in=asset_ids)
node_ids = list(set(through.node_id for through in all_through_records))
# Batch fetch all Asset and Node objects
log("Batch loading Asset objects...")
assets_map = {asset.id: asset for asset in Asset.objects.filter(id__in=asset_ids)}
log("Batch loading Node objects...")
nodes_map = {node.id: node for node in Node.objects.filter(id__in=node_ids)}
# Batch fetch all Organization objects
org_ids = list(set(asset.org_id for asset in assets_map.values())) + \
list(set(node.org_id for node in nodes_map.values()))
org_ids = list(set(org_ids)) # Remove duplicates
log("Batch loading Organization objects...")
orgs_map = {org.id: org for org in Organization.objects.filter(id__in=org_ids)}
# Build mapping of asset_id -> list of through_records
asset_nodes_map = {}
for through in all_through_records:
if through.asset_id not in asset_nodes_map:
asset_nodes_map[through.asset_id] = []
asset_nodes_map[through.asset_id].append(through)
# Organize by organization first, then by node count, then by asset
org_assets_data = {} # org_id -> { node_count -> [asset_data] }
for item in multi_parent_assets:
asset_id = item['asset_id']
node_count = item['node_count']
# Get Asset object from map
asset = assets_map.get(asset_id)
if not asset:
log(f"⚠ Asset {asset_id} not found in map, skipping")
continue
org_id = asset.org_id
# Initialize org data if not exists
if org_id not in org_assets_data:
org_assets_data[org_id] = {}
# Get all nodes for this asset
through_records = asset_nodes_map.get(asset_id, [])
node_details = []
for through in through_records:
# Get Node object from map
node = nodes_map.get(through.node_id)
if not node:
log(f"⚠ Node {through.node_id} not found in map, skipping")
continue
node_details.append({
'id': node.id,
'name': node.name,
'key': node.key,
'path': node.full_value if hasattr(node, 'full_value') else ''
})
if not node_details:
continue
if node_count not in org_assets_data[org_id]:
org_assets_data[org_id][node_count] = []
org_assets_data[org_id][node_count].append({
'asset_id': asset.id,
'asset_name': asset.name,
'nodes': node_details
})
return org_assets_data
def generate_report(org_assets_data):
"""Generate and write report to file organized by organization"""
# Clear previous report
if os.path.exists(OUTPUT_FILE):
os.remove(OUTPUT_FILE)
# Write header
write_report(f"Multi-Parent Assets Report\n")
write_report(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
write_report(f"{'='*80}\n\n")
# Get all organizations
all_org_ids = list(set(org_id for org_id in org_assets_data.keys()))
all_orgs = {org.id: org for org in Organization.objects.filter(id__in=all_org_ids)}
# Calculate statistics
total_orgs = Organization.objects.count()
orgs_with_issues = len(org_assets_data)
orgs_without_issues = total_orgs - orgs_with_issues
total_assets_with_issues = sum(
len(assets)
for org_id in org_assets_data
for assets in org_assets_data[org_id].values()
)
# Overview
write_report("OVERVIEW\n")
write_report(f"{'-'*80}\n")
write_report(f"Total organizations: {total_orgs:,}\n")
write_report(f"Organizations with multiple-parent assets: {orgs_with_issues:,}\n")
write_report(f"Organizations without issues: {orgs_without_issues:,}\n")
write_report(f"Total assets with multiple parent nodes: {total_assets_with_issues:,}\n\n")
# Summary by organization
write_report("Summary by Organization:\n")
for org_id in sorted(org_assets_data.keys()):
org_name = get_org_name(org_id, all_orgs)
org_asset_count = sum(
len(assets)
for assets in org_assets_data[org_id].values()
)
write_report(f" - {org_name} ({org_id}): {org_asset_count:,} assets\n")
write_report(f"\n{'='*80}\n\n")
# Detailed sections grouped by organization, then node count
for org_id in sorted(org_assets_data.keys()):
org_name = get_org_name(org_id, all_orgs)
org_asset_count = sum(
len(assets)
for assets in org_assets_data[org_id].values()
)
write_report(f"ORGANIZATION: {org_name} ({org_id})\n")
write_report(f"Total assets with issues: {org_asset_count:,}\n")
write_report(f"{'-'*80}\n\n")
# Group by node count within this organization
for node_count in sorted(org_assets_data[org_id].keys(), reverse=True):
assets = org_assets_data[org_id][node_count]
write_report(f" Section: {node_count} Parent Nodes ({len(assets):,} assets)\n")
write_report(f" {'-'*76}\n\n")
for asset in assets:
write_report(f" {asset['asset_name']} ({asset['asset_id']})\n")
for node in asset['nodes']:
write_report(f" {node['name']} ({node['key']}) ({node['path']}) ({node['id']})\n")
write_report(f"\n")
write_report(f"\n")
write_report(f"{'='*80}\n\n")
log(f"✓ Report written to {OUTPUT_FILE}")
def main():
try:
org_assets_data = find_assets_multiple_parents()
if not org_assets_data:
log("✓ Detection complete, no issues found")
sys.exit(0)
total_assets = sum(
len(assets)
for org_id in org_assets_data
for assets in org_assets_data[org_id].values()
)
log(f"Generating report for {total_assets:,} assets across {len(org_assets_data):,} organizations...")
generate_report(org_assets_data)
log(f"✗ Detected {total_assets:,} assets with multiple parent nodes")
sys.exit(1)
except Exception as e:
log(f"✗ Error occurred: {str(e)}")
import traceback
traceback.print_exc()
sys.exit(2)
if __name__ == "__main__":
main()