diff --git a/apps/accounts/api/account/task.py b/apps/accounts/api/account/task.py index c4f6ebd9f..07e6db35a 100644 --- a/apps/accounts/api/account/task.py +++ b/apps/accounts/api/account/task.py @@ -1,11 +1,12 @@ +from django.db.models import Q from rest_framework.generics import CreateAPIView from accounts import serializers +from accounts.models import Account from accounts.permissions import AccountTaskActionPermission from accounts.tasks import ( remove_accounts_task, verify_accounts_connectivity_task, push_accounts_to_assets_task ) -from assets.exceptions import NotSupportedTemporarilyError from authentication.permissions import UserConfirmation, ConfirmType __all__ = [ @@ -26,25 +27,35 @@ class AccountsTaskCreateAPI(CreateAPIView): ] return super().get_permissions() - def perform_create(self, serializer): - data = serializer.validated_data - accounts = data.get('accounts', []) - params = data.get('params') + @staticmethod + def get_account_ids(data, action): + account_type = 'gather_accounts' if action == 'remove' else 'accounts' + accounts = data.get(account_type, []) account_ids = [str(a.id) for a in accounts] - if data['action'] == 'push': - task = push_accounts_to_assets_task.delay(account_ids, params) - elif data['action'] == 'remove': - gather_accounts = data.get('gather_accounts', []) - gather_account_ids = [str(a.id) for a in gather_accounts] - task = remove_accounts_task.delay(gather_account_ids) + if action == 'remove': + return account_ids + + assets = data.get('assets', []) + asset_ids = [str(a.id) for a in assets] + ids = Account.objects.filter( + Q(id__in=account_ids) | Q(asset_id__in=asset_ids) + ).distinct().values_list('id', flat=True) + return [str(_id) for _id in ids] + + def perform_create(self, serializer): + data = serializer.validated_data + action = data['action'] + ids = self.get_account_ids(data, action) + + if action == 'push': + task = push_accounts_to_assets_task.delay(ids, data.get('params')) + elif action == 'remove': + task = remove_accounts_task.delay(ids) + elif action == 'verify': + task = verify_accounts_connectivity_task.delay(ids) else: - account = accounts[0] - asset = account.asset - if not asset.auto_config['ansible_enabled'] or \ - not asset.auto_config['ping_enabled']: - raise NotSupportedTemporarilyError() - task = verify_accounts_connectivity_task.delay(account_ids) + raise ValueError(f"Invalid action: {action}") data = getattr(serializer, '_data', {}) data["task"] = task.id diff --git a/apps/accounts/serializers/account/account.py b/apps/accounts/serializers/account/account.py index 197c64c7d..ec3f8f87b 100644 --- a/apps/accounts/serializers/account/account.py +++ b/apps/accounts/serializers/account/account.py @@ -455,12 +455,14 @@ class AccountHistorySerializer(serializers.ModelSerializer): class AccountTaskSerializer(serializers.Serializer): ACTION_CHOICES = ( - ('test', 'test'), ('verify', 'verify'), ('push', 'push'), ('remove', 'remove'), ) action = serializers.ChoiceField(choices=ACTION_CHOICES, write_only=True) + assets = serializers.PrimaryKeyRelatedField( + queryset=Asset.objects, required=False, allow_empty=True, many=True + ) accounts = serializers.PrimaryKeyRelatedField( queryset=Account.objects, required=False, allow_empty=True, many=True )