Merge pull request #9976 from jumpserver/dev

v3.1.0 rc4
This commit is contained in:
Jiangjie.Bai 2023-03-15 19:29:22 +08:00 committed by GitHub
commit 2bcd411164
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
109 changed files with 1573 additions and 878 deletions

View File

@ -7,7 +7,7 @@ assignees: wojiushixiaobai
--- ---
**JumpServer 版本(v1.5.9以下不再支持)** **JumpServer 版本( v2.28 之前的版本不再支持 )**
**浏览器版本** **浏览器版本**
@ -17,6 +17,6 @@ assignees: wojiushixiaobai
**Bug 重现步骤(有截图更好)** **Bug 重现步骤(有截图更好)**
1. 1.
2. 2.
3. 3.

View File

@ -24,6 +24,7 @@ jobs:
build-args: | build-args: |
APT_MIRROR=http://deb.debian.org APT_MIRROR=http://deb.debian.org
PIP_MIRROR=https://pypi.org/simple PIP_MIRROR=https://pypi.org/simple
PIP_JMS_MIRROR=https://pypi.org/simple
cache-from: type=gha cache-from: type=gha
cache-to: type=gha,mode=max cache-to: type=gha,mode=max

View File

@ -10,6 +10,17 @@
<a href="https://github.com/jumpserver/jumpserver"><img src="https://img.shields.io/github/stars/jumpserver/jumpserver?color=%231890FF&style=flat-square" alt="Stars"></a> <a href="https://github.com/jumpserver/jumpserver"><img src="https://img.shields.io/github/stars/jumpserver/jumpserver?color=%231890FF&style=flat-square" alt="Stars"></a>
</p> </p>
<p align="center">
JumpServer <a href="https://github.com/jumpserver/jumpserver/releases/tag/v3.0.0">v3.0</a> 正式发布。
<br>
9 年时间,倾情投入,用心做好一款开源堡垒机。
</p>
| :warning: 注意 :warning: |
|:-------------------------------------------------------------------------------------------------------------------------:|
| 3.0 架构上和 2.0 变化较大,建议全新安装一套环境来体验。如需升级,请务必升级前进行备份,并[查阅文档](https://kb.fit2cloud.com/?p=06638d69-f109-4333-b5bf-65b17b297ed9) |
-------------------------- --------------------------
JumpServer 是广受欢迎的开源堡垒机,是符合 4A 规范的专业运维安全审计系统。 JumpServer 是广受欢迎的开源堡垒机,是符合 4A 规范的专业运维安全审计系统。
@ -27,7 +38,7 @@ JumpServer 是广受欢迎的开源堡垒机,是符合 4A 规范的专业运
## UI 展示 ## UI 展示
![UI展示](https://www.jumpserver.org/images/screenshot/1.png) ![UI展示](https://docs.jumpserver.org/zh/v3/img/dashboard.png)
## 在线体验 ## 在线体验
@ -41,8 +52,7 @@ JumpServer 是广受欢迎的开源堡垒机,是符合 4A 规范的专业运
## 快速开始 ## 快速开始
- [极速安装](https://docs.jumpserver.org/zh/master/install/setup_by_fast/) - [快速入门](https://docs.jumpserver.org/zh/v3/quick_start/)
- [手动安装](https://github.com/jumpserver/installer)
- [产品文档](https://docs.jumpserver.org) - [产品文档](https://docs.jumpserver.org)
- [知识库](https://kb.fit2cloud.com/categories/jumpserver) - [知识库](https://kb.fit2cloud.com/categories/jumpserver)

View File

@ -6,7 +6,7 @@ from rest_framework.response import Response
from accounts import serializers from accounts import serializers
from accounts.filters import AccountFilterSet from accounts.filters import AccountFilterSet
from accounts.models import Account from accounts.models import Account
from assets.models import Asset from assets.models import Asset, Node
from common.permissions import UserConfirmation, ConfirmType from common.permissions import UserConfirmation, ConfirmType
from common.views.mixins import RecordViewLogMixin from common.views.mixins import RecordViewLogMixin
from orgs.mixins.api import OrgBulkModelViewSet from orgs.mixins.api import OrgBulkModelViewSet
@ -28,6 +28,7 @@ class AccountViewSet(OrgBulkModelViewSet):
rbac_perms = { rbac_perms = {
'partial_update': ['accounts.change_account'], 'partial_update': ['accounts.change_account'],
'su_from_accounts': 'accounts.view_account', 'su_from_accounts': 'accounts.view_account',
'username_suggestions': 'accounts.view_account',
} }
@action(methods=['get'], detail=False, url_path='su-from-accounts') @action(methods=['get'], detail=False, url_path='su-from-accounts')
@ -42,11 +43,34 @@ class AccountViewSet(OrgBulkModelViewSet):
asset = get_object_or_404(Asset, pk=asset_id) asset = get_object_or_404(Asset, pk=asset_id)
accounts = asset.accounts.all() accounts = asset.accounts.all()
else: else:
accounts = [] accounts = Account.objects.none()
accounts = self.filter_queryset(accounts) accounts = self.filter_queryset(accounts)
serializer = serializers.AccountSerializer(accounts, many=True) serializer = serializers.AccountSerializer(accounts, many=True)
return Response(data=serializer.data) return Response(data=serializer.data)
@action(methods=['get'], detail=False, url_path='username-suggestions')
def username_suggestions(self, request, *args, **kwargs):
asset_ids = request.query_params.get('assets')
node_keys = request.query_params.get('keys')
username = request.query_params.get('username')
assets = Asset.objects.all()
if asset_ids:
assets = assets.filter(id__in=asset_ids.split(','))
if node_keys:
patten = Node.get_node_all_children_key_pattern(node_keys.split(','))
assets = assets.filter(nodes__key__regex=patten)
accounts = Account.objects.filter(asset__in=assets)
if username:
accounts = accounts.filter(username__icontains=username)
usernames = list(accounts.values_list('username', flat=True).distinct()[:10])
usernames.sort()
common = [i for i in usernames if i in usernames if i.lower() in ['root', 'admin', 'administrator']]
others = [i for i in usernames if i not in common]
usernames = common + others
return Response(data=usernames)
class AccountSecretsViewSet(RecordViewLogMixin, AccountViewSet): class AccountSecretsViewSet(RecordViewLogMixin, AccountViewSet):
""" """

View File

@ -1,15 +1,39 @@
from rbac.permissions import RBACPermission from django_filters import rest_framework as drf_filters
from common.permissions import UserConfirmation, ConfirmType
from common.views.mixins import RecordViewLogMixin from assets.const import Protocol
from orgs.mixins.api import OrgBulkModelViewSet
from accounts import serializers from accounts import serializers
from accounts.models import AccountTemplate from accounts.models import AccountTemplate
from orgs.mixins.api import OrgBulkModelViewSet
from rbac.permissions import RBACPermission
from common.permissions import UserConfirmation, ConfirmType
from common.views.mixins import RecordViewLogMixin
from common.drf.filters import BaseFilterSet
class AccountTemplateFilterSet(BaseFilterSet):
protocols = drf_filters.CharFilter(method='filter_protocols')
class Meta:
model = AccountTemplate
fields = ('username', 'name')
@staticmethod
def filter_protocols(queryset, name, value):
secret_types = set()
protocols = value.split(',')
protocol_secret_type_map = Protocol.settings()
for p in protocols:
if p not in protocol_secret_type_map:
continue
_st = protocol_secret_type_map[p].get('secret_types', [])
secret_types.update(_st)
queryset = queryset.filter(secret_type__in=secret_types)
return queryset
class AccountTemplateViewSet(OrgBulkModelViewSet): class AccountTemplateViewSet(OrgBulkModelViewSet):
model = AccountTemplate model = AccountTemplate
filterset_fields = ("username", 'name') filterset_class = AccountTemplateFilterSet
search_fields = ('username', 'name') search_fields = ('username', 'name')
serializer_classes = { serializer_classes = {
'default': serializers.AccountTemplateSerializer 'default': serializers.AccountTemplateSerializer

View File

@ -9,12 +9,12 @@
name: "{{ account.username }}" name: "{{ account.username }}"
password: "{{ account.secret | password_hash('des') }}" password: "{{ account.secret | password_hash('des') }}"
update_password: always update_password: always
when: secret_type == "password" when: account.secret_type == "password"
- name: create user If it already exists, no operation will be performed - name: create user If it already exists, no operation will be performed
ansible.builtin.user: ansible.builtin.user:
name: "{{ account.username }}" name: "{{ account.username }}"
when: secret_type == "ssh_key" when: account.secret_type == "ssh_key"
- name: remove jumpserver ssh key - name: remove jumpserver ssh key
ansible.builtin.lineinfile: ansible.builtin.lineinfile:
@ -22,7 +22,7 @@
regexp: "{{ kwargs.regexp }}" regexp: "{{ kwargs.regexp }}"
state: absent state: absent
when: when:
- secret_type == "ssh_key" - account.secret_type == "ssh_key"
- kwargs.strategy == "set_jms" - kwargs.strategy == "set_jms"
- name: Change SSH key - name: Change SSH key
@ -30,7 +30,7 @@
user: "{{ account.username }}" user: "{{ account.username }}"
key: "{{ account.secret }}" key: "{{ account.secret }}"
exclusive: "{{ kwargs.exclusive }}" exclusive: "{{ kwargs.exclusive }}"
when: secret_type == "ssh_key" when: account.secret_type == "ssh_key"
- name: Refresh connection - name: Refresh connection
ansible.builtin.meta: reset_connection ansible.builtin.meta: reset_connection
@ -42,7 +42,7 @@
ansible_user: "{{ account.username }}" ansible_user: "{{ account.username }}"
ansible_password: "{{ account.secret }}" ansible_password: "{{ account.secret }}"
ansible_become: no ansible_become: no
when: secret_type == "password" when: account.secret_type == "password"
- name: Verify SSH key - name: Verify SSH key
ansible.builtin.ping: ansible.builtin.ping:
@ -51,4 +51,4 @@
ansible_user: "{{ account.username }}" ansible_user: "{{ account.username }}"
ansible_ssh_private_key_file: "{{ account.private_key_path }}" ansible_ssh_private_key_file: "{{ account.private_key_path }}"
ansible_become: no ansible_become: no
when: secret_type == "ssh_key" when: account.secret_type == "ssh_key"

View File

@ -9,12 +9,12 @@
name: "{{ account.username }}" name: "{{ account.username }}"
password: "{{ account.secret | password_hash('sha512') }}" password: "{{ account.secret | password_hash('sha512') }}"
update_password: always update_password: always
when: secret_type == "password" when: account.secret_type == "password"
- name: create user If it already exists, no operation will be performed - name: create user If it already exists, no operation will be performed
ansible.builtin.user: ansible.builtin.user:
name: "{{ account.username }}" name: "{{ account.username }}"
when: secret_type == "ssh_key" when: account.secret_type == "ssh_key"
- name: remove jumpserver ssh key - name: remove jumpserver ssh key
ansible.builtin.lineinfile: ansible.builtin.lineinfile:
@ -22,7 +22,7 @@
regexp: "{{ kwargs.regexp }}" regexp: "{{ kwargs.regexp }}"
state: absent state: absent
when: when:
- secret_type == "ssh_key" - account.secret_type == "ssh_key"
- kwargs.strategy == "set_jms" - kwargs.strategy == "set_jms"
- name: Change SSH key - name: Change SSH key
@ -30,7 +30,7 @@
user: "{{ account.username }}" user: "{{ account.username }}"
key: "{{ account.secret }}" key: "{{ account.secret }}"
exclusive: "{{ kwargs.exclusive }}" exclusive: "{{ kwargs.exclusive }}"
when: secret_type == "ssh_key" when: account.secret_type == "ssh_key"
- name: Refresh connection - name: Refresh connection
ansible.builtin.meta: reset_connection ansible.builtin.meta: reset_connection
@ -42,7 +42,7 @@
ansible_user: "{{ account.username }}" ansible_user: "{{ account.username }}"
ansible_password: "{{ account.secret }}" ansible_password: "{{ account.secret }}"
ansible_become: no ansible_become: no
when: secret_type == "password" when: account.secret_type == "password"
- name: Verify SSH key - name: Verify SSH key
ansible.builtin.ping: ansible.builtin.ping:
@ -51,4 +51,4 @@
ansible_user: "{{ account.username }}" ansible_user: "{{ account.username }}"
ansible_ssh_private_key_file: "{{ account.private_key_path }}" ansible_ssh_private_key_file: "{{ account.private_key_path }}"
ansible_become: no ansible_become: no
when: secret_type == "ssh_key" when: account.secret_type == "ssh_key"

View File

@ -12,7 +12,7 @@ from accounts.models import ChangeSecretRecord
from accounts.notifications import ChangeSecretExecutionTaskMsg from accounts.notifications import ChangeSecretExecutionTaskMsg
from accounts.serializers import ChangeSecretRecordBackUpSerializer from accounts.serializers import ChangeSecretRecordBackUpSerializer
from assets.const import HostTypes from assets.const import HostTypes
from common.utils import get_logger, lazyproperty from common.utils import get_logger
from common.utils.file import encrypt_and_compress_zip_file from common.utils.file import encrypt_and_compress_zip_file
from common.utils.timezone import local_now_display from common.utils.timezone import local_now_display
from users.models import User from users.models import User
@ -28,23 +28,23 @@ class ChangeSecretManager(AccountBasePlaybookManager):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.method_hosts_mapper = defaultdict(list) self.method_hosts_mapper = defaultdict(list)
self.secret_type = self.execution.snapshot['secret_type'] self.secret_type = self.execution.snapshot.get('secret_type')
self.secret_strategy = self.execution.snapshot.get( self.secret_strategy = self.execution.snapshot.get(
'secret_strategy', SecretStrategy.custom 'secret_strategy', SecretStrategy.custom
) )
self.ssh_key_change_strategy = self.execution.snapshot.get( self.ssh_key_change_strategy = self.execution.snapshot.get(
'ssh_key_change_strategy', SSHKeyStrategy.add 'ssh_key_change_strategy', SSHKeyStrategy.add
) )
self.snapshot_account_usernames = self.execution.snapshot['accounts'] self.account_ids = self.execution.snapshot['accounts']
self.name_recorder_mapper = {} # 做个映射,方便后面处理 self.name_recorder_mapper = {} # 做个映射,方便后面处理
@classmethod @classmethod
def method_type(cls): def method_type(cls):
return AutomationTypes.change_secret return AutomationTypes.change_secret
def get_kwargs(self, account, secret): def get_kwargs(self, account, secret, secret_type):
kwargs = {} kwargs = {}
if self.secret_type != SecretType.SSH_KEY: if secret_type != SecretType.SSH_KEY:
return kwargs return kwargs
kwargs['strategy'] = self.ssh_key_change_strategy kwargs['strategy'] = self.ssh_key_change_strategy
kwargs['exclusive'] = 'yes' if kwargs['strategy'] == SSHKeyStrategy.set else 'no' kwargs['exclusive'] = 'yes' if kwargs['strategy'] == SSHKeyStrategy.set else 'no'
@ -54,18 +54,29 @@ class ChangeSecretManager(AccountBasePlaybookManager):
kwargs['regexp'] = '.*{}$'.format(secret.split()[2].strip()) kwargs['regexp'] = '.*{}$'.format(secret.split()[2].strip())
return kwargs return kwargs
@lazyproperty def secret_generator(self, secret_type):
def secret_generator(self):
return SecretGenerator( return SecretGenerator(
self.secret_strategy, self.secret_type, self.secret_strategy, secret_type,
self.execution.snapshot.get('password_rules') self.execution.snapshot.get('password_rules')
) )
def get_secret(self): def get_secret(self, secret_type):
if self.secret_strategy == SecretStrategy.custom: if self.secret_strategy == SecretStrategy.custom:
return self.execution.snapshot['secret'] return self.execution.snapshot['secret']
else: else:
return self.secret_generator.get_secret() return self.secret_generator(secret_type).get_secret()
def get_accounts(self, privilege_account):
if not privilege_account:
print(f'not privilege account')
return []
asset = privilege_account.asset
accounts = asset.accounts.exclude(username=privilege_account.username)
accounts = accounts.filter(id__in=self.account_ids)
if self.secret_type:
accounts = accounts.filter(secret_type=self.secret_type)
return accounts
def host_callback( def host_callback(
self, host, asset=None, account=None, self, host, asset=None, account=None,
@ -78,17 +89,10 @@ class ChangeSecretManager(AccountBasePlaybookManager):
if host.get('error'): if host.get('error'):
return host return host
accounts = asset.accounts.all() accounts = self.get_accounts(account)
if account:
accounts = accounts.exclude(username=account.username)
if '*' not in self.snapshot_account_usernames:
accounts = accounts.filter(username__in=self.snapshot_account_usernames)
accounts = accounts.filter(secret_type=self.secret_type)
if not accounts: if not accounts:
print('没有发现待改密账号: %s 用户: %s 类型: %s' % ( print('没有发现待改密账号: %s 用户ID: %s 类型: %s' % (
asset.name, self.snapshot_account_usernames, self.secret_type asset.name, self.account_ids, self.secret_type
)) ))
return [] return []
@ -97,16 +101,16 @@ class ChangeSecretManager(AccountBasePlaybookManager):
method_hosts = [h for h in method_hosts if h != host['name']] method_hosts = [h for h in method_hosts if h != host['name']]
inventory_hosts = [] inventory_hosts = []
records = [] records = []
host['secret_type'] = self.secret_type
if asset.type == HostTypes.WINDOWS and self.secret_type == SecretType.SSH_KEY: if asset.type == HostTypes.WINDOWS and self.secret_type == SecretType.SSH_KEY:
print(f'Windows {asset} does not support ssh key push \n') print(f'Windows {asset} does not support ssh key push')
return inventory_hosts return inventory_hosts
for account in accounts: for account in accounts:
h = deepcopy(host) h = deepcopy(host)
secret_type = account.secret_type
h['name'] += '(' + account.username + ')' h['name'] += '(' + account.username + ')'
new_secret = self.get_secret() new_secret = self.get_secret(secret_type)
recorder = ChangeSecretRecord( recorder = ChangeSecretRecord(
asset=asset, account=account, execution=self.execution, asset=asset, account=account, execution=self.execution,
@ -116,15 +120,15 @@ class ChangeSecretManager(AccountBasePlaybookManager):
self.name_recorder_mapper[h['name']] = recorder self.name_recorder_mapper[h['name']] = recorder
private_key_path = None private_key_path = None
if self.secret_type == SecretType.SSH_KEY: if secret_type == SecretType.SSH_KEY:
private_key_path = self.generate_private_key_path(new_secret, path_dir) private_key_path = self.generate_private_key_path(new_secret, path_dir)
new_secret = self.generate_public_key(new_secret) new_secret = self.generate_public_key(new_secret)
h['kwargs'] = self.get_kwargs(account, new_secret) h['kwargs'] = self.get_kwargs(account, new_secret, secret_type)
h['account'] = { h['account'] = {
'name': account.name, 'name': account.name,
'username': account.username, 'username': account.username,
'secret_type': account.secret_type, 'secret_type': secret_type,
'secret': new_secret, 'secret': new_secret,
'private_key_path': private_key_path 'private_key_path': private_key_path
} }
@ -206,7 +210,7 @@ class ChangeSecretManager(AccountBasePlaybookManager):
serializer = serializer_cls(recorders, many=True) serializer = serializer_cls(recorders, many=True)
header = [str(v.label) for v in serializer.child.fields.values()] header = [str(v.label) for v in serializer.child.fields.values()]
rows = [list(row.values()) for row in serializer.data] rows = [[str(i) for i in row.values()] for row in serializer.data]
if not rows: if not rows:
return False return False

View File

@ -60,4 +60,6 @@ class GatherAccountsFilter:
if not run_method_name: if not run_method_name:
return info return info
return getattr(self, f'{run_method_name}_filter')(info) if hasattr(self, f'{run_method_name}_filter'):
return getattr(self, f'{run_method_name}_filter')(info)
return info

View File

@ -22,8 +22,8 @@ class GatherAccountsManager(AccountBasePlaybookManager):
self.host_asset_mapper[host['name']] = asset self.host_asset_mapper[host['name']] = asset
return host return host
def filter_success_result(self, host, result): def filter_success_result(self, tp, result):
result = GatherAccountsFilter(host).run(self.method_id_meta_mapper, result) result = GatherAccountsFilter(tp).run(self.method_id_meta_mapper, result)
return result return result
@staticmethod @staticmethod

View File

@ -1,9 +1,6 @@
from copy import deepcopy from copy import deepcopy
from django.db.models import QuerySet
from accounts.const import AutomationTypes, SecretType from accounts.const import AutomationTypes, SecretType
from accounts.models import Account
from assets.const import HostTypes from assets.const import HostTypes
from common.utils import get_logger from common.utils import get_logger
from ..base.manager import AccountBasePlaybookManager from ..base.manager import AccountBasePlaybookManager
@ -19,36 +16,6 @@ class PushAccountManager(ChangeSecretManager, AccountBasePlaybookManager):
def method_type(cls): def method_type(cls):
return AutomationTypes.push_account return AutomationTypes.push_account
def create_nonlocal_accounts(self, accounts, snapshot_account_usernames, asset):
secret_type = self.secret_type
usernames = accounts.filter(secret_type=secret_type).values_list(
'username', flat=True
)
create_usernames = set(snapshot_account_usernames) - set(usernames)
create_account_objs = [
Account(
name=f'{username}-{secret_type}', username=username,
secret_type=secret_type, asset=asset,
)
for username in create_usernames
]
Account.objects.bulk_create(create_account_objs)
def get_accounts(self, privilege_account, accounts: QuerySet):
if not privilege_account:
print(f'not privilege account')
return []
snapshot_account_usernames = self.execution.snapshot['accounts']
if '*' in snapshot_account_usernames:
return accounts.exclude(username=privilege_account.username)
asset = privilege_account.asset
self.create_nonlocal_accounts(accounts, snapshot_account_usernames, asset)
accounts = asset.accounts.exclude(username=privilege_account.username).filter(
username__in=snapshot_account_usernames, secret_type=self.secret_type
)
return accounts
def host_callback(self, host, asset=None, account=None, automation=None, path_dir=None, **kwargs): def host_callback(self, host, asset=None, account=None, automation=None, path_dir=None, **kwargs):
host = super(ChangeSecretManager, self).host_callback( host = super(ChangeSecretManager, self).host_callback(
host, asset=asset, account=account, automation=automation, host, asset=asset, account=account, automation=automation,
@ -57,34 +24,36 @@ class PushAccountManager(ChangeSecretManager, AccountBasePlaybookManager):
if host.get('error'): if host.get('error'):
return host return host
accounts = asset.accounts.all() accounts = self.get_accounts(account)
accounts = self.get_accounts(account, accounts)
inventory_hosts = [] inventory_hosts = []
host['secret_type'] = self.secret_type
if asset.type == HostTypes.WINDOWS and self.secret_type == SecretType.SSH_KEY: if asset.type == HostTypes.WINDOWS and self.secret_type == SecretType.SSH_KEY:
msg = f'Windows {asset} does not support ssh key push \n' msg = f'Windows {asset} does not support ssh key push'
print(msg) print(msg)
return inventory_hosts return inventory_hosts
for account in accounts: for account in accounts:
h = deepcopy(host) h = deepcopy(host)
secret_type = account.secret_type
h['name'] += '(' + account.username + ')' h['name'] += '(' + account.username + ')'
new_secret = self.get_secret() if self.secret_type is None:
new_secret = account.secret
else:
new_secret = self.get_secret(secret_type)
self.name_recorder_mapper[h['name']] = { self.name_recorder_mapper[h['name']] = {
'account': account, 'new_secret': new_secret, 'account': account, 'new_secret': new_secret,
} }
private_key_path = None private_key_path = None
if self.secret_type == SecretType.SSH_KEY: if secret_type == SecretType.SSH_KEY:
private_key_path = self.generate_private_key_path(new_secret, path_dir) private_key_path = self.generate_private_key_path(new_secret, path_dir)
new_secret = self.generate_public_key(new_secret) new_secret = self.generate_public_key(new_secret)
h['kwargs'] = self.get_kwargs(account, new_secret) h['kwargs'] = self.get_kwargs(account, new_secret, secret_type)
h['account'] = { h['account'] = {
'name': account.name, 'name': account.name,
'username': account.username, 'username': account.username,
'secret_type': account.secret_type, 'secret_type': secret_type,
'secret': new_secret, 'secret': new_secret,
'private_key_path': private_key_path 'private_key_path': private_key_path
} }
@ -112,9 +81,9 @@ class PushAccountManager(ChangeSecretManager, AccountBasePlaybookManager):
logger.error("Pust account error: ", e) logger.error("Pust account error: ", e)
def run(self, *args, **kwargs): def run(self, *args, **kwargs):
if not self.check_secret(): if self.secret_type and not self.check_secret():
return return
super().run(*args, **kwargs) super(ChangeSecretManager, self).run(*args, **kwargs)
# @classmethod # @classmethod
# def trigger_by_asset_create(cls, asset): # def trigger_by_asset_create(cls, asset):

View File

@ -25,6 +25,15 @@ class VerifyAccountManager(AccountBasePlaybookManager):
f.write('ssh_args = -o ControlMaster=no -o ControlPersist=no\n') f.write('ssh_args = -o ControlMaster=no -o ControlPersist=no\n')
return path return path
@classmethod
def method_type(cls):
return AutomationTypes.verify_account
def get_accounts(self, privilege_account, accounts: QuerySet):
account_ids = self.execution.snapshot['accounts']
accounts = accounts.filter(id__in=account_ids)
return accounts
def host_callback(self, host, asset=None, account=None, automation=None, path_dir=None, **kwargs): def host_callback(self, host, asset=None, account=None, automation=None, path_dir=None, **kwargs):
host = super().host_callback( host = super().host_callback(
host, asset=asset, account=account, host, asset=asset, account=account,
@ -62,16 +71,6 @@ class VerifyAccountManager(AccountBasePlaybookManager):
inventory_hosts.append(h) inventory_hosts.append(h)
return inventory_hosts return inventory_hosts
@classmethod
def method_type(cls):
return AutomationTypes.verify_account
def get_accounts(self, privilege_account, accounts: QuerySet):
snapshot_account_usernames = self.execution.snapshot['accounts']
if '*' not in snapshot_account_usernames:
accounts = accounts.filter(username__in=snapshot_account_usernames)
return accounts
def on_host_success(self, host, result): def on_host_success(self, host, result):
account = self.host_account_mapper.get(host) account = self.host_account_mapper.get(host)
account.set_connectivity(Connectivity.OK) account.set_connectivity(Connectivity.OK)

View File

@ -1,6 +1,6 @@
from common.utils import get_logger
from accounts.const import AutomationTypes from accounts.const import AutomationTypes
from assets.automations.ping_gateway.manager import PingGatewayManager from assets.automations.ping_gateway.manager import PingGatewayManager
from common.utils import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)
@ -16,6 +16,6 @@ class VerifyGatewayAccountManager(PingGatewayManager):
logger.info(">>> 开始执行测试网关账号可连接性任务") logger.info(">>> 开始执行测试网关账号可连接性任务")
def get_accounts(self, gateway): def get_accounts(self, gateway):
usernames = self.execution.snapshot['accounts'] account_ids = self.execution.snapshot['accounts']
accounts = gateway.accounts.filter(username__in=usernames) accounts = gateway.accounts.filter(id__in=account_ids)
return accounts return accounts

View File

@ -0,0 +1,69 @@
# Generated by Django 3.2.16 on 2023-03-07 07:36
from django.db import migrations
from django.db.models import Q
def get_nodes_all_assets(apps, *nodes):
node_model = apps.get_model('assets', 'Node')
asset_model = apps.get_model('assets', 'Asset')
node_ids = set()
descendant_node_query = Q()
for n in nodes:
node_ids.add(n.id)
descendant_node_query |= Q(key__istartswith=f'{n.key}:')
if descendant_node_query:
_ids = node_model.objects.order_by().filter(descendant_node_query).values_list('id', flat=True)
node_ids.update(_ids)
return asset_model.objects.order_by().filter(nodes__id__in=node_ids).distinct()
def get_all_assets(apps, snapshot):
node_model = apps.get_model('assets', 'Node')
asset_model = apps.get_model('assets', 'Asset')
asset_ids = snapshot.get('assets', [])
node_ids = snapshot.get('nodes', [])
nodes = node_model.objects.filter(id__in=node_ids)
node_asset_ids = get_nodes_all_assets(apps, *nodes).values_list('id', flat=True)
asset_ids = set(list(asset_ids) + list(node_asset_ids))
return asset_model.objects.filter(id__in=asset_ids)
def migrate_account_usernames_to_ids(apps, schema_editor):
db_alias = schema_editor.connection.alias
execution_model = apps.get_model('accounts', 'AutomationExecution')
account_model = apps.get_model('accounts', 'Account')
executions = execution_model.objects.using(db_alias).all()
executions_update = []
for execution in executions:
snapshot = execution.snapshot
accounts = account_model.objects.none()
account_usernames = snapshot.get('accounts', [])
for asset in get_all_assets(apps, snapshot):
accounts = accounts | asset.accounts.all()
secret_type = snapshot.get('secret_type')
if secret_type:
ids = accounts.filter(
username__in=account_usernames,
secret_type=secret_type
).values_list('id', flat=True)
else:
ids = accounts.filter(
username__in=account_usernames
).values_list('id', flat=True)
snapshot['accounts'] = [str(_id) for _id in ids]
execution.snapshot = snapshot
executions_update.append(execution)
execution_model.objects.bulk_update(executions_update, ['snapshot'])
class Migration(migrations.Migration):
dependencies = [
('accounts', '0008_alter_gatheredaccount_options'),
]
operations = [
migrations.RunPython(migrate_account_usernames_to_ids),
]

View File

@ -1,11 +1,12 @@
from django.db import models from django.db import models
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from common.db import fields
from common.db.models import JMSBaseModel
from accounts.const import ( from accounts.const import (
AutomationTypes, SecretType, SecretStrategy, SSHKeyStrategy AutomationTypes, SecretType, SecretStrategy, SSHKeyStrategy
) )
from accounts.models import Account
from common.db import fields
from common.db.models import JMSBaseModel
from .base import AccountBaseAutomation from .base import AccountBaseAutomation
__all__ = ['ChangeSecretAutomation', 'ChangeSecretRecord', 'ChangeSecretMixin'] __all__ = ['ChangeSecretAutomation', 'ChangeSecretRecord', 'ChangeSecretMixin']
@ -27,18 +28,35 @@ class ChangeSecretMixin(models.Model):
default=SSHKeyStrategy.add, verbose_name=_('SSH key change strategy') default=SSHKeyStrategy.add, verbose_name=_('SSH key change strategy')
) )
accounts: list[str] # account usernames
get_all_assets: callable # get all assets
class Meta: class Meta:
abstract = True abstract = True
def create_nonlocal_accounts(self, usernames, asset):
pass
def get_account_ids(self):
usernames = self.accounts
accounts = Account.objects.none()
for asset in self.get_all_assets():
self.create_nonlocal_accounts(usernames, asset)
accounts = accounts | asset.accounts.all()
account_ids = accounts.filter(
username__in=usernames, secret_type=self.secret_type
).values_list('id', flat=True)
return [str(_id) for _id in account_ids]
def to_attr_json(self): def to_attr_json(self):
attr_json = super().to_attr_json() attr_json = super().to_attr_json()
attr_json.update({ attr_json.update({
'secret': self.secret, 'secret': self.secret,
'secret_type': self.secret_type, 'secret_type': self.secret_type,
'secret_strategy': self.secret_strategy, 'accounts': self.get_account_ids(),
'password_rules': self.password_rules, 'password_rules': self.password_rules,
'secret_strategy': self.secret_strategy,
'ssh_key_change_strategy': self.ssh_key_change_strategy, 'ssh_key_change_strategy': self.ssh_key_change_strategy,
}) })
return attr_json return attr_json

View File

@ -2,6 +2,8 @@ from django.db import models
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from accounts.const import AutomationTypes from accounts.const import AutomationTypes
from accounts.models import Account
from jumpserver.utils import has_valid_xpack_license
from .base import AccountBaseAutomation from .base import AccountBaseAutomation
from .change_secret import ChangeSecretMixin from .change_secret import ChangeSecretMixin
@ -13,6 +15,21 @@ class PushAccountAutomation(ChangeSecretMixin, AccountBaseAutomation):
username = models.CharField(max_length=128, verbose_name=_('Username')) username = models.CharField(max_length=128, verbose_name=_('Username'))
action = models.CharField(max_length=16, verbose_name=_('Action')) action = models.CharField(max_length=16, verbose_name=_('Action'))
def create_nonlocal_accounts(self, usernames, asset):
secret_type = self.secret_type
account_usernames = asset.accounts.filter(secret_type=self.secret_type).values_list(
'username', flat=True
)
create_usernames = set(usernames) - set(account_usernames)
create_account_objs = [
Account(
name=f'{username}-{secret_type}', username=username,
secret_type=secret_type, asset=asset,
)
for username in create_usernames
]
Account.objects.bulk_create(create_account_objs)
def set_period_schedule(self): def set_period_schedule(self):
pass pass
@ -27,6 +44,8 @@ class PushAccountAutomation(ChangeSecretMixin, AccountBaseAutomation):
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
self.type = AutomationTypes.push_account self.type = AutomationTypes.push_account
if not has_valid_xpack_license():
self.is_periodic = False
super().save(*args, **kwargs) super().save(*args, **kwargs)
def to_attr_json(self): def to_attr_json(self):

View File

@ -12,7 +12,7 @@ from accounts.const import SecretType
from common.db import fields from common.db import fields
from common.utils import ( from common.utils import (
ssh_key_string_to_obj, ssh_key_gen, get_logger, ssh_key_string_to_obj, ssh_key_gen, get_logger,
random_string, lazyproperty, parse_ssh_public_key_str random_string, lazyproperty, parse_ssh_public_key_str, is_openssh_format_key
) )
from orgs.mixins.models import JMSOrgBaseModel, OrgManager from orgs.mixins.models import JMSOrgBaseModel, OrgManager
@ -118,7 +118,13 @@ class BaseAccount(JMSOrgBaseModel):
key_name = '.' + md5(self.private_key.encode('utf-8')).hexdigest() key_name = '.' + md5(self.private_key.encode('utf-8')).hexdigest()
key_path = os.path.join(tmp_dir, key_name) key_path = os.path.join(tmp_dir, key_name)
if not os.path.exists(key_path): if not os.path.exists(key_path):
self.private_key_obj.write_private_key_file(key_path) # https://github.com/ansible/ansible-runner/issues/544
# ssh requires OpenSSH format keys to have a full ending newline.
# It does not require this for old-style PEM keys.
with open(key_path, 'w') as f:
f.write(self.secret)
if is_openssh_format_key(self.secret.encode('utf-8')):
f.write("\n")
os.chmod(key_path, 0o400) os.chmod(key_path, 0o400)
return key_path return key_path

View File

@ -33,7 +33,8 @@ class AuthValidateMixin(serializers.Serializer):
return secret return secret
elif secret_type == SecretType.SSH_KEY: elif secret_type == SecretType.SSH_KEY:
passphrase = passphrase if passphrase else None passphrase = passphrase if passphrase else None
return validate_ssh_key(secret, passphrase) secret = validate_ssh_key(secret, passphrase)
return secret
else: else:
return secret return secret
@ -41,8 +42,9 @@ class AuthValidateMixin(serializers.Serializer):
secret_type = validated_data.get('secret_type') secret_type = validated_data.get('secret_type')
passphrase = validated_data.get('passphrase') passphrase = validated_data.get('passphrase')
secret = validated_data.pop('secret', None) secret = validated_data.pop('secret', None)
self.handle_secret(secret, secret_type, passphrase) validated_data['secret'] = self.handle_secret(
validated_data['secret'] = secret secret, secret_type, passphrase
)
for field in ('secret',): for field in ('secret',):
value = validated_data.get(field) value = validated_data.get(field)
if not value: if not value:

View File

@ -8,7 +8,7 @@ from orgs.utils import tmp_to_org, tmp_to_root_org
logger = get_logger(__file__) logger = get_logger(__file__)
def task_activity_callback(self, pid, trigger, tp): def task_activity_callback(self, pid, trigger, tp, *args, **kwargs):
model = AutomationTypes.get_type_model(tp) model = AutomationTypes.get_type_model(tp)
with tmp_to_root_org(): with tmp_to_root_org():
instance = get_object_or_none(model, pk=pid) instance = get_object_or_none(model, pk=pid)

View File

@ -9,7 +9,7 @@ from orgs.utils import tmp_to_org, tmp_to_root_org
logger = get_logger(__file__) logger = get_logger(__file__)
def task_activity_callback(self, pid, trigger): def task_activity_callback(self, pid, trigger, *args, **kwargs):
from accounts.models import AccountBackupAutomation from accounts.models import AccountBackupAutomation
with tmp_to_root_org(): with tmp_to_root_org():
plan = get_object_or_none(AccountBackupAutomation, pk=pid) plan = get_object_or_none(AccountBackupAutomation, pk=pid)

View File

@ -27,7 +27,7 @@ def gather_asset_accounts_util(nodes, task_name):
@shared_task( @shared_task(
queue="ansible", verbose_name=_('Gather asset accounts'), queue="ansible", verbose_name=_('Gather asset accounts'),
activity_callback=lambda self, node_ids, task_name=None: (node_ids, None) activity_callback=lambda self, node_ids, task_name=None, *args, **kwargs: (node_ids, None)
) )
def gather_asset_accounts_task(node_ids, task_name=None): def gather_asset_accounts_task(node_ids, task_name=None):
if task_name is None: if task_name is None:

View File

@ -13,7 +13,7 @@ __all__ = [
@shared_task( @shared_task(
queue="ansible", verbose_name=_('Push accounts to assets'), queue="ansible", verbose_name=_('Push accounts to assets'),
activity_callback=lambda self, account_ids, asset_ids: (account_ids, None) activity_callback=lambda self, account_ids, *args, **kwargs: (account_ids, None)
) )
def push_accounts_to_assets_task(account_ids): def push_accounts_to_assets_task(account_ids):
from accounts.models import PushAccountAutomation from accounts.models import PushAccountAutomation
@ -23,12 +23,10 @@ def push_accounts_to_assets_task(account_ids):
task_name = gettext_noop("Push accounts to assets") task_name = gettext_noop("Push accounts to assets")
task_name = PushAccountAutomation.generate_unique_name(task_name) task_name = PushAccountAutomation.generate_unique_name(task_name)
for account in accounts: task_snapshot = {
task_snapshot = { 'accounts': [str(account.id) for account in accounts],
'secret': account.secret, 'assets': [str(account.asset_id) for account in accounts],
'secret_type': account.secret_type, }
'accounts': [account.username],
'assets': [str(account.asset_id)], tp = AutomationTypes.push_account
} quickstart_automation_by_snapshot(task_name, tp, task_snapshot)
tp = AutomationTypes.push_account
quickstart_automation_by_snapshot(task_name, tp, task_snapshot)

View File

@ -17,9 +17,9 @@ __all__ = [
def verify_connectivity_util(assets, tp, accounts, task_name): def verify_connectivity_util(assets, tp, accounts, task_name):
if not assets or not accounts: if not assets or not accounts:
return return
account_usernames = list(accounts.values_list('username', flat=True)) account_ids = [str(account.id) for account in accounts]
task_snapshot = { task_snapshot = {
'accounts': account_usernames, 'accounts': account_ids,
'assets': [str(asset.id) for asset in assets], 'assets': [str(asset.id) for asset in assets],
} }
quickstart_automation_by_snapshot(task_name, tp, task_snapshot) quickstart_automation_by_snapshot(task_name, tp, task_snapshot)

View File

@ -99,13 +99,14 @@ class AssetViewSet(SuggestionMixin, NodeFilterMixin, OrgBulkModelViewSet):
("platform", serializers.PlatformSerializer), ("platform", serializers.PlatformSerializer),
("suggestion", serializers.MiniAssetSerializer), ("suggestion", serializers.MiniAssetSerializer),
("gateways", serializers.GatewaySerializer), ("gateways", serializers.GatewaySerializer),
("spec_info", serializers.SpecSerializer) ("spec_info", serializers.SpecSerializer),
) )
rbac_perms = ( rbac_perms = (
("match", "assets.match_asset"), ("match", "assets.match_asset"),
("platform", "assets.view_platform"), ("platform", "assets.view_platform"),
("gateways", "assets.view_gateway"), ("gateways", "assets.view_gateway"),
("spec_info", "assets.view_asset"), ("spec_info", "assets.view_asset"),
("info", "assets.view_asset"),
) )
extra_filter_backends = [LabelFilterBackend, IpInFilterBackend, NodeFilterBackend] extra_filter_backends = [LabelFilterBackend, IpInFilterBackend, NodeFilterBackend]

View File

@ -21,4 +21,10 @@ class HostViewSet(AssetViewSet):
@action(methods=["GET"], detail=True, url_path="info") @action(methods=["GET"], detail=True, url_path="info")
def info(self, *args, **kwargs): def info(self, *args, **kwargs):
asset = super().get_object() asset = super().get_object()
return Response(asset.info) serializer = self.get_serializer(asset.info)
data = serializer.data
data['asset'] = {
'id': asset.id, 'name': asset.name,
'address': asset.address
}
return Response(data)

View File

@ -12,8 +12,7 @@ from django.utils.translation import gettext as _
from sshtunnel import SSHTunnelForwarder, BaseSSHTunnelForwarderError from sshtunnel import SSHTunnelForwarder, BaseSSHTunnelForwarderError
from assets.automations.methods import platform_automation_methods from assets.automations.methods import platform_automation_methods
from common.utils import get_logger, lazyproperty from common.utils import get_logger, lazyproperty, is_openssh_format_key, ssh_pubkey_gen
from common.utils import ssh_pubkey_gen, ssh_key_string_to_obj
from ops.ansible import JMSInventory, PlaybookRunner, DefaultCallback from ops.ansible import JMSInventory, PlaybookRunner, DefaultCallback
logger = get_logger(__name__) logger = get_logger(__name__)
@ -127,7 +126,13 @@ class BasePlaybookManager:
key_path = os.path.join(path_dir, key_name) key_path = os.path.join(path_dir, key_name)
if not os.path.exists(key_path): if not os.path.exists(key_path):
ssh_key_string_to_obj(secret, password=None).write_private_key_file(key_path) # https://github.com/ansible/ansible-runner/issues/544
# ssh requires OpenSSH format keys to have a full ending newline.
# It does not require this for old-style PEM keys.
with open(key_path, 'w') as f:
f.write(secret)
if is_openssh_format_key(secret.encode('utf-8')):
f.write("\n")
os.chmod(key_path, 0o400) os.chmod(key_path, 0o400)
return key_path return key_path

View File

@ -0,0 +1,35 @@
__all__ = ['FormatAssetInfo']
class FormatAssetInfo:
def __init__(self, tp):
self.tp = tp
@staticmethod
def posix_format(info):
for cpu_model in info.get('cpu_model', []):
if cpu_model.endswith('GHz') or cpu_model.startswith("Intel"):
break
else:
cpu_model = ''
info['cpu_model'] = cpu_model[:48]
info['cpu_count'] = info.get('cpu_count', 0)
return info
def run(self, method_id_meta_mapper, info):
for k, v in info.items():
info[k] = v.strip() if isinstance(v, str) else v
run_method_name = None
for k, v in method_id_meta_mapper.items():
if self.tp not in v['type']:
continue
run_method_name = k.replace(f'{v["method"]}_', '')
if not run_method_name:
return info
if hasattr(self, f'{run_method_name}_format'):
return getattr(self, f'{run_method_name}_format')(info)
return info

View File

@ -11,7 +11,7 @@
cpu_count: "{{ ansible_processor_count }}" cpu_count: "{{ ansible_processor_count }}"
cpu_cores: "{{ ansible_processor_cores }}" cpu_cores: "{{ ansible_processor_cores }}"
cpu_vcpus: "{{ ansible_processor_vcpus }}" cpu_vcpus: "{{ ansible_processor_vcpus }}"
memory: "{{ ansible_memtotal_mb }}" memory: "{{ ansible_memtotal_mb / 1024 | round(2) }}"
disk_total: "{{ (ansible_mounts | map(attribute='size_total') | sum / 1024 / 1024 / 1024) | round(2) }}" disk_total: "{{ (ansible_mounts | map(attribute='size_total') | sum / 1024 / 1024 / 1024) | round(2) }}"
distribution: "{{ ansible_distribution }}" distribution: "{{ ansible_distribution }}"
distribution_version: "{{ ansible_distribution_version }}" distribution_version: "{{ ansible_distribution_version }}"

View File

@ -1,5 +1,6 @@
from common.utils import get_logger
from assets.const import AutomationTypes from assets.const import AutomationTypes
from common.utils import get_logger
from .format_asset_info import FormatAssetInfo
from ..base.manager import BasePlaybookManager from ..base.manager import BasePlaybookManager
logger = get_logger(__name__) logger = get_logger(__name__)
@ -19,13 +20,16 @@ class GatherFactsManager(BasePlaybookManager):
self.host_asset_mapper[host['name']] = asset self.host_asset_mapper[host['name']] = asset
return host return host
def format_asset_info(self, tp, info):
info = FormatAssetInfo(tp).run(self.method_id_meta_mapper, info)
return info
def on_host_success(self, host, result): def on_host_success(self, host, result):
info = result.get('debug', {}).get('res', {}).get('info', {}) info = result.get('debug', {}).get('res', {}).get('info', {})
asset = self.host_asset_mapper.get(host) asset = self.host_asset_mapper.get(host)
if asset and info: if asset and info:
for k, v in info.items(): info = self.format_asset_info(asset.type, info)
info[k] = v.strip() if isinstance(v, str) else v
asset.info = info asset.info = info
asset.save() asset.save(update_fields=['info'])
else: else:
logger.error("Not found info: {}".format(host)) logger.error("Not found info: {}".format(host))

View File

@ -1,10 +1,12 @@
from django.utils.translation import gettext_lazy as _
from .base import BaseType from .base import BaseType
class CloudTypes(BaseType): class CloudTypes(BaseType):
PUBLIC = 'public', 'Public cloud' PUBLIC = 'public', _('Public cloud')
PRIVATE = 'private', 'Private cloud' PRIVATE = 'private', _('Private cloud')
K8S = 'k8s', 'Kubernetes' K8S = 'k8s', _('Kubernetes')
@classmethod @classmethod
def _get_base_constrains(cls) -> dict: def _get_base_constrains(cls) -> dict:

View File

@ -1,3 +1,5 @@
from django.utils.translation import gettext_lazy as _
from .base import BaseType from .base import BaseType
GATEWAY_NAME = 'Gateway' GATEWAY_NAME = 'Gateway'
@ -7,7 +9,7 @@ class HostTypes(BaseType):
LINUX = 'linux', 'Linux' LINUX = 'linux', 'Linux'
WINDOWS = 'windows', 'Windows' WINDOWS = 'windows', 'Windows'
UNIX = 'unix', 'Unix' UNIX = 'unix', 'Unix'
OTHER_HOST = 'other', "Other" OTHER_HOST = 'other', _("Other")
@classmethod @classmethod
def _get_base_constrains(cls) -> dict: def _get_base_constrains(cls) -> dict:

View File

@ -39,7 +39,7 @@ class Protocol(ChoicesMixin, models.TextChoices):
'port': 3389, 'port': 3389,
'secret_types': ['password'], 'secret_types': ['password'],
'setting': { 'setting': {
'console': True, 'console': False,
'security': 'any', 'security': 'any',
} }
}, },

View File

@ -214,10 +214,13 @@ class AllTypes(ChoicesMixin):
tp_node = cls.choice_to_node(tp, category_node['id'], opened=False, meta=meta) tp_node = cls.choice_to_node(tp, category_node['id'], opened=False, meta=meta)
tp_count = category_type_mapper.get(category + '_' + tp, 0) tp_count = category_type_mapper.get(category + '_' + tp, 0)
tp_node['name'] += f'({tp_count})' tp_node['name'] += f'({tp_count})'
platforms = tp_platforms.get(category + '_' + tp, [])
if not platforms:
tp_node['isParent'] = False
nodes.append(tp_node) nodes.append(tp_node)
# Platform 格式化 # Platform 格式化
for p in tp_platforms.get(category + '_' + tp, []): for p in platforms:
platform_node = cls.platform_to_node(p, tp_node['id'], include_asset) platform_node = cls.platform_to_node(p, tp_node['id'], include_asset)
platform_node['name'] += f'({platform_count.get(p.id, 0)})' platform_node['name'] += f'({platform_count.get(p.id, 0)})'
nodes.append(platform_node) nodes.append(platform_node)
@ -306,10 +309,11 @@ class AllTypes(ChoicesMixin):
protocols_data = deepcopy(default_protocols) protocols_data = deepcopy(default_protocols)
if _protocols: if _protocols:
protocols_data = [p for p in protocols_data if p['name'] in _protocols] protocols_data = [p for p in protocols_data if p['name'] in _protocols]
for p in protocols_data: for p in protocols_data:
setting = _protocols_setting.get(p['name'], {}) setting = _protocols_setting.get(p['name'], {})
p['required'] = p.pop('required', False) p['required'] = setting.pop('required', False)
p['default'] = p.pop('default', False) p['default'] = setting.pop('default', False)
p['setting'] = {**p.get('setting', {}), **setting} p['setting'] = {**p.get('setting', {}), **setting}
platform_data = { platform_data = {

View File

@ -93,7 +93,7 @@ class Migration(migrations.Migration):
migrations.AlterField( migrations.AlterField(
model_name='asset', model_name='asset',
name='address', name='address',
field=models.CharField(db_index=True, max_length=1024, verbose_name='Address'), field=models.CharField(db_index=True, max_length=767, verbose_name='Address'),
), ),
migrations.AddField( migrations.AddField(
model_name='asset', model_name='asset',

View File

@ -34,8 +34,9 @@ def migrate_database_to_asset(apps, *args):
_attrs = app.attrs or {} _attrs = app.attrs or {}
attrs.update(_attrs) attrs.update(_attrs)
name = 'DB-{}'.format(app.name)
db = db_model( db = db_model(
id=app.id, name=app.name, address=attrs['host'], id=app.id, name=name, address=attrs['host'],
protocols='{}/{}'.format(app.type, attrs['port']), protocols='{}/{}'.format(app.type, attrs['port']),
db_name=attrs['database'] or '', db_name=attrs['database'] or '',
platform=platforms_map[app.type], platform=platforms_map[app.type],
@ -61,8 +62,9 @@ def migrate_cloud_to_asset(apps, *args):
for app in applications: for app in applications:
attrs = app.attrs attrs = app.attrs
print("\t- Create cloud: {}".format(app.name)) print("\t- Create cloud: {}".format(app.name))
name = 'Cloud-{}'.format(app.name)
cloud = cloud_model( cloud = cloud_model(
id=app.id, name=app.name, id=app.id, name=name,
address=attrs.get('cluster', ''), address=attrs.get('cluster', ''),
protocols='k8s/443', platform=platform, protocols='k8s/443', platform=platform,
org_id=app.org_id, org_id=app.org_id,

View File

@ -1,12 +1,15 @@
# Generated by Django 3.2.12 on 2022-07-11 06:13 # Generated by Django 3.2.12 on 2022-07-11 06:13
import time import time
from django.utils import timezone
from itertools import groupby
from django.db import migrations from django.db import migrations
def migrate_asset_accounts(apps, schema_editor): def migrate_asset_accounts(apps, schema_editor):
auth_book_model = apps.get_model('assets', 'AuthBook') auth_book_model = apps.get_model('assets', 'AuthBook')
account_model = apps.get_model('accounts', 'Account') account_model = apps.get_model('accounts', 'Account')
account_history_model = apps.get_model('accounts', 'HistoricalAccount')
count = 0 count = 0
bulk_size = 1000 bulk_size = 1000
@ -20,34 +23,35 @@ def migrate_asset_accounts(apps, schema_editor):
break break
count += len(auth_books) count += len(auth_books)
accounts = []
# auth book 和 account 相同的属性 # auth book 和 account 相同的属性
same_attrs = [ same_attrs = [
'id', 'username', 'comment', 'date_created', 'date_updated', 'id', 'username', 'comment', 'date_created', 'date_updated',
'created_by', 'asset_id', 'org_id', 'created_by', 'asset_id', 'org_id',
] ]
# 认证的属性,可能是 authbook 的,可能是 systemuser 的 # 认证的属性,可能是 auth_book 的,可能是 system_user 的
auth_attrs = ['password', 'private_key', 'token'] auth_attrs = ['password', 'private_key', 'token']
all_attrs = same_attrs + auth_attrs all_attrs = same_attrs + auth_attrs
accounts = []
for auth_book in auth_books: for auth_book in auth_books:
values = {'version': 1} account_values = {'version': 1}
system_user = auth_book.systemuser system_user = auth_book.systemuser
if system_user: if system_user:
# 更新一次系统用户的认证属性 # 更新一次系统用户的认证属性
values.update({attr: getattr(system_user, attr, '') for attr in all_attrs}) account_values.update({attr: getattr(system_user, attr, '') for attr in all_attrs})
values['created_by'] = str(system_user.id) account_values['created_by'] = str(system_user.id)
values['privileged'] = system_user.type == 'admin' account_values['privileged'] = system_user.type == 'admin' \
or system_user.username in ['root', 'Administrator']
auth_book_auth = {attr: getattr(auth_book, attr, '') for attr in all_attrs if getattr(auth_book, attr, '')} auth_book_auth = {attr: getattr(auth_book, attr, '') for attr in all_attrs if getattr(auth_book, attr, '')}
# 最终使用 authbook 的认证属性 # 最终优先使用 auth_book 的认证属性
values.update(auth_book_auth) account_values.update(auth_book_auth)
auth_infos = [] auth_infos = []
username = values['username'] username = account_values['username']
for attr in auth_attrs: for attr in auth_attrs:
secret = values.pop(attr, None) secret = account_values.pop(attr, None)
if not secret: if not secret:
continue continue
@ -66,13 +70,48 @@ def migrate_asset_accounts(apps, schema_editor):
auth_infos.append((username, 'password', '')) auth_infos.append((username, 'password', ''))
for name, secret_type, secret in auth_infos: for name, secret_type, secret in auth_infos:
account = account_model(**values, name=name, secret=secret, secret_type=secret_type) if not name:
continue
account = account_model(**account_values, name=name, secret=secret, secret_type=secret_type)
accounts.append(account) accounts.append(account)
account_model.objects.bulk_create(accounts, ignore_conflicts=True) accounts.sort(key=lambda x: (x.name, x.asset_id, x.date_updated))
grouped_accounts = groupby(accounts, lambda x: (x.name, x.asset_id))
accounts_to_add = []
accounts_to_history = []
for key, _accounts in grouped_accounts:
_accounts = list(_accounts)
if not _accounts:
continue
_account = _accounts[-1]
accounts_to_add.append(_account)
_account_history = []
for ac in _accounts:
if not ac.secret:
continue
if ac.id != _account.id and ac.secret == _account.secret:
continue
history_data = {
'id': _account.id,
'secret': ac.secret,
'secret_type': ac.secret_type,
'history_date': ac.date_updated,
'history_type': '~',
'history_change_reason': 'from account {}'.format(_account.name),
}
_account_history.append(account_history_model(**history_data))
_account.version = len(_account_history)
accounts_to_history.extend(_account_history)
account_model.objects.bulk_create(accounts_to_add, ignore_conflicts=True)
account_history_model.objects.bulk_create(accounts_to_history, ignore_conflicts=True)
print("\t - Create asset accounts: {}-{} using: {:.2f}s".format( print("\t - Create asset accounts: {}-{} using: {:.2f}s".format(
count - len(auth_books), count, time.time() - start count - len(auth_books), count, time.time() - start
)) ))
print("\t - accounts: {}".format(len(accounts_to_add)))
print("\t - histories: {}".format(len(accounts_to_history)))
def migrate_db_accounts(apps, schema_editor): def migrate_db_accounts(apps, schema_editor):
@ -130,6 +169,9 @@ def migrate_db_accounts(apps, schema_editor):
values['secret_type'] = secret_type values['secret_type'] = secret_type
values['secret'] = secret values['secret'] = secret
if not name:
continue
for app in apps: for app in apps:
values['asset_id'] = str(app.id) values['asset_id'] = str(app.id)
account = account_model(**values) account = account_model(**values)

View File

@ -0,0 +1,29 @@
# Generated by Django 3.2.17 on 2023-03-15 09:41
from django.db import migrations
def set_windows_platform_non_console(apps, schema_editor):
Platform = apps.get_model('assets', 'Platform')
names = ['Windows', 'Windows-RDP', 'Windows-TLS', 'RemoteAppHost']
windows = Platform.objects.filter(name__in=names)
if not windows:
return
for p in windows:
rdp = p.protocols.filter(name='rdp').first()
if not rdp:
continue
rdp.setting['console'] = False
rdp.save()
class Migration(migrations.Migration):
dependencies = [
('assets', '0109_alter_asset_options'),
]
operations = [
migrations.RunPython(set_windows_platform_non_console)
]

View File

@ -100,7 +100,7 @@ class Asset(NodesRelationMixin, AbsConnectivity, JMSOrgBaseModel):
Type = const.AllTypes Type = const.AllTypes
name = models.CharField(max_length=128, verbose_name=_('Name')) name = models.CharField(max_length=128, verbose_name=_('Name'))
address = models.CharField(max_length=1024, verbose_name=_('Address'), db_index=True) address = models.CharField(max_length=767, verbose_name=_('Address'), db_index=True)
platform = models.ForeignKey(Platform, on_delete=models.PROTECT, verbose_name=_("Platform"), related_name='assets') platform = models.ForeignKey(Platform, on_delete=models.PROTECT, verbose_name=_("Platform"), related_name='assets')
domain = models.ForeignKey("assets.Domain", null=True, blank=True, related_name='assets', domain = models.ForeignKey("assets.Domain", null=True, blank=True, related_name='assets',
verbose_name=_("Domain"), on_delete=models.SET_NULL) verbose_name=_("Domain"), on_delete=models.SET_NULL)
@ -108,7 +108,7 @@ class Asset(NodesRelationMixin, AbsConnectivity, JMSOrgBaseModel):
verbose_name=_("Nodes")) verbose_name=_("Nodes"))
is_active = models.BooleanField(default=True, verbose_name=_('Is active')) is_active = models.BooleanField(default=True, verbose_name=_('Is active'))
labels = models.ManyToManyField('assets.Label', blank=True, related_name='assets', verbose_name=_("Labels")) labels = models.ManyToManyField('assets.Label', blank=True, related_name='assets', verbose_name=_("Labels"))
info = models.JSONField(verbose_name='Info', default=dict, blank=True) # 资产的一些信息,如 硬件信息 info = models.JSONField(verbose_name=_('Info'), default=dict, blank=True) # 资产的一些信息,如 硬件信息
objects = AssetManager.from_queryset(AssetQuerySet)() objects = AssetManager.from_queryset(AssetQuerySet)()

View File

@ -489,7 +489,7 @@ class SomeNodesMixin:
return cls.default_node() return cls.default_node()
if ori_org and ori_org.is_root(): if ori_org and ori_org.is_root():
return None return cls.default_node()
org_roots = cls.org_root_nodes() org_roots = cls.org_root_nodes()
org_roots_length = len(org_roots) org_roots_length = len(org_roots)

View File

@ -11,7 +11,7 @@ __all__ = ['Platform', 'PlatformProtocol', 'PlatformAutomation']
class PlatformProtocol(models.Model): class PlatformProtocol(models.Model):
SETTING_ATTRS = { SETTING_ATTRS = {
'console': True, 'console': False,
'security': 'any,tls,rdp', 'security': 'any,tls,rdp',
'sftp_enabled': True, 'sftp_enabled': True,
'sftp_home': '/tmp' 'sftp_home': '/tmp'

View File

@ -26,6 +26,13 @@ __all__ = [
class AssetProtocolsSerializer(serializers.ModelSerializer): class AssetProtocolsSerializer(serializers.ModelSerializer):
port = serializers.IntegerField(required=False, allow_null=True, max_value=65535, min_value=1) port = serializers.IntegerField(required=False, allow_null=True, max_value=65535, min_value=1)
def to_file_representation(self, data):
return '{name}/{port}'.format(**data)
def to_file_internal_value(self, data):
name, port = data.split('/')
return {'name': name, 'port': port}
class Meta: class Meta:
model = Protocol model = Protocol
fields = ['name', 'port'] fields = ['name', 'port']
@ -73,7 +80,7 @@ class AssetAccountSerializer(
'is_active', 'version', 'secret_type', 'is_active', 'version', 'secret_type',
] ]
fields_write_only = [ fields_write_only = [
'secret', 'push_now', 'template' 'secret', 'passphrase', 'push_now', 'template'
] ]
fields = fields_mini + fields_write_only fields = fields_mini + fields_write_only
extra_kwargs = { extra_kwargs = {
@ -121,7 +128,8 @@ class AssetSerializer(BulkOrgResourceModelSerializer, WritableNestedModelSeriali
type = LabeledChoiceField(choices=AllTypes.choices(), read_only=True, label=_('Type')) type = LabeledChoiceField(choices=AllTypes.choices(), read_only=True, label=_('Type'))
labels = AssetLabelSerializer(many=True, required=False, label=_('Label')) labels = AssetLabelSerializer(many=True, required=False, label=_('Label'))
protocols = AssetProtocolsSerializer(many=True, required=False, label=_('Protocols'), default=()) protocols = AssetProtocolsSerializer(many=True, required=False, label=_('Protocols'), default=())
accounts = AssetAccountSerializer(many=True, required=False, write_only=True, label=_('Account')) accounts = AssetAccountSerializer(many=True, required=False, allow_null=True, write_only=True, label=_('Account'))
nodes_display = serializers.ListField(read_only=True, label=_("Node path"))
class Meta: class Meta:
model = Asset model = Asset
@ -133,11 +141,11 @@ class AssetSerializer(BulkOrgResourceModelSerializer, WritableNestedModelSeriali
'nodes_display', 'accounts' 'nodes_display', 'accounts'
] ]
read_only_fields = [ read_only_fields = [
'category', 'type', 'connectivity', 'category', 'type', 'connectivity', 'auto_info',
'date_verified', 'created_by', 'date_created', 'date_verified', 'created_by', 'date_created',
'auto_info',
] ]
fields = fields_small + fields_fk + fields_m2m + read_only_fields fields = fields_small + fields_fk + fields_m2m + read_only_fields
fields_unexport = ['auto_info']
extra_kwargs = { extra_kwargs = {
'auto_info': {'label': _('Auto info')}, 'auto_info': {'label': _('Auto info')},
'name': {'label': _("Name")}, 'name': {'label': _("Name")},
@ -150,7 +158,7 @@ class AssetSerializer(BulkOrgResourceModelSerializer, WritableNestedModelSeriali
self._init_field_choices() self._init_field_choices()
def _get_protocols_required_default(self): def _get_protocols_required_default(self):
platform = self._initial_data_platform platform = self._asset_platform
platform_protocols = platform.protocols.all() platform_protocols = platform.protocols.all()
protocols_default = [p for p in platform_protocols if p.default] protocols_default = [p for p in platform_protocols if p.default]
protocols_required = [p for p in platform_protocols if p.required or p.primary] protocols_required = [p for p in platform_protocols if p.required or p.primary]
@ -206,20 +214,22 @@ class AssetSerializer(BulkOrgResourceModelSerializer, WritableNestedModelSeriali
instance.nodes.set(nodes_to_set) instance.nodes.set(nodes_to_set)
@property @property
def _initial_data_platform(self): def _asset_platform(self):
if self.instance:
return self.instance.platform
platform_id = self.initial_data.get('platform') platform_id = self.initial_data.get('platform')
if isinstance(platform_id, dict): if isinstance(platform_id, dict):
platform_id = platform_id.get('id') or platform_id.get('pk') platform_id = platform_id.get('id') or platform_id.get('pk')
platform = Platform.objects.filter(id=platform_id).first()
if not platform_id and self.instance:
platform = self.instance.platform
else:
platform = Platform.objects.filter(id=platform_id).first()
if not platform: if not platform:
raise serializers.ValidationError({'platform': _("Platform not exist")}) raise serializers.ValidationError({'platform': _("Platform not exist")})
return platform return platform
def validate_domain(self, value): def validate_domain(self, value):
platform = self._initial_data_platform platform = self._asset_platform
if platform.domain_enabled: if platform.domain_enabled:
return value return value
else: else:
@ -263,6 +273,8 @@ class AssetSerializer(BulkOrgResourceModelSerializer, WritableNestedModelSeriali
@staticmethod @staticmethod
def accounts_create(accounts_data, asset): def accounts_create(accounts_data, asset):
if not accounts_data:
return
for data in accounts_data: for data in accounts_data:
data['asset'] = asset data['asset'] = asset
AssetAccountSerializer().create(data) AssetAccountSerializer().create(data)

View File

@ -1,26 +1,25 @@
from rest_framework import serializers
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from rest_framework import serializers
from assets.models import Host from assets.models import Host
from .common import AssetSerializer from .common import AssetSerializer
__all__ = ['HostInfoSerializer', 'HostSerializer'] __all__ = ['HostInfoSerializer', 'HostSerializer']
class HostInfoSerializer(serializers.Serializer): class HostInfoSerializer(serializers.Serializer):
vendor = serializers.CharField(max_length=64, required=False, allow_blank=True, label=_('Vendor')) vendor = serializers.CharField(max_length=64, required=False, allow_blank=True, label=_('Vendor'))
model = serializers.CharField(max_length=54, required=False, allow_blank=True, label=_('Model')) model = serializers.CharField(max_length=54, required=False, allow_blank=True, label=_('Model'))
sn = serializers.CharField(max_length=128, required=False, allow_blank=True, label=_('Serial number')) sn = serializers.CharField(max_length=128, required=False, allow_blank=True, label=_('Serial number'))
cpu_model = serializers.ListField(child=serializers.CharField(max_length=64, allow_blank=True), required=False, label=_('CPU model')) cpu_model = serializers.CharField(max_length=64, allow_blank=True, required=False, label=_('CPU model'))
cpu_count = serializers.IntegerField(required=False, label=_('CPU count')) cpu_count = serializers.CharField(max_length=64, required=False, allow_blank=True, label=_('CPU count'))
cpu_cores = serializers.IntegerField(required=False, label=_('CPU cores')) cpu_cores = serializers.CharField(max_length=64, required=False, allow_blank=True, label=_('CPU cores'))
cpu_vcpus = serializers.IntegerField(required=False, label=_('CPU vcpus')) cpu_vcpus = serializers.CharField(max_length=64, required=False, allow_blank=True, label=_('CPU vcpus'))
memory = serializers.CharField(max_length=64, allow_blank=True, required=False, label=_('Memory')) memory = serializers.CharField(max_length=64, allow_blank=True, required=False, label=_('Memory'))
disk_total = serializers.CharField(max_length=1024, allow_blank=True, required=False, label=_('Disk total')) disk_total = serializers.CharField(max_length=1024, allow_blank=True, required=False, label=_('Disk total'))
distribution = serializers.CharField(max_length=128, allow_blank=True, required=False, label=_('OS')) distribution = serializers.CharField(max_length=128, allow_blank=True, required=False, label=_('OS'))
distribution_version = serializers.CharField(max_length=16, allow_blank=True, required=False, label=_('OS version')) distribution_version = serializers.CharField(max_length=16, allow_blank=True, required=False, label=_('OS version'))
arch = serializers.CharField(max_length=16, allow_blank=True, required=False, label=_('OS arch')) arch = serializers.CharField(max_length=16, allow_blank=True, required=False, label=_('OS arch'))
@ -36,5 +35,3 @@ class HostSerializer(AssetSerializer):
'label': _("IP/Host") 'label': _("IP/Host")
}, },
} }

View File

@ -29,7 +29,8 @@ class LabelSerializer(BulkOrgResourceModelSerializer):
@classmethod @classmethod
def setup_eager_loading(cls, queryset): def setup_eager_loading(cls, queryset):
queryset = queryset.annotate(asset_count=Count('assets')) queryset = queryset.prefetch_related('assets') \
.annotate(asset_count=Count('assets'))
return queryset return queryset

View File

@ -1,6 +1,5 @@
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from rest_framework import serializers from rest_framework import serializers
from django.core import validators
from assets.const.web import FillType from assets.const.web import FillType
from common.serializers import WritableNestedModelSerializer from common.serializers import WritableNestedModelSerializer
@ -19,7 +18,7 @@ class ProtocolSettingSerializer(serializers.Serializer):
("nla", "NLA"), ("nla", "NLA"),
] ]
# RDP # RDP
console = serializers.BooleanField(required=False) console = serializers.BooleanField(required=False, default=False)
security = serializers.ChoiceField(choices=SECURITY_CHOICES, default="any") security = serializers.ChoiceField(choices=SECURITY_CHOICES, default="any")
# SFTP # SFTP

View File

@ -8,7 +8,7 @@ from orgs.utils import tmp_to_root_org, tmp_to_org
logger = get_logger(__file__) logger = get_logger(__file__)
def task_activity_callback(self, pid, trigger, tp): def task_activity_callback(self, pid, trigger, tp, *args, **kwargs):
model = AutomationTypes.get_type_model(tp) model = AutomationTypes.get_type_model(tp)
with tmp_to_root_org(): with tmp_to_root_org():
instance = get_object_or_none(model, pk=pid) instance = get_object_or_none(model, pk=pid)

View File

@ -1,14 +1,11 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from urllib.parse import urlencode from urllib.parse import urlencode
from urllib3.exceptions import MaxRetryError, LocationParseError
from kubernetes import client from kubernetes import client
from kubernetes.client import api_client from kubernetes.client import api_client
from kubernetes.client.api import core_v1_api from kubernetes.client.api import core_v1_api
from kubernetes.client.exceptions import ApiException
from common.utils import get_logger from common.utils import get_logger
from common.exceptions import JMSException
from ..const import CloudTypes, Category from ..const import CloudTypes, Category
logger = get_logger(__file__) logger = get_logger(__file__)
@ -20,7 +17,8 @@ class KubernetesClient:
self.token = token self.token = token
self.proxy = proxy self.proxy = proxy
def get_api(self): @property
def api(self):
configuration = client.Configuration() configuration = client.Configuration()
configuration.host = self.url configuration.host = self.url
configuration.proxy = self.proxy configuration.proxy = self.proxy
@ -30,64 +28,29 @@ class KubernetesClient:
api = core_v1_api.CoreV1Api(c) api = core_v1_api.CoreV1Api(c)
return api return api
def get_namespace_list(self): def get_namespaces(self):
api = self.get_api() namespaces = []
namespace_list = [] resp = self.api.list_namespace()
for ns in api.list_namespace().items: for ns in resp.items:
namespace_list.append(ns.metadata.name) namespaces.append(ns.metadata.name)
return namespace_list return namespaces
def get_services(self): def get_pods(self, namespace):
api = self.get_api() pods = []
ret = api.list_service_for_all_namespaces(watch=False) resp = self.api.list_namespaced_pod(namespace)
for i in ret.items: for pd in resp.items:
print("%s \t%s \t%s \t%s \t%s \n" % ( pods.append(pd.metadata.name)
i.kind, i.metadata.namespace, i.metadata.name, i.spec.cluster_ip, i.spec.ports)) return pods
def get_pod_info(self, namespace, pod): def get_containers(self, namespace, pod_name):
api = self.get_api() containers = []
resp = api.read_namespaced_pod(namespace=namespace, name=pod) resp = self.api.read_namespaced_pod(pod_name, namespace)
return resp for container in resp.spec.containers:
containers.append(container.name)
return containers
def get_pod_logs(self, namespace, pod): @staticmethod
api = self.get_api() def get_proxy_url(asset):
log_content = api.read_namespaced_pod_log(pod, namespace, pretty=True, tail_lines=200)
return log_content
def get_pods(self):
api = self.get_api()
try:
ret = api.list_pod_for_all_namespaces(watch=False, _request_timeout=(3, 3))
except LocationParseError as e:
logger.warning("Kubernetes API request url error: {}".format(e))
raise JMSException(code='k8s_tree_error', detail=e)
except MaxRetryError:
msg = "Kubernetes API request timeout"
logger.warning(msg)
raise JMSException(code='k8s_tree_error', detail=msg)
except ApiException as e:
if e.status == 401:
msg = "Kubernetes API request unauthorized"
logger.warning(msg)
else:
msg = e
logger.warning(msg)
raise JMSException(code='k8s_tree_error', detail=msg)
data = {}
for i in ret.items:
namespace = i.metadata.namespace
pod_info = {
'pod_name': i.metadata.name,
'containers': [j.name for j in i.spec.containers]
}
if namespace in data:
data[namespace].append(pod_info)
else:
data[namespace] = [pod_info, ]
return data
@classmethod
def get_proxy_url(cls, asset):
if not asset.domain: if not asset.domain:
return None return None
@ -97,11 +60,14 @@ class KubernetesClient:
return f'{gateway.address}:{gateway.port}' return f'{gateway.address}:{gateway.port}'
@classmethod @classmethod
def get_kubernetes_data(cls, asset, secret): def run(cls, asset, secret, tp, *args):
k8s_url = f'{asset.address}' k8s_url = f'{asset.address}'
proxy_url = cls.get_proxy_url(asset) proxy_url = cls.get_proxy_url(asset)
k8s = cls(k8s_url, secret, proxy=proxy_url) k8s = cls(k8s_url, secret, proxy=proxy_url)
return k8s.get_pods() func_name = f'get_{tp}s'
if hasattr(k8s, func_name):
return getattr(k8s, func_name)(*args)
return []
class KubernetesTree: class KubernetesTree:
@ -117,17 +83,15 @@ class KubernetesTree:
) )
return node return node
def as_namespace_node(self, name, tp, counts=0): def as_namespace_node(self, name, tp):
i = urlencode({'namespace': name}) i = urlencode({'namespace': name})
pid = str(self.asset.id) pid = str(self.asset.id)
name = f'{name}({counts})'
node = self.create_tree_node(i, pid, name, tp, icon='cloud') node = self.create_tree_node(i, pid, name, tp, icon='cloud')
return node return node
def as_pod_tree_node(self, namespace, name, tp, counts=0): def as_pod_tree_node(self, namespace, name, tp):
pid = urlencode({'namespace': namespace}) pid = urlencode({'namespace': namespace})
i = urlencode({'namespace': namespace, 'pod': name}) i = urlencode({'namespace': namespace, 'pod': name})
name = f'{name}({counts})'
node = self.create_tree_node(i, pid, name, tp, icon='cloud') node = self.create_tree_node(i, pid, name, tp, icon='cloud')
return node return node
@ -162,30 +126,26 @@ class KubernetesTree:
def async_tree_node(self, namespace, pod): def async_tree_node(self, namespace, pod):
tree = [] tree = []
data = KubernetesClient.get_kubernetes_data(self.asset, self.secret)
if not data:
return tree
if pod: if pod:
for container in next( tp = 'container'
filter( containers = KubernetesClient.run(
lambda x: x['pod_name'] == pod, data[namespace] self.asset, self.secret, tp, namespace, pod
) )
)['containers']: for container in containers:
container_node = self.as_container_tree_node( container_node = self.as_container_tree_node(
namespace, pod, container, 'container' namespace, pod, container, tp
) )
tree.append(container_node) tree.append(container_node)
elif namespace: elif namespace:
for pod in data[namespace]: tp = 'pod'
pod_nodes = self.as_pod_tree_node( pods = KubernetesClient.run(self.asset, self.secret, tp, namespace)
namespace, pod['pod_name'], 'pod', len(pod['containers']) for pod in pods:
) pod_node = self.as_pod_tree_node(namespace, pod, tp)
tree.append(pod_nodes) tree.append(pod_node)
else: else:
for namespace, pods in data.items(): tp = 'namespace'
namespace_node = self.as_namespace_node( namespaces = KubernetesClient.run(self.asset, self.secret, tp)
namespace, 'namespace', len(pods) for namespace in namespaces:
) namespace_node = self.as_namespace_node(namespace, tp)
tree.append(namespace_node) tree.append(namespace_node)
return tree return tree

View File

@ -10,6 +10,7 @@ from rest_framework.permissions import IsAuthenticated
from common.drf.filters import DatetimeRangeFilter from common.drf.filters import DatetimeRangeFilter
from common.plugins.es import QuerySet as ESQuerySet from common.plugins.es import QuerySet as ESQuerySet
from common.utils import is_uuid from common.utils import is_uuid
from common.utils import lazyproperty
from orgs.mixins.api import OrgReadonlyModelViewSet, OrgModelViewSet from orgs.mixins.api import OrgReadonlyModelViewSet, OrgModelViewSet
from orgs.utils import current_org, tmp_to_root_org from orgs.utils import current_org, tmp_to_root_org
from orgs.models import Organization from orgs.models import Organization
@ -143,13 +144,19 @@ class OperateLogViewSet(OrgReadonlyModelViewSet):
search_fields = ['resource', 'user'] search_fields = ['resource', 'user']
ordering = ['-datetime'] ordering = ['-datetime']
@lazyproperty
def is_action_detail(self):
return self.detail and self.request.query_params.get('type') == 'action_detail'
def get_serializer_class(self): def get_serializer_class(self):
if self.request.query_params.get('type') == 'action_detail': if self.is_action_detail:
return OperateLogActionDetailSerializer return OperateLogActionDetailSerializer
return super().get_serializer_class() return super().get_serializer_class()
def get_queryset(self): def get_queryset(self):
org_q = Q(org_id=Organization.SYSTEM_ID) | Q(org_id=current_org.id) org_q = Q(org_id=current_org.id)
if self.is_action_detail:
org_q |= Q(org_id=Organization.SYSTEM_ID)
with tmp_to_root_org(): with tmp_to_root_org():
qs = OperateLog.objects.filter(org_q) qs = OperateLog.objects.filter(org_q)
es_config = settings.OPERATE_LOG_ELASTICSEARCH_CONFIG es_config = settings.OPERATE_LOG_ELASTICSEARCH_CONFIG

View File

@ -4,7 +4,6 @@ from django.db import transaction
from django.core.cache import cache from django.core.cache import cache
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from users.models import User
from common.utils import get_request_ip, get_logger from common.utils import get_request_ip, get_logger
from common.utils.timezone import as_current_tz from common.utils.timezone import as_current_tz
from common.utils.encode import Singleton from common.utils.encode import Singleton

View File

@ -2,7 +2,7 @@
# #
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework import serializers from rest_framework import serializers
from orgs.mixins.serializers import BulkOrgResourceModelSerializer
from audits.backends.db import OperateLogStore from audits.backends.db import OperateLogStore
from common.serializers.fields import LabeledChoiceField from common.serializers.fields import LabeledChoiceField
from common.utils import reverse, i18n_trans from common.utils import reverse, i18n_trans
@ -78,7 +78,7 @@ class OperateLogActionDetailSerializer(serializers.ModelSerializer):
return data return data
class OperateLogSerializer(serializers.ModelSerializer): class OperateLogSerializer(BulkOrgResourceModelSerializer):
action = LabeledChoiceField(choices=ActionChoices.choices, label=_("Action")) action = LabeledChoiceField(choices=ActionChoices.choices, label=_("Action"))
resource = serializers.SerializerMethodField(label=_("Resource")) resource = serializers.SerializerMethodField(label=_("Resource"))
resource_type = serializers.SerializerMethodField(label=_('Resource Type')) resource_type = serializers.SerializerMethodField(label=_('Resource Type'))

View File

@ -1,13 +1,15 @@
import codecs import codecs
import copy import copy
import csv import csv
from itertools import chain from itertools import chain
from datetime import datetime
from django.db import models from django.db import models
from django.http import HttpResponse from django.http import HttpResponse
from common.utils.timezone import as_current_tz
from common.utils import validate_ip, get_ip_city, get_logger from common.utils import validate_ip, get_ip_city, get_logger
from settings.serializers import SettingsSerializer
from .const import DEFAULT_CITY from .const import DEFAULT_CITY
logger = get_logger(__name__) logger = get_logger(__name__)
@ -70,6 +72,8 @@ def _get_instance_field_value(
f.verbose_name = 'id' f.verbose_name = 'id'
elif isinstance(value, (list, dict)): elif isinstance(value, (list, dict)):
value = copy.deepcopy(value) value = copy.deepcopy(value)
elif isinstance(value, datetime):
value = as_current_tz(value).strftime('%Y-%m-%d %H:%M:%S')
elif isinstance(f, models.OneToOneField) and isinstance(value, models.Model): elif isinstance(f, models.OneToOneField) and isinstance(value, models.Model):
nested_data = _get_instance_field_value( nested_data = _get_instance_field_value(
value, include_model_fields, model_need_continue_fields, ('id',) value, include_model_fields, model_need_continue_fields, ('id',)

View File

@ -24,7 +24,7 @@ from orgs.mixins.api import RootOrgViewMixin
from perms.models import ActionChoices from perms.models import ActionChoices
from terminal.connect_methods import NativeClient, ConnectMethodUtil from terminal.connect_methods import NativeClient, ConnectMethodUtil
from terminal.models import EndpointRule from terminal.models import EndpointRule
from ..models import ConnectionToken from ..models import ConnectionToken, date_expired_default
from ..serializers import ( from ..serializers import (
ConnectionTokenSerializer, ConnectionTokenSecretSerializer, ConnectionTokenSerializer, ConnectionTokenSecretSerializer,
SuperConnectionTokenSerializer, ConnectTokenAppletOptionSerializer SuperConnectionTokenSerializer, ConnectTokenAppletOptionSerializer
@ -172,6 +172,7 @@ class ExtraActionApiMixin(RDPFileClientProtocolURLMixin):
get_object: callable get_object: callable
get_serializer: callable get_serializer: callable
perform_create: callable perform_create: callable
validate_exchange_token: callable
@action(methods=['POST', 'GET'], detail=True, url_path='rdp-file') @action(methods=['POST', 'GET'], detail=True, url_path='rdp-file')
def get_rdp_file(self, *args, **kwargs): def get_rdp_file(self, *args, **kwargs):
@ -204,6 +205,18 @@ class ExtraActionApiMixin(RDPFileClientProtocolURLMixin):
instance.expire() instance.expire()
return Response(status=status.HTTP_204_NO_CONTENT) return Response(status=status.HTTP_204_NO_CONTENT)
@action(methods=['POST'], detail=False)
def exchange(self, request, *args, **kwargs):
pk = request.data.get('id', None) or request.data.get('pk', None)
# 只能兑换自己使用的 Token
instance = get_object_or_404(ConnectionToken, pk=pk, user=request.user)
instance.id = None
self.validate_exchange_token(instance)
instance.date_expired = date_expired_default()
instance.save()
serializer = self.get_serializer(instance)
return Response(serializer.data, status=status.HTTP_201_CREATED)
class ConnectionTokenViewSet(ExtraActionApiMixin, RootOrgViewMixin, JMSModelViewSet): class ConnectionTokenViewSet(ExtraActionApiMixin, RootOrgViewMixin, JMSModelViewSet):
filterset_fields = ( filterset_fields = (
@ -217,6 +230,7 @@ class ConnectionTokenViewSet(ExtraActionApiMixin, RootOrgViewMixin, JMSModelView
'list': 'authentication.view_connectiontoken', 'list': 'authentication.view_connectiontoken',
'retrieve': 'authentication.view_connectiontoken', 'retrieve': 'authentication.view_connectiontoken',
'create': 'authentication.add_connectiontoken', 'create': 'authentication.add_connectiontoken',
'exchange': 'authentication.add_connectiontoken',
'expire': 'authentication.change_connectiontoken', 'expire': 'authentication.change_connectiontoken',
'get_rdp_file': 'authentication.add_connectiontoken', 'get_rdp_file': 'authentication.add_connectiontoken',
'get_client_protocol_url': 'authentication.add_connectiontoken', 'get_client_protocol_url': 'authentication.add_connectiontoken',
@ -240,10 +254,24 @@ class ConnectionTokenViewSet(ExtraActionApiMixin, RootOrgViewMixin, JMSModelView
user = self.get_user(serializer) user = self.get_user(serializer)
asset = data.get('asset') asset = data.get('asset')
account_name = data.get('account') account_name = data.get('account')
_data = self._validate(user, asset, account_name)
data.update(_data)
return serializer
def validate_exchange_token(self, token):
user = token.user
asset = token.asset
account_name = token.account
_data = self._validate(user, asset, account_name)
for k, v in _data.items():
setattr(token, k, v)
return token
def _validate(self, user, asset, account_name):
data = dict()
data['org_id'] = asset.org_id data['org_id'] = asset.org_id
data['user'] = user data['user'] = user
data['value'] = random_string(16) data['value'] = random_string(16)
account = self._validate_perm(user, asset, account_name) account = self._validate_perm(user, asset, account_name)
if account.has_secret: if account.has_secret:
data['input_secret'] = '' data['input_secret'] = ''
@ -257,8 +285,7 @@ class ConnectionTokenViewSet(ExtraActionApiMixin, RootOrgViewMixin, JMSModelView
if ticket: if ticket:
data['from_ticket'] = ticket data['from_ticket'] = ticket
data['is_active'] = False data['is_active'] = False
return data
return account
@staticmethod @staticmethod
def _validate_perm(user, asset, account_name): def _validate_perm(user, asset, account_name):

View File

@ -225,6 +225,7 @@ class MFAMixin:
self.request.session['auth_mfa_time'] = time.time() self.request.session['auth_mfa_time'] = time.time()
self.request.session['auth_mfa_required'] = 0 self.request.session['auth_mfa_required'] = 0
self.request.session['auth_mfa_type'] = mfa_type self.request.session['auth_mfa_type'] = mfa_type
MFABlockUtils(self.request.user.username, self.get_request_ip()).clean_failed_count()
def clean_mfa_mark(self): def clean_mfa_mark(self):
keys = ['auth_mfa', 'auth_mfa_time', 'auth_mfa_required', 'auth_mfa_type'] keys = ['auth_mfa', 'auth_mfa_time', 'auth_mfa_required', 'auth_mfa_type']

View File

@ -222,7 +222,8 @@ class ConnectionToken(JMSOrgBaseModel):
'secret_type': account.secret_type, 'secret_type': account.secret_type,
'secret': account.secret or self.input_secret, 'secret': account.secret or self.input_secret,
'su_from': account.su_from, 'su_from': account.su_from,
'org_id': account.org_id 'org_id': account.org_id,
'privileged': account.privileged
} }
return Account(**data) return Account(**data)

View File

@ -60,7 +60,7 @@ class FeiShuQRMixin(UserConfirmRequiredExceptionMixin, PermissionsMixin, View):
'state': state, 'state': state,
'redirect_uri': redirect_uri, 'redirect_uri': redirect_uri,
} }
url = URL.AUTHEN + '?' + urlencode(params) url = URL().authen + '?' + urlencode(params)
return url return url
@staticmethod @staticmethod

View File

@ -6,6 +6,7 @@ import os
import datetime import datetime
from typing import Callable from typing import Callable
from django.db import IntegrityError
from django.templatetags.static import static from django.templatetags.static import static
from django.contrib.auth import login as auth_login, logout as auth_logout from django.contrib.auth import login as auth_login, logout as auth_logout
from django.http import HttpResponse, HttpRequest from django.http import HttpResponse, HttpRequest
@ -229,6 +230,23 @@ class UserLoginView(mixins.AuthMixin, UserLoginContextMixin, FormView):
) as e: ) as e:
form.add_error('code', e.msg) form.add_error('code', e.msg)
return super().form_invalid(form) return super().form_invalid(form)
except (IntegrityError,) as e:
# (1062, "Duplicate entry 'youtester001@example.com' for key 'users_user.email'")
error = str(e)
if len(e.args) < 2:
form.add_error(None, error)
return super().form_invalid(form)
msg_list = e.args[1].split("'")
if len(msg_list) < 4:
form.add_error(None, error)
return super().form_invalid(form)
email, field = msg_list[1], msg_list[3]
if field == 'users_user.email':
error = _('User email already exists ({})').format(email)
form.add_error(None, error)
return super().form_invalid(form)
self.clear_rsa_key() self.clear_rsa_key()
return self.redirect_to_guard_view() return self.redirect_to_guard_view()

View File

@ -32,11 +32,14 @@ class UserLoginMFAView(mixins.AuthMixin, FormView):
return super().get(*args, **kwargs) return super().get(*args, **kwargs)
def form_valid(self, form): def form_valid(self, form):
from users.utils import MFABlockUtils
code = form.cleaned_data.get('code') code = form.cleaned_data.get('code')
mfa_type = form.cleaned_data.get('mfa_type') mfa_type = form.cleaned_data.get('mfa_type')
try: try:
self._do_check_user_mfa(code, mfa_type) self._do_check_user_mfa(code, mfa_type)
user, ip = self.get_user_from_session(), self.get_request_ip()
MFABlockUtils(user.username, ip).clean_failed_count()
return redirect_to_guard_view('mfa_ok') return redirect_to_guard_view('mfa_ok')
except (errors.MFAFailedError, errors.BlockMFAError) as e: except (errors.MFAFailedError, errors.BlockMFAError) as e:
form.add_error('code', e.msg) form.add_error('code', e.msg)

View File

@ -2,4 +2,3 @@ from __future__ import absolute_import
# This will make sure the app is always imported when # This will make sure the app is always imported when
# Django starts so that shared_task will use this app. # Django starts so that shared_task will use this app.

View File

@ -3,13 +3,12 @@
from typing import Callable from typing import Callable
from django.utils.translation import ugettext as _ from django.utils.translation import ugettext as _
from rest_framework.response import Response
from rest_framework.decorators import action from rest_framework.decorators import action
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.response import Response
from common.const.http import POST from common.const.http import POST
__all__ = ['SuggestionMixin', 'RenderToJsonMixin'] __all__ = ['SuggestionMixin', 'RenderToJsonMixin']

View File

@ -1,11 +1,15 @@
import abc import abc
import json
import codecs import codecs
from rest_framework import serializers import json
import re
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework.parsers import BaseParser from rest_framework import serializers
from rest_framework import status from rest_framework import status
from rest_framework.exceptions import ParseError, APIException from rest_framework.exceptions import ParseError, APIException
from rest_framework.parsers import BaseParser
from common.serializers.fields import ObjectRelatedField
from common.utils import get_logger from common.utils import get_logger
logger = get_logger(__file__) logger = get_logger(__file__)
@ -18,11 +22,11 @@ class FileContentOverflowedError(APIException):
class BaseFileParser(BaseParser): class BaseFileParser(BaseParser):
FILE_CONTENT_MAX_LENGTH = 1024 * 1024 * 10 FILE_CONTENT_MAX_LENGTH = 1024 * 1024 * 10
serializer_cls = None serializer_cls = None
serializer_fields = None serializer_fields = None
obj_pattern = re.compile(r'^(.+)\(([a-z0-9-]+)\)$')
def check_content_length(self, meta): def check_content_length(self, meta):
content_length = int(meta.get('CONTENT_LENGTH', meta.get('HTTP_CONTENT_LENGTH', 0))) content_length = int(meta.get('CONTENT_LENGTH', meta.get('HTTP_CONTENT_LENGTH', 0)))
@ -74,7 +78,7 @@ class BaseFileParser(BaseParser):
return s.translate(trans_table) return s.translate(trans_table)
@classmethod @classmethod
def process_row(cls, row): def load_row(cls, row):
""" """
构建json数据前的行处理 构建json数据前的行处理
""" """
@ -84,33 +88,63 @@ class BaseFileParser(BaseParser):
col = cls._replace_chinese_quote(col) col = cls._replace_chinese_quote(col)
# 列表/字典转换 # 列表/字典转换
if isinstance(col, str) and ( if isinstance(col, str) and (
(col.startswith('[') and col.endswith(']')) (col.startswith('[') and col.endswith(']')) or
or
(col.startswith("{") and col.endswith("}")) (col.startswith("{") and col.endswith("}"))
): ):
col = json.loads(col) try:
col = json.loads(col)
except json.JSONDecodeError as e:
logger.error('Json load error: ', e)
logger.error('col: ', col)
new_row.append(col) new_row.append(col)
return new_row return new_row
def id_name_to_obj(self, v):
if not v or not isinstance(v, str):
return v
matched = self.obj_pattern.match(v)
if not matched:
return v
obj_name, obj_id = matched.groups()
if len(obj_id) < 36:
obj_id = int(obj_id)
return {'pk': obj_id, 'name': obj_name}
def parse_value(self, field, value):
if value is '-':
return None
elif hasattr(field, 'to_file_internal_value'):
value = field.to_file_internal_value(value)
elif isinstance(field, serializers.BooleanField):
value = value.lower() in ['true', '1', 'yes']
elif isinstance(field, serializers.ChoiceField):
value = value
elif isinstance(field, ObjectRelatedField):
if field.many:
value = [self.id_name_to_obj(v) for v in value]
else:
value = self.id_name_to_obj(value)
elif isinstance(field, serializers.ListSerializer):
value = [self.parse_value(field.child, v) for v in value]
elif isinstance(field, serializers.Serializer):
value = self.id_name_to_obj(value)
elif isinstance(field, serializers.ManyRelatedField):
value = [self.parse_value(field.child_relation, v) for v in value]
elif isinstance(field, serializers.ListField):
value = [self.parse_value(field.child, v) for v in value]
return value
def process_row_data(self, row_data): def process_row_data(self, row_data):
""" """
构建json数据后的行数据处理 构建json数据后的行数据处理
""" """
new_row_data = {} new_row = {}
serializer_fields = self.serializer_fields
for k, v in row_data.items(): for k, v in row_data.items():
if type(v) in [list, dict, int, bool] or (isinstance(v, str) and k.strip() and v.strip()): field = self.serializer_fields.get(k)
# 处理类似disk_info为字符串的'{}'的问题 v = self.parse_value(field, v)
if not isinstance(v, str) and isinstance(serializer_fields[k], serializers.CharField): new_row[k] = v
v = str(v) return new_row
# 处理 BooleanField 的问题, 导出是 'True', 'False'
if isinstance(v, str) and v.strip().lower() == 'true':
v = True
elif isinstance(v, str) and v.strip().lower() == 'false':
v = False
new_row_data[k] = v
return new_row_data
def generate_data(self, fields_name, rows): def generate_data(self, fields_name, rows):
data = [] data = []
@ -118,7 +152,7 @@ class BaseFileParser(BaseParser):
# 空行不处理 # 空行不处理
if not any(row): if not any(row):
continue continue
row = self.process_row(row) row = self.load_row(row)
row_data = dict(zip(fields_name, row)) row_data = dict(zip(fields_name, row))
row_data = self.process_row_data(row_data) row_data = self.process_row_data(row_data)
data.append(row_data) data.append(row_data)
@ -139,7 +173,6 @@ class BaseFileParser(BaseParser):
raise ParseError('The resource does not support imports!') raise ParseError('The resource does not support imports!')
self.check_content_length(meta) self.check_content_length(meta)
try: try:
stream_data = self.get_stream_data(stream) stream_data = self.get_stream_data(stream)
rows = self.generate_rows(stream_data) rows = self.generate_rows(stream_data)
@ -148,6 +181,7 @@ class BaseFileParser(BaseParser):
# 给 `common.mixins.api.RenderToJsonMixin` 提供,暂时只能耦合 # 给 `common.mixins.api.RenderToJsonMixin` 提供,暂时只能耦合
column_title_field_pairs = list(zip(column_titles, field_names)) column_title_field_pairs = list(zip(column_titles, field_names))
column_title_field_pairs = [(k, v) for k, v in column_title_field_pairs if k and v]
if not hasattr(request, 'jms_context'): if not hasattr(request, 'jms_context'):
request.jms_context = {} request.jms_context = {}
request.jms_context['column_title_field_pairs'] = column_title_field_pairs request.jms_context['column_title_field_pairs'] = column_title_field_pairs
@ -157,4 +191,3 @@ class BaseFileParser(BaseParser):
except Exception as e: except Exception as e:
logger.error(e, exc_info=True) logger.error(e, exc_info=True)
raise ParseError(_('Parse file error: {}').format(e)) raise ParseError(_('Parse file error: {}').format(e))

View File

@ -1,13 +1,17 @@
import pyexcel import pyexcel
from django.utils.translation import gettext as _
from .base import BaseFileParser from .base import BaseFileParser
class ExcelFileParser(BaseFileParser): class ExcelFileParser(BaseFileParser):
media_type = 'text/xlsx' media_type = 'text/xlsx'
def generate_rows(self, stream_data): def generate_rows(self, stream_data):
workbook = pyexcel.get_book(file_type='xlsx', file_content=stream_data) try:
workbook = pyexcel.get_book(file_type='xlsx', file_content=stream_data)
except Exception:
raise Exception(_('Invalid excel file'))
# 默认获取第一个工作表sheet # 默认获取第一个工作表sheet
sheet = workbook.sheet_by_index(0) sheet = workbook.sheet_by_index(0)
rows = sheet.rows() rows = sheet.rows()

View File

@ -1,8 +1,11 @@
import abc import abc
from datetime import datetime from datetime import datetime
from rest_framework import serializers
from rest_framework.renderers import BaseRenderer from rest_framework.renderers import BaseRenderer
from rest_framework.utils import encoders, json from rest_framework.utils import encoders, json
from common.serializers.fields import ObjectRelatedField, LabeledChoiceField
from common.utils import get_logger from common.utils import get_logger
logger = get_logger(__file__) logger = get_logger(__file__)
@ -38,18 +41,27 @@ class BaseFileRenderer(BaseRenderer):
def get_rendered_fields(self): def get_rendered_fields(self):
fields = self.serializer.fields fields = self.serializer.fields
if self.template == 'import': if self.template == 'import':
return [v for k, v in fields.items() if not v.read_only and k != "org_id" and k != 'id'] fields = [v for k, v in fields.items() if not v.read_only and k != "org_id" and k != 'id']
elif self.template == 'update': elif self.template == 'update':
return [v for k, v in fields.items() if not v.read_only and k != "org_id"] fields = [v for k, v in fields.items() if not v.read_only and k != "org_id"]
else: else:
return [v for k, v in fields.items() if not v.write_only and k != "org_id"] fields = [v for k, v in fields.items() if not v.write_only and k != "org_id"]
meta = getattr(self.serializer, 'Meta', None)
if meta:
fields_unexport = getattr(meta, 'fields_unexport', [])
fields = [v for v in fields if v.field_name not in fields_unexport]
return fields
@staticmethod @staticmethod
def get_column_titles(render_fields): def get_column_titles(render_fields):
return [ titles = []
'*{}'.format(field.label) if field.required else str(field.label) for field in render_fields:
for field in render_fields name = field.label
] if field.required:
name = '*' + name
titles.append(name)
return titles
def process_data(self, data): def process_data(self, data):
results = data['results'] if 'results' in data else data results = data['results'] if 'results' in data else data
@ -59,7 +71,6 @@ class BaseFileRenderer(BaseRenderer):
if self.template == 'import': if self.template == 'import':
results = [results[0]] if results else results results = [results[0]] if results else results
else: else:
# 限制数据数量 # 限制数据数量
results = results[:10000] results = results[:10000]
@ -68,17 +79,53 @@ class BaseFileRenderer(BaseRenderer):
return results return results
@staticmethod @staticmethod
def generate_rows(data, render_fields): def to_id_name(value):
if value is None:
return '-'
pk = str(value.get('id', '') or value.get('pk', ''))
name = value.get('name') or value.get('display_name', '')
return '{}({})'.format(name, pk)
@staticmethod
def to_choice_name(value):
if value is None:
return '-'
value = value.get('value', '')
return value
def render_value(self, field, value):
if value is None:
value = '-'
elif hasattr(field, 'to_file_representation'):
value = field.to_file_representation(value)
elif isinstance(value, bool):
value = 'Yes' if value else 'No'
elif isinstance(field, LabeledChoiceField):
value = value.get('value', '')
elif isinstance(field, ObjectRelatedField):
if field.many:
value = [self.to_id_name(v) for v in value]
else:
value = self.to_id_name(value)
elif isinstance(field, serializers.ListSerializer):
value = [self.render_value(field.child, v) for v in value]
elif isinstance(field, serializers.Serializer) and value.get('id'):
value = self.to_id_name(value)
elif isinstance(field, serializers.ManyRelatedField):
value = [self.render_value(field.child_relation, v) for v in value]
elif isinstance(field, serializers.ListField):
value = [self.render_value(field.child, v) for v in value]
if not isinstance(value, str):
value = json.dumps(value, cls=encoders.JSONEncoder, ensure_ascii=False)
return str(value)
def generate_rows(self, data, render_fields):
for item in data: for item in data:
row = [] row = []
for field in render_fields: for field in render_fields:
value = item.get(field.field_name) value = item.get(field.field_name)
if value is None: value = self.render_value(field, value)
value = ''
elif isinstance(value, dict):
value = json.dumps(value, ensure_ascii=False)
else:
value = str(value)
row.append(value) row.append(value)
yield row yield row
@ -101,6 +148,9 @@ class BaseFileRenderer(BaseRenderer):
def get_rendered_value(self): def get_rendered_value(self):
raise NotImplementedError raise NotImplementedError
def after_render(self):
pass
def render(self, data, accepted_media_type=None, renderer_context=None): def render(self, data, accepted_media_type=None, renderer_context=None):
if data is None: if data is None:
return bytes() return bytes()
@ -129,11 +179,10 @@ class BaseFileRenderer(BaseRenderer):
self.initial_writer() self.initial_writer()
self.write_column_titles(column_titles) self.write_column_titles(column_titles)
self.write_rows(rows) self.write_rows(rows)
self.after_render()
value = self.get_rendered_value() value = self.get_rendered_value()
except Exception as e: except Exception as e:
logger.debug(e, exc_info=True) logger.debug(e, exc_info=True)
value = 'Render error! ({})'.format(self.media_type).encode('utf-8') value = 'Render error! ({})'.format(self.media_type).encode('utf-8')
return value return value
return value return value

View File

@ -1,6 +1,6 @@
from openpyxl import Workbook from openpyxl import Workbook
from openpyxl.writer.excel import save_virtual_workbook
from openpyxl.cell.cell import ILLEGAL_CHARACTERS_RE from openpyxl.cell.cell import ILLEGAL_CHARACTERS_RE
from openpyxl.writer.excel import save_virtual_workbook
from .base import BaseFileRenderer from .base import BaseFileRenderer
@ -19,12 +19,26 @@ class ExcelFileRenderer(BaseFileRenderer):
def write_row(self, row): def write_row(self, row):
self.row_count += 1 self.row_count += 1
self.ws.row_dimensions[self.row_count].height = 20
column_count = 0 column_count = 0
for cell_value in row: for cell_value in row:
# 处理非法字符 # 处理非法字符
column_count += 1 column_count += 1
cell_value = ILLEGAL_CHARACTERS_RE.sub(r'', cell_value) cell_value = ILLEGAL_CHARACTERS_RE.sub(r'', str(cell_value))
self.ws.cell(row=self.row_count, column=column_count, value=cell_value) self.ws.cell(row=self.row_count, column=column_count, value=str(cell_value))
def after_render(self):
for col in self.ws.columns:
max_length = 0
column = col[0].column_letter
for cell in col:
if len(str(cell.value)) > max_length:
max_length = len(cell.value)
adjusted_width = (max_length + 2) * 1.0
adjusted_width = 300 if adjusted_width > 300 else adjusted_width
adjusted_width = 30 if adjusted_width < 30 else adjusted_width
self.ws.column_dimensions[column].width = adjusted_width
self.wb.save('/tmp/test.xlsx')
def get_rendered_value(self): def get_rendered_value(self):
value = save_virtual_workbook(self.wb) value = save_virtual_workbook(self.wb)

View File

@ -1,7 +1,10 @@
from werkzeug.local import Local from werkzeug.local import Local
from django.utils import translation
thread_local = Local() thread_local = Local()
encrypted_field_set = set() encrypted_field_set = {'password', 'secret'}
def _find(attr): def _find(attr):
@ -10,4 +13,5 @@ def _find(attr):
def add_encrypted_field_set(label): def add_encrypted_field_set(label):
if label: if label:
encrypted_field_set.add(str(label)) with translation.override('en'):
encrypted_field_set.add(str(label))

View File

@ -114,26 +114,28 @@ class ES(object):
self._ensure_index_exists() self._ensure_index_exists()
def _ensure_index_exists(self): def _ensure_index_exists(self):
info = self.es.info()
version = info['version']['number'].split('.')[0]
if version == '6':
mappings = {'mappings': {'data': {'properties': self.properties}}}
else:
mappings = {'mappings': {'properties': self.properties}}
if self.is_index_by_date:
mappings['aliases'] = {
self.query_index: {}
}
try: try:
self.es.indices.create(self.index, body=mappings) info = self.es.info()
return version = info['version']['number'].split('.')[0]
except RequestError as e: if version == '6':
if e.error == 'resource_already_exists_exception': mappings = {'mappings': {'data': {'properties': self.properties}}}
logger.warning(e)
else: else:
logger.exception(e) mappings = {'mappings': {'properties': self.properties}}
if self.is_index_by_date:
mappings['aliases'] = {
self.query_index: {}
}
try:
self.es.indices.create(self.index, body=mappings)
except RequestError as e:
if e.error == 'resource_already_exists_exception':
logger.warning(e)
else:
logger.exception(e)
except Exception as e:
logger.error(e, exc_info=True)
def make_data(self, data): def make_data(self, data):
return [] return []

View File

@ -3,6 +3,7 @@ import json
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework.exceptions import APIException from rest_framework.exceptions import APIException
from django.conf import settings
from common.utils.common import get_logger from common.utils.common import get_logger
from common.sdk.im.utils import digest from common.sdk.im.utils import digest
from common.sdk.im.mixin import RequestMixin, BaseRequest from common.sdk.im.mixin import RequestMixin, BaseRequest
@ -11,14 +12,30 @@ logger = get_logger(__name__)
class URL: class URL:
AUTHEN = 'https://open.feishu.cn/open-apis/authen/v1/index'
GET_TOKEN = 'https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal/'
# https://open.feishu.cn/document/ukTMukTMukTM/uEDO4UjLxgDO14SM4gTN # https://open.feishu.cn/document/ukTMukTMukTM/uEDO4UjLxgDO14SM4gTN
GET_USER_INFO_BY_CODE = 'https://open.feishu.cn/open-apis/authen/v1/access_token' @property
def host(self):
if settings.FEISHU_VERSION == 'feishu':
h = 'https://open.feishu.cn'
else:
h = 'https://open.larksuite.com'
return h
SEND_MESSAGE = 'https://open.feishu.cn/open-apis/im/v1/messages' @property
def authen(self):
return f'{self.host}/open-apis/authen/v1/index'
@property
def get_token(self):
return f'{self.host}/open-apis/auth/v3/tenant_access_token/internal/'
@property
def get_user_info_by_code(self):
return f'{self.host}/open-apis/authen/v1/access_token'
@property
def send_message(self):
return f'{self.host}/open-apis/im/v1/messages'
class ErrorCode: class ErrorCode:
@ -51,7 +68,7 @@ class FeishuRequests(BaseRequest):
def request_access_token(self): def request_access_token(self):
data = {'app_id': self._app_id, 'app_secret': self._app_secret} data = {'app_id': self._app_id, 'app_secret': self._app_secret}
response = self.raw_request('post', url=URL.GET_TOKEN, data=data) response = self.raw_request('post', url=URL().get_token, data=data)
self.check_errcode_is_0(response) self.check_errcode_is_0(response)
access_token = response['tenant_access_token'] access_token = response['tenant_access_token']
@ -86,7 +103,7 @@ class FeiShu(RequestMixin):
'code': code 'code': code
} }
data = self._requests.post(URL.GET_USER_INFO_BY_CODE, json=body, check_errcode_is_0=False) data = self._requests.post(URL().get_user_info_by_code, json=body, check_errcode_is_0=False)
self._requests.check_errcode_is_0(data) self._requests.check_errcode_is_0(data)
return data['data']['user_id'] return data['data']['user_id']
@ -107,7 +124,7 @@ class FeiShu(RequestMixin):
try: try:
logger.info(f'Feishu send text: user_ids={user_ids} msg={msg}') logger.info(f'Feishu send text: user_ids={user_ids} msg={msg}')
self._requests.post(URL.SEND_MESSAGE, params=params, json=body) self._requests.post(URL().send_message, params=params, json=body)
except APIException as e: except APIException as e:
# 只处理可预知的错误 # 只处理可预知的错误
logger.exception(e) logger.exception(e)

View File

@ -55,9 +55,11 @@ class BulkSerializerMixin(object):
# add update_lookup_field field back to validated data # add update_lookup_field field back to validated data
# since super by default strips out read-only fields # since super by default strips out read-only fields
# hence id will no longer be present in validated_data # hence id will no longer be present in validated_data
if all((isinstance(self.root, BulkListSerializer), if all([
id_attr, isinstance(self.root, BulkListSerializer),
request_method in ('PUT', 'PATCH'))): id_attr,
request_method in ('PUT', 'PATCH')
]):
id_field = self.fields.get("id") or self.fields.get('pk') id_field = self.fields.get("id") or self.fields.get('pk')
if data.get("id"): if data.get("id"):
id_value = id_field.to_internal_value(data.get("id")) id_value = id_field.to_internal_value(data.get("id"))
@ -135,7 +137,7 @@ class BulkListSerializerMixin:
pk = item["pk"] pk = item["pk"]
else: else:
raise ValidationError("id or pk not in data") raise ValidationError("id or pk not in data")
child = self.instance.get(id=pk) child = self.instance.get(pk=pk)
self.child.instance = child self.child.instance = child
self.child.initial_data = item self.child.initial_data = item
# raw # raw

View File

@ -32,7 +32,7 @@ class Counter:
return self.counter == other.counter return self.counter == other.counter
def on_request_finished_logging_db_query(sender, **kwargs): def digest_sql_query():
queries = connection.queries queries = connection.queries
counters = defaultdict(Counter) counters = defaultdict(Counter)
table_queries = defaultdict(list) table_queries = defaultdict(list)
@ -79,6 +79,9 @@ def on_request_finished_logging_db_query(sender, **kwargs):
counter.counter, counter.time, name) counter.counter, counter.time, name)
) )
def on_request_finished_logging_db_query(sender, **kwargs):
digest_sql_query()
on_request_finished_release_local(sender, **kwargs) on_request_finished_release_local(sender, **kwargs)

View File

@ -10,7 +10,7 @@ from .utils import get_logger
logger = get_logger(__file__) logger = get_logger(__file__)
def task_activity_callback(self, subject, message, recipient_list, **kwargs): def task_activity_callback(self, subject, message, recipient_list, *args, **kwargs):
from users.models import User from users.models import User
email_list = recipient_list email_list = recipient_list
resource_ids = list(User.objects.filter(email__in=email_list).values_list('id', flat=True)) resource_ids = list(User.objects.filter(email__in=email_list).values_list('id', flat=True))

View File

@ -108,7 +108,7 @@ class Subscription:
try: try:
self.sub.close() self.sub.close()
except Exception as e: except Exception as e:
logger.error('Unsubscribe msg error: {}'.format(e)) logger.debug('Unsubscribe msg error: {}'.format(e))
def retry(self, _next, error, complete): def retry(self, _next, error, complete):
logger.info('Retry subscribe channel: {}'.format(self.ch)) logger.info('Retry subscribe channel: {}'.format(self.ch))

View File

@ -98,7 +98,7 @@ def ssh_private_key_gen(private_key, password=None):
def ssh_pubkey_gen(private_key=None, username='jumpserver', hostname='localhost', password=None): def ssh_pubkey_gen(private_key=None, username='jumpserver', hostname='localhost', password=None):
private_key = ssh_private_key_gen(private_key, password=password) private_key = ssh_private_key_gen(private_key, password=password)
if not isinstance(private_key, (paramiko.RSAKey, paramiko.DSSKey)): if not isinstance(private_key, _supported_paramiko_ssh_key_types):
raise IOError('Invalid private key') raise IOError('Invalid private key')
public_key = "%(key_type)s %(key_content)s %(username)s@%(hostname)s" % { public_key = "%(key_type)s %(key_content)s %(username)s@%(hostname)s" % {

View File

@ -35,7 +35,10 @@ def i18n_trans(s):
tpl, args = s.split(' % ', 1) tpl, args = s.split(' % ', 1)
args = args.split(', ') args = args.split(', ')
args = [gettext(arg) for arg in args] args = [gettext(arg) for arg in args]
return gettext(tpl) % tuple(args) try:
return gettext(tpl) % tuple(args)
except TypeError:
return gettext(tpl)
def hello(): def hello():

View File

@ -214,7 +214,7 @@ class Config(dict):
'REDIS_DB_WS': 6, 'REDIS_DB_WS': 6,
'GLOBAL_ORG_DISPLAY_NAME': '', 'GLOBAL_ORG_DISPLAY_NAME': '',
'SITE_URL': 'http://localhost:8080', 'SITE_URL': 'http://127.0.0.1',
'USER_GUIDE_URL': '', 'USER_GUIDE_URL': '',
'ANNOUNCEMENT_ENABLED': True, 'ANNOUNCEMENT_ENABLED': True,
'ANNOUNCEMENT': {}, 'ANNOUNCEMENT': {},
@ -376,6 +376,7 @@ class Config(dict):
'AUTH_FEISHU': False, 'AUTH_FEISHU': False,
'FEISHU_APP_ID': '', 'FEISHU_APP_ID': '',
'FEISHU_APP_SECRET': '', 'FEISHU_APP_SECRET': '',
'FEISHU_VERSION': 'feishu',
'LOGIN_REDIRECT_TO_BACKEND': '', # 'OPENID / CAS / SAML2 'LOGIN_REDIRECT_TO_BACKEND': '', # 'OPENID / CAS / SAML2
'LOGIN_REDIRECT_MSG_ENABLED': True, 'LOGIN_REDIRECT_MSG_ENABLED': True,

View File

@ -20,7 +20,7 @@ default_context = {
'LOGIN_WECOM_logo_logout': static('img/login_wecom_logo.png'), 'LOGIN_WECOM_logo_logout': static('img/login_wecom_logo.png'),
'LOGIN_DINGTALK_logo_logout': static('img/login_dingtalk_logo.png'), 'LOGIN_DINGTALK_logo_logout': static('img/login_dingtalk_logo.png'),
'LOGIN_FEISHU_logo_logout': static('img/login_feishu_logo.png'), 'LOGIN_FEISHU_logo_logout': static('img/login_feishu_logo.png'),
'COPYRIGHT': 'FIT2CLOUD 飞致云' + ' © 2014-2022', 'COPYRIGHT': 'FIT2CLOUD 飞致云' + ' © 2014-2023',
'INTERFACE': default_interface, 'INTERFACE': default_interface,
} }

View File

@ -137,6 +137,7 @@ DINGTALK_APPSECRET = CONFIG.DINGTALK_APPSECRET
AUTH_FEISHU = CONFIG.AUTH_FEISHU AUTH_FEISHU = CONFIG.AUTH_FEISHU
FEISHU_APP_ID = CONFIG.FEISHU_APP_ID FEISHU_APP_ID = CONFIG.FEISHU_APP_ID
FEISHU_APP_SECRET = CONFIG.FEISHU_APP_SECRET FEISHU_APP_SECRET = CONFIG.FEISHU_APP_SECRET
FEISHU_VERSION = CONFIG.FEISHU_VERSION
# Saml2 auth # Saml2 auth
AUTH_SAML2 = CONFIG.AUTH_SAML2 AUTH_SAML2 = CONFIG.AUTH_SAML2

View File

@ -1,10 +1,11 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
from functools import partial
from werkzeug.local import LocalProxy
from datetime import datetime from datetime import datetime
from functools import partial
from django.conf import settings from django.conf import settings
from werkzeug.local import LocalProxy
from common.local import thread_local from common.local import thread_local
@ -34,7 +35,7 @@ def get_xpack_license_info() -> dict:
corporation = info.get('corporation', '') corporation = info.get('corporation', '')
else: else:
current_year = datetime.now().year current_year = datetime.now().year
corporation = f'Copyright - FIT2CLOUD 飞致云 © 2014-{current_year}' corporation = f'FIT2CLOUD 飞致云 © 2014-{current_year}'
info = { info = {
'corporation': corporation 'corporation': corporation
} }

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:8c2600b7094db2a9e64862169ff1c826d5064fae9b9e71744545a1cea88cbc65 oid sha256:6fa80b59b9b5f95a9cfcad8ec47eacd519bb962d139ab90463795a7b306a0a72
size 136280 size 137935

File diff suppressed because it is too large Load Diff

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:a29193d2982b254444285cfb2d61f7ef7355ae2bab181cdf366446e879ab32fb oid sha256:9819889a6d8b2934b06c5b242e3f63f404997f30851919247a405f542e8a03bc
size 111963 size 113244

File diff suppressed because it is too large Load Diff

View File

@ -17,7 +17,7 @@ class JMSInventory:
:param account_policy: privileged_only, privileged_first, skip :param account_policy: privileged_only, privileged_first, skip
""" """
self.assets = self.clean_assets(assets) self.assets = self.clean_assets(assets)
self.account_prefer = account_prefer self.account_prefer = self.get_account_prefer(account_prefer)
self.account_policy = account_policy self.account_policy = account_policy
self.host_callback = host_callback self.host_callback = host_callback
self.exclude_hosts = {} self.exclude_hosts = {}
@ -140,36 +140,51 @@ class JMSInventory:
return host return host
def get_asset_accounts(self, asset): def get_asset_accounts(self, asset):
return list(asset.accounts.filter(is_active=True)) from assets.const import Connectivity
accounts = asset.accounts.filter(is_active=True).order_by('-privileged', '-date_updated')
accounts_connectivity_ok = list(accounts.filter(connectivity=Connectivity.OK))
accounts_connectivity_no = list(accounts.exclude(connectivity=Connectivity.OK))
return accounts_connectivity_ok + accounts_connectivity_no
@staticmethod
def get_account_prefer(account_prefer):
account_usernames = []
if isinstance(account_prefer, str) and account_prefer:
account_usernames = list(map(lambda x: x.lower(), account_prefer.split(',')))
return account_usernames
def get_refer_account(self, accounts):
account = None
if accounts:
account = list(filter(
lambda a: a.username.lower() in self.account_prefer, accounts
))
account = account[0] if account else None
return account
def select_account(self, asset): def select_account(self, asset):
accounts = self.get_asset_accounts(asset) accounts = self.get_asset_accounts(asset)
if not accounts: if not accounts or self.account_policy == 'skip':
return None return None
account_selected = None account_selected = None
account_usernames = self.account_prefer
if isinstance(self.account_prefer, str): # 首先找到特权账号
account_usernames = self.account_prefer.split(',') privileged_accounts = list(filter(lambda account: account.privileged, accounts))
# 优先使用提供的名称
if account_usernames:
account_matched = list(filter(lambda account: account.username in account_usernames, accounts))
account_selected = account_matched[0] if account_matched else None
if account_selected or self.account_policy == 'skip':
return account_selected
# 不同类型的账号选择,优先使用提供的名称
refer_privileged_account = self.get_refer_account(privileged_accounts)
if self.account_policy in ['privileged_only', 'privileged_first']: if self.account_policy in ['privileged_only', 'privileged_first']:
account_matched = list(filter(lambda account: account.privileged, accounts)) first_privileged = privileged_accounts[0] if privileged_accounts else None
account_selected = account_matched[0] if account_matched else None account_selected = refer_privileged_account or first_privileged
if account_selected: # 此策略不管是否匹配到账号都需强制返回
if self.account_policy == 'privileged_only':
return account_selected return account_selected
if self.account_policy == 'privileged_first': if not account_selected:
account_selected = accounts[0] if accounts else None account_selected = self.get_refer_account(accounts)
return account_selected
return account_selected or accounts[0]
def generate(self, path_dir): def generate(self, path_dir):
hosts = [] hosts = []

View File

@ -83,7 +83,7 @@ class CeleryResultApi(generics.RetrieveAPIView):
def get_object(self): def get_object(self):
pk = self.kwargs.get('pk') pk = self.kwargs.get('pk')
return AsyncResult(pk) return AsyncResult(str(pk))
class CeleryPeriodTaskViewSet(CommonApiMixin, viewsets.ModelViewSet): class CeleryPeriodTaskViewSet(CommonApiMixin, viewsets.ModelViewSet):

View File

@ -32,6 +32,15 @@ class PlaybookViewSet(OrgBulkModelViewSet):
model = Playbook model = Playbook
search_fields = ('name', 'comment') search_fields = ('name', 'comment')
def perform_destroy(self, instance):
instance = self.get_object()
if instance.job_set.exists():
raise JMSException(code='playbook_has_job', detail={"msg": _("Currently playbook is being used in a job")})
instance_id = instance.id
super().perform_destroy(instance)
dest_path = os.path.join(settings.DATA_DIR, "ops", "playbook", instance_id.__str__())
shutil.rmtree(dest_path)
def get_queryset(self): def get_queryset(self):
queryset = super().get_queryset() queryset = super().get_queryset()
queryset = queryset.filter(creator=self.request.user) queryset = queryset.filter(creator=self.request.user)
@ -62,10 +71,10 @@ class PlaybookFileBrowserAPIView(APIView):
rbac_perms = () rbac_perms = ()
permission_classes = (RBACPermission,) permission_classes = (RBACPermission,)
rbac_perms = { rbac_perms = {
'GET': 'ops.change_playbooks', 'GET': 'ops.change_playbook',
'POST': 'ops.change_playbooks', 'POST': 'ops.change_playbook',
'DELETE': 'ops.change_playbooks', 'DELETE': 'ops.change_playbook',
'PATCH': 'ops.change_playbooks', 'PATCH': 'ops.change_playbook',
} }
protected_files = ['root', 'main.yml'] protected_files = ['root', 'main.yml']

View File

@ -3,8 +3,10 @@
from celery import shared_task from celery import shared_task
from celery.exceptions import SoftTimeLimitExceeded from celery.exceptions import SoftTimeLimitExceeded
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from django_celery_beat.models import PeriodicTask
from common.utils import get_logger, get_object_or_none from common.utils import get_logger, get_object_or_none
from ops.celery import app
from orgs.utils import tmp_to_org, tmp_to_root_org from orgs.utils import tmp_to_org, tmp_to_root_org
from .celery.decorator import ( from .celery.decorator import (
register_as_period_task, after_app_ready_start register_as_period_task, after_app_ready_start
@ -19,7 +21,7 @@ from .notifications import ServerPerformanceCheckUtil
logger = get_logger(__file__) logger = get_logger(__file__)
def job_task_activity_callback(self, job_id, trigger): def job_task_activity_callback(self, job_id, *args, **kwargs):
job = get_object_or_none(Job, id=job_id) job = get_object_or_none(Job, id=job_id)
if not job: if not job:
return return
@ -48,7 +50,7 @@ def run_ops_job(job_id):
logger.error("Start adhoc execution error: {}".format(e)) logger.error("Start adhoc execution error: {}".format(e))
def job_execution_task_activity_callback(self, execution_id, trigger): def job_execution_task_activity_callback(self, execution_id, *args, **kwargs):
execution = get_object_or_none(JobExecution, id=execution_id) execution = get_object_or_none(JobExecution, id=execution_id)
if not execution: if not execution:
return return
@ -78,16 +80,14 @@ def run_ops_job_execution(execution_id, **kwargs):
@after_app_ready_start @after_app_ready_start
def clean_celery_periodic_tasks(): def clean_celery_periodic_tasks():
"""清除celery定时任务""" """清除celery定时任务"""
need_cleaned_tasks = [ logger.info('Start clean celery periodic tasks.')
'handle_be_interrupted_change_auth_task_periodic', register_tasks = PeriodicTask.objects.all()
] for task in register_tasks:
logger.info('Start clean celery periodic tasks: {}'.format(need_cleaned_tasks)) if task.task in app.tasks:
for task_name in need_cleaned_tasks:
logger.info('Start clean task: {}'.format(task_name))
task = get_celery_periodic_task(task_name)
if task is None:
logger.info('Task does not exist: {}'.format(task_name))
continue continue
task_name = task.name
logger.info('Start clean task: {}'.format(task_name))
disable_celery_periodic_task(task_name) disable_celery_periodic_task(task_name)
delete_celery_periodic_task(task_name) delete_celery_periodic_task(task_name)
task = get_celery_periodic_task(task_name) task = get_celery_periodic_task(task_name)

View File

@ -13,7 +13,7 @@ class CeleryTaskLogView(PermissionsMixin, TemplateView):
template_name = 'ops/celery_task_log.html' template_name = 'ops/celery_task_log.html'
permission_classes = [RBACPermission] permission_classes = [RBACPermission]
rbac_perms = { rbac_perms = {
'GET': 'ops.view_celerytask' 'GET': 'ops.view_celerytaskexecution'
} }
def get_context_data(self, **kwargs): def get_context_data(self, **kwargs):

View File

@ -114,9 +114,7 @@ class OrgResourceStatisticsCache(OrgRelatedCache):
@staticmethod @staticmethod
def compute_total_count_today_active_assets(): def compute_total_count_today_active_assets():
t = local_zero_hour() t = local_zero_hour()
return Session.objects.filter( return Session.objects.filter(date_start__gte=t).values('asset_id').distinct().count()
date_start__gte=t, is_success=False
).values('asset_id').distinct().count()
@staticmethod @staticmethod
def compute_total_count_today_failed_sessions(): def compute_total_count_today_failed_sessions():

View File

@ -102,7 +102,10 @@ def on_post_delete_refresh_org_resource_statistics_cache(sender, instance, **kwa
def _refresh_session_org_resource_statistics_cache(instance: Session): def _refresh_session_org_resource_statistics_cache(instance: Session):
cache_field_name = ['total_count_online_users', 'total_count_online_sessions', 'total_count_today_failed_sessions'] cache_field_name = [
'total_count_online_users', 'total_count_online_sessions',
'total_count_today_active_assets','total_count_today_failed_sessions'
]
org_cache = OrgResourceStatisticsCache(instance.org) org_cache = OrgResourceStatisticsCache(instance.org)
org_cache.expire(*cache_field_name) org_cache.expire(*cache_field_name)

View File

@ -30,6 +30,12 @@ class BaseUserPermedAssetsApi(SelfOrPKUserMixin, ListAPIView):
filterset_class = AssetFilterSet filterset_class = AssetFilterSet
serializer_class = serializers.AssetPermedSerializer serializer_class = serializers.AssetPermedSerializer
def get_serializer_class(self):
serializer_class = super().get_serializer_class()
if self.request.query_params.get('id'):
serializer_class = serializers.AssetPermedDetailSerializer
return serializer_class
def get_queryset(self): def get_queryset(self):
if getattr(self, 'swagger_fake_view', False): if getattr(self, 'swagger_fake_view', False):
return Asset.objects.none() return Asset.objects.none()

View File

@ -23,6 +23,7 @@ def migrate_app_perms_to_assets(apps, schema_editor):
asset_permission = asset_permission_model() asset_permission = asset_permission_model()
for attr in attrs: for attr in attrs:
setattr(asset_permission, attr, getattr(app_perm, attr)) setattr(asset_permission, attr, getattr(app_perm, attr))
asset_permission.name = f"App-{app_perm.name}"
asset_permissions.append(asset_permission) asset_permissions.append(asset_permission)
asset_permission_model.objects.bulk_create(asset_permissions, ignore_conflicts=True) asset_permission_model.objects.bulk_create(asset_permissions, ignore_conflicts=True)

View File

@ -9,11 +9,11 @@ def migrate_system_user_to_accounts(apps, schema_editor):
bulk_size = 10000 bulk_size = 10000
while True: while True:
asset_permissions = asset_permission_model.objects \ asset_permissions = asset_permission_model.objects \
.prefetch_related('system_users')[count:bulk_size] .prefetch_related('system_users')[count:bulk_size]
if not asset_permissions: if not asset_permissions:
break break
count += len(asset_permissions)
count += len(asset_permissions)
updated = [] updated = []
for asset_permission in asset_permissions: for asset_permission in asset_permissions:
asset_permission.accounts = [s.username for s in asset_permission.system_users.all()] asset_permission.accounts = [s.username for s in asset_permission.system_users.all()]

View File

@ -1,10 +1,12 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
from django.db.models import Q from django.db.models import Q, QuerySet
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework import serializers from rest_framework import serializers
from accounts.models import AccountTemplate, Account
from accounts.tasks import push_accounts_to_assets_task
from assets.models import Asset, Node from assets.models import Asset, Node
from common.serializers.fields import BitChoicesField, ObjectRelatedField from common.serializers.fields import BitChoicesField, ObjectRelatedField
from orgs.mixins.serializers import BulkOrgResourceModelSerializer from orgs.mixins.serializers import BulkOrgResourceModelSerializer
@ -18,6 +20,12 @@ class ActionChoicesField(BitChoicesField):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(choice_cls=ActionChoices, **kwargs) super().__init__(choice_cls=ActionChoices, **kwargs)
def to_file_representation(self, value):
return [v['value'] for v in value]
def to_file_internal_value(self, data):
return data
class AssetPermissionSerializer(BulkOrgResourceModelSerializer): class AssetPermissionSerializer(BulkOrgResourceModelSerializer):
users = ObjectRelatedField(queryset=User.objects, many=True, required=False, label=_('User')) users = ObjectRelatedField(queryset=User.objects, many=True, required=False, label=_('User'))
@ -31,6 +39,8 @@ class AssetPermissionSerializer(BulkOrgResourceModelSerializer):
is_expired = serializers.BooleanField(read_only=True, label=_("Is expired")) is_expired = serializers.BooleanField(read_only=True, label=_("Is expired"))
accounts = serializers.ListField(label=_("Account"), required=False) accounts = serializers.ListField(label=_("Account"), required=False)
template_accounts = AccountTemplate.objects.none()
class Meta: class Meta:
model = AssetPermission model = AssetPermission
fields_mini = ["id", "name"] fields_mini = ["id", "name"]
@ -73,8 +83,55 @@ class AssetPermissionSerializer(BulkOrgResourceModelSerializer):
actions.default = list(actions.choices.keys()) actions.default = list(actions.choices.keys())
@staticmethod @staticmethod
def validate_accounts(accounts): def get_all_assets(nodes, assets):
return list(set(accounts)) node_asset_ids = Node.get_nodes_all_assets(*nodes).values_list('id', flat=True)
direct_asset_ids = [asset.id for asset in assets]
asset_ids = set(direct_asset_ids + list(node_asset_ids))
return Asset.objects.filter(id__in=asset_ids)
def create_accounts(self, assets):
need_create_accounts = []
account_attribute = [
'name', 'username', 'secret_type', 'secret', 'privileged', 'is_active', 'org_id'
]
for asset in assets:
asset_exist_accounts = Account.objects.none()
for template in self.template_accounts:
asset_exist_accounts |= asset.accounts.filter(
username=template.username,
secret_type=template.secret_type,
)
username_secret_type_dict = asset_exist_accounts.values('username', 'secret_type')
for template in self.template_accounts:
condition = {
'username': template.username,
'secret_type': template.secret_type
}
if condition in username_secret_type_dict:
continue
account_data = {key: getattr(template, key) for key in account_attribute}
account_data['name'] = f"{account_data['name']}-clone"
need_create_accounts.append(Account(**{'asset_id': asset.id, **account_data}))
return Account.objects.bulk_create(need_create_accounts)
def create_and_push_account(self, nodes, assets):
if not self.template_accounts:
return
assets = self.get_all_assets(nodes, assets)
accounts = self.create_accounts(assets)
push_accounts_to_assets_task.delay([str(account.id) for account in accounts])
def validate_accounts(self, usernames: list[str]):
template_ids = []
account_usernames = []
for username in usernames:
if username.startswith('%'):
template_ids.append(username[1:])
else:
account_usernames.append(username)
self.template_accounts = AccountTemplate.objects.filter(id__in=template_ids)
template_usernames = list(self.template_accounts.values_list('username', flat=True))
return list(set(account_usernames + template_usernames))
@classmethod @classmethod
def setup_eager_loading(cls, queryset): def setup_eager_loading(cls, queryset):
@ -112,6 +169,13 @@ class AssetPermissionSerializer(BulkOrgResourceModelSerializer):
).distinct() ).distinct()
instance.nodes.add(*nodes_to_set) instance.nodes.add(*nodes_to_set)
def validate(self, attrs):
self.create_and_push_account(
attrs.get("nodes", []),
attrs.get("assets", [])
)
return super().validate(attrs)
def create(self, validated_data): def create(self, validated_data):
display = { display = {
"users_display": validated_data.pop("users_display", ""), "users_display": validated_data.pop("users_display", ""),

View File

@ -15,7 +15,7 @@ from perms.serializers.permission import ActionChoicesField
__all__ = [ __all__ = [
'NodePermedSerializer', 'AssetPermedSerializer', 'NodePermedSerializer', 'AssetPermedSerializer',
'AccountsPermedSerializer' 'AssetPermedDetailSerializer', 'AccountsPermedSerializer'
] ]
@ -46,6 +46,12 @@ class AssetPermedSerializer(OrgResourceModelSerializerMixin):
return queryset return queryset
class AssetPermedDetailSerializer(AssetPermedSerializer):
class Meta(AssetPermedSerializer.Meta):
fields = AssetPermedSerializer.Meta.fields + ['spec_info']
read_only_fields = fields
class NodePermedSerializer(serializers.ModelSerializer): class NodePermedSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = Node model = Node

View File

@ -1,5 +1,6 @@
from collections import defaultdict from collections import defaultdict
from orgs.utils import tmp_to_org
from accounts.models import Account from accounts.models import Account
from accounts.const import AliasAccount from accounts.const import AliasAccount
from .permission import AssetPermissionUtil from .permission import AssetPermissionUtil
@ -16,10 +17,11 @@ class PermAccountUtil(AssetPermissionUtil):
:param asset: Asset :param asset: Asset
:param account_name: 可能是 @USER @INPUT 字符串 :param account_name: 可能是 @USER @INPUT 字符串
""" """
permed_accounts = self.get_permed_accounts_for_user(user, asset) with tmp_to_org(asset.org):
accounts_mapper = {account.alias: account for account in permed_accounts} permed_accounts = self.get_permed_accounts_for_user(user, asset)
account = accounts_mapper.get(account_name) accounts_mapper = {account.alias: account for account in permed_accounts}
return account account = accounts_mapper.get(account_name)
return account
def get_permed_accounts_for_user(self, user, asset): def get_permed_accounts_for_user(self, user, asset):
""" 获取授权给用户某个资产的账号 """ """ 获取授权给用户某个资产的账号 """

View File

@ -18,14 +18,19 @@ user_perms = (
('assets', 'asset', 'match', 'asset'), ('assets', 'asset', 'match', 'asset'),
('assets', 'systemuser', 'match', 'systemuser'), ('assets', 'systemuser', 'match', 'systemuser'),
('assets', 'node', 'match', 'node'), ('assets', 'node', 'match', 'node'),
("ops", "adhoc", "*", "*"),
("ops", "playbook", "*", "*"),
("ops", "job", "*", "*"),
("ops", "jobexecution", "*", "*"),
("ops", "celerytaskexecution", "view", "*"),
) )
system_user_perms = ( system_user_perms = (
('authentication', 'connectiontoken', 'add,change,view', 'connectiontoken'), ('authentication', 'connectiontoken', 'add,change,view', 'connectiontoken'),
('authentication', 'temptoken', 'add,change,view', 'temptoken'), ('authentication', 'temptoken', 'add,change,view', 'temptoken'),
('authentication', 'accesskey', '*', '*'), ('authentication', 'accesskey', '*', '*'),
('tickets', 'ticket', 'view', 'ticket'), ('tickets', 'ticket', 'view', 'ticket'),
) + user_perms + _view_all_joined_org_perms ) + user_perms + _view_all_joined_org_perms
_auditor_perms = ( _auditor_perms = (
('rbac', 'menupermission', 'view', 'audit'), ('rbac', 'menupermission', 'view', 'audit'),
@ -41,7 +46,6 @@ auditor_perms = user_perms + _auditor_perms
system_auditor_perms = system_user_perms + _auditor_perms + _view_root_perms system_auditor_perms = system_user_perms + _auditor_perms + _view_root_perms
app_exclude_perms = [ app_exclude_perms = [
('users', 'user', 'add,delete', 'user'), ('users', 'user', 'add,delete', 'user'),
('orgs', 'org', 'add,delete,change', 'org'), ('orgs', 'org', 'add,delete,change', 'org'),

View File

@ -135,7 +135,7 @@ only_system_permissions = (
('xpack', 'license', '*', '*'), ('xpack', 'license', '*', '*'),
('settings', 'setting', '*', '*'), ('settings', 'setting', '*', '*'),
('tickets', '*', '*', '*'), ('tickets', '*', '*', '*'),
('ops', 'task', 'view', 'taskmonitor'), ('ops', 'celerytask', 'view', 'taskmonitor'),
('terminal', 'terminal', '*', '*'), ('terminal', 'terminal', '*', '*'),
('terminal', 'commandstorage', '*', '*'), ('terminal', 'commandstorage', '*', '*'),
('terminal', 'replaystorage', '*', '*'), ('terminal', 'replaystorage', '*', '*'),

View File

@ -97,13 +97,13 @@ class RBACPermission(permissions.DjangoModelPermissions):
else: else:
model_cls = queryset.model model_cls = queryset.model
except AssertionError as e: except AssertionError as e:
logger.error(f'Error get model cls: {e}') # logger.error(f'Error get model cls: {e}')
model_cls = None model_cls = None
except AttributeError as e: except AttributeError as e:
logger.error(f'Error get model cls: {e}') # logger.error(f'Error get model cls: {e}')
model_cls = None model_cls = None
except Exception as e: except Exception as e:
logger.error('Error get model class: {} of {}'.format(e, view)) # logger.error('Error get model class: {} of {}'.format(e, view))
raise e raise e
return model_cls return model_cls

View File

@ -42,7 +42,7 @@ class MailTestingAPI(APIView):
# if k.startswith('EMAIL'): # if k.startswith('EMAIL'):
# setattr(settings, k, v) # setattr(settings, k, v)
try: try:
subject = settings.EMAIL_SUBJECT_PREFIX + "Test" subject = settings.EMAIL_SUBJECT_PREFIX or '' + "Test"
message = "Test smtp setting" message = "Test smtp setting"
email_from = email_from or email_host_user email_from = email_from or email_host_user
email_recipient = email_recipient or email_from email_recipient = email_recipient or email_from

View File

@ -9,6 +9,13 @@ __all__ = ['FeiShuSettingSerializer']
class FeiShuSettingSerializer(serializers.Serializer): class FeiShuSettingSerializer(serializers.Serializer):
PREFIX_TITLE = _('FeiShu') PREFIX_TITLE = _('FeiShu')
VERSION_CHOICES = (
('feishu', _('FeiShu')),
('lark', 'Lark')
)
AUTH_FEISHU = serializers.BooleanField(default=False, label=_('Enable FeiShu Auth'))
FEISHU_APP_ID = serializers.CharField(max_length=256, required=True, label='App ID') FEISHU_APP_ID = serializers.CharField(max_length=256, required=True, label='App ID')
FEISHU_APP_SECRET = EncryptedField(max_length=256, required=False, label='App Secret') FEISHU_APP_SECRET = EncryptedField(max_length=256, required=False, label='App Secret')
AUTH_FEISHU = serializers.BooleanField(default=False, label=_('Enable FeiShu Auth')) FEISHU_VERSION = serializers.ChoiceField(
choices=VERSION_CHOICES, default='feishu', label=_('Version')
)

View File

@ -74,9 +74,9 @@ class LDAPSettingSerializer(serializers.Serializer):
) )
AUTH_LDAP_CONNECT_TIMEOUT = serializers.IntegerField( AUTH_LDAP_CONNECT_TIMEOUT = serializers.IntegerField(
min_value=1, max_value=300, min_value=1, max_value=300,
required=False, label=_('Connect timeout'), required=False, label=_('Connect timeout (s)'),
) )
AUTH_LDAP_SEARCH_PAGED_SIZE = serializers.IntegerField(required=False, label=_('Search paged size')) AUTH_LDAP_SEARCH_PAGED_SIZE = serializers.IntegerField(required=False, label=_('Search paged size (piece)'))
AUTH_LDAP = serializers.BooleanField(required=False, label=_('Enable LDAP auth')) AUTH_LDAP = serializers.BooleanField(required=False, label=_('Enable LDAP auth'))

View File

@ -87,7 +87,7 @@ class OIDCSettingSerializer(KeycloakSettingSerializer):
) )
AUTH_OPENID_SCOPES = serializers.CharField(required=False, max_length=1024, label=_('Scopes')) AUTH_OPENID_SCOPES = serializers.CharField(required=False, max_length=1024, label=_('Scopes'))
AUTH_OPENID_ID_TOKEN_MAX_AGE = serializers.IntegerField( AUTH_OPENID_ID_TOKEN_MAX_AGE = serializers.IntegerField(
required=False, label=_('Id token max age') required=False, label=_('Id token max age (s)')
) )
AUTH_OPENID_ID_TOKEN_INCLUDE_CLAIMS = serializers.BooleanField( AUTH_OPENID_ID_TOKEN_INCLUDE_CLAIMS = serializers.BooleanField(
required=False, label=_('Id token include claims') required=False, label=_('Id token include claims')

View File

@ -18,6 +18,10 @@
margin: 0 auto; margin: 0 auto;
padding: 100px 20px 20px 20px; padding: 100px 20px 20px 20px;
} }
.ibox-content {
padding: 30px;
}
</style> </style>
{% block custom_head_css_js %} {% endblock %} {% block custom_head_css_js %} {% endblock %}
</head> </head>
@ -30,7 +34,7 @@
<h2 class="font-bold"> <h2 class="font-bold">
{% block title %}{% endblock %} {% block title %}{% endblock %}
</h2> </h2>
<div style="margin: 10px 0"> <div style="margin: 20px 0 0 0">
{% block content %} {% endblock %} {% block content %} {% endblock %}
</div> </div>
</div> </div>

Some files were not shown because too many files have changed in this diff Show More