From bbeadf7dbe9ce81d7e8c4b7148cbad0c24d39e6c Mon Sep 17 00:00:00 2001
From: wangruidong <940853815@qq.com>
Date: Mon, 7 Apr 2025 17:13:01 +0800
Subject: [PATCH] perf: optimize adhoc asset selection experience

---
 apps/ops/ansible/inventory.py | 63 +++++++++++++++++++++++++++++++++++
 apps/ops/api/__init__.py      |  1 +
 apps/ops/api/inventory.py     | 42 +++++++++++++++++++++++
 apps/ops/urls/api_urls.py     |  1 +
 4 files changed, 107 insertions(+)
 create mode 100644 apps/ops/api/inventory.py

diff --git a/apps/ops/ansible/inventory.py b/apps/ops/ansible/inventory.py
index fabb0c06d..74913f065 100644
--- a/apps/ops/ansible/inventory.py
+++ b/apps/ops/ansible/inventory.py
@@ -4,6 +4,7 @@ import os
 import re
 from collections import defaultdict
 
+import sys
 from django.utils.translation import gettext as _
 
 __all__ = ['JMSInventory']
@@ -281,6 +282,68 @@ class JMSInventory:
             setattr(p, 'setting', platform_protocols.get(p.name, {}))
         return asset_protocols
 
+    def get_classified_hosts(self, path_dir):
+        """
+        返回三种类型的主机:可运行的、错误的和跳过的
+        :param path_dir: 存储密钥的路径
+        :return: dict,包含三类主机信息
+        """
+        hosts = []
+        platform_assets = self.group_by_platform(self.assets)
+        runnable_hosts = []
+        error_hosts = []
+
+        for platform, assets in platform_assets.items():
+            automation = platform.automation
+            platform_protocols = {
+                p['name']: p['setting'] for p in platform.protocols.values('name', 'setting')
+            }
+            for asset in assets:
+                protocols = self.set_platform_protocol_setting_to_asset(asset, platform_protocols)
+                account = self.select_account(asset)
+                host = self.asset_to_host(asset, account, automation, protocols, platform, path_dir)
+
+                if not automation.ansible_enabled:
+                    host['error'] = _('Ansible disabled')
+
+                if isinstance(host, list):
+                    hosts.extend(host)
+                else:
+                    hosts.append(host)
+
+        # 分类主机
+        for host in hosts:
+            if host.get('error'):
+                self.exclude_hosts[host['name']] = host['error']
+                error_hosts.append({
+                    'name': host['name'],
+                    'id': host.get('jms_asset', {}).get('id'),
+                    'error': host['error']
+                })
+            else:
+                runnable_hosts.append({
+                    'name': host['name'],
+                    'ip': host['ansible_host'],
+                    'id': host.get('jms_asset', {}).get('id')
+                })
+
+        # 获取跳过的主机
+        skipped_hosts = []
+        for name, error in self.exclude_hosts.items():
+            if any(h['name'] == name for h in error_hosts):
+                continue
+            skipped_hosts.append({
+                'name': name,
+                'error': error
+            })
+
+        result = {
+            'runnable': runnable_hosts,
+            'error': error_hosts,
+            'skipped': skipped_hosts
+        }
+        return result
+
     def generate(self, path_dir):
         hosts = []
         platform_assets = self.group_by_platform(self.assets)
diff --git a/apps/ops/api/__init__.py b/apps/ops/api/__init__.py
index 0a96bbbdb..4cdb987c9 100644
--- a/apps/ops/api/__init__.py
+++ b/apps/ops/api/__init__.py
@@ -2,6 +2,7 @@
 #
 from .adhoc import *
 from .celery import *
+from .inventory import *
 from .job import *
 from .playbook import *
 from .variable import *
diff --git a/apps/ops/api/inventory.py b/apps/ops/api/inventory.py
new file mode 100644
index 000000000..8d65ec39c
--- /dev/null
+++ b/apps/ops/api/inventory.py
@@ -0,0 +1,42 @@
+import os
+import uuid
+
+from django.conf import settings
+from rest_framework.response import Response
+from rest_framework.views import APIView
+
+from assets.models import Asset
+from common.permissions import IsValidUser
+from ops.models.job import JMSPermedInventory
+
+__all__ = ['InventoryClassifiedHostsAPI']
+
+
+class InventoryClassifiedHostsAPI(APIView):
+    permission_classes = [IsValidUser]
+
+    def post(self, request, **kwargs):
+        asset_ids = request.data.get('assets', [])
+        node_ids = request.data.get('nodes', [])
+        runas_policy = request.data.get('runas_policy', 'privileged_first')
+        account_prefer = request.data.get('account_prefer', 'root,Administrator')
+        module = request.data.get('module', 'shell')
+        # 合并节点和资产
+        assets = list(Asset.objects.filter(id__in=asset_ids).all())
+
+        # 创建临时目录
+        tmp_dir = os.path.join(settings.PROJECT_DIR, 'inventory', str(uuid.uuid4()))
+        os.makedirs(tmp_dir, exist_ok=True)
+
+        # 创建库存对象并获取分类的主机
+        inventory = JMSPermedInventory(
+            assets=assets,
+            nodes=node_ids,
+            module=module,
+            account_policy=runas_policy,
+            account_prefer=account_prefer,
+            user=self.request.user
+        )
+        classified_hosts = inventory.get_classified_hosts(tmp_dir)
+
+        return Response(data=classified_hosts)
diff --git a/apps/ops/urls/api_urls.py b/apps/ops/urls/api_urls.py
index 12801b1a3..c410cbcf2 100644
--- a/apps/ops/urls/api_urls.py
+++ b/apps/ops/urls/api_urls.py
@@ -28,6 +28,7 @@ urlpatterns = [
     path('job-execution/task-detail/<uuid:task_id>/', api.JobExecutionTaskDetail.as_view(), name='task-detail'),
     path('username-hints/', api.UsernameHintsAPI.as_view(), name='username-hints'),
     path('ansible/job-execution/<uuid:pk>/log/', api.AnsibleTaskLogApi.as_view(), name='job-execution-log'),
+    path('inventory/classified-hosts/', api.InventoryClassifiedHostsAPI.as_view(), name='inventory-classified-hosts'),
 
     path('celery/task/<uuid:name>/task-execution/<uuid:pk>/log/', api.CeleryTaskExecutionLogApi.as_view(),
          name='celery-task-execution-log'),