Compare commits

..

6 Commits

33 changed files with 1026 additions and 216 deletions

View File

@@ -0,0 +1,126 @@
# Generated by Django 4.1.13 on 2025-12-16 09:14
from django.db import migrations, models, transaction
import django.db.models.deletion
def log(msg=''):
print(f' -> {msg}')
def ensure_asset_single_node(apps, schema_editor):
print('')
log('Checking that all assets are linked to only one node...')
Asset = apps.get_model('assets', 'Asset')
Through = Asset.nodes.through
assets_count_multi_nodes = Through.objects.values('asset_id').annotate(
node_count=models.Count('node_id')
).filter(node_count__gt=1).count()
if assets_count_multi_nodes > 0:
raise Exception(
f'There are {assets_count_multi_nodes} assets associated with more than one node. '
'Please ensure each asset is linked to only one node before applying this migration.'
)
else:
log('All assets are linked to only one node. Proceeding with the migration.')
def ensure_asset_has_node(apps, schema_editor):
log('Checking that all assets are linked to at least one node...')
Asset = apps.get_model('assets', 'Asset')
Through = Asset.nodes.through
asset_count = Asset.objects.count()
through_asset_count = Through.objects.values('asset_id').count()
assets_count_without_node = asset_count - through_asset_count
if assets_count_without_node > 0:
raise Exception(
f'Some assets ({assets_count_without_node}) are not associated with any node. '
'Please ensure all assets are linked to a node before applying this migration.'
)
else:
log('All assets are linked to a node. Proceeding with the migration.')
def migrate_asset_node_id_field(apps, schema_editor):
log('Migrating node_id field for all assets...')
Asset = apps.get_model('assets', 'Asset')
Through = Asset.nodes.through
assets = Asset.objects.filter(node_id__isnull=True)
log (f'Found {assets.count()} assets to migrate.')
asset_node_mapper = {
str(asset_id): str(node_id)
for asset_id, node_id in Through.objects.values_list('asset_id', 'node_id')
}
# 测试
asset_node_mapper.pop(None, None) # Remove any entries with None keys
for asset in assets:
node_id = asset_node_mapper.get(str(asset.id))
if not node_id:
raise Exception(
f'Asset (ID: {asset.id}) is not associated with any node. '
'Cannot migrate node_id field.'
)
asset.node_id = node_id
with transaction.atomic():
total = len(assets)
batch_size = 5000
for i in range(0, total, batch_size):
batch = assets[i:i+batch_size]
start = i + 1
end = min(i + batch_size, total)
for asset in batch:
asset.save(update_fields=['node_id'])
log(f"Migrated {start}-{end}/{total} assets")
count = Asset.objects.filter(node_id__isnull=True).count()
if count > 0:
log('Warning: Some assets still have null node_id after migration.')
raise Exception('Migration failed: Some assets have null node_id.')
count = Asset.objects.filter(node_id__isnull=False).count()
log(f'Successfully migrated node_id for {count} assets.')
class Migration(migrations.Migration):
dependencies = [
('assets', '0019_alter_asset_connectivity'),
]
operations = [
migrations.RunPython(
ensure_asset_single_node,
reverse_code=migrations.RunPython.noop
),
migrations.RunPython(
ensure_asset_has_node,
reverse_code=migrations.RunPython.noop
),
migrations.AddField(
model_name='asset',
name='node',
field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.PROTECT, related_name='direct_assets', to='assets.node', verbose_name='Node'),
),
migrations.RunPython(
migrate_asset_node_id_field,
reverse_code=migrations.RunPython.noop
),
migrations.AlterField(
model_name='asset',
name='node',
field=models.ForeignKey(on_delete=django.db.models.deletion.PROTECT, related_name='direct_assets', to='assets.node', verbose_name='Node'),
),
]

View File

@@ -172,6 +172,11 @@ class Asset(NodesRelationMixin, LabeledMixin, AbsConnectivity, JSONFilterMixin,
"assets.Zone", null=True, blank=True, related_name='assets',
verbose_name=_("Zone"), on_delete=models.SET_NULL
)
node = models.ForeignKey(
'assets.Node', null=False, blank=False, on_delete=models.PROTECT,
related_name='direct_assets', verbose_name=_("Node")
)
# TODO: 删除完代码中所有使用的地方后,再删除 nodes 字段,并将 node 字段的 related_name 改为 'assets'
nodes = models.ManyToManyField(
'assets.Node', default=default_node, related_name='assets', verbose_name=_("Nodes")
)

View File

@@ -186,6 +186,7 @@ class AssetSerializer(BulkOrgResourceModelSerializer, ResourceLabelsMixin, Writa
super().__init__(*args, **kwargs)
self._init_field_choices()
self._extract_accounts()
self._set_platform()
def _extract_accounts(self):
if not getattr(self, 'initial_data', None):
@@ -217,6 +218,21 @@ class AssetSerializer(BulkOrgResourceModelSerializer, ResourceLabelsMixin, Writa
protocols_data = [{'name': p.name, 'port': p.port} for p in protocols]
self.initial_data['protocols'] = protocols_data
def _set_platform(self):
if not hasattr(self, 'initial_data'):
return
platform_id = self.initial_data.get('platform')
if not platform_id:
return
if isinstance(platform_id, int) or str(platform_id).isdigit() or not isinstance(platform_id, str):
return
platform = Platform.objects.filter(name=platform_id).first()
if not platform:
return
self.initial_data['platform'] = platform.id
def _init_field_choices(self):
request = self.context.get('request')
if not request:
@@ -265,8 +281,10 @@ class AssetSerializer(BulkOrgResourceModelSerializer, ResourceLabelsMixin, Writa
if not platform_id and self.instance:
platform = self.instance.platform
else:
elif isinstance(platform_id, int):
platform = Platform.objects.filter(id=platform_id).first()
else:
platform = Platform.objects.filter(name=platform_id).first()
if not platform:
raise serializers.ValidationError({'platform': _("Platform not exist")})
@@ -297,6 +315,7 @@ class AssetSerializer(BulkOrgResourceModelSerializer, ResourceLabelsMixin, Writa
def is_valid(self, raise_exception=False):
self._set_protocols_default()
self._set_platform()
return super().is_valid(raise_exception=raise_exception)
def validate_protocols(self, protocols_data):

View File

@@ -67,6 +67,7 @@ class UserLoginMFAView(mixins.AuthMixin, FormView):
def get_context_data(self, **kwargs):
user = self.get_user_from_session()
mfa_context = self.get_user_mfa_context(user)
print(mfa_context)
kwargs.update(mfa_context)
return kwargs

View File

@@ -1,6 +1,5 @@
from collections import defaultdict
from django.core.cache import cache
from django.db.models import Count, Max, F, CharField, Q
from django.db.models.functions import Cast
from django.http.response import JsonResponse
@@ -145,7 +144,6 @@ class DateTimeMixin:
class DatesLoginMetricMixin:
days: int
dates_list: list
date_start_end: tuple
command_type_queryset_list: list
@@ -157,8 +155,6 @@ class DatesLoginMetricMixin:
operate_logs_queryset: OperateLog.objects
password_change_logs_queryset: PasswordChangeLog.objects
CACHE_TIMEOUT = 60
@lazyproperty
def get_type_to_assets(self):
result = Asset.objects.annotate(type=F('platform__type')). \
@@ -218,34 +214,19 @@ class DatesLoginMetricMixin:
return date_metrics_dict.get('id', [])
def get_dates_login_times_assets(self):
cache_key = f"stats:top10_assets:{self.days}"
data = cache.get(cache_key)
if data is not None:
return data
assets = self.sessions_queryset.values("asset") \
.annotate(total=Count("asset")) \
.annotate(last=Cast(Max("date_start"), output_field=CharField())) \
.order_by("-total")
result = list(assets[:10])
cache.set(cache_key, result, self.CACHE_TIMEOUT)
return result
return list(assets[:10])
def get_dates_login_times_users(self):
cache_key = f"stats:top10_users:{self.days}"
data = cache.get(cache_key)
if data is not None:
return data
users = self.sessions_queryset.values("user_id") \
.annotate(total=Count("user_id")) \
.annotate(user=Max('user')) \
.annotate(last=Cast(Max("date_start"), output_field=CharField())) \
.order_by("-total")
result = list(users[:10])
cache.set(cache_key, result, self.CACHE_TIMEOUT)
return result
return list(users[:10])
def get_dates_login_record_sessions(self):
sessions = self.sessions_queryset.order_by('-date_start')

View File

@@ -701,15 +701,7 @@ class Config(dict):
'CHAT_AI_ENABLED': False,
'CHAT_AI_METHOD': 'api',
'CHAT_AI_EMBED_URL': '',
'CHAT_AI_TYPE': 'gpt',
'GPT_BASE_URL': '',
'GPT_API_KEY': '',
'GPT_PROXY': '',
'GPT_MODEL': 'gpt-4o-mini',
'DEEPSEEK_BASE_URL': '',
'DEEPSEEK_API_KEY': '',
'DEEPSEEK_PROXY': '',
'DEEPSEEK_MODEL': 'deepseek-chat',
'CHAT_AI_PROVIDERS': [],
'VIRTUAL_APP_ENABLED': False,
'FILE_UPLOAD_SIZE_LIMIT_MB': 200,

View File

@@ -241,15 +241,7 @@ ASSET_SIZE = 'small'
CHAT_AI_ENABLED = CONFIG.CHAT_AI_ENABLED
CHAT_AI_METHOD = CONFIG.CHAT_AI_METHOD
CHAT_AI_EMBED_URL = CONFIG.CHAT_AI_EMBED_URL
CHAT_AI_TYPE = CONFIG.CHAT_AI_TYPE
GPT_BASE_URL = CONFIG.GPT_BASE_URL
GPT_API_KEY = CONFIG.GPT_API_KEY
GPT_PROXY = CONFIG.GPT_PROXY
GPT_MODEL = CONFIG.GPT_MODEL
DEEPSEEK_BASE_URL = CONFIG.DEEPSEEK_BASE_URL
DEEPSEEK_API_KEY = CONFIG.DEEPSEEK_API_KEY
DEEPSEEK_PROXY = CONFIG.DEEPSEEK_PROXY
DEEPSEEK_MODEL = CONFIG.DEEPSEEK_MODEL
CHAT_AI_DEFAULT_PROVIDER = CONFIG.CHAT_AI_DEFAULT_PROVIDER
VIRTUAL_APP_ENABLED = CONFIG.VIRTUAL_APP_ENABLED
@@ -268,4 +260,6 @@ LOKI_BASE_URL = CONFIG.LOKI_BASE_URL
TOOL_USER_ENABLED = CONFIG.TOOL_USER_ENABLED
SUGGESTION_LIMIT = CONFIG.SUGGESTION_LIMIT
MCP_ENABLED = CONFIG.MCP_ENABLED
MCP_ENABLED = CONFIG.MCP_ENABLED
CHAT_AI_PROVIDERS = CONFIG.CHAT_AI_PROVIDERS

View File

@@ -1,98 +1,10 @@
import httpx
import openai
from django.conf import settings
from django.utils.translation import gettext_lazy as _
from rest_framework import status
from rest_framework.generics import GenericAPIView
from rest_framework.response import Response
from common.api import JMSModelViewSet
from common.permissions import IsValidUser, OnlySuperUser
from .. import serializers
from ..const import ChatAITypeChoices
from ..models import ChatPrompt
from ..prompt import DefaultChatPrompt
class ChatAITestingAPI(GenericAPIView):
serializer_class = serializers.ChatAISettingSerializer
rbac_perms = {
'POST': 'settings.change_chatai'
}
def get_config(self, request):
serializer = self.serializer_class(data=request.data)
serializer.is_valid(raise_exception=True)
data = self.serializer_class().data
data.update(serializer.validated_data)
for k, v in data.items():
if v:
continue
# 页面没有传递值, 从 settings 中获取
data[k] = getattr(settings, k, None)
return data
def post(self, request):
config = self.get_config(request)
chat_ai_enabled = config['CHAT_AI_ENABLED']
if not chat_ai_enabled:
return Response(
status=status.HTTP_400_BAD_REQUEST,
data={'msg': _('Chat AI is not enabled')}
)
tp = config['CHAT_AI_TYPE']
if tp == ChatAITypeChoices.gpt:
url = config['GPT_BASE_URL']
api_key = config['GPT_API_KEY']
proxy = config['GPT_PROXY']
model = config['GPT_MODEL']
else:
url = config['DEEPSEEK_BASE_URL']
api_key = config['DEEPSEEK_API_KEY']
proxy = config['DEEPSEEK_PROXY']
model = config['DEEPSEEK_MODEL']
kwargs = {
'base_url': url or None,
'api_key': api_key,
}
try:
if proxy:
kwargs['http_client'] = httpx.Client(
proxies=proxy,
transport=httpx.HTTPTransport(local_address='0.0.0.0')
)
client = openai.OpenAI(**kwargs)
ok = False
error = ''
client.chat.completions.create(
messages=[
{
"role": "user",
"content": "Say this is a test",
}
],
model=model,
)
ok = True
except openai.APIConnectionError as e:
error = str(e.__cause__) # an underlying Exception, likely raised within httpx.
except openai.APIStatusError as e:
error = str(e.message)
except Exception as e:
ok, error = False, str(e)
if ok:
_status, msg = status.HTTP_200_OK, _('Test success')
else:
_status, msg = status.HTTP_400_BAD_REQUEST, error
return Response(status=_status, data={'msg': msg})
class ChatPromptViewSet(JMSModelViewSet):
serializer_classes = {
'default': serializers.ChatPromptSerializer,

View File

@@ -154,7 +154,10 @@ class SettingsApi(generics.RetrieveUpdateAPIView):
def parse_serializer_data(self, serializer):
data = []
fields = self.get_fields()
encrypted_items = [name for name, field in fields.items() if field.write_only]
encrypted_items = [
name for name, field in fields.items()
if field.write_only or getattr(field, 'encrypted', False)
]
category = self.request.query_params.get('category', '')
for name, value in serializer.validated_data.items():
encrypted = name in encrypted_items

View File

@@ -14,18 +14,5 @@ class ChatAIMethodChoices(TextChoices):
class ChatAITypeChoices(TextChoices):
gpt = 'gpt', 'GPT'
deep_seek = 'deep-seek', 'DeepSeek'
class GPTModelChoices(TextChoices):
gpt_4o_mini = 'gpt-4o-mini', 'gpt-4o-mini'
gpt_4o = 'gpt-4o', 'gpt-4o'
o3_mini = 'o3-mini', 'o3-mini'
o1_mini = 'o1-mini', 'o1-mini'
o1 = 'o1', 'o1'
class DeepSeekModelChoices(TextChoices):
deepseek_chat = 'deepseek-chat', 'DeepSeek-V3'
deepseek_reasoner = 'deepseek-reasoner', 'DeepSeek-R1'
openai = 'openai', 'Openai'
ollama = 'ollama', 'Ollama'

View File

@@ -1,6 +1,7 @@
import json
import os
import shutil
from typing import Any, Dict, List
from django.conf import settings
from django.core.files.base import ContentFile
@@ -14,7 +15,6 @@ from rest_framework.utils.encoders import JSONEncoder
from common.db.models import JMSBaseModel
from common.db.utils import Encryptor
from common.utils import get_logger
from .const import ChatAITypeChoices
from .signals import setting_changed
logger = get_logger(__name__)
@@ -196,20 +196,25 @@ class ChatPrompt(JMSBaseModel):
return self.name
def get_chatai_data():
data = {
'url': settings.GPT_BASE_URL,
'api_key': settings.GPT_API_KEY,
'proxy': settings.GPT_PROXY,
'model': settings.GPT_MODEL,
}
if settings.CHAT_AI_TYPE != ChatAITypeChoices.gpt:
data['url'] = settings.DEEPSEEK_BASE_URL
data['api_key'] = settings.DEEPSEEK_API_KEY
data['proxy'] = settings.DEEPSEEK_PROXY
data['model'] = settings.DEEPSEEK_MODEL
def get_chatai_data() -> Dict[str, Any]:
raw_providers = settings.CHAT_AI_PROVIDERS
providers: List[dict] = [p for p in raw_providers if isinstance(p, dict)]
return data
if not providers:
return {}
selected = next(
(p for p in providers if p.get('is_assistant')),
providers[0],
)
return {
'url': selected.get('base_url'),
'api_key': selected.get('api_key'),
'proxy': selected.get('proxy'),
'model': selected.get('model'),
'name': selected.get('name'),
}
def init_sqlite_db():

View File

@@ -10,11 +10,12 @@ from common.utils import date_expired_default
__all__ = [
'AnnouncementSettingSerializer', 'OpsSettingSerializer', 'VaultSettingSerializer',
'HashicorpKVSerializer', 'AzureKVSerializer', 'TicketSettingSerializer',
'ChatAISettingSerializer', 'VirtualAppSerializer', 'AmazonSMSerializer',
'ChatAIProviderSerializer', 'ChatAISettingSerializer',
'VirtualAppSerializer', 'AmazonSMSerializer',
]
from settings.const import (
ChatAITypeChoices, GPTModelChoices, DeepSeekModelChoices, ChatAIMethodChoices
ChatAITypeChoices, ChatAIMethodChoices
)
@@ -120,6 +121,29 @@ class AmazonSMSerializer(serializers.Serializer):
)
class ChatAIProviderListSerializer(serializers.ListSerializer):
# 标记整个列表需要加密存储,避免明文保存 API Key
encrypted = True
class ChatAIProviderSerializer(serializers.Serializer):
type = serializers.ChoiceField(
default=ChatAITypeChoices.openai, choices=ChatAITypeChoices.choices,
label=_("Types"), required=False,
)
base_url = serializers.CharField(
allow_blank=True, required=False, label=_('Base URL'),
help_text=_('The base URL of the Chat service.')
)
api_key = EncryptedField(
allow_blank=True, required=False, label=_('API Key'),
)
proxy = serializers.CharField(
allow_blank=True, required=False, label=_('Proxy'),
help_text=_('The proxy server address of the GPT service. For example: http://ip:port')
)
class ChatAISettingSerializer(serializers.Serializer):
PREFIX_TITLE = _('Chat AI')
@@ -130,44 +154,14 @@ class ChatAISettingSerializer(serializers.Serializer):
default=ChatAIMethodChoices.api, choices=ChatAIMethodChoices.choices,
label=_("Method"), required=False,
)
CHAT_AI_PROVIDERS = ChatAIProviderListSerializer(
child=ChatAIProviderSerializer(),
allow_empty=True, required=False, default=list, label=_('Providers')
)
CHAT_AI_EMBED_URL = serializers.CharField(
allow_blank=True, required=False, label=_('Base URL'),
help_text=_('The base URL of the Chat service.')
)
CHAT_AI_TYPE = serializers.ChoiceField(
default=ChatAITypeChoices.gpt, choices=ChatAITypeChoices.choices,
label=_("Types"), required=False,
)
GPT_BASE_URL = serializers.CharField(
allow_blank=True, required=False, label=_('Base URL'),
help_text=_('The base URL of the Chat service.')
)
GPT_API_KEY = EncryptedField(
allow_blank=True, required=False, label=_('API Key'),
)
GPT_PROXY = serializers.CharField(
allow_blank=True, required=False, label=_('Proxy'),
help_text=_('The proxy server address of the GPT service. For example: http://ip:port')
)
GPT_MODEL = serializers.ChoiceField(
default=GPTModelChoices.gpt_4o_mini, choices=GPTModelChoices.choices,
label=_("GPT Model"), required=False,
)
DEEPSEEK_BASE_URL = serializers.CharField(
allow_blank=True, required=False, label=_('Base URL'),
help_text=_('The base URL of the Chat service.')
)
DEEPSEEK_API_KEY = EncryptedField(
allow_blank=True, required=False, label=_('API Key'),
)
DEEPSEEK_PROXY = serializers.CharField(
allow_blank=True, required=False, label=_('Proxy'),
help_text=_('The proxy server address of the GPT service. For example: http://ip:port')
)
DEEPSEEK_MODEL = serializers.ChoiceField(
default=DeepSeekModelChoices.deepseek_chat, choices=DeepSeekModelChoices.choices,
label=_("DeepSeek Model"), required=False,
)
class TicketSettingSerializer(serializers.Serializer):

View File

@@ -73,8 +73,6 @@ class PrivateSettingSerializer(PublicSettingSerializer):
CHAT_AI_ENABLED = serializers.BooleanField()
CHAT_AI_METHOD = serializers.CharField()
CHAT_AI_EMBED_URL = serializers.CharField()
CHAT_AI_TYPE = serializers.CharField()
GPT_MODEL = serializers.CharField()
FILE_UPLOAD_SIZE_LIMIT_MB = serializers.IntegerField()
FTP_FILE_MAX_STORE = serializers.IntegerField()
LOKI_LOG_ENABLED = serializers.BooleanField()

View File

@@ -21,7 +21,6 @@ urlpatterns = [
path('sms/<str:backend>/testing/', api.SMSTestingAPI.as_view(), name='sms-testing'),
path('sms/backend/', api.SMSBackendAPI.as_view(), name='sms-backend'),
path('vault/<str:backend>/testing/', api.VaultTestingAPI.as_view(), name='vault-testing'),
path('chatai/testing/', api.ChatAITestingAPI.as_view(), name='chatai-testing'),
path('vault/sync/', api.VaultSyncDataAPI.as_view(), name='vault-sync'),
path('security/block-ip/', api.BlockIPSecurityAPI.as_view(), name='block-ip'),
path('security/unlock-ip/', api.UnlockIPSecurityAPI.as_view(), name='unlock-ip'),

View File

@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
#
from .applet import *
from .chat import *
from .component import *
from .session import *
from .virtualapp import *

View File

@@ -0,0 +1 @@
from .chat import *

View File

@@ -0,0 +1,15 @@
from common.api import JMSBulkModelViewSet
from terminal import serializers
from terminal.filters import ChatFilter
from terminal.models import Chat
__all__ = ['ChatViewSet']
class ChatViewSet(JMSBulkModelViewSet):
queryset = Chat.objects.all()
serializer_class = serializers.ChatSerializer
filterset_class = ChatFilter
search_fields = ['title']
ordering_fields = ['date_updated']
ordering = ['-date_updated']

View File

@@ -2,7 +2,7 @@ from django.db.models import QuerySet
from django_filters import rest_framework as filters
from orgs.utils import filter_org_queryset
from terminal.models import Command, CommandStorage, Session
from terminal.models import Command, CommandStorage, Session, Chat
class CommandFilter(filters.FilterSet):
@@ -79,7 +79,34 @@ class CommandStorageFilter(filters.FilterSet):
model = CommandStorage
fields = ['real', 'name', 'type', 'is_default']
def filter_real(self, queryset, name, value):
@staticmethod
def filter_real(queryset, name, value):
if value:
queryset = queryset.exclude(name='null')
return queryset
class ChatFilter(filters.FilterSet):
ids = filters.BooleanFilter(method='filter_ids')
folder_ids = filters.BooleanFilter(method='filter_folder_ids')
class Meta:
model = Chat
fields = [
'title', 'user_id', 'pinned', 'folder_id',
'archived', 'socket_id', 'share_id'
]
@staticmethod
def filter_ids(queryset, name, value):
ids = value.split(',')
queryset = queryset.filter(id__in=ids)
return queryset
@staticmethod
def filter_folder_ids(queryset, name, value):
ids = value.split(',')
queryset = queryset.filter(folder_id__in=ids)
return queryset

View File

@@ -0,0 +1,38 @@
# Generated by Django 4.1.13 on 2025-09-30 06:57
from django.db import migrations, models
import uuid
class Migration(migrations.Migration):
dependencies = [
('terminal', '0010_alter_command_risk_level_alter_session_login_from_and_more'),
]
operations = [
migrations.CreateModel(
name='Chat',
fields=[
('created_by', models.CharField(blank=True, max_length=128, null=True, verbose_name='Created by')),
('updated_by', models.CharField(blank=True, max_length=128, null=True, verbose_name='Updated by')),
('date_created', models.DateTimeField(auto_now_add=True, null=True, verbose_name='Date created')),
('date_updated', models.DateTimeField(auto_now=True, verbose_name='Date updated')),
('comment', models.TextField(blank=True, default='', verbose_name='Comment')),
('id', models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False)),
('title', models.CharField(max_length=256, verbose_name='Title')),
('chat', models.JSONField(default=dict, verbose_name='Chat')),
('meta', models.JSONField(default=dict, verbose_name='Meta')),
('pinned', models.BooleanField(default=False, verbose_name='Pinned')),
('archived', models.BooleanField(default=False, verbose_name='Archived')),
('share_id', models.CharField(blank=True, default='', max_length=36)),
('folder_id', models.CharField(blank=True, default='', max_length=36)),
('socket_id', models.CharField(blank=True, default='', max_length=36)),
('user_id', models.CharField(blank=True, db_index=True, default='', max_length=36)),
('session_info', models.JSONField(default=dict, verbose_name='Session Info')),
],
options={
'verbose_name': 'Chat',
},
),
]

View File

@@ -1,4 +1,5 @@
from .session import *
from .component import *
from .applet import *
from .chat import *
from .component import *
from .session import *
from .virtualapp import *

View File

@@ -0,0 +1 @@
from .chat import *

View File

@@ -0,0 +1,30 @@
from django.db import models
from django.utils.translation import gettext_lazy as _
from common.db.models import JMSBaseModel
from common.utils import get_logger
logger = get_logger(__name__)
__all__ = ['Chat']
class Chat(JMSBaseModel):
# id == session_id # 36 chars
title = models.CharField(max_length=256, verbose_name=_('Title'))
chat = models.JSONField(default=dict, verbose_name=_('Chat'))
meta = models.JSONField(default=dict, verbose_name=_('Meta'))
pinned = models.BooleanField(default=False, verbose_name=_('Pinned'))
archived = models.BooleanField(default=False, verbose_name=_('Archived'))
share_id = models.CharField(blank=True, default='', max_length=36)
folder_id = models.CharField(blank=True, default='', max_length=36)
socket_id = models.CharField(blank=True, default='', max_length=36)
user_id = models.CharField(blank=True, default='', max_length=36, db_index=True)
session_info = models.JSONField(default=dict, verbose_name=_('Session Info'))
class Meta:
verbose_name = _('Chat')
def __str__(self):
return self.title

View File

@@ -123,11 +123,10 @@ class Terminal(StorageMixin, TerminalStatusMixin, JMSBaseModel):
def get_chat_ai_setting():
data = get_chatai_data()
return {
'GPT_BASE_URL': data['url'],
'GPT_API_KEY': data['api_key'],
'GPT_PROXY': data['proxy'],
'GPT_MODEL': data['model'],
'CHAT_AI_TYPE': settings.CHAT_AI_TYPE,
'GPT_BASE_URL': data.get('url'),
'GPT_API_KEY': data.get('api_key'),
'GPT_PROXY': data.get('proxy'),
'CHAT_AI_PROVIDERS': settings.CHAT_AI_PROVIDERS,
}
@staticmethod

View File

@@ -2,8 +2,10 @@
#
from .applet import *
from .applet_host import *
from .chat import *
from .command import *
from .endpoint import *
from .loki import *
from .session import *
from .sharing import *
from .storage import *
@@ -11,4 +13,3 @@ from .task import *
from .terminal import *
from .virtualapp import *
from .virtualapp_provider import *
from .loki import *

View File

@@ -0,0 +1,28 @@
from rest_framework import serializers
from common.serializers import CommonBulkModelSerializer
from terminal.models import Chat
__all__ = ['ChatSerializer']
class ChatSerializer(CommonBulkModelSerializer):
created_at = serializers.SerializerMethodField()
updated_at = serializers.SerializerMethodField()
class Meta:
model = Chat
fields_mini = ['id', 'title', 'created_at', 'updated_at']
fields = fields_mini + [
'chat', 'meta', 'pinned', 'archived',
'share_id', 'folder_id',
'user_id', 'session_info'
]
@staticmethod
def get_created_at(obj):
return int(obj.date_created.timestamp())
@staticmethod
def get_updated_at(obj):
return int(obj.date_updated.timestamp())

View File

@@ -32,6 +32,7 @@ router.register(r'virtual-apps', api.VirtualAppViewSet, 'virtual-app')
router.register(r'app-providers', api.AppProviderViewSet, 'app-provider')
router.register(r'app-providers/((?P<provider>[^/.]+)/)?apps', api.AppProviderAppViewSet, 'app-provider-app')
router.register(r'virtual-app-publications', api.VirtualAppPublicationViewSet, 'virtual-app-publication')
router.register(r'chats', api.ChatViewSet, 'chat')
urlpatterns = [
path('my-sessions/', api.MySessionAPIView.as_view(), name='my-session'),

View File

@@ -199,11 +199,19 @@ class UserChangePasswordApi(UserQuerysetMixin, generics.UpdateAPIView):
class UserUnblockPKApi(UserQuerysetMixin, generics.UpdateAPIView):
serializer_class = serializers.UserSerializer
def get_object(self):
pk = self.kwargs.get('pk')
if is_uuid(pk):
return super().get_object()
else:
return self.get_queryset().filter(username=pk).first()
def perform_update(self, serializer):
user = self.get_object()
username = user.username if user else ''
LoginBlockUtil.unblock_user(username)
MFABlockUtils.unblock_user(username)
if not user:
return Response({"error": _("User not found")}, status=404)
user.unblock_login()
class UserResetMFAApi(UserQuerysetMixin, generics.RetrieveAPIView):

View File

@@ -274,8 +274,8 @@ class User(
LoginBlockUtil.unblock_user(self.username)
MFABlockUtils.unblock_user(self.username)
@lazyproperty
def login_blocked(self):
@property
def is_login_blocked(self):
from users.utils import LoginBlockUtil, MFABlockUtils
if LoginBlockUtil.is_user_block(self.username):
@@ -284,6 +284,13 @@ class User(
return True
return False
@classmethod
def block_login(cls, username):
from users.utils import LoginBlockUtil, MFABlockUtils
LoginBlockUtil.block_user(username)
MFABlockUtils.block_user(username)
def delete(self, using=None, keep_parents=False):
if self.pk == 1 or self.username == "admin":
raise PermissionDenied(_("Can not delete admin user"))

View File

@@ -123,7 +123,7 @@ class UserSerializer(
mfa_force_enabled = serializers.BooleanField(
read_only=True, label=_("MFA force enabled")
)
login_blocked = serializers.BooleanField(read_only=True, label=_("Login blocked"))
is_login_blocked = serializers.BooleanField(read_only=True, label=_("Login blocked"))
is_expired = serializers.BooleanField(read_only=True, label=_("Is expired"))
is_valid = serializers.BooleanField(read_only=True, label=_("Is valid"))
is_otp_secret_key_bound = serializers.BooleanField(
@@ -193,6 +193,7 @@ class UserSerializer(
"is_valid", "is_expired", "is_active", # 布尔字段
"is_otp_secret_key_bound", "can_public_key_auth",
"mfa_enabled", "need_update_password", "is_face_code_set",
"is_login_blocked",
]
# 包含不太常用的字段,可以没有
fields_verbose = (
@@ -211,7 +212,7 @@ class UserSerializer(
# 多对多字段
fields_m2m = ["groups", "system_roles", "org_roles", "orgs_roles", "labels"]
# 在serializer 上定义的字段
fields_custom = ["login_blocked", "password_strategy"]
fields_custom = ["is_login_blocked", "password_strategy"]
fields = fields_verbose + fields_fk + fields_m2m + fields_custom
fields_unexport = ["avatar_url", "is_service_account"]

View File

@@ -28,6 +28,6 @@ urlpatterns = [
path('users/<uuid:pk>/password/', api.UserChangePasswordApi.as_view(), name='change-user-password'),
path('users/<uuid:pk>/password/reset/', api.UserResetPasswordApi.as_view(), name='user-reset-password'),
path('users/<uuid:pk>/pubkey/reset/', api.UserResetPKApi.as_view(), name='user-public-key-reset'),
path('users/<uuid:pk>/unblock/', api.UserUnblockPKApi.as_view(), name='user-unblock'),
path('users/<str:pk>/unblock/', api.UserUnblockPKApi.as_view(), name='user-unblock'),
]
urlpatterns += router.urls

View File

@@ -186,6 +186,13 @@ class BlockUtilBase:
def is_block(self):
return bool(cache.get(self.block_key))
@classmethod
def block_user(cls, username):
username = username.lower()
block_key = cls.BLOCK_KEY_TMPL.format(username)
key_ttl = int(settings.SECURITY_LOGIN_LIMIT_TIME) * 60
cache.set(block_key, True, key_ttl)
@classmethod
def get_blocked_usernames(cls):
key = cls.BLOCK_KEY_TMPL.format('*')

View File

@@ -0,0 +1,358 @@
import os
import sys
import django
import random
from datetime import datetime
if os.path.exists('../../apps'):
sys.path.insert(0, '../../apps')
if os.path.exists('../apps'):
sys.path.insert(0, '../apps')
elif os.path.exists('./apps'):
sys.path.insert(0, './apps')
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "jumpserver.settings")
django.setup()
from assets.models import Asset, Node
from orgs.models import Organization
from django.db.models import Count
OUTPUT_FILE = 'report_cleanup_and_keep_one_node_for_multi_parent_nodes_assets.txt'
# Special organization IDs and names
SPECIAL_ORGS = {
'00000000-0000-0000-0000-000000000000': 'GLOBAL',
'00000000-0000-0000-0000-000000000002': 'DEFAULT',
'00000000-0000-0000-0000-000000000004': 'SYSTEM',
}
try:
AssetNodeThrough = Asset.nodes.through
except Exception as e:
print("Failed to get AssetNodeThrough model. Check Asset.nodes field definition.")
raise e
def log(msg=''):
"""Print log with timestamp to console"""
print(f"[{datetime.now().strftime('%H:%M:%S')}] {msg}")
def write_report(content):
"""Write content to report file"""
with open(OUTPUT_FILE, 'a', encoding='utf-8') as f:
f.write(content)
def get_org_name(org_id, orgs_map):
"""Get organization name, check special orgs first, then orgs_map"""
# Check if it's a special organization
org_id_str = str(org_id)
if org_id_str in SPECIAL_ORGS:
return SPECIAL_ORGS[org_id_str]
# Try to get from orgs_map
org = orgs_map.get(org_id)
if org:
return org.name
return 'Unknown'
def find_and_cleanup_multi_parent_assets():
"""Find and cleanup assets with multiple parent nodes"""
log("Searching for assets with multiple parent nodes...")
# Find all asset_ids that belong to multiple node_ids
multi_parent_assets = AssetNodeThrough.objects.values('asset_id').annotate(
node_count=Count('node_id', distinct=True)
).filter(node_count__gt=1).order_by('-node_count')
total_count = multi_parent_assets.count()
log(f"Found {total_count:,} assets with multiple parent nodes\n")
if total_count == 0:
log("✓ All assets already have single parent node")
return {}
# Collect all asset_ids and node_ids
asset_ids = [item['asset_id'] for item in multi_parent_assets]
# Get all through records
all_through_records = AssetNodeThrough.objects.filter(asset_id__in=asset_ids)
node_ids = list(set(through.node_id for through in all_through_records))
# Batch fetch all objects
log("Batch loading Asset objects...")
assets_map = {asset.id: asset for asset in Asset.objects.filter(id__in=asset_ids)}
log("Batch loading Node objects...")
nodes_map = {node.id: node for node in Node.objects.filter(id__in=node_ids)}
# Batch fetch all Organization objects
org_ids = list(set(asset.org_id for asset in assets_map.values())) + \
list(set(node.org_id for node in nodes_map.values()))
org_ids = list(set(org_ids))
log("Batch loading Organization objects...")
orgs_map = {org.id: org for org in Organization.objects.filter(id__in=org_ids)}
# Build mapping of asset_id -> list of through_records
asset_nodes_map = {}
for through in all_through_records:
if through.asset_id not in asset_nodes_map:
asset_nodes_map[through.asset_id] = []
asset_nodes_map[through.asset_id].append(through)
# Organize by organization
org_cleanup_data = {} # org_id -> { asset_id -> { keep_node_id, remove_node_ids } }
for item in multi_parent_assets:
asset_id = item['asset_id']
# Get Asset object
asset = assets_map.get(asset_id)
if not asset:
log(f"⚠ Asset {asset_id} not found in map, skipping")
continue
org_id = asset.org_id
# Initialize org data if not exists
if org_id not in org_cleanup_data:
org_cleanup_data[org_id] = {}
# Get all nodes for this asset
through_records = asset_nodes_map.get(asset_id, [])
if len(through_records) < 2:
continue
# Randomly select one node to keep
keep_through = random.choice(through_records)
remove_throughs = [t for t in through_records if t.id != keep_through.id]
org_cleanup_data[org_id][asset_id] = {
'asset_name': asset.name,
'keep_node_id': keep_through.node_id,
'keep_node': nodes_map.get(keep_through.node_id),
'remove_records': remove_throughs,
'remove_nodes': [nodes_map.get(t.node_id) for t in remove_throughs]
}
return org_cleanup_data
def perform_cleanup(org_cleanup_data, dry_run=False):
"""Perform the actual cleanup - delete extra node relationships"""
if dry_run:
log("DRY RUN: Simulating cleanup process (no data will be deleted)...")
else:
log("\nStarting cleanup process...")
total_deleted = 0
for org_id in org_cleanup_data.keys():
for asset_id, cleanup_info in org_cleanup_data[org_id].items():
# Delete the extra relationships
for through_record in cleanup_info['remove_records']:
if not dry_run:
through_record.delete()
total_deleted += 1
return total_deleted
def verify_cleanup():
"""Verify that there are no more assets with multiple parent nodes"""
log("\n" + "="*80)
log("VERIFICATION: Checking for remaining assets with multiple parent nodes...")
log("="*80)
# Find all asset_ids that belong to multiple node_ids
multi_parent_assets = AssetNodeThrough.objects.values('asset_id').annotate(
node_count=Count('node_id', distinct=True)
).filter(node_count__gt=1).order_by('-node_count')
remaining_count = multi_parent_assets.count()
if remaining_count == 0:
log(f"✓ Verification successful: No assets with multiple parent nodes remaining\n")
return True
else:
log(f"✗ Verification failed: Found {remaining_count:,} assets still with multiple parent nodes\n")
# Show some details
for item in multi_parent_assets[:10]:
asset_id = item['asset_id']
node_count = item['node_count']
try:
asset = Asset.objects.get(id=asset_id)
log(f" - Asset: {asset.name} ({asset_id}) has {node_count} parent nodes")
except:
log(f" - Asset ID: {asset_id} has {node_count} parent nodes")
if remaining_count > 10:
log(f" ... and {remaining_count - 10} more")
return False
def generate_report(org_cleanup_data, total_deleted):
"""Generate and write report to file"""
# Clear previous report
if os.path.exists(OUTPUT_FILE):
os.remove(OUTPUT_FILE)
# Write header
write_report(f"Multi-Parent Assets Cleanup Report\n")
write_report(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
write_report(f"{'='*80}\n\n")
# Get all organizations
all_org_ids = list(set(org_id for org_id in org_cleanup_data.keys()))
all_orgs = {org.id: org for org in Organization.objects.filter(id__in=all_org_ids)}
# Calculate statistics
total_orgs = Organization.objects.count()
orgs_processed = len(org_cleanup_data)
orgs_no_issues = total_orgs - orgs_processed
total_assets_cleaned = sum(len(assets) for assets in org_cleanup_data.values())
# Overview
write_report("OVERVIEW\n")
write_report(f"{'-'*80}\n")
write_report(f"Total organizations: {total_orgs:,}\n")
write_report(f"Organizations processed: {orgs_processed:,}\n")
write_report(f"Organizations without issues: {orgs_no_issues:,}\n")
write_report(f"Total assets cleaned: {total_assets_cleaned:,}\n")
total_relationships = AssetNodeThrough.objects.count()
write_report(f"Total relationships (through records): {total_relationships:,}\n")
write_report(f"Total relationships deleted: {total_deleted:,}\n\n")
# Summary by organization
write_report("Summary by Organization:\n")
for org_id in sorted(org_cleanup_data.keys()):
org_name = get_org_name(org_id, all_orgs)
asset_count = len(org_cleanup_data[org_id])
write_report(f" - {org_name} ({org_id}): {asset_count:,} assets cleaned\n")
write_report(f"\n{'='*80}\n\n")
# Detailed cleanup information grouped by organization
for org_id in sorted(org_cleanup_data.keys()):
org_name = get_org_name(org_id, all_orgs)
asset_count = len(org_cleanup_data[org_id])
write_report(f"ORGANIZATION: {org_name} ({org_id})\n")
write_report(f"Total assets cleaned: {asset_count:,}\n")
write_report(f"{'-'*80}\n\n")
for asset_id, cleanup_info in org_cleanup_data[org_id].items():
write_report(f"Asset: {cleanup_info['asset_name']} ({asset_id})\n")
# Kept node
keep_node = cleanup_info['keep_node']
if keep_node:
write_report(f" ✓ Kept: {keep_node.name} (key: {keep_node.key}) (id: {keep_node.id})\n")
else:
write_report(f" ✓ Kept: Unknown (id: {cleanup_info['keep_node_id']})\n")
# Removed nodes
write_report(f" ✗ Removed: {len(cleanup_info['remove_nodes'])} node(s)\n")
for node in cleanup_info['remove_nodes']:
if node:
write_report(f" - {node.name} (key: {node.key}) (id: {node.id})\n")
else:
write_report(f" - Unknown\n")
write_report(f"\n")
write_report(f"{'='*80}\n\n")
log(f"✓ Report written to {OUTPUT_FILE}")
def main():
try:
# Display warning banner
warning_message = """
╔══════════════════════════════════════════════════════════════════════════════╗
║ ⚠️ WARNING ⚠️ ║
║ ║
║ This script is designed for TEST/FAKE DATA ONLY! ║
║ DO NOT run this script in PRODUCTION environment! ║
║ ║
║ This script will DELETE asset-node relationships from the database. ║
║ Use only for data cleanup in development/testing environments. ║
║ ║
╚══════════════════════════════════════════════════════════════════════════════╝
"""
print(warning_message)
# Ask user to confirm before proceeding
confirm = input("Do you understand the warning and want to continue? (yes/no): ").strip().lower()
if confirm not in ['yes', 'y']:
log("✗ Operation cancelled by user")
sys.exit(0)
log("✓ Proceeding with operation\n")
org_cleanup_data = find_and_cleanup_multi_parent_assets()
if not org_cleanup_data:
log("✓ Cleanup complete, no assets to process")
sys.exit(0)
total_assets = sum(len(assets) for assets in org_cleanup_data.values())
log(f"\nProcessing {total_assets:,} assets across {len(org_cleanup_data):,} organizations...")
# First, do a dry-run to show what will be deleted
log("\n" + "="*80)
log("PREVIEW: Simulating cleanup process...")
log("="*80)
total_deleted_preview = perform_cleanup(org_cleanup_data, dry_run=True)
log(f"✓ Dry-run complete: {total_deleted_preview:,} relationships would be deleted\n")
# Generate preview report
generate_report(org_cleanup_data, total_deleted_preview)
log(f"✓ Preview report written to {OUTPUT_FILE}\n")
# Ask for confirmation 3 times before actual deletion
log("="*80)
log("FINAL CONFIRMATION: Do you want to proceed with actual cleanup?")
log("="*80)
confirmation_count = 3
for attempt in range(1, confirmation_count + 1):
response = input(f"Confirm cleanup (attempt {attempt}/{confirmation_count})? (yes/no): ").strip().lower()
if response not in ['yes', 'y']:
log(f"✗ Cleanup cancelled by user at attempt {attempt}")
sys.exit(1)
log("✓ All confirmations received, proceeding with actual cleanup")
# Perform cleanup
total_deleted = perform_cleanup(org_cleanup_data)
log(f"✓ Deleted {total_deleted:,} relationships")
# Generate final report
generate_report(org_cleanup_data, total_deleted)
# Verify cleanup by checking for remaining multi-parent assets
verify_cleanup()
log(f"✓ Cleanup complete: processed {total_assets:,} assets")
sys.exit(0)
except Exception as e:
log(f"✗ Error occurred: {str(e)}")
import traceback
traceback.print_exc()
sys.exit(2)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,270 @@
import os
import sys
import django
from datetime import datetime
if os.path.exists('../../apps'):
sys.path.insert(0, '../../apps')
if os.path.exists('../apps'):
sys.path.insert(0, '../apps')
elif os.path.exists('./apps'):
sys.path.insert(0, './apps')
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "jumpserver.settings")
django.setup()
from assets.models import Asset, Node
from orgs.models import Organization
from django.db.models import Count
OUTPUT_FILE = 'report_find_multi_parent_nodes_assets.txt'
# Special organization IDs and names
SPECIAL_ORGS = {
'00000000-0000-0000-0000-000000000000': 'GLOBAL',
'00000000-0000-0000-0000-000000000002': 'DEFAULT',
'00000000-0000-0000-0000-000000000004': 'SYSTEM',
}
try:
AssetNodeThrough = Asset.nodes.through
except Exception as e:
print("Failed to get AssetNodeThrough model. Check Asset.nodes field definition.")
raise e
def log(msg=''):
"""Print log with timestamp"""
print(f"[{datetime.now().strftime('%H:%M:%S')}] {msg}")
def get_org_name(org_id, orgs_map):
"""Get organization name, check special orgs first, then orgs_map"""
# Check if it's a special organization
org_id_str = str(org_id)
if org_id_str in SPECIAL_ORGS:
return SPECIAL_ORGS[org_id_str]
# Try to get from orgs_map
org = orgs_map.get(org_id)
if org:
return org.name
return 'Unknown'
def write_report(content):
"""Write content to report file"""
with open(OUTPUT_FILE, 'a', encoding='utf-8') as f:
f.write(content)
def find_assets_multiple_parents():
"""Find assets belonging to multiple node_ids organized by organization"""
log("Searching for assets with multiple parent nodes...")
# Find all asset_ids that belong to multiple node_ids
multi_parent_assets = AssetNodeThrough.objects.values('asset_id').annotate(
node_count=Count('node_id', distinct=True)
).filter(node_count__gt=1).order_by('-node_count')
total_count = multi_parent_assets.count()
log(f"Found {total_count:,} assets with multiple parent nodes\n")
if total_count == 0:
log("✓ All assets belong to only one node")
return {}
# Collect all asset_ids and node_ids that need to be fetched
asset_ids = [item['asset_id'] for item in multi_parent_assets]
# Get all through records for these assets
all_through_records = AssetNodeThrough.objects.filter(asset_id__in=asset_ids)
node_ids = list(set(through.node_id for through in all_through_records))
# Batch fetch all Asset and Node objects
log("Batch loading Asset objects...")
assets_map = {asset.id: asset for asset in Asset.objects.filter(id__in=asset_ids)}
log("Batch loading Node objects...")
nodes_map = {node.id: node for node in Node.objects.filter(id__in=node_ids)}
# Batch fetch all Organization objects
org_ids = list(set(asset.org_id for asset in assets_map.values())) + \
list(set(node.org_id for node in nodes_map.values()))
org_ids = list(set(org_ids)) # Remove duplicates
log("Batch loading Organization objects...")
orgs_map = {org.id: org for org in Organization.objects.filter(id__in=org_ids)}
# Build mapping of asset_id -> list of through_records
asset_nodes_map = {}
for through in all_through_records:
if through.asset_id not in asset_nodes_map:
asset_nodes_map[through.asset_id] = []
asset_nodes_map[through.asset_id].append(through)
# Organize by organization first, then by node count, then by asset
org_assets_data = {} # org_id -> { node_count -> [asset_data] }
for item in multi_parent_assets:
asset_id = item['asset_id']
node_count = item['node_count']
# Get Asset object from map
asset = assets_map.get(asset_id)
if not asset:
log(f"⚠ Asset {asset_id} not found in map, skipping")
continue
org_id = asset.org_id
# Initialize org data if not exists
if org_id not in org_assets_data:
org_assets_data[org_id] = {}
# Get all nodes for this asset
through_records = asset_nodes_map.get(asset_id, [])
node_details = []
for through in through_records:
# Get Node object from map
node = nodes_map.get(through.node_id)
if not node:
log(f"⚠ Node {through.node_id} not found in map, skipping")
continue
node_details.append({
'id': node.id,
'name': node.name,
'key': node.key,
'path': node.full_value if hasattr(node, 'full_value') else ''
})
if not node_details:
continue
if node_count not in org_assets_data[org_id]:
org_assets_data[org_id][node_count] = []
org_assets_data[org_id][node_count].append({
'asset_id': asset.id,
'asset_name': asset.name,
'nodes': node_details
})
return org_assets_data
def generate_report(org_assets_data):
"""Generate and write report to file organized by organization"""
# Clear previous report
if os.path.exists(OUTPUT_FILE):
os.remove(OUTPUT_FILE)
# Write header
write_report(f"Multi-Parent Assets Report\n")
write_report(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
write_report(f"{'='*80}\n\n")
# Get all organizations
all_org_ids = list(set(org_id for org_id in org_assets_data.keys()))
all_orgs = {org.id: org for org in Organization.objects.filter(id__in=all_org_ids)}
# Calculate statistics
total_orgs = Organization.objects.count()
orgs_with_issues = len(org_assets_data)
orgs_without_issues = total_orgs - orgs_with_issues
total_assets_with_issues = sum(
len(assets)
for org_id in org_assets_data
for assets in org_assets_data[org_id].values()
)
# Overview
write_report("OVERVIEW\n")
write_report(f"{'-'*80}\n")
write_report(f"Total organizations: {total_orgs:,}\n")
write_report(f"Organizations with multiple-parent assets: {orgs_with_issues:,}\n")
write_report(f"Organizations without issues: {orgs_without_issues:,}\n")
write_report(f"Total assets with multiple parent nodes: {total_assets_with_issues:,}\n\n")
# Summary by organization
write_report("Summary by Organization:\n")
for org_id in sorted(org_assets_data.keys()):
org_name = get_org_name(org_id, all_orgs)
org_asset_count = sum(
len(assets)
for assets in org_assets_data[org_id].values()
)
write_report(f" - {org_name} ({org_id}): {org_asset_count:,} assets\n")
write_report(f"\n{'='*80}\n\n")
# Detailed sections grouped by organization, then node count
for org_id in sorted(org_assets_data.keys()):
org_name = get_org_name(org_id, all_orgs)
org_asset_count = sum(
len(assets)
for assets in org_assets_data[org_id].values()
)
write_report(f"ORGANIZATION: {org_name} ({org_id})\n")
write_report(f"Total assets with issues: {org_asset_count:,}\n")
write_report(f"{'-'*80}\n\n")
# Group by node count within this organization
for node_count in sorted(org_assets_data[org_id].keys(), reverse=True):
assets = org_assets_data[org_id][node_count]
write_report(f" Section: {node_count} Parent Nodes ({len(assets):,} assets)\n")
write_report(f" {'-'*76}\n\n")
for asset in assets:
write_report(f" {asset['asset_name']} ({asset['asset_id']})\n")
for node in asset['nodes']:
write_report(f" {node['name']} ({node['key']}) ({node['path']}) ({node['id']})\n")
write_report(f"\n")
write_report(f"\n")
write_report(f"{'='*80}\n\n")
log(f"✓ Report written to {OUTPUT_FILE}")
def main():
try:
org_assets_data = find_assets_multiple_parents()
if not org_assets_data:
log("✓ Detection complete, no issues found")
sys.exit(0)
total_assets = sum(
len(assets)
for org_id in org_assets_data
for assets in org_assets_data[org_id].values()
)
log(f"Generating report for {total_assets:,} assets across {len(org_assets_data):,} organizations...")
generate_report(org_assets_data)
log(f"✗ Detected {total_assets:,} assets with multiple parent nodes")
sys.exit(1)
except Exception as e:
log(f"✗ Error occurred: {str(e)}")
import traceback
traceback.print_exc()
sys.exit(2)
if __name__ == "__main__":
main()