Merge pull request #12734 from jumpserver/master

v3.10.4 (branch-v3.10)
This commit is contained in:
Bryan 2024-02-29 16:39:55 +08:00 committed by GitHub
commit eedc2f1b41
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
122 changed files with 2611 additions and 1422 deletions

View File

@ -19,11 +19,11 @@ ARG BUILD_DEPENDENCIES=" \
ARG DEPENDENCIES=" \ ARG DEPENDENCIES=" \
freetds-dev \ freetds-dev \
libpq-dev \
libffi-dev \ libffi-dev \
libjpeg-dev \ libjpeg-dev \
libkrb5-dev \ libkrb5-dev \
libldap2-dev \ libldap2-dev \
libpq-dev \
libsasl2-dev \ libsasl2-dev \
libssl-dev \ libssl-dev \
libxml2-dev \ libxml2-dev \
@ -75,6 +75,7 @@ ENV LANG=zh_CN.UTF-8 \
ARG DEPENDENCIES=" \ ARG DEPENDENCIES=" \
libjpeg-dev \ libjpeg-dev \
libpq-dev \
libx11-dev \ libx11-dev \
freerdp2-dev \ freerdp2-dev \
libxmlsec1-openssl" libxmlsec1-openssl"

View File

@ -1,11 +1,12 @@
from django.db.models import Q
from rest_framework.generics import CreateAPIView from rest_framework.generics import CreateAPIView
from accounts import serializers from accounts import serializers
from accounts.models import Account
from accounts.permissions import AccountTaskActionPermission from accounts.permissions import AccountTaskActionPermission
from accounts.tasks import ( from accounts.tasks import (
remove_accounts_task, verify_accounts_connectivity_task, push_accounts_to_assets_task remove_accounts_task, verify_accounts_connectivity_task, push_accounts_to_assets_task
) )
from assets.exceptions import NotSupportedTemporarilyError
from authentication.permissions import UserConfirmation, ConfirmType from authentication.permissions import UserConfirmation, ConfirmType
__all__ = [ __all__ = [
@ -26,25 +27,35 @@ class AccountsTaskCreateAPI(CreateAPIView):
] ]
return super().get_permissions() return super().get_permissions()
def perform_create(self, serializer): @staticmethod
data = serializer.validated_data def get_account_ids(data, action):
accounts = data.get('accounts', []) account_type = 'gather_accounts' if action == 'remove' else 'accounts'
params = data.get('params') accounts = data.get(account_type, [])
account_ids = [str(a.id) for a in accounts] account_ids = [str(a.id) for a in accounts]
if data['action'] == 'push': if action == 'remove':
task = push_accounts_to_assets_task.delay(account_ids, params) return account_ids
elif data['action'] == 'remove':
gather_accounts = data.get('gather_accounts', []) assets = data.get('assets', [])
gather_account_ids = [str(a.id) for a in gather_accounts] asset_ids = [str(a.id) for a in assets]
task = remove_accounts_task.delay(gather_account_ids) ids = Account.objects.filter(
Q(id__in=account_ids) | Q(asset_id__in=asset_ids)
).distinct().values_list('id', flat=True)
return [str(_id) for _id in ids]
def perform_create(self, serializer):
data = serializer.validated_data
action = data['action']
ids = self.get_account_ids(data, action)
if action == 'push':
task = push_accounts_to_assets_task.delay(ids, data.get('params'))
elif action == 'remove':
task = remove_accounts_task.delay(ids)
elif action == 'verify':
task = verify_accounts_connectivity_task.delay(ids)
else: else:
account = accounts[0] raise ValueError(f"Invalid action: {action}")
asset = account.asset
if not asset.auto_config['ansible_enabled'] or \
not asset.auto_config['ping_enabled']:
raise NotSupportedTemporarilyError()
task = verify_accounts_connectivity_task.delay(account_ids)
data = getattr(serializer, '_data', {}) data = getattr(serializer, '_data', {})
data["task"] = task.id data["task"] = task.id

View File

@ -168,9 +168,8 @@ class AccountBackupHandler:
if not user.secret_key: if not user.secret_key:
attachment_list = [] attachment_list = []
else: else:
password = user.secret_key.encode('utf8')
attachment = os.path.join(PATH, f'{plan_name}-{local_now_filename()}-{time.time()}.zip') attachment = os.path.join(PATH, f'{plan_name}-{local_now_filename()}-{time.time()}.zip')
encrypt_and_compress_zip_file(attachment, password, files) encrypt_and_compress_zip_file(attachment, user.secret_key, files)
attachment_list = [attachment, ] attachment_list = [attachment, ]
AccountBackupExecutionTaskMsg(plan_name, user).publish(attachment_list) AccountBackupExecutionTaskMsg(plan_name, user).publish(attachment_list)
print('邮件已发送至{}({})'.format(user, user.email)) print('邮件已发送至{}({})'.format(user, user.email))
@ -191,7 +190,6 @@ class AccountBackupHandler:
attachment = os.path.join(PATH, f'{plan_name}-{local_now_filename()}-{time.time()}.zip') attachment = os.path.join(PATH, f'{plan_name}-{local_now_filename()}-{time.time()}.zip')
if password: if password:
print('\033[32m>>> 使用加密密码对文件进行加密中\033[0m') print('\033[32m>>> 使用加密密码对文件进行加密中\033[0m')
password = password.encode('utf8')
encrypt_and_compress_zip_file(attachment, password, files) encrypt_and_compress_zip_file(attachment, password, files)
else: else:
zip_files(attachment, files) zip_files(attachment, files)

View File

@ -7,6 +7,7 @@ type:
- all - all
method: change_secret method: change_secret
protocol: ssh protocol: ssh
priority: 50
params: params:
- name: commands - name: commands
type: list type: list

View File

@ -39,3 +39,4 @@
login_host: "{{ jms_asset.address }}" login_host: "{{ jms_asset.address }}"
login_port: "{{ jms_asset.port }}" login_port: "{{ jms_asset.port }}"
login_database: "{{ jms_asset.spec_info.db_name }}" login_database: "{{ jms_asset.spec_info.db_name }}"
mode: "{{ account.mode }}"

View File

@ -5,6 +5,7 @@ method: change_secret
category: host category: host
type: type:
- windows - windows
priority: 49
params: params:
- name: groups - name: groups
type: str type: str

View File

@ -4,6 +4,7 @@ from copy import deepcopy
from django.conf import settings from django.conf import settings
from django.utils import timezone from django.utils import timezone
from django.utils.translation import gettext_lazy as _
from xlsxwriter import Workbook from xlsxwriter import Workbook
from accounts.const import AutomationTypes, SecretType, SSHKeyStrategy, SecretStrategy from accounts.const import AutomationTypes, SecretType, SSHKeyStrategy, SecretStrategy
@ -118,6 +119,10 @@ class ChangeSecretManager(AccountBasePlaybookManager):
else: else:
new_secret = self.get_secret(secret_type) new_secret = self.get_secret(secret_type)
if new_secret is None:
print(f'new_secret is None, account: {account}')
continue
if self.record_id is None: if self.record_id is None:
recorder = ChangeSecretRecord( recorder = ChangeSecretRecord(
asset=asset, account=account, execution=self.execution, asset=asset, account=account, execution=self.execution,
@ -183,17 +188,33 @@ class ChangeSecretManager(AccountBasePlaybookManager):
return False return False
return True return True
@staticmethod
def get_summary(recorders):
total, succeed, failed = 0, 0, 0
for recorder in recorders:
if recorder.status == 'success':
succeed += 1
else:
failed += 1
total += 1
summary = _('Success: %s, Failed: %s, Total: %s') % (succeed, failed, total)
return summary
def run(self, *args, **kwargs): def run(self, *args, **kwargs):
if self.secret_type and not self.check_secret(): if self.secret_type and not self.check_secret():
return return
super().run(*args, **kwargs) super().run(*args, **kwargs)
recorders = list(self.name_recorder_mapper.values())
summary = self.get_summary(recorders)
print(summary, end='')
if self.record_id: if self.record_id:
return return
recorders = self.name_recorder_mapper.values()
recorders = list(recorders)
self.send_recorder_mail(recorders)
def send_recorder_mail(self, recorders): self.send_recorder_mail(recorders, summary)
def send_recorder_mail(self, recorders, summary):
recipients = self.execution.recipients recipients = self.execution.recipients
if not recorders or not recipients: if not recorders or not recipients:
return return
@ -209,11 +230,10 @@ class ChangeSecretManager(AccountBasePlaybookManager):
for user in recipients: for user in recipients:
attachments = [] attachments = []
if user.secret_key: if user.secret_key:
password = user.secret_key.encode('utf8')
attachment = os.path.join(path, f'{name}-{local_now_filename()}-{time.time()}.zip') attachment = os.path.join(path, f'{name}-{local_now_filename()}-{time.time()}.zip')
encrypt_and_compress_zip_file(attachment, password, [filename]) encrypt_and_compress_zip_file(attachment, user.secret_key, [filename])
attachments = [attachment] attachments = [attachment]
ChangeSecretExecutionTaskMsg(name, user).publish(attachments) ChangeSecretExecutionTaskMsg(name, user, summary).publish(attachments)
os.remove(filename) os.remove(filename)
@staticmethod @staticmethod

View File

@ -1,9 +1,10 @@
- hosts: demo - hosts: demo
gather_facts: no gather_facts: no
tasks: tasks:
- name: Gather posix account - name: Gather windows account
ansible.builtin.win_shell: net user ansible.builtin.win_shell: net user
register: result register: result
ignore_errors: true
- name: Define info by set_fact - name: Define info by set_fact
ansible.builtin.set_fact: ansible.builtin.set_fact:

View File

@ -39,3 +39,4 @@
login_host: "{{ jms_asset.address }}" login_host: "{{ jms_asset.address }}"
login_port: "{{ jms_asset.port }}" login_port: "{{ jms_asset.port }}"
login_database: "{{ jms_asset.spec_info.db_name }}" login_database: "{{ jms_asset.spec_info.db_name }}"
mode: "{{ account.mode }}"

View File

@ -5,6 +5,7 @@ method: push_account
category: host category: host
type: type:
- windows - windows
priority: 49
params: params:
- name: groups - name: groups
type: str type: str

View File

@ -6,6 +6,7 @@ type:
- windows - windows
method: verify_account method: verify_account
protocol: rdp protocol: rdp
priority: 1
i18n: i18n:
Windows rdp account verify: Windows rdp account verify:

View File

@ -7,6 +7,7 @@ type:
- all - all
method: verify_account method: verify_account
protocol: ssh protocol: ssh
priority: 50
i18n: i18n:
SSH account verify: SSH account verify:

View File

@ -51,6 +51,9 @@ class VerifyAccountManager(AccountBasePlaybookManager):
h['name'] += '(' + account.username + ')' h['name'] += '(' + account.username + ')'
self.host_account_mapper[h['name']] = account self.host_account_mapper[h['name']] = account
secret = account.secret secret = account.secret
if secret is None:
print(f'account {account.name} secret is None')
continue
private_key_path = None private_key_path = None
if account.secret_type == SecretType.SSH_KEY: if account.secret_type == SecretType.SSH_KEY:
@ -62,7 +65,7 @@ class VerifyAccountManager(AccountBasePlaybookManager):
'name': account.name, 'name': account.name,
'username': account.username, 'username': account.username,
'secret_type': account.secret_type, 'secret_type': account.secret_type,
'secret': account.escape_jinja2_syntax(secret), 'secret': account.escape_jinja2_syntax(secret),
'private_key_path': private_key_path, 'private_key_path': private_key_path,
'become': account.get_ansible_become_auth(), 'become': account.get_ansible_become_auth(),
} }

View File

@ -52,6 +52,7 @@ class AccountFilterSet(BaseFilterSet):
class GatheredAccountFilterSet(BaseFilterSet): class GatheredAccountFilterSet(BaseFilterSet):
node_id = drf_filters.CharFilter(method='filter_nodes') node_id = drf_filters.CharFilter(method='filter_nodes')
asset_id = drf_filters.CharFilter(field_name='asset_id', lookup_expr='exact') asset_id = drf_filters.CharFilter(field_name='asset_id', lookup_expr='exact')
asset_name = drf_filters.CharFilter(field_name='asset__name', lookup_expr='icontains')
@staticmethod @staticmethod
def filter_nodes(queryset, name, value): def filter_nodes(queryset, name, value):

View File

@ -54,20 +54,23 @@ class AccountBackupByObjStorageExecutionTaskMsg(object):
class ChangeSecretExecutionTaskMsg(object): class ChangeSecretExecutionTaskMsg(object):
subject = _('Notification of implementation result of encryption change plan') subject = _('Notification of implementation result of encryption change plan')
def __init__(self, name: str, user: User): def __init__(self, name: str, user: User, summary):
self.name = name self.name = name
self.user = user self.user = user
self.summary = summary
@property @property
def message(self): def message(self):
name = self.name name = self.name
if self.user.secret_key: if self.user.secret_key:
return _('{} - The encryption change task has been completed. ' default_message = _('{} - The encryption change task has been completed. '
'See the attachment for details').format(name) 'See the attachment for details').format(name)
else: else:
return _("{} - The encryption change task has been completed: the encryption " default_message = _("{} - The encryption change task has been completed: the encryption "
"password has not been set - please go to personal information -> " "password has not been set - please go to personal information -> "
"file encryption password to set the encryption password").format(name) "set encryption password in preferences").format(name)
return self.summary + '\n' + default_message
def publish(self, attachments=None): def publish(self, attachments=None):
send_mail_attachment_async( send_mail_attachment_async(

View File

@ -58,7 +58,7 @@ class AccountCreateUpdateSerializerMixin(serializers.Serializer):
for data in initial_data: for data in initial_data:
if not data.get('asset') and not self.instance: if not data.get('asset') and not self.instance:
raise serializers.ValidationError({'asset': UniqueTogetherValidator.missing_message}) raise serializers.ValidationError({'asset': UniqueTogetherValidator.missing_message})
asset = data.get('asset') or self.instance.asset asset = data.get('asset') or getattr(self.instance, 'asset', None)
self.from_template_if_need(data) self.from_template_if_need(data)
self.set_uniq_name_if_need(data, asset) self.set_uniq_name_if_need(data, asset)
@ -455,12 +455,14 @@ class AccountHistorySerializer(serializers.ModelSerializer):
class AccountTaskSerializer(serializers.Serializer): class AccountTaskSerializer(serializers.Serializer):
ACTION_CHOICES = ( ACTION_CHOICES = (
('test', 'test'),
('verify', 'verify'), ('verify', 'verify'),
('push', 'push'), ('push', 'push'),
('remove', 'remove'), ('remove', 'remove'),
) )
action = serializers.ChoiceField(choices=ACTION_CHOICES, write_only=True) action = serializers.ChoiceField(choices=ACTION_CHOICES, write_only=True)
assets = serializers.PrimaryKeyRelatedField(
queryset=Asset.objects, required=False, allow_empty=True, many=True
)
accounts = serializers.PrimaryKeyRelatedField( accounts = serializers.PrimaryKeyRelatedField(
queryset=Account.objects, required=False, allow_empty=True, many=True queryset=Account.objects, required=False, allow_empty=True, many=True
) )

View File

@ -63,7 +63,7 @@ def create_accounts_activities(account, action='create'):
def on_account_create_by_template(sender, instance, created=False, **kwargs): def on_account_create_by_template(sender, instance, created=False, **kwargs):
if not created or instance.source != 'template': if not created or instance.source != 'template':
return return
push_accounts_if_need(accounts=(instance,)) push_accounts_if_need.delay(accounts=(instance,))
create_accounts_activities(instance, action='create') create_accounts_activities(instance, action='create')

View File

@ -41,21 +41,21 @@ class UserLoginReminderMsg(UserMessage):
class AssetLoginReminderMsg(UserMessage): class AssetLoginReminderMsg(UserMessage):
subject = _('Asset login reminder') subject = _('Asset login reminder')
def __init__(self, user, asset: Asset, login_user: User, account_username): def __init__(self, user, asset: Asset, login_user: User, account: Account, input_username):
self.asset = asset self.asset = asset
self.login_user = login_user self.login_user = login_user
self.account_username = account_username self.account = account
self.input_username = input_username
super().__init__(user) super().__init__(user)
def get_html_msg(self) -> dict: def get_html_msg(self) -> dict:
account = Account.objects.get(asset=self.asset, username=self.account_username)
context = { context = {
'recipient': self.user, 'recipient': self.user,
'username': self.login_user.username, 'username': self.login_user.username,
'name': self.login_user.name, 'name': self.login_user.name,
'asset': str(self.asset), 'asset': str(self.asset),
'account': self.account_username, 'account': self.input_username,
'account_name': account.name, 'account_name': self.account.name,
} }
message = render_to_string('acls/asset_login_reminder.html', context) message = render_to_string('acls/asset_login_reminder.html', context)

View File

@ -92,6 +92,7 @@ class AssetViewSet(SuggestionMixin, OrgBulkModelViewSet):
model = Asset model = Asset
filterset_class = AssetFilterSet filterset_class = AssetFilterSet
search_fields = ("name", "address", "comment") search_fields = ("name", "address", "comment")
ordering = ('name',)
ordering_fields = ('name', 'address', 'connectivity', 'platform', 'date_updated', 'date_created') ordering_fields = ('name', 'address', 'connectivity', 'platform', 'date_updated', 'date_created')
serializer_classes = ( serializer_classes = (
("default", serializers.AssetSerializer), ("default", serializers.AssetSerializer),

View File

@ -48,7 +48,7 @@ class AssetPermUserListApi(BaseAssetPermUserOrUserGroupListApi):
def get_queryset(self): def get_queryset(self):
perms = self.get_asset_related_perms() perms = self.get_asset_related_perms()
users = User.objects.filter( users = User.get_queryset().filter(
Q(assetpermissions__in=perms) | Q(groups__assetpermissions__in=perms) Q(assetpermissions__in=perms) | Q(groups__assetpermissions__in=perms)
).distinct() ).distinct()
return users return users

View File

@ -1,2 +1,2 @@
from .endpoint import ExecutionManager from .endpoint import ExecutionManager
from .methods import platform_automation_methods, filter_platform_methods from .methods import platform_automation_methods, filter_platform_methods, sorted_methods

View File

@ -68,6 +68,10 @@ def filter_platform_methods(category, tp_name, method=None, methods=None):
return methods return methods
def sorted_methods(methods):
return sorted(methods, key=lambda x: x.get('priority', 10))
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) BASE_DIR = os.path.dirname(os.path.abspath(__file__))
platform_automation_methods = get_platform_automation_methods(BASE_DIR) platform_automation_methods = get_platform_automation_methods(BASE_DIR)

View File

@ -7,6 +7,7 @@ type:
- windows - windows
method: ping method: ping
protocol: rdp protocol: rdp
priority: 1
i18n: i18n:
Ping by pyfreerdp: Ping by pyfreerdp:

View File

@ -7,6 +7,7 @@ type:
- all - all
method: ping method: ping
protocol: ssh protocol: ssh
priority: 50
i18n: i18n:
Ping by paramiko: Ping by paramiko:

View File

@ -90,7 +90,7 @@ class AllTypes(ChoicesMixin):
@classmethod @classmethod
def set_automation_methods(cls, category, tp_name, constraints): def set_automation_methods(cls, category, tp_name, constraints):
from assets.automations import filter_platform_methods from assets.automations import filter_platform_methods, sorted_methods
automation = constraints.get('automation', {}) automation = constraints.get('automation', {})
automation_methods = {} automation_methods = {}
platform_automation_methods = cls.get_automation_methods() platform_automation_methods = cls.get_automation_methods()
@ -101,6 +101,7 @@ class AllTypes(ChoicesMixin):
methods = filter_platform_methods( methods = filter_platform_methods(
category, tp_name, item_name, methods=platform_automation_methods category, tp_name, item_name, methods=platform_automation_methods
) )
methods = sorted_methods(methods)
methods = [{'name': m['name'], 'id': m['id']} for m in methods] methods = [{'name': m['name'], 'id': m['id']} for m in methods]
automation_methods[item_name + '_methods'] = methods automation_methods[item_name + '_methods'] = methods
automation.update(automation_methods) automation.update(automation_methods)

View File

@ -12,6 +12,6 @@ class Migration(migrations.Migration):
operations = [ operations = [
migrations.AlterModelOptions( migrations.AlterModelOptions(
name='asset', name='asset',
options={'ordering': ['name'], 'permissions': [('refresh_assethardwareinfo', 'Can refresh asset hardware info'), ('test_assetconnectivity', 'Can test asset connectivity'), ('match_asset', 'Can match asset'), ('change_assetnodes', 'Can change asset nodes')], 'verbose_name': 'Asset'}, options={'ordering': [], 'permissions': [('refresh_assethardwareinfo', 'Can refresh asset hardware info'), ('test_assetconnectivity', 'Can test asset connectivity'), ('match_asset', 'Can match asset'), ('change_assetnodes', 'Can change asset nodes')], 'verbose_name': 'Asset'},
), ),
] ]

View File

@ -348,7 +348,7 @@ class Asset(NodesRelationMixin, LabeledMixin, AbsConnectivity, JSONFilterMixin,
class Meta: class Meta:
unique_together = [('org_id', 'name')] unique_together = [('org_id', 'name')]
verbose_name = _("Asset") verbose_name = _("Asset")
ordering = ["name", ] ordering = []
permissions = [ permissions = [
('refresh_assethardwareinfo', _('Can refresh asset hardware info')), ('refresh_assethardwareinfo', _('Can refresh asset hardware info')),
('test_assetconnectivity', _('Can test asset connectivity')), ('test_assetconnectivity', _('Can test asset connectivity')),

View File

@ -429,7 +429,7 @@ class NodeAssetsMixin(NodeAllAssetsMappingMixin):
@classmethod @classmethod
@timeit @timeit
def get_nodes_all_assets(cls, *nodes): def get_nodes_all_assets(cls, *nodes, distinct=True):
from .asset import Asset from .asset import Asset
node_ids = set() node_ids = set()
descendant_node_query = Q() descendant_node_query = Q()
@ -439,7 +439,10 @@ class NodeAssetsMixin(NodeAllAssetsMappingMixin):
if descendant_node_query: if descendant_node_query:
_ids = Node.objects.order_by().filter(descendant_node_query).values_list('id', flat=True) _ids = Node.objects.order_by().filter(descendant_node_query).values_list('id', flat=True)
node_ids.update(_ids) node_ids.update(_ids)
return Asset.objects.order_by().filter(nodes__id__in=node_ids).distinct() assets = Asset.objects.order_by().filter(nodes__id__in=node_ids)
if distinct:
assets = assets.distinct()
return assets
def get_all_asset_ids(self): def get_all_asset_ids(self):
asset_ids = self.get_all_asset_ids_by_node_key(org_id=self.org_id, node_key=self.key) asset_ids = self.get_all_asset_ids_by_node_key(org_id=self.org_id, node_key=self.key)

View File

@ -63,13 +63,13 @@ def on_asset_create(sender, instance=None, created=False, **kwargs):
return return
logger.info("Asset create signal recv: {}".format(instance)) logger.info("Asset create signal recv: {}".format(instance))
ensure_asset_has_node(assets=(instance,)) ensure_asset_has_node.delay(assets=(instance,))
# 获取资产硬件信息 # 获取资产硬件信息
auto_config = instance.auto_config auto_config = instance.auto_config
if auto_config.get('ping_enabled'): if auto_config.get('ping_enabled'):
logger.debug('Asset {} ping enabled, test connectivity'.format(instance.name)) logger.debug('Asset {} ping enabled, test connectivity'.format(instance.name))
test_assets_connectivity_handler(assets=(instance,)) test_assets_connectivity_handler.delay(assets=(instance,))
if auto_config.get('gather_facts_enabled'): if auto_config.get('gather_facts_enabled'):
logger.debug('Asset {} gather facts enabled, gather facts'.format(instance.name)) logger.debug('Asset {} gather facts enabled, gather facts'.format(instance.name))
gather_assets_facts_handler(assets=(instance,)) gather_assets_facts_handler(assets=(instance,))

View File

@ -2,14 +2,16 @@
# #
from operator import add, sub from operator import add, sub
from django.conf import settings
from django.db.models.signals import m2m_changed from django.db.models.signals import m2m_changed
from django.dispatch import receiver from django.dispatch import receiver
from assets.models import Asset, Node from assets.models import Asset, Node
from common.const.signals import PRE_CLEAR, POST_ADD, PRE_REMOVE from common.const.signals import PRE_CLEAR, POST_ADD, PRE_REMOVE
from common.decorators import on_transaction_commit, merge_delay_run from common.decorators import on_transaction_commit, merge_delay_run
from common.signals import django_ready
from common.utils import get_logger from common.utils import get_logger
from orgs.utils import tmp_to_org from orgs.utils import tmp_to_org, tmp_to_root_org
from ..tasks import check_node_assets_amount_task from ..tasks import check_node_assets_amount_task
logger = get_logger(__file__) logger = get_logger(__file__)
@ -34,7 +36,7 @@ def on_node_asset_change(sender, action, instance, reverse, pk_set, **kwargs):
node_ids = [instance.id] node_ids = [instance.id]
else: else:
node_ids = list(pk_set) node_ids = list(pk_set)
update_nodes_assets_amount(node_ids=node_ids) update_nodes_assets_amount.delay(node_ids=node_ids)
@merge_delay_run(ttl=30) @merge_delay_run(ttl=30)
@ -52,3 +54,18 @@ def update_nodes_assets_amount(node_ids=()):
node.assets_amount = node.get_assets_amount() node.assets_amount = node.get_assets_amount()
Node.objects.bulk_update(nodes, ['assets_amount']) Node.objects.bulk_update(nodes, ['assets_amount'])
@receiver(django_ready)
def set_assets_size_to_setting(sender, **kwargs):
from assets.models import Asset
try:
with tmp_to_root_org():
amount = Asset.objects.order_by().count()
except:
amount = 0
if amount > 20000:
settings.ASSET_SIZE = 'large'
elif amount > 2000:
settings.ASSET_SIZE = 'medium'

View File

@ -44,18 +44,18 @@ def on_node_post_create(sender, instance, created, update_fields, **kwargs):
need_expire = False need_expire = False
if need_expire: if need_expire:
expire_node_assets_mapping(org_ids=(instance.org_id,)) expire_node_assets_mapping.delay(org_ids=(instance.org_id,))
@receiver(post_delete, sender=Node) @receiver(post_delete, sender=Node)
def on_node_post_delete(sender, instance, **kwargs): def on_node_post_delete(sender, instance, **kwargs):
expire_node_assets_mapping(org_ids=(instance.org_id,)) expire_node_assets_mapping.delay(org_ids=(instance.org_id,))
@receiver(m2m_changed, sender=Asset.nodes.through) @receiver(m2m_changed, sender=Asset.nodes.through)
def on_node_asset_change(sender, instance, action='pre_remove', **kwargs): def on_node_asset_change(sender, instance, action='pre_remove', **kwargs):
if action.startswith('post'): if action.startswith('post'):
expire_node_assets_mapping(org_ids=(instance.org_id,)) expire_node_assets_mapping.delay(org_ids=(instance.org_id,))
@receiver(django_ready) @receiver(django_ready)

View File

@ -20,6 +20,7 @@ from common.const.http import GET, POST
from common.drf.filters import DatetimeRangeFilterBackend from common.drf.filters import DatetimeRangeFilterBackend
from common.permissions import IsServiceAccount from common.permissions import IsServiceAccount
from common.plugins.es import QuerySet as ESQuerySet from common.plugins.es import QuerySet as ESQuerySet
from common.sessions.cache import user_session_manager
from common.storage.ftp_file import FTPFileStorageHandler from common.storage.ftp_file import FTPFileStorageHandler
from common.utils import is_uuid, get_logger, lazyproperty from common.utils import is_uuid, get_logger, lazyproperty
from orgs.mixins.api import OrgReadonlyModelViewSet, OrgModelViewSet from orgs.mixins.api import OrgReadonlyModelViewSet, OrgModelViewSet
@ -30,7 +31,7 @@ from terminal.models import default_storage
from users.models import User from users.models import User
from .backends import TYPE_ENGINE_MAPPING from .backends import TYPE_ENGINE_MAPPING
from .const import ActivityChoices from .const import ActivityChoices
from .filters import UserSessionFilterSet from .filters import UserSessionFilterSet, OperateLogFilterSet
from .models import ( from .models import (
FTPLog, UserLoginLog, OperateLog, PasswordChangeLog, FTPLog, UserLoginLog, OperateLog, PasswordChangeLog,
ActivityLog, JobLog, UserSession ActivityLog, JobLog, UserSession
@ -204,10 +205,7 @@ class OperateLogViewSet(OrgReadonlyModelViewSet):
date_range_filter_fields = [ date_range_filter_fields = [
('datetime', ('date_from', 'date_to')) ('datetime', ('date_from', 'date_to'))
] ]
filterset_fields = [ filterset_class = OperateLogFilterSet
'user', 'action', 'resource_type', 'resource',
'remote_addr'
]
search_fields = ['resource', 'user'] search_fields = ['resource', 'user']
ordering = ['-datetime'] ordering = ['-datetime']
@ -289,8 +287,7 @@ class UserSessionViewSet(CommonApiMixin, viewsets.ModelViewSet):
return Response(status=status.HTTP_200_OK) return Response(status=status.HTTP_200_OK)
keys = queryset.values_list('key', flat=True) keys = queryset.values_list('key', flat=True)
session_store_cls = import_module(settings.SESSION_ENGINE).SessionStore
for key in keys: for key in keys:
session_store_cls(key).delete() user_session_manager.decrement_or_remove(key)
queryset.delete() queryset.delete()
return Response(status=status.HTTP_200_OK) return Response(status=status.HTTP_200_OK)

View File

@ -1,12 +1,13 @@
from django.core.cache import cache from django.apps import apps
from django.utils import translation
from django_filters import rest_framework as drf_filters from django_filters import rest_framework as drf_filters
from rest_framework import filters from rest_framework import filters
from rest_framework.compat import coreapi, coreschema from rest_framework.compat import coreapi, coreschema
from common.drf.filters import BaseFilterSet from common.drf.filters import BaseFilterSet
from notifications.ws import WS_SESSION_KEY from common.sessions.cache import user_session_manager
from orgs.utils import current_org from orgs.utils import current_org
from .models import UserSession from .models import UserSession, OperateLog
__all__ = ['CurrentOrgMembersFilter'] __all__ = ['CurrentOrgMembersFilter']
@ -41,15 +42,32 @@ class UserSessionFilterSet(BaseFilterSet):
@staticmethod @staticmethod
def filter_is_active(queryset, name, is_active): def filter_is_active(queryset, name, is_active):
redis_client = cache.client.get_client() keys = user_session_manager.get_active_keys()
members = redis_client.smembers(WS_SESSION_KEY)
members = [member.decode('utf-8') for member in members]
if is_active: if is_active:
queryset = queryset.filter(key__in=members) queryset = queryset.filter(key__in=keys)
else: else:
queryset = queryset.exclude(key__in=members) queryset = queryset.exclude(key__in=keys)
return queryset return queryset
class Meta: class Meta:
model = UserSession model = UserSession
fields = ['id', 'ip', 'city', 'type'] fields = ['id', 'ip', 'city', 'type']
class OperateLogFilterSet(BaseFilterSet):
resource_type = drf_filters.CharFilter(method='filter_resource_type')
@staticmethod
def filter_resource_type(queryset, name, resource_type):
current_lang = translation.get_language()
with translation.override(current_lang):
mapper = {str(m._meta.verbose_name): m._meta.verbose_name_raw for m in apps.get_models()}
tp = mapper.get(resource_type)
queryset = queryset.filter(resource_type=tp)
return queryset
class Meta:
model = OperateLog
fields = [
'user', 'action', 'resource', 'remote_addr'
]

View File

@ -4,15 +4,15 @@ from datetime import timedelta
from importlib import import_module from importlib import import_module
from django.conf import settings from django.conf import settings
from django.core.cache import caches, cache from django.core.cache import caches
from django.db import models from django.db import models
from django.db.models import Q from django.db.models import Q
from django.utils import timezone from django.utils import timezone
from django.utils.translation import gettext, gettext_lazy as _ from django.utils.translation import gettext, gettext_lazy as _
from common.db.encoder import ModelJSONFieldEncoder from common.db.encoder import ModelJSONFieldEncoder
from common.sessions.cache import user_session_manager
from common.utils import lazyproperty, i18n_trans from common.utils import lazyproperty, i18n_trans
from notifications.ws import WS_SESSION_KEY
from ops.models import JobExecution from ops.models import JobExecution
from orgs.mixins.models import OrgModelMixin, Organization from orgs.mixins.models import OrgModelMixin, Organization
from orgs.utils import current_org from orgs.utils import current_org
@ -278,8 +278,7 @@ class UserSession(models.Model):
@property @property
def is_active(self): def is_active(self):
redis_client = cache.client.get_client() return user_session_manager.check_active(self.key)
return redis_client.sismember(WS_SESSION_KEY, self.key)
@property @property
def date_expired(self): def date_expired(self):

View File

@ -23,7 +23,7 @@ class JobLogSerializer(JobExecutionSerializer):
class Meta: class Meta:
model = models.JobLog model = models.JobLog
read_only_fields = [ read_only_fields = [
"id", "material", "time_cost", 'date_start', "id", "material", 'job_type', "time_cost", 'date_start',
'date_finished', 'date_created', 'date_finished', 'date_created',
'is_finished', 'is_success', 'is_finished', 'is_success',
'task_id', 'creator_name' 'task_id', 'creator_name'

View File

@ -19,7 +19,7 @@ from ops.celery.decorator import (
from ops.models import CeleryTaskExecution from ops.models import CeleryTaskExecution
from terminal.models import Session, Command from terminal.models import Session, Command
from terminal.backends import server_replay_storage from terminal.backends import server_replay_storage
from .models import UserLoginLog, OperateLog, FTPLog, ActivityLog from .models import UserLoginLog, OperateLog, FTPLog, ActivityLog, PasswordChangeLog
logger = get_logger(__name__) logger = get_logger(__name__)
@ -38,6 +38,14 @@ def clean_operation_log_period():
OperateLog.objects.filter(datetime__lt=expired_day).delete() OperateLog.objects.filter(datetime__lt=expired_day).delete()
def clean_password_change_log_period():
now = timezone.now()
days = get_log_keep_day('PASSWORD_CHANGE_LOG_KEEP_DAYS')
expired_day = now - datetime.timedelta(days=days)
PasswordChangeLog.objects.filter(datetime__lt=expired_day).delete()
logger.info("Clean password change log done")
def clean_activity_log_period(): def clean_activity_log_period():
now = timezone.now() now = timezone.now()
days = get_log_keep_day('ACTIVITY_LOG_KEEP_DAYS') days = get_log_keep_day('ACTIVITY_LOG_KEEP_DAYS')
@ -109,6 +117,7 @@ def clean_audits_log_period():
clean_activity_log_period() clean_activity_log_period()
clean_celery_tasks_period() clean_celery_tasks_period()
clean_expired_session_period() clean_expired_session_period()
clean_password_change_log_period()
@shared_task(verbose_name=_('Upload FTP file to external storage')) @shared_task(verbose_name=_('Upload FTP file to external storage'))

View File

@ -205,7 +205,7 @@ class RDPFileClientProtocolURLMixin:
return data return data
def get_smart_endpoint(self, protocol, asset=None): def get_smart_endpoint(self, protocol, asset=None):
endpoint = Endpoint.match_by_instance_label(asset, protocol) endpoint = Endpoint.match_by_instance_label(asset, protocol, self.request)
if not endpoint: if not endpoint:
target_ip = asset.get_target_ip() if asset else '' target_ip = asset.get_target_ip() if asset else ''
endpoint = EndpointRule.match_endpoint( endpoint = EndpointRule.match_endpoint(
@ -443,7 +443,7 @@ class ConnectionTokenViewSet(ExtraActionApiMixin, RootOrgViewMixin, JMSModelView
self._record_operate_log(acl, asset) self._record_operate_log(acl, asset)
for reviewer in reviewers: for reviewer in reviewers:
AssetLoginReminderMsg( AssetLoginReminderMsg(
reviewer, asset, user, self.input_username reviewer, asset, user, account, self.input_username
).publish_async() ).publish_async()
def create(self, request, *args, **kwargs): def create(self, request, *args, **kwargs):

View File

@ -10,6 +10,7 @@ from rest_framework import authentication, exceptions
from common.auth import signature from common.auth import signature
from common.decorators import merge_delay_run from common.decorators import merge_delay_run
from common.utils import get_object_or_none, get_request_ip_or_data, contains_ip from common.utils import get_object_or_none, get_request_ip_or_data, contains_ip
from users.models import User
from ..models import AccessKey, PrivateToken from ..models import AccessKey, PrivateToken
@ -19,22 +20,23 @@ def date_more_than(d, seconds):
@merge_delay_run(ttl=60) @merge_delay_run(ttl=60)
def update_token_last_used(tokens=()): def update_token_last_used(tokens=()):
for token in tokens: access_keys_ids = [token.id for token in tokens if isinstance(token, AccessKey)]
token.date_last_used = timezone.now() private_token_keys = [token.key for token in tokens if isinstance(token, PrivateToken)]
token.save(update_fields=['date_last_used']) if len(access_keys_ids) > 0:
AccessKey.objects.filter(id__in=access_keys_ids).update(date_last_used=timezone.now())
if len(private_token_keys) > 0:
PrivateToken.objects.filter(key__in=private_token_keys).update(date_last_used=timezone.now())
@merge_delay_run(ttl=60) @merge_delay_run(ttl=60)
def update_user_last_used(users=()): def update_user_last_used(users=()):
for user in users: User.objects.filter(id__in=users).update(date_api_key_last_used=timezone.now())
user.date_api_key_last_used = timezone.now()
user.save(update_fields=['date_api_key_last_used'])
def after_authenticate_update_date(user, token=None): def after_authenticate_update_date(user, token=None):
update_user_last_used(users=(user,)) update_user_last_used.delay(users=(user.id,))
if token: if token:
update_token_last_used(tokens=(token,)) update_token_last_used.delay(tokens=(token,))
class AccessTokenAuthentication(authentication.BaseAuthentication): class AccessTokenAuthentication(authentication.BaseAuthentication):

View File

@ -98,16 +98,19 @@ class OAuth2Backend(JMSModelBackend):
access_token_url = '{url}{separator}{query}'.format( access_token_url = '{url}{separator}{query}'.format(
url=settings.AUTH_OAUTH2_ACCESS_TOKEN_ENDPOINT, separator=separator, query=urlencode(query_dict) url=settings.AUTH_OAUTH2_ACCESS_TOKEN_ENDPOINT, separator=separator, query=urlencode(query_dict)
) )
# token_method -> get, post(post_data), post_json
token_method = settings.AUTH_OAUTH2_ACCESS_TOKEN_METHOD.lower() token_method = settings.AUTH_OAUTH2_ACCESS_TOKEN_METHOD.lower()
requests_func = getattr(requests, token_method, requests.get)
logger.debug(log_prompt.format('Call the access token endpoint[method: %s]' % token_method)) logger.debug(log_prompt.format('Call the access token endpoint[method: %s]' % token_method))
headers = { headers = {
'Accept': 'application/json' 'Accept': 'application/json'
} }
if token_method == 'post': if token_method.startswith('post'):
access_token_response = requests_func(access_token_url, headers=headers, data=query_dict) body_key = 'json' if token_method.endswith('json') else 'data'
access_token_response = requests.post(
access_token_url, headers=headers, **{body_key: query_dict}
)
else: else:
access_token_response = requests_func(access_token_url, headers=headers) access_token_response = requests.get(access_token_url, headers=headers)
try: try:
access_token_response.raise_for_status() access_token_response.raise_for_status()
access_token_response_data = access_token_response.json() access_token_response_data = access_token_response.json()

View File

@ -18,7 +18,7 @@ class EncryptedField(forms.CharField):
class UserLoginForm(forms.Form): class UserLoginForm(forms.Form):
days_auto_login = int(settings.SESSION_COOKIE_AGE / 3600 / 24) days_auto_login = int(settings.SESSION_COOKIE_AGE / 3600 / 24)
disable_days_auto_login = settings.SESSION_EXPIRE_AT_BROWSER_CLOSE_FORCE \ disable_days_auto_login = settings.SESSION_EXPIRE_AT_BROWSER_CLOSE \
or days_auto_login < 1 or days_auto_login < 1
username = forms.CharField( username = forms.CharField(

View File

@ -142,23 +142,7 @@ class SessionCookieMiddleware(MiddlewareMixin):
return response return response
response.set_cookie(key, value) response.set_cookie(key, value)
@staticmethod
def set_cookie_session_expire(request, response):
if not request.session.get('auth_session_expiration_required'):
return
value = 'age'
if settings.SESSION_EXPIRE_AT_BROWSER_CLOSE_FORCE or \
not request.session.get('auto_login', False):
value = 'close'
age = request.session.get_expiry_age()
expire_timestamp = request.session.get_expiry_date().timestamp()
response.set_cookie('jms_session_expire_timestamp', expire_timestamp)
response.set_cookie('jms_session_expire', value, max_age=age)
request.session.pop('auth_session_expiration_required', None)
def process_response(self, request, response: HttpResponse): def process_response(self, request, response: HttpResponse):
self.set_cookie_session_prefix(request, response) self.set_cookie_session_prefix(request, response)
self.set_cookie_public_key(request, response) self.set_cookie_public_key(request, response)
self.set_cookie_session_expire(request, response)
return response return response

View File

@ -37,9 +37,6 @@ def on_user_auth_login_success(sender, user, request, **kwargs):
UserSession.objects.filter(key=session_key).delete() UserSession.objects.filter(key=session_key).delete()
cache.set(lock_key, request.session.session_key, None) cache.set(lock_key, request.session.session_key, None)
# 标记登录,设置 cookie前端可以控制刷新, Middleware 会拦截这个生成 cookie
request.session['auth_session_expiration_required'] = 1
@receiver(cas_user_authenticated) @receiver(cas_user_authenticated)
def on_cas_user_login_success(sender, request, user, **kwargs): def on_cas_user_login_success(sender, request, user, **kwargs):

View File

@ -407,6 +407,15 @@
$('#password-hidden').val(passwordEncrypted); //返回给密码输入input $('#password-hidden').val(passwordEncrypted); //返回给密码输入input
$('#login-form').submit(); //post提交 $('#login-form').submit(); //post提交
} }
function checkHealth() {
let url = "{% url 'health' %}";
requestApi({
url: url,
method: "GET",
flash_message: false,
})
}
setInterval(checkHealth, 30 * 1000);
</script> </script>
</html> </html>

View File

@ -70,11 +70,12 @@ class DingTalkQRMixin(DingTalkBaseMixin, View):
self.request.session[DINGTALK_STATE_SESSION_KEY] = state self.request.session[DINGTALK_STATE_SESSION_KEY] = state
params = { params = {
'appid': settings.DINGTALK_APPKEY, 'client_id': settings.DINGTALK_APPKEY,
'response_type': 'code', 'response_type': 'code',
'scope': 'snsapi_login', 'scope': 'openid',
'state': state, 'state': state,
'redirect_uri': redirect_uri, 'redirect_uri': redirect_uri,
'prompt': 'consent'
} }
url = URL.QR_CONNECT + '?' + urlencode(params) url = URL.QR_CONNECT + '?' + urlencode(params)
return url return url

View File

@ -19,3 +19,17 @@ class Status(models.TextChoices):
failed = 'failed', _("Failed") failed = 'failed', _("Failed")
error = 'error', _("Error") error = 'error', _("Error")
canceled = 'canceled', _("Canceled") canceled = 'canceled', _("Canceled")
COUNTRY_CALLING_CODES = [
{'name': 'China(中国)', 'value': '+86'},
{'name': 'HongKong(中国香港)', 'value': '+852'},
{'name': 'Macao(中国澳门)', 'value': '+853'},
{'name': 'Taiwan(中国台湾)', 'value': '+886'},
{'name': 'America(America)', 'value': '+1'}, {'name': 'Russia(Россия)', 'value': '+7'},
{'name': 'France(français)', 'value': '+33'},
{'name': 'Britain(Britain)', 'value': '+44'},
{'name': 'Germany(Deutschland)', 'value': '+49'},
{'name': 'Japan(日本)', 'value': '+81'}, {'name': 'Korea(한국)', 'value': '+82'},
{'name': 'India(भारत)', 'value': '+91'}
]

View File

@ -362,11 +362,15 @@ class RelatedManager:
if name is None or val is None: if name is None or val is None:
continue continue
if custom_attr_filter: custom_filter_q = None
spec_attr_filter = getattr(to_model, "get_{}_filter_attr_q".format(name), None)
if spec_attr_filter:
custom_filter_q = spec_attr_filter(val, match)
elif custom_attr_filter:
custom_filter_q = custom_attr_filter(name, val, match) custom_filter_q = custom_attr_filter(name, val, match)
if custom_filter_q: if custom_filter_q:
filters.append(custom_filter_q) filters.append(custom_filter_q)
continue continue
if match == 'ip_in': if match == 'ip_in':
q = cls.get_ip_in_q(name, val) q = cls.get_ip_in_q(name, val)
@ -464,11 +468,15 @@ class JSONManyToManyDescriptor:
rule_value = rule.get('value', '') rule_value = rule.get('value', '')
rule_match = rule.get('match', 'exact') rule_match = rule.get('match', 'exact')
if custom_attr_filter: custom_filter_q = None
q = custom_attr_filter(rule['name'], rule_value, rule_match) spec_attr_filter = getattr(to_model, "get_filter_{}_attr_q".format(rule['name']), None)
if q: if spec_attr_filter:
custom_q &= q custom_filter_q = spec_attr_filter(rule_value, rule_match)
continue elif custom_attr_filter:
custom_filter_q = custom_attr_filter(rule['name'], rule_value, rule_match)
if custom_filter_q:
custom_q &= custom_filter_q
continue
if rule_match == 'in': if rule_match == 'in':
res &= value in rule_value or '*' in rule_value res &= value in rule_value or '*' in rule_value
@ -517,7 +525,6 @@ class JSONManyToManyDescriptor:
res &= rule_value.issubset(value) res &= rule_value.issubset(value)
else: else:
res &= bool(value & rule_value) res &= bool(value & rule_value)
else: else:
logging.error("unknown match: {}".format(rule['match'])) logging.error("unknown match: {}".format(rule['match']))
res &= False res &= False

View File

@ -3,6 +3,7 @@
import asyncio import asyncio
import functools import functools
import inspect import inspect
import os
import threading import threading
import time import time
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
@ -101,7 +102,11 @@ def run_debouncer_func(cache_key, org, ttl, func, *args, **kwargs):
first_run_time = current first_run_time = current
if current - first_run_time > ttl: if current - first_run_time > ttl:
_loop_debouncer_func_args_cache.pop(cache_key, None)
_loop_debouncer_func_task_time_cache.pop(cache_key, None)
executor.submit(run_func_partial, *args, **kwargs) executor.submit(run_func_partial, *args, **kwargs)
logger.debug('pid {} executor submit run {}'.format(
os.getpid(), func.__name__, ))
return return
loop = _loop_thread.get_loop() loop = _loop_thread.get_loop()
@ -133,13 +138,26 @@ class Debouncer(object):
return await self.loop.run_in_executor(self.executor, func) return await self.loop.run_in_executor(self.executor, func)
ignore_err_exceptions = (
"(3101, 'Plugin instructed the server to rollback the current transaction.')",
)
def _run_func_with_org(key, org, func, *args, **kwargs): def _run_func_with_org(key, org, func, *args, **kwargs):
from orgs.utils import set_current_org from orgs.utils import set_current_org
try: try:
set_current_org(org) with transaction.atomic():
func(*args, **kwargs) set_current_org(org)
func(*args, **kwargs)
except Exception as e: except Exception as e:
logger.error('delay run error: %s' % e) msg = str(e)
log_func = logger.error
if msg in ignore_err_exceptions:
log_func = logger.info
pid = os.getpid()
thread_name = threading.current_thread()
log_func('pid {} thread {} delay run {} error: {}'.format(
pid, thread_name, func.__name__, msg))
_loop_debouncer_func_task_cache.pop(key, None) _loop_debouncer_func_task_cache.pop(key, None)
_loop_debouncer_func_args_cache.pop(key, None) _loop_debouncer_func_args_cache.pop(key, None)
_loop_debouncer_func_task_time_cache.pop(key, None) _loop_debouncer_func_task_time_cache.pop(key, None)
@ -181,6 +199,32 @@ def merge_delay_run(ttl=5, key=None):
:return: :return:
""" """
def delay(func, *args, **kwargs):
from orgs.utils import get_current_org
suffix_key_func = key if key else default_suffix_key
org = get_current_org()
func_name = f'{func.__module__}_{func.__name__}'
key_suffix = suffix_key_func(*args, **kwargs)
cache_key = f'MERGE_DELAY_RUN_{func_name}_{key_suffix}'
cache_kwargs = _loop_debouncer_func_args_cache.get(cache_key, {})
for k, v in kwargs.items():
if not isinstance(v, (tuple, list, set)):
raise ValueError('func kwargs value must be list or tuple: %s %s' % (func.__name__, v))
v = set(v)
if k not in cache_kwargs:
cache_kwargs[k] = v
else:
cache_kwargs[k] = cache_kwargs[k].union(v)
_loop_debouncer_func_args_cache[cache_key] = cache_kwargs
run_debouncer_func(cache_key, org, ttl, func, *args, **cache_kwargs)
def apply(func, sync=False, *args, **kwargs):
if sync:
return func(*args, **kwargs)
else:
return delay(func, *args, **kwargs)
def inner(func): def inner(func):
sigs = inspect.signature(func) sigs = inspect.signature(func)
if len(sigs.parameters) != 1: if len(sigs.parameters) != 1:
@ -188,27 +232,12 @@ def merge_delay_run(ttl=5, key=None):
param = list(sigs.parameters.values())[0] param = list(sigs.parameters.values())[0]
if not isinstance(param.default, tuple): if not isinstance(param.default, tuple):
raise ValueError('func default must be tuple: %s' % param.default) raise ValueError('func default must be tuple: %s' % param.default)
suffix_key_func = key if key else default_suffix_key func.delay = functools.partial(delay, func)
func.apply = functools.partial(apply, func)
@functools.wraps(func) @functools.wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
from orgs.utils import get_current_org return func(*args, **kwargs)
org = get_current_org()
func_name = f'{func.__module__}_{func.__name__}'
key_suffix = suffix_key_func(*args, **kwargs)
cache_key = f'MERGE_DELAY_RUN_{func_name}_{key_suffix}'
cache_kwargs = _loop_debouncer_func_args_cache.get(cache_key, {})
for k, v in kwargs.items():
if not isinstance(v, (tuple, list, set)):
raise ValueError('func kwargs value must be list or tuple: %s %s' % (func.__name__, v))
v = set(v)
if k not in cache_kwargs:
cache_kwargs[k] = v
else:
cache_kwargs[k] = cache_kwargs[k].union(v)
_loop_debouncer_func_args_cache[cache_key] = cache_kwargs
run_debouncer_func(cache_key, org, ttl, func, *args, **cache_kwargs)
return wrapper return wrapper

View File

@ -6,7 +6,7 @@ import logging
from django.core.cache import cache from django.core.cache import cache
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.db.models import Q, Count from django.db.models import Q
from django_filters import rest_framework as drf_filters from django_filters import rest_framework as drf_filters
from rest_framework import filters from rest_framework import filters
from rest_framework.compat import coreapi, coreschema from rest_framework.compat import coreapi, coreschema
@ -180,36 +180,30 @@ class LabelFilterBackend(filters.BaseFilterBackend):
] ]
@staticmethod @staticmethod
def filter_resources(resources, labels_id): def parse_label_ids(labels_id):
from labels.models import Label
label_ids = [i.strip() for i in labels_id.split(',')] label_ids = [i.strip() for i in labels_id.split(',')]
cleaned = []
args = [] args = []
for label_id in label_ids: for label_id in label_ids:
kwargs = {} kwargs = {}
if ':' in label_id: if ':' in label_id:
k, v = label_id.split(':', 1) k, v = label_id.split(':', 1)
kwargs['label__name'] = k.strip() kwargs['name'] = k.strip()
if v != '*': if v != '*':
kwargs['label__value'] = v.strip() kwargs['value'] = v.strip()
args.append(kwargs)
else: else:
kwargs['label_id'] = label_id cleaned.append(label_id)
args.append(kwargs)
if len(args) == 1: if len(args) != 0:
resources = resources.filter(**args[0]) q = Q()
return resources for kwarg in args:
q |= Q(**kwarg)
q = Q() ids = Label.objects.filter(q).values_list('id', flat=True)
for kwarg in args: cleaned.extend(list(ids))
q |= Q(**kwarg) return cleaned
resources = resources.filter(q) \
.values('res_id') \
.order_by('res_id') \
.annotate(count=Count('res_id', distinct=True)) \
.values('res_id', 'count') \
.filter(count=len(args))
return resources
def filter_queryset(self, request, queryset, view): def filter_queryset(self, request, queryset, view):
labels_id = request.query_params.get('labels') labels_id = request.query_params.get('labels')
@ -230,7 +224,8 @@ class LabelFilterBackend(filters.BaseFilterBackend):
resources = labeled_resource_cls.objects.filter( resources = labeled_resource_cls.objects.filter(
res_type__app_label=app_label, res_type__model=model_name, res_type__app_label=app_label, res_type__model=model_name,
) )
resources = self.filter_resources(resources, labels_id) label_ids = self.parse_label_ids(labels_id)
resources = model.filter_resources_by_labels(resources, label_ids)
res_ids = resources.values_list('res_id', flat=True) res_ids = resources.values_list('res_id', flat=True)
queryset = queryset.filter(id__in=set(res_ids)) queryset = queryset.filter(id__in=set(res_ids))
return queryset return queryset

View File

@ -87,7 +87,7 @@ class BaseFileRenderer(BaseRenderer):
if value is None: if value is None:
return '-' return '-'
pk = str(value.get('id', '') or value.get('pk', '')) pk = str(value.get('id', '') or value.get('pk', ''))
name = value.get('name') or value.get('display_name', '') name = value.get('display_name', '') or value.get('name', '')
return '{}({})'.format(name, pk) return '{}({})'.format(name, pk)
@staticmethod @staticmethod

View File

@ -28,9 +28,10 @@ class ErrorCode:
class URL: class URL:
QR_CONNECT = 'https://oapi.dingtalk.com/connect/qrconnect' QR_CONNECT = 'https://login.dingtalk.com/oauth2/auth'
OAUTH_CONNECT = 'https://oapi.dingtalk.com/connect/oauth2/sns_authorize' OAUTH_CONNECT = 'https://oapi.dingtalk.com/connect/oauth2/sns_authorize'
GET_USER_INFO_BY_CODE = 'https://oapi.dingtalk.com/sns/getuserinfo_bycode' GET_USER_ACCESSTOKEN = 'https://api.dingtalk.com/v1.0/oauth2/userAccessToken'
GET_USER_INFO = 'https://api.dingtalk.com/v1.0/contact/users/me'
GET_TOKEN = 'https://oapi.dingtalk.com/gettoken' GET_TOKEN = 'https://oapi.dingtalk.com/gettoken'
SEND_MESSAGE_BY_TEMPLATE = 'https://oapi.dingtalk.com/topapi/message/corpconversation/sendbytemplate' SEND_MESSAGE_BY_TEMPLATE = 'https://oapi.dingtalk.com/topapi/message/corpconversation/sendbytemplate'
SEND_MESSAGE = 'https://oapi.dingtalk.com/topapi/message/corpconversation/asyncsend_v2' SEND_MESSAGE = 'https://oapi.dingtalk.com/topapi/message/corpconversation/asyncsend_v2'
@ -72,8 +73,9 @@ class DingTalkRequests(BaseRequest):
def get(self, url, params=None, def get(self, url, params=None,
with_token=False, with_sign=False, with_token=False, with_sign=False,
check_errcode_is_0=True, check_errcode_is_0=True,
**kwargs): **kwargs) -> dict:
pass pass
get = as_request(get) get = as_request(get)
def post(self, url, json=None, params=None, def post(self, url, json=None, params=None,
@ -81,6 +83,7 @@ class DingTalkRequests(BaseRequest):
check_errcode_is_0=True, check_errcode_is_0=True,
**kwargs) -> dict: **kwargs) -> dict:
pass pass
post = as_request(post) post = as_request(post)
def _add_sign(self, kwargs: dict): def _add_sign(self, kwargs: dict):
@ -123,17 +126,22 @@ class DingTalk:
) )
def get_userinfo_bycode(self, code): def get_userinfo_bycode(self, code):
# https://developers.dingtalk.com/document/app/obtain-the-user-information-based-on-the-sns-temporary-authorization?spm=ding_open_doc.document.0.0.3a256573y8Y7yg#topic-1995619
body = { body = {
"tmp_auth_code": code 'clientId': self._appid,
'clientSecret': self._appsecret,
'code': code,
'grantType': 'authorization_code'
} }
data = self._request.post(URL.GET_USER_ACCESSTOKEN, json=body, check_errcode_is_0=False)
token = data['accessToken']
data = self._request.post(URL.GET_USER_INFO_BY_CODE, json=body, with_sign=True) user = self._request.get(URL.GET_USER_INFO,
return data['user_info'] headers={'x-acs-dingtalk-access-token': token}, check_errcode_is_0=False)
return user
def get_user_id_by_code(self, code): def get_user_id_by_code(self, code):
user_info = self.get_userinfo_bycode(code) user_info = self.get_userinfo_bycode(code)
unionid = user_info['unionid'] unionid = user_info['unionId']
userid = self.get_userid_by_unionid(unionid) userid = self.get_userid_by_unionid(unionid)
return userid, None return userid, None

View File

View File

@ -0,0 +1,56 @@
import re
from django.contrib.sessions.backends.cache import (
SessionStore as DjangoSessionStore
)
from django.core.cache import cache
from jumpserver.utils import get_current_request
class SessionStore(DjangoSessionStore):
ignore_urls = [
r'^/api/v1/users/profile/'
]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.ignore_pattern = re.compile('|'.join(self.ignore_urls))
def save(self, *args, **kwargs):
request = get_current_request()
if request is None or not self.ignore_pattern.match(request.path):
super().save(*args, **kwargs)
class RedisUserSessionManager:
JMS_SESSION_KEY = 'jms_session_key'
def __init__(self):
self.client = cache.client.get_client()
def add_or_increment(self, session_key):
self.client.hincrby(self.JMS_SESSION_KEY, session_key, 1)
def decrement_or_remove(self, session_key):
new_count = self.client.hincrby(self.JMS_SESSION_KEY, session_key, -1)
if new_count <= 0:
self.client.hdel(self.JMS_SESSION_KEY, session_key)
def check_active(self, session_key):
count = self.client.hget(self.JMS_SESSION_KEY, session_key)
count = 0 if count is None else int(count.decode('utf-8'))
return count > 0
def get_active_keys(self):
session_keys = []
for k, v in self.client.hgetall(self.JMS_SESSION_KEY).items():
count = int(v.decode('utf-8'))
if count <= 0:
continue
key = k.decode('utf-8')
session_keys.append(key)
return session_keys
user_session_manager = RedisUserSessionManager()

View File

@ -69,7 +69,7 @@ def digest_sql_query():
for query in queries: for query in queries:
sql = query['sql'] sql = query['sql']
print(" # {}: {}".format(query['time'], sql[:1000])) print(" # {}: {}".format(query['time'], sql[:1000]))
if len(queries) < 3: if len(queries) < 3:
continue continue
print("- Table: {}".format(table_name)) print("- Table: {}".format(table_name))

View File

@ -21,6 +21,8 @@ def encrypt_and_compress_zip_file(filename, secret_password, encrypted_filenames
with pyzipper.AESZipFile( with pyzipper.AESZipFile(
filename, 'w', compression=pyzipper.ZIP_LZMA, encryption=pyzipper.WZ_AES filename, 'w', compression=pyzipper.ZIP_LZMA, encryption=pyzipper.WZ_AES
) as zf: ) as zf:
if secret_password and isinstance(secret_password, str):
secret_password = secret_password.encode('utf8')
zf.setpassword(secret_password) zf.setpassword(secret_password)
for encrypted_filename in encrypted_filenames: for encrypted_filename in encrypted_filenames:
with open(encrypted_filename, 'rb') as f: with open(encrypted_filename, 'rb') as f:

View File

@ -547,7 +547,6 @@ class Config(dict):
'REFERER_CHECK_ENABLED': False, 'REFERER_CHECK_ENABLED': False,
'SESSION_ENGINE': 'cache', 'SESSION_ENGINE': 'cache',
'SESSION_SAVE_EVERY_REQUEST': True, 'SESSION_SAVE_EVERY_REQUEST': True,
'SESSION_EXPIRE_AT_BROWSER_CLOSE_FORCE': False,
'SERVER_REPLAY_STORAGE': {}, 'SERVER_REPLAY_STORAGE': {},
'SECURITY_DATA_CRYPTO_ALGO': None, 'SECURITY_DATA_CRYPTO_ALGO': None,
'GMSSL_ENABLED': False, 'GMSSL_ENABLED': False,
@ -564,8 +563,10 @@ class Config(dict):
'FTP_LOG_KEEP_DAYS': 180, 'FTP_LOG_KEEP_DAYS': 180,
'CLOUD_SYNC_TASK_EXECUTION_KEEP_DAYS': 180, 'CLOUD_SYNC_TASK_EXECUTION_KEEP_DAYS': 180,
'JOB_EXECUTION_KEEP_DAYS': 180, 'JOB_EXECUTION_KEEP_DAYS': 180,
'PASSWORD_CHANGE_LOG_KEEP_DAYS': 999,
'TICKETS_ENABLED': True, 'TICKETS_ENABLED': True,
'TICKETS_DIRECT_APPROVE': False,
# 废弃的 # 废弃的
'DEFAULT_ORG_SHOW_ALL_USERS': True, 'DEFAULT_ORG_SHOW_ALL_USERS': True,
@ -606,7 +607,9 @@ class Config(dict):
'GPT_MODEL': 'gpt-3.5-turbo', 'GPT_MODEL': 'gpt-3.5-turbo',
'VIRTUAL_APP_ENABLED': False, 'VIRTUAL_APP_ENABLED': False,
'FILE_UPLOAD_SIZE_LIMIT_MB': 200 'FILE_UPLOAD_SIZE_LIMIT_MB': 200,
'TICKET_APPLY_ASSET_SCOPE': 'all'
} }
old_config_map = { old_config_map = {
@ -701,7 +704,8 @@ class Config(dict):
def compatible_redis(self): def compatible_redis(self):
redis_config = { redis_config = {
'REDIS_PASSWORD': quote(str(self.REDIS_PASSWORD)), 'REDIS_PASSWORD': str(self.REDIS_PASSWORD),
'REDIS_PASSWORD_QUOTE': quote(str(self.REDIS_PASSWORD)),
} }
for key, value in redis_config.items(): for key, value in redis_config.items():
self[key] = value self[key] = value

View File

@ -66,11 +66,6 @@ class RequestMiddleware:
def __call__(self, request): def __call__(self, request):
set_current_request(request) set_current_request(request)
response = self.get_response(request) response = self.get_response(request)
is_request_api = request.path.startswith('/api')
if not settings.SESSION_EXPIRE_AT_BROWSER_CLOSE and \
not is_request_api:
age = request.session.get_expiry_age()
request.session.set_expiry(age)
return response return response

View File

@ -3,6 +3,7 @@
path_perms_map = { path_perms_map = {
'xpack': '*', 'xpack': '*',
'settings': '*', 'settings': '*',
'img': '*',
'replay': 'default', 'replay': 'default',
'applets': 'terminal.view_applet', 'applets': 'terminal.view_applet',
'virtual_apps': 'terminal.view_virtualapp', 'virtual_apps': 'terminal.view_virtualapp',

View File

@ -234,11 +234,9 @@ CSRF_COOKIE_NAME = '{}csrftoken'.format(SESSION_COOKIE_NAME_PREFIX)
SESSION_COOKIE_NAME = '{}sessionid'.format(SESSION_COOKIE_NAME_PREFIX) SESSION_COOKIE_NAME = '{}sessionid'.format(SESSION_COOKIE_NAME_PREFIX)
SESSION_COOKIE_AGE = CONFIG.SESSION_COOKIE_AGE SESSION_COOKIE_AGE = CONFIG.SESSION_COOKIE_AGE
SESSION_EXPIRE_AT_BROWSER_CLOSE = True
# 自定义的配置SESSION_EXPIRE_AT_BROWSER_CLOSE 始终为 True, 下面这个来控制是否强制关闭后过期 cookie
SESSION_EXPIRE_AT_BROWSER_CLOSE_FORCE = CONFIG.SESSION_EXPIRE_AT_BROWSER_CLOSE_FORCE
SESSION_SAVE_EVERY_REQUEST = CONFIG.SESSION_SAVE_EVERY_REQUEST SESSION_SAVE_EVERY_REQUEST = CONFIG.SESSION_SAVE_EVERY_REQUEST
SESSION_ENGINE = "django.contrib.sessions.backends.{}".format(CONFIG.SESSION_ENGINE) SESSION_EXPIRE_AT_BROWSER_CLOSE = CONFIG.SESSION_EXPIRE_AT_BROWSER_CLOSE
SESSION_ENGINE = "common.sessions.{}".format(CONFIG.SESSION_ENGINE)
MESSAGE_STORAGE = 'django.contrib.messages.storage.cookie.CookieStorage' MESSAGE_STORAGE = 'django.contrib.messages.storage.cookie.CookieStorage'
# Database # Database
@ -408,7 +406,7 @@ if REDIS_SENTINEL_SERVICE_NAME and REDIS_SENTINELS:
else: else:
REDIS_LOCATION_NO_DB = '%(protocol)s://:%(password)s@%(host)s:%(port)s/{}' % { REDIS_LOCATION_NO_DB = '%(protocol)s://:%(password)s@%(host)s:%(port)s/{}' % {
'protocol': REDIS_PROTOCOL, 'protocol': REDIS_PROTOCOL,
'password': CONFIG.REDIS_PASSWORD, 'password': CONFIG.REDIS_PASSWORD_QUOTE,
'host': CONFIG.REDIS_HOST, 'host': CONFIG.REDIS_HOST,
'port': CONFIG.REDIS_PORT, 'port': CONFIG.REDIS_PORT,
} }

View File

@ -122,11 +122,11 @@ WS_LISTEN_PORT = CONFIG.WS_LISTEN_PORT
LOGIN_LOG_KEEP_DAYS = CONFIG.LOGIN_LOG_KEEP_DAYS LOGIN_LOG_KEEP_DAYS = CONFIG.LOGIN_LOG_KEEP_DAYS
TASK_LOG_KEEP_DAYS = CONFIG.TASK_LOG_KEEP_DAYS TASK_LOG_KEEP_DAYS = CONFIG.TASK_LOG_KEEP_DAYS
OPERATE_LOG_KEEP_DAYS = CONFIG.OPERATE_LOG_KEEP_DAYS OPERATE_LOG_KEEP_DAYS = CONFIG.OPERATE_LOG_KEEP_DAYS
PASSWORD_CHANGE_LOG_KEEP_DAYS = CONFIG.PASSWORD_CHANGE_LOG_KEEP_DAYS
ACTIVITY_LOG_KEEP_DAYS = CONFIG.ACTIVITY_LOG_KEEP_DAYS ACTIVITY_LOG_KEEP_DAYS = CONFIG.ACTIVITY_LOG_KEEP_DAYS
FTP_LOG_KEEP_DAYS = CONFIG.FTP_LOG_KEEP_DAYS FTP_LOG_KEEP_DAYS = CONFIG.FTP_LOG_KEEP_DAYS
CLOUD_SYNC_TASK_EXECUTION_KEEP_DAYS = CONFIG.CLOUD_SYNC_TASK_EXECUTION_KEEP_DAYS CLOUD_SYNC_TASK_EXECUTION_KEEP_DAYS = CONFIG.CLOUD_SYNC_TASK_EXECUTION_KEEP_DAYS
JOB_EXECUTION_KEEP_DAYS = CONFIG.JOB_EXECUTION_KEEP_DAYS JOB_EXECUTION_KEEP_DAYS = CONFIG.JOB_EXECUTION_KEEP_DAYS
ORG_CHANGE_TO_URL = CONFIG.ORG_CHANGE_TO_URL ORG_CHANGE_TO_URL = CONFIG.ORG_CHANGE_TO_URL
WINDOWS_SKIP_ALL_MANUAL_PASSWORD = CONFIG.WINDOWS_SKIP_ALL_MANUAL_PASSWORD WINDOWS_SKIP_ALL_MANUAL_PASSWORD = CONFIG.WINDOWS_SKIP_ALL_MANUAL_PASSWORD
@ -137,6 +137,7 @@ CHANGE_AUTH_PLAN_SECURE_MODE_ENABLED = CONFIG.CHANGE_AUTH_PLAN_SECURE_MODE_ENABL
DATETIME_DISPLAY_FORMAT = '%Y-%m-%d %H:%M:%S' DATETIME_DISPLAY_FORMAT = '%Y-%m-%d %H:%M:%S'
TICKETS_ENABLED = CONFIG.TICKETS_ENABLED TICKETS_ENABLED = CONFIG.TICKETS_ENABLED
TICKETS_DIRECT_APPROVE = CONFIG.TICKETS_DIRECT_APPROVE
REFERER_CHECK_ENABLED = CONFIG.REFERER_CHECK_ENABLED REFERER_CHECK_ENABLED = CONFIG.REFERER_CHECK_ENABLED
CONNECTION_TOKEN_ENABLED = CONFIG.CONNECTION_TOKEN_ENABLED CONNECTION_TOKEN_ENABLED = CONFIG.CONNECTION_TOKEN_ENABLED
@ -214,6 +215,9 @@ PERM_TREE_REGEN_INTERVAL = CONFIG.PERM_TREE_REGEN_INTERVAL
MAGNUS_ORACLE_PORTS = CONFIG.MAGNUS_ORACLE_PORTS MAGNUS_ORACLE_PORTS = CONFIG.MAGNUS_ORACLE_PORTS
LIMIT_SUPER_PRIV = CONFIG.LIMIT_SUPER_PRIV LIMIT_SUPER_PRIV = CONFIG.LIMIT_SUPER_PRIV
# Asset account may be too many
ASSET_SIZE = 'small'
# Chat AI # Chat AI
CHAT_AI_ENABLED = CONFIG.CHAT_AI_ENABLED CHAT_AI_ENABLED = CONFIG.CHAT_AI_ENABLED
GPT_API_KEY = CONFIG.GPT_API_KEY GPT_API_KEY = CONFIG.GPT_API_KEY
@ -224,3 +228,5 @@ GPT_MODEL = CONFIG.GPT_MODEL
VIRTUAL_APP_ENABLED = CONFIG.VIRTUAL_APP_ENABLED VIRTUAL_APP_ENABLED = CONFIG.VIRTUAL_APP_ENABLED
FILE_UPLOAD_SIZE_LIMIT_MB = CONFIG.FILE_UPLOAD_SIZE_LIMIT_MB FILE_UPLOAD_SIZE_LIMIT_MB = CONFIG.FILE_UPLOAD_SIZE_LIMIT_MB
TICKET_APPLY_ASSET_SCOPE = CONFIG.TICKET_APPLY_ASSET_SCOPE

View File

@ -82,7 +82,6 @@ BOOTSTRAP3 = {
# Django channels support websocket # Django channels support websocket
REDIS_LAYERS_HOST = { REDIS_LAYERS_HOST = {
'db': CONFIG.REDIS_DB_WS, 'db': CONFIG.REDIS_DB_WS,
'password': CONFIG.REDIS_PASSWORD or None,
} }
REDIS_LAYERS_SSL_PARAMS = {} REDIS_LAYERS_SSL_PARAMS = {}
@ -97,6 +96,7 @@ if REDIS_USE_SSL:
if REDIS_SENTINEL_SERVICE_NAME and REDIS_SENTINELS: if REDIS_SENTINEL_SERVICE_NAME and REDIS_SENTINELS:
REDIS_LAYERS_HOST['sentinels'] = REDIS_SENTINELS REDIS_LAYERS_HOST['sentinels'] = REDIS_SENTINELS
REDIS_LAYERS_HOST['password'] = CONFIG.REDIS_PASSWORD or None
REDIS_LAYERS_HOST['master_name'] = REDIS_SENTINEL_SERVICE_NAME REDIS_LAYERS_HOST['master_name'] = REDIS_SENTINEL_SERVICE_NAME
REDIS_LAYERS_HOST['sentinel_kwargs'] = { REDIS_LAYERS_HOST['sentinel_kwargs'] = {
'password': REDIS_SENTINEL_PASSWORD, 'password': REDIS_SENTINEL_PASSWORD,
@ -111,7 +111,7 @@ else:
# More info see: https://github.com/django/channels_redis/issues/334 # More info see: https://github.com/django/channels_redis/issues/334
# REDIS_LAYERS_HOST['address'] = (CONFIG.REDIS_HOST, CONFIG.REDIS_PORT) # REDIS_LAYERS_HOST['address'] = (CONFIG.REDIS_HOST, CONFIG.REDIS_PORT)
REDIS_LAYERS_ADDRESS = '{protocol}://:{password}@{host}:{port}/{db}'.format( REDIS_LAYERS_ADDRESS = '{protocol}://:{password}@{host}:{port}/{db}'.format(
protocol=REDIS_PROTOCOL, password=CONFIG.REDIS_PASSWORD, protocol=REDIS_PROTOCOL, password=CONFIG.REDIS_PASSWORD_QUOTE,
host=CONFIG.REDIS_HOST, port=CONFIG.REDIS_PORT, db=CONFIG.REDIS_DB_WS host=CONFIG.REDIS_HOST, port=CONFIG.REDIS_PORT, db=CONFIG.REDIS_DB_WS
) )
REDIS_LAYERS_HOST['address'] = REDIS_LAYERS_ADDRESS REDIS_LAYERS_HOST['address'] = REDIS_LAYERS_ADDRESS
@ -153,7 +153,7 @@ if REDIS_SENTINEL_SERVICE_NAME and REDIS_SENTINELS:
else: else:
CELERY_BROKER_URL = CELERY_BROKER_URL_FORMAT % { CELERY_BROKER_URL = CELERY_BROKER_URL_FORMAT % {
'protocol': REDIS_PROTOCOL, 'protocol': REDIS_PROTOCOL,
'password': CONFIG.REDIS_PASSWORD, 'password': CONFIG.REDIS_PASSWORD_QUOTE,
'host': CONFIG.REDIS_HOST, 'host': CONFIG.REDIS_HOST,
'port': CONFIG.REDIS_PORT, 'port': CONFIG.REDIS_PORT,
'db': CONFIG.REDIS_DB_CELERY, 'db': CONFIG.REDIS_DB_CELERY,
@ -187,6 +187,7 @@ ANSIBLE_LOG_DIR = os.path.join(PROJECT_DIR, 'data', 'ansible')
REDIS_HOST = CONFIG.REDIS_HOST REDIS_HOST = CONFIG.REDIS_HOST
REDIS_PORT = CONFIG.REDIS_PORT REDIS_PORT = CONFIG.REDIS_PORT
REDIS_PASSWORD = CONFIG.REDIS_PASSWORD REDIS_PASSWORD = CONFIG.REDIS_PASSWORD
REDIS_PASSWORD_QUOTE = CONFIG.REDIS_PASSWORD_QUOTE
DJANGO_REDIS_SCAN_ITERSIZE = 1000 DJANGO_REDIS_SCAN_ITERSIZE = 1000

View File

@ -1,6 +1,6 @@
from django.contrib.contenttypes.fields import GenericRelation from django.contrib.contenttypes.fields import GenericRelation
from django.db import models from django.db import models
from django.db.models import OneToOneField from django.db.models import OneToOneField, Count
from common.utils import lazyproperty from common.utils import lazyproperty
from .models import LabeledResource from .models import LabeledResource
@ -36,3 +36,37 @@ class LabeledMixin(models.Model):
@res_labels.setter @res_labels.setter
def res_labels(self, value): def res_labels(self, value):
self.real.labels.set(value, bulk=False) self.real.labels.set(value, bulk=False)
@classmethod
def filter_resources_by_labels(cls, resources, label_ids):
return cls._get_filter_res_by_labels_m2m_all(resources, label_ids)
@classmethod
def _get_filter_res_by_labels_m2m_in(cls, resources, label_ids):
return resources.filter(label_id__in=label_ids)
@classmethod
def _get_filter_res_by_labels_m2m_all(cls, resources, label_ids):
if len(label_ids) == 1:
return cls._get_filter_res_by_labels_m2m_in(resources, label_ids)
resources = resources.filter(label_id__in=label_ids) \
.values('res_id') \
.order_by('res_id') \
.annotate(count=Count('res_id', distinct=True)) \
.values('res_id', 'count') \
.filter(count=len(label_ids))
return resources
@classmethod
def get_labels_filter_attr_q(cls, value, match):
resources = LabeledResource.objects.all()
if not value:
return None
if match != 'm2m_all':
resources = cls._get_filter_res_by_labels_m2m_in(resources, value)
else:
resources = cls._get_filter_res_by_labels_m2m_all(resources, value)
res_ids = set(resources.values_list('res_id', flat=True))
return models.Q(id__in=res_ids)

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:7879f4eeb499e920ad6c4bfdb0b1f334936ca344c275be056f12fcf7485f2bf6 oid sha256:d04781f4f0b0de3ac5f707febb222e239553d6103bca0cec41ab2fd5ab044571
size 170948 size 173799

File diff suppressed because it is too large Load Diff

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:19d3a111cc245f9a9d36b860fd95447df916ad66c918bef672bacdad6bc77a8f oid sha256:e66a6fa05d25f1c502f95001b5ff0d0a310affd32eac939fd7b840845028074f
size 140119 size 142298

File diff suppressed because it is too large Load Diff

View File

@ -1,28 +1,32 @@
import json import json
import time
from threading import Thread
from channels.generic.websocket import JsonWebsocketConsumer from channels.generic.websocket import JsonWebsocketConsumer
from django.core.cache import cache from django.conf import settings
from common.db.utils import safe_db_connection from common.db.utils import safe_db_connection
from common.sessions.cache import user_session_manager
from common.utils import get_logger from common.utils import get_logger
from .signal_handlers import new_site_msg_chan from .signal_handlers import new_site_msg_chan
from .site_msg import SiteMessageUtil from .site_msg import SiteMessageUtil
logger = get_logger(__name__) logger = get_logger(__name__)
WS_SESSION_KEY = 'ws_session_key'
class SiteMsgWebsocket(JsonWebsocketConsumer): class SiteMsgWebsocket(JsonWebsocketConsumer):
sub = None sub = None
refresh_every_seconds = 10 refresh_every_seconds = 10
@property
def session(self):
return self.scope['session']
def connect(self): def connect(self):
user = self.scope["user"] user = self.scope["user"]
if user.is_authenticated: if user.is_authenticated:
self.accept() self.accept()
session = self.scope['session'] user_session_manager.add_or_increment(self.session.session_key)
redis_client = cache.client.get_client()
redis_client.sadd(WS_SESSION_KEY, session.session_key)
self.sub = self.watch_recv_new_site_msg() self.sub = self.watch_recv_new_site_msg()
else: else:
self.close() self.close()
@ -66,6 +70,32 @@ class SiteMsgWebsocket(JsonWebsocketConsumer):
if not self.sub: if not self.sub:
return return
self.sub.unsubscribe() self.sub.unsubscribe()
session = self.scope['session']
redis_client = cache.client.get_client() user_session_manager.decrement_or_remove(self.session.session_key)
redis_client.srem(WS_SESSION_KEY, session.session_key) if self.should_delete_session():
thread = Thread(target=self.delay_delete_session)
thread.start()
def should_delete_session(self):
return (self.session.modified or settings.SESSION_SAVE_EVERY_REQUEST) and \
not self.session.is_empty() and \
self.session.get_expire_at_browser_close() and \
not user_session_manager.check_active(self.session.session_key)
def delay_delete_session(self):
timeout = 6
check_interval = 0.5
start_time = time.time()
while time.time() - start_time < timeout:
time.sleep(check_interval)
if user_session_manager.check_active(self.session.session_key):
return
self.delete_session()
def delete_session(self):
try:
self.session.delete()
except Exception as e:
logger.info(f'delete session error: {e}')

View File

@ -1,3 +1,4 @@
import os
from collections import defaultdict from collections import defaultdict
from functools import reduce from functools import reduce
@ -29,6 +30,8 @@ class DefaultCallback:
) )
self.status = 'running' self.status = 'running'
self.finished = False self.finished = False
self.local_pid = 0
self.private_data_dir = None
@property @property
def host_results(self): def host_results(self):
@ -45,6 +48,9 @@ class DefaultCallback:
event = data.get('event', None) event = data.get('event', None)
if not event: if not event:
return return
pid = data.get('pid', None)
if pid:
self.write_pid(pid)
event_data = data.get('event_data', {}) event_data = data.get('event_data', {})
host = event_data.get('remote_addr', '') host = event_data.get('remote_addr', '')
task = event_data.get('task', '') task = event_data.get('task', '')
@ -152,3 +158,11 @@ class DefaultCallback:
def status_handler(self, data, **kwargs): def status_handler(self, data, **kwargs):
status = data.get('status', '') status = data.get('status', '')
self.status = self.STATUS_MAPPER.get(status, 'unknown') self.status = self.STATUS_MAPPER.get(status, 'unknown')
rc = kwargs.get('runner_config', None)
self.private_data_dir = rc.private_data_dir if rc else '/tmp/'
def write_pid(self, pid):
pid_filepath = os.path.join(self.private_data_dir, 'local.pid')
with open(pid_filepath, 'w') as f:
f.write(str(pid))

View File

@ -2,6 +2,7 @@
# #
import os import os
import re import re
from collections import defaultdict
from celery.result import AsyncResult from celery.result import AsyncResult
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
@ -166,16 +167,58 @@ class CeleryTaskViewSet(
i.next_exec_time = now + next_run_at i.next_exec_time = now + next_run_at
return queryset return queryset
def generate_summary_state(self, execution_qs):
model = self.get_queryset().model
executions = execution_qs.order_by('-date_published').values('name', 'state')
summary_state_dict = defaultdict(
lambda: {
'states': [], 'state': 'green',
'summary': {'total': 0, 'success': 0}
}
)
for execution in executions:
name = execution['name']
state = execution['state']
summary = summary_state_dict[name]['summary']
summary['total'] += 1
summary['success'] += 1 if state == 'SUCCESS' else 0
states = summary_state_dict[name].get('states')
if states is not None and len(states) >= 5:
color = model.compute_state_color(states)
summary_state_dict[name]['state'] = color
summary_state_dict[name].pop('states', None)
elif isinstance(states, list):
states.append(state)
return summary_state_dict
def loading_summary_state(self, queryset):
if isinstance(queryset, list):
names = [i.name for i in queryset]
execution_qs = CeleryTaskExecution.objects.filter(name__in=names)
else:
execution_qs = CeleryTaskExecution.objects.all()
summary_state_dict = self.generate_summary_state(execution_qs)
for i in queryset:
i.summary = summary_state_dict.get(i.name, {}).get('summary', {})
i.state = summary_state_dict.get(i.name, {}).get('state', 'green')
return queryset
def list(self, request, *args, **kwargs): def list(self, request, *args, **kwargs):
queryset = self.filter_queryset(self.get_queryset()) queryset = self.filter_queryset(self.get_queryset())
page = self.paginate_queryset(queryset) page = self.paginate_queryset(queryset)
if page is not None: if page is not None:
page = self.generate_execute_time(page) page = self.generate_execute_time(page)
page = self.loading_summary_state(page)
serializer = self.get_serializer(page, many=True) serializer = self.get_serializer(page, many=True)
return self.get_paginated_response(serializer.data) return self.get_paginated_response(serializer.data)
queryset = self.generate_execute_time(queryset) queryset = self.generate_execute_time(queryset)
queryset = self.loading_summary_state(queryset)
serializer = self.get_serializer(queryset, many=True) serializer = self.get_serializer(queryset, many=True)
return Response(serializer.data) return Response(serializer.data)

View File

@ -1,9 +1,11 @@
import json import json
import os import os
from celery.result import AsyncResult
from django.conf import settings from django.conf import settings
from django.db import transaction from django.db import transaction
from django.db.models import Count from django.db.models import Count
from django.http import Http404
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
from django.utils._os import safe_join from django.utils._os import safe_join
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
@ -14,9 +16,10 @@ from rest_framework.views import APIView
from assets.models import Asset from assets.models import Asset
from common.const.http import POST from common.const.http import POST
from common.permissions import IsValidUser from common.permissions import IsValidUser
from ops.celery import app
from ops.const import Types from ops.const import Types
from ops.models import Job, JobExecution from ops.models import Job, JobExecution
from ops.serializers.job import JobSerializer, JobExecutionSerializer, FileSerializer from ops.serializers.job import JobSerializer, JobExecutionSerializer, FileSerializer, JobTaskStopSerializer
__all__ = [ __all__ = [
'JobViewSet', 'JobExecutionViewSet', 'JobRunVariableHelpAPIView', 'JobViewSet', 'JobExecutionViewSet', 'JobRunVariableHelpAPIView',
@ -187,6 +190,33 @@ class JobExecutionViewSet(OrgBulkModelViewSet):
queryset = queryset.filter(creator=self.request.user) queryset = queryset.filter(creator=self.request.user)
return queryset return queryset
@action(methods=[POST], detail=False, serializer_class=JobTaskStopSerializer, permission_classes=[IsValidUser, ],
url_path='stop')
def stop(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
if not serializer.is_valid():
return Response({'error': serializer.errors}, status=400)
task_id = serializer.validated_data['task_id']
try:
instance = get_object_or_404(JobExecution, task_id=task_id, creator=request.user)
except Http404:
return Response(
{'error': _('The task is being created and cannot be interrupted. Please try again later.')},
status=400
)
task = AsyncResult(task_id, app=app)
inspect = app.control.inspect()
for worker in inspect.registered().keys():
if task_id not in [at['id'] for at in inspect.active().get(worker, [])]:
# 在队列中未执行使用revoke执行
task.revoke(terminate=True)
instance.set_error('Job stop by "revoke task {}"'.format(task_id))
return Response({'task_id': task_id}, status=200)
instance.stop()
return Response({'task_id': task_id}, status=200)
class JobAssetDetail(APIView): class JobAssetDetail(APIView):
rbac_perms = { rbac_perms = {

View File

@ -15,6 +15,9 @@ class CeleryTask(models.Model):
name = models.CharField(max_length=1024, verbose_name=_('Name')) name = models.CharField(max_length=1024, verbose_name=_('Name'))
date_last_publish = models.DateTimeField(null=True, verbose_name=_("Date last publish")) date_last_publish = models.DateTimeField(null=True, verbose_name=_("Date last publish"))
__summary = None
__state = None
@property @property
def meta(self): def meta(self):
task = app.tasks.get(self.name, None) task = app.tasks.get(self.name, None)
@ -25,23 +28,43 @@ class CeleryTask(models.Model):
@property @property
def summary(self): def summary(self):
if self.__summary is not None:
return self.__summary
executions = CeleryTaskExecution.objects.filter(name=self.name) executions = CeleryTaskExecution.objects.filter(name=self.name)
total = executions.count() total = executions.count()
success = executions.filter(state='SUCCESS').count() success = executions.filter(state='SUCCESS').count()
return {'total': total, 'success': success} return {'total': total, 'success': success}
@summary.setter
def summary(self, value):
self.__summary = value
@staticmethod
def compute_state_color(states: list, default_count=5):
color = 'green'
states = states[:default_count]
if not states:
return color
if states[0] == 'FAILURE':
color = 'red'
elif 'FAILURE' in states:
color = 'yellow'
return color
@property @property
def state(self): def state(self):
last_five_executions = CeleryTaskExecution.objects.filter(name=self.name).order_by('-date_published')[:5] if self.__state is not None:
return self.__state
last_five_executions = CeleryTaskExecution.objects.filter(
name=self.name
).order_by('-date_published').values('state')[:5]
states = [i['state'] for i in last_five_executions]
color = self.compute_state_color(states)
return color
if len(last_five_executions) > 0: @state.setter
if last_five_executions[0].state == 'FAILURE': def state(self, value):
return "red" self.__state = value
for execution in last_five_executions:
if execution.state == 'FAILURE':
return "yellow"
return "green"
class Meta: class Meta:
verbose_name = _("Celery Task") verbose_name = _("Celery Task")

View File

@ -67,6 +67,7 @@ class JMSPermedInventory(JMSInventory):
'postgresql': ['postgresql'], 'postgresql': ['postgresql'],
'sqlserver': ['sqlserver'], 'sqlserver': ['sqlserver'],
'ssh': ['shell', 'python', 'win_shell', 'raw'], 'ssh': ['shell', 'python', 'win_shell', 'raw'],
'winrm': ['win_shell', 'shell'],
} }
if self.module not in protocol_supported_modules_mapping.get(protocol.name, []): if self.module not in protocol_supported_modules_mapping.get(protocol.name, []):
@ -553,6 +554,15 @@ class JobExecution(JMSOrgBaseModel):
finally: finally:
ssh_tunnel.local_gateway_clean(runner) ssh_tunnel.local_gateway_clean(runner)
def stop(self):
with open(os.path.join(self.private_dir, 'local.pid')) as f:
try:
pid = f.read()
os.kill(int(pid), 9)
except Exception as e:
print(e)
self.set_error('Job stop by "kill -9 {}"'.format(pid))
class Meta: class Meta:
verbose_name = _("Job Execution") verbose_name = _("Job Execution")
ordering = ['-date_created'] ordering = ['-date_created']

View File

@ -57,6 +57,13 @@ class FileSerializer(serializers.Serializer):
ref_name = "JobFileSerializer" ref_name = "JobFileSerializer"
class JobTaskStopSerializer(serializers.Serializer):
task_id = serializers.CharField(max_length=128)
class Meta:
ref_name = "JobTaskStopSerializer"
class JobExecutionSerializer(BulkOrgResourceModelSerializer): class JobExecutionSerializer(BulkOrgResourceModelSerializer):
creator = ReadableHiddenField(default=serializers.CurrentUserDefault()) creator = ReadableHiddenField(default=serializers.CurrentUserDefault())
job_type = serializers.ReadOnlyField(label=_("Job type")) job_type = serializers.ReadOnlyField(label=_("Job type"))

View File

@ -173,6 +173,9 @@ class Organization(OrgRoleMixin, JMSBaseModel):
def is_default(self): def is_default(self):
return str(self.id) == self.DEFAULT_ID return str(self.id) == self.DEFAULT_ID
def is_system(self):
return str(self.id) == self.SYSTEM_ID
@property @property
def internal(self): def internal(self):
return str(self.id) in self.INTERNAL_IDS return str(self.id) in self.INTERNAL_IDS

View File

@ -87,7 +87,8 @@ class OrgResourceStatisticsRefreshUtil:
if not cache_field_name: if not cache_field_name:
return return
org = getattr(instance, 'org', None) org = getattr(instance, 'org', None)
cls.refresh_org_fields(((org, cache_field_name),)) cache_field_name = tuple(cache_field_name)
cls.refresh_org_fields.delay(org_fields=((org, cache_field_name),))
@receiver(post_save) @receiver(post_save)

View File

@ -6,6 +6,7 @@ from functools import wraps
from inspect import signature from inspect import signature
from werkzeug.local import LocalProxy from werkzeug.local import LocalProxy
from django.conf import settings
from common.local import thread_local from common.local import thread_local
from .models import Organization from .models import Organization
@ -14,7 +15,6 @@ from .models import Organization
def get_org_from_request(request): def get_org_from_request(request):
# query中优先级最高 # query中优先级最高
oid = request.GET.get("oid") oid = request.GET.get("oid")
# 其次header # 其次header
if not oid: if not oid:
oid = request.META.get("HTTP_X_JMS_ORG") oid = request.META.get("HTTP_X_JMS_ORG")
@ -24,14 +24,33 @@ def get_org_from_request(request):
# 其次session # 其次session
if not oid: if not oid:
oid = request.session.get("oid") oid = request.session.get("oid")
if oid and oid.lower() == 'default':
return Organization.default()
if oid and oid.lower() == 'root':
return Organization.root()
if oid and oid.lower() == 'system':
return Organization.system()
org = Organization.get_instance(oid)
if org and org.internal:
# 内置组织直接返回
return org
if not settings.XPACK_ENABLED:
# 社区版用户只能使用默认组织
return Organization.default()
if not org and request.user.is_authenticated:
# 企业版用户优先从自己有权限的组织中获取
org = request.user.orgs.first()
if not org:
org = Organization.default()
if not oid:
oid = Organization.DEFAULT_ID
if oid.lower() == "default":
oid = Organization.DEFAULT_ID
elif oid.lower() == "root":
oid = Organization.ROOT_ID
org = Organization.get_instance(oid, default=Organization.default())
return org return org

View File

@ -1,9 +1,11 @@
import abc import abc
from django.conf import settings
from rest_framework.generics import ListAPIView, RetrieveAPIView from rest_framework.generics import ListAPIView, RetrieveAPIView
from assets.api.asset.asset import AssetFilterSet from assets.api.asset.asset import AssetFilterSet
from assets.models import Asset, Node from assets.models import Asset, Node
from common.api.mixin import ExtraFilterFieldsMixin
from common.utils import get_logger, lazyproperty, is_uuid from common.utils import get_logger, lazyproperty, is_uuid
from orgs.utils import tmp_to_root_org from orgs.utils import tmp_to_root_org
from perms import serializers from perms import serializers
@ -37,8 +39,8 @@ class UserPermedAssetRetrieveApi(SelfOrPKUserMixin, RetrieveAPIView):
return asset return asset
class BaseUserPermedAssetsApi(SelfOrPKUserMixin, ListAPIView): class BaseUserPermedAssetsApi(SelfOrPKUserMixin, ExtraFilterFieldsMixin, ListAPIView):
ordering = ('name',) ordering = []
search_fields = ('name', 'address', 'comment') search_fields = ('name', 'address', 'comment')
ordering_fields = ("name", "address") ordering_fields = ("name", "address")
filterset_class = AssetFilterSet filterset_class = AssetFilterSet
@ -47,6 +49,8 @@ class BaseUserPermedAssetsApi(SelfOrPKUserMixin, ListAPIView):
def get_queryset(self): def get_queryset(self):
if getattr(self, 'swagger_fake_view', False): if getattr(self, 'swagger_fake_view', False):
return Asset.objects.none() return Asset.objects.none()
if settings.ASSET_SIZE == 'small':
self.ordering = ['name']
assets = self.get_assets() assets = self.get_assets()
assets = self.serializer_class.setup_eager_loading(assets) assets = self.serializer_class.setup_eager_loading(assets)
return assets return assets

View File

@ -8,9 +8,9 @@ from rest_framework import serializers
from accounts.models import Account from accounts.models import Account
from assets.const import Category, AllTypes from assets.const import Category, AllTypes
from assets.models import Node, Asset, Platform from assets.models import Node, Asset, Platform
from assets.serializers.asset.common import AssetLabelSerializer, AssetProtocolsPermsSerializer from assets.serializers.asset.common import AssetProtocolsPermsSerializer
from common.serializers.fields import ObjectRelatedField, LabeledChoiceField
from common.serializers import ResourceLabelsMixin from common.serializers import ResourceLabelsMixin
from common.serializers.fields import ObjectRelatedField, LabeledChoiceField
from orgs.mixins.serializers import OrgResourceModelSerializerMixin from orgs.mixins.serializers import OrgResourceModelSerializerMixin
from perms.serializers.permission import ActionChoicesField from perms.serializers.permission import ActionChoicesField

View File

@ -13,7 +13,7 @@ class AssetPermissionUtil(object):
""" 资产授权相关的方法工具 """ """ 资产授权相关的方法工具 """
@timeit @timeit
def get_permissions_for_user(self, user, with_group=True, flat=False): def get_permissions_for_user(self, user, with_group=True, flat=False, with_expired=False):
""" 获取用户的授权规则 """ """ 获取用户的授权规则 """
perm_ids = set() perm_ids = set()
# user # user
@ -25,7 +25,7 @@ class AssetPermissionUtil(object):
groups = user.groups.all() groups = user.groups.all()
group_perm_ids = self.get_permissions_for_user_groups(groups, flat=True) group_perm_ids = self.get_permissions_for_user_groups(groups, flat=True)
perm_ids.update(group_perm_ids) perm_ids.update(group_perm_ids)
perms = self.get_permissions(ids=perm_ids) perms = self.get_permissions(ids=perm_ids, with_expired=with_expired)
if flat: if flat:
return perms.values_list('id', flat=True) return perms.values_list('id', flat=True)
return perms return perms
@ -102,6 +102,8 @@ class AssetPermissionUtil(object):
return model.objects.filter(id__in=ids) return model.objects.filter(id__in=ids)
@staticmethod @staticmethod
def get_permissions(ids): def get_permissions(ids, with_expired=False):
perms = AssetPermission.objects.filter(id__in=ids).valid().order_by('-date_expired') perms = AssetPermission.objects.filter(id__in=ids)
return perms if not with_expired:
perms = perms.valid()
return perms.order_by('-date_expired')

View File

@ -7,10 +7,10 @@ from django.db.models import Q
from rest_framework.utils.encoders import JSONEncoder from rest_framework.utils.encoders import JSONEncoder
from assets.const import AllTypes from assets.const import AllTypes
from assets.models import FavoriteAsset, Asset from assets.models import FavoriteAsset, Asset, Node
from common.utils.common import timeit, get_logger from common.utils.common import timeit, get_logger
from orgs.utils import current_org, tmp_to_root_org from orgs.utils import current_org, tmp_to_root_org
from perms.models import PermNode, UserAssetGrantedTreeNodeRelation from perms.models import PermNode, UserAssetGrantedTreeNodeRelation, AssetPermission
from .permission import AssetPermissionUtil from .permission import AssetPermissionUtil
__all__ = ['AssetPermissionPermAssetUtil', 'UserPermAssetUtil', 'UserPermNodeUtil'] __all__ = ['AssetPermissionPermAssetUtil', 'UserPermAssetUtil', 'UserPermNodeUtil']
@ -21,38 +21,37 @@ logger = get_logger(__name__)
class AssetPermissionPermAssetUtil: class AssetPermissionPermAssetUtil:
def __init__(self, perm_ids): def __init__(self, perm_ids):
self.perm_ids = perm_ids self.perm_ids = set(perm_ids)
def get_all_assets(self): def get_all_assets(self):
node_assets = self.get_perm_nodes_assets() node_assets = self.get_perm_nodes_assets()
direct_assets = self.get_direct_assets() direct_assets = self.get_direct_assets()
# 比原来的查到所有 asset id 再搜索块很多,因为当资产量大的时候,搜索会很慢 # 比原来的查到所有 asset id 再搜索块很多,因为当资产量大的时候,搜索会很慢
return (node_assets | direct_assets).distinct() return (node_assets | direct_assets).order_by().distinct()
@timeit def get_perm_nodes(self):
def get_perm_nodes_assets(self, flat=False): """ 获取所有授权节点 """
""" 获取所有授权节点下的资产 """
from assets.models import Node
from ..models import AssetPermission
nodes_ids = AssetPermission.objects \ nodes_ids = AssetPermission.objects \
.filter(id__in=self.perm_ids) \ .filter(id__in=self.perm_ids) \
.values_list('nodes', flat=True) .values_list('nodes', flat=True)
nodes_ids = set(nodes_ids)
nodes = Node.objects.filter(id__in=nodes_ids).only('id', 'key') nodes = Node.objects.filter(id__in=nodes_ids).only('id', 'key')
assets = PermNode.get_nodes_all_assets(*nodes) return nodes
if flat:
return set(assets.values_list('id', flat=True)) @timeit
def get_perm_nodes_assets(self):
""" 获取所有授权节点下的资产 """
nodes = self.get_perm_nodes()
assets = PermNode.get_nodes_all_assets(*nodes, distinct=False)
return assets return assets
@timeit @timeit
def get_direct_assets(self, flat=False): def get_direct_assets(self):
""" 获取直接授权的资产 """ """ 获取直接授权的资产 """
from ..models import AssetPermission asset_ids = AssetPermission.assets.through.objects \
asset_ids = AssetPermission.objects \ .filter(assetpermission_id__in=self.perm_ids) \
.filter(id__in=self.perm_ids) \ .values_list('asset_id', flat=True)
.values_list('assets', flat=True) assets = Asset.objects.filter(id__in=asset_ids)
assets = Asset.objects.filter(id__in=asset_ids).distinct()
if flat:
return set(assets.values_list('id', flat=True))
return assets return assets

View File

@ -72,7 +72,7 @@ class UserPermTreeRefreshUtil(_UserPermTreeCacheMixin):
@timeit @timeit
def refresh_if_need(self, force=False): def refresh_if_need(self, force=False):
built_just_now = cache.get(self.cache_key_time) built_just_now = False if settings.ASSET_SIZE == 'small' else cache.get(self.cache_key_time)
if built_just_now: if built_just_now:
logger.info('Refresh user perm tree just now, pass: {}'.format(built_just_now)) logger.info('Refresh user perm tree just now, pass: {}'.format(built_just_now))
return return
@ -80,12 +80,18 @@ class UserPermTreeRefreshUtil(_UserPermTreeCacheMixin):
if not to_refresh_orgs: if not to_refresh_orgs:
logger.info('Not have to refresh orgs') logger.info('Not have to refresh orgs')
return return
logger.info("Delay refresh user orgs: {} {}".format(self.user, [o.name for o in to_refresh_orgs])) logger.info("Delay refresh user orgs: {} {}".format(self.user, [o.name for o in to_refresh_orgs]))
refresh_user_orgs_perm_tree(user_orgs=((self.user, tuple(to_refresh_orgs)),)) sync = True if settings.ASSET_SIZE == 'small' else False
refresh_user_favorite_assets(users=(self.user,)) refresh_user_orgs_perm_tree.apply(sync=sync, user_orgs=((self.user, tuple(to_refresh_orgs)),))
refresh_user_favorite_assets.apply(sync=sync, users=(self.user,))
@timeit @timeit
def refresh_tree_manual(self): def refresh_tree_manual(self):
"""
用来手动 debug
:return:
"""
built_just_now = cache.get(self.cache_key_time) built_just_now = cache.get(self.cache_key_time)
if built_just_now: if built_just_now:
logger.info('Refresh just now, pass: {}'.format(built_just_now)) logger.info('Refresh just now, pass: {}'.format(built_just_now))
@ -105,8 +111,9 @@ class UserPermTreeRefreshUtil(_UserPermTreeCacheMixin):
return return
self._clean_user_perm_tree_for_legacy_org() self._clean_user_perm_tree_for_legacy_org()
ttl = settings.PERM_TREE_REGEN_INTERVAL if settings.ASSET_SIZE != 'small':
cache.set(self.cache_key_time, int(time.time()), ttl) ttl = settings.PERM_TREE_REGEN_INTERVAL
cache.set(self.cache_key_time, int(time.time()), ttl)
lock = UserGrantedTreeRebuildLock(self.user.id) lock = UserGrantedTreeRebuildLock(self.user.id)
got = lock.acquire(blocking=False) got = lock.acquire(blocking=False)
@ -187,13 +194,20 @@ class UserPermTreeExpireUtil(_UserPermTreeCacheMixin):
@on_transaction_commit @on_transaction_commit
def expire_perm_tree_for_users_orgs(self, user_ids, org_ids): def expire_perm_tree_for_users_orgs(self, user_ids, org_ids):
user_ids = list(user_ids)
org_ids = [str(oid) for oid in org_ids] org_ids = [str(oid) for oid in org_ids]
with self.client.pipeline() as p: with self.client.pipeline() as p:
for uid in user_ids: for uid in user_ids:
cache_key = self.get_cache_key(uid) cache_key = self.get_cache_key(uid)
p.srem(cache_key, *org_ids) p.srem(cache_key, *org_ids)
p.execute() p.execute()
logger.info('Expire perm tree for users: [{}], orgs: [{}]'.format(user_ids, org_ids)) users_display = ','.join([str(i) for i in user_ids[:3]])
if len(user_ids) > 3:
users_display += '...'
orgs_display = ','.join([str(i) for i in org_ids[:3]])
if len(org_ids) > 3:
orgs_display += '...'
logger.info('Expire perm tree for users: [{}], orgs: [{}]'.format(users_display, orgs_display))
def expire_perm_tree_for_all_user(self): def expire_perm_tree_for_all_user(self):
keys = self.client.keys(self.cache_key_all_user) keys = self.client.keys(self.cache_key_all_user)

View File

@ -1,28 +1,16 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
#
import threading
from django.conf import settings
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from rest_framework import generics from rest_framework import generics
from rest_framework.generics import CreateAPIView from rest_framework.views import Response
from rest_framework.views import Response, APIView
from common.api import AsyncApiMixin
from common.utils import get_logger from common.utils import get_logger
from orgs.models import Organization
from orgs.utils import current_org
from users.models import User from users.models import User
from ..models import Setting from ..models import Setting
from ..serializers import ( from ..serializers import LDAPUserSerializer
LDAPTestConfigSerializer, LDAPUserSerializer,
LDAPTestLoginSerializer
)
from ..tasks import sync_ldap_user
from ..utils import ( from ..utils import (
LDAPServerUtil, LDAPCacheUtil, LDAPImportUtil, LDAPSyncUtil, LDAPServerUtil, LDAPCacheUtil,
LDAP_USE_CACHE_FLAGS, LDAPTestUtil LDAP_USE_CACHE_FLAGS
) )
logger = get_logger(__file__) logger = get_logger(__file__)
@ -100,49 +88,3 @@ class LDAPUserListApi(generics.ListAPIView):
else: else:
data = {'msg': _('Users are not synchronized, please click the user synchronization button')} data = {'msg': _('Users are not synchronized, please click the user synchronization button')}
return Response(data=data, status=400) return Response(data=data, status=400)
class LDAPUserImportAPI(APIView):
perm_model = Setting
rbac_perms = {
'POST': 'settings.change_auth'
}
def get_orgs(self):
org_ids = self.request.data.get('org_ids')
if org_ids:
orgs = list(Organization.objects.filter(id__in=org_ids))
else:
orgs = [current_org]
return orgs
def get_ldap_users(self):
username_list = self.request.data.get('username_list', [])
cache_police = self.request.query_params.get('cache_police', True)
if '*' in username_list:
users = LDAPServerUtil().search()
elif cache_police in LDAP_USE_CACHE_FLAGS:
users = LDAPCacheUtil().search(search_users=username_list)
else:
users = LDAPServerUtil().search(search_users=username_list)
return users
def post(self, request):
try:
users = self.get_ldap_users()
except Exception as e:
return Response({'error': str(e)}, status=400)
if users is None:
return Response({'msg': _('Get ldap users is None')}, status=400)
orgs = self.get_orgs()
new_users, errors = LDAPImportUtil().perform_import(users, orgs)
if errors:
return Response({'errors': errors}, status=400)
count = users if users is None else len(users)
orgs_name = ', '.join([str(org) for org in orgs])
return Response({
'msg': _('Imported {} users successfully (Organization: {})').format(count, orgs_name)
})

View File

@ -3,6 +3,7 @@ from rest_framework import generics
from rest_framework.permissions import AllowAny from rest_framework.permissions import AllowAny
from authentication.permissions import IsValidUserOrConnectionToken from authentication.permissions import IsValidUserOrConnectionToken
from common.const.choices import COUNTRY_CALLING_CODES
from common.utils import get_logger, lazyproperty from common.utils import get_logger, lazyproperty
from common.utils.timezone import local_now from common.utils.timezone import local_now
from .. import serializers from .. import serializers
@ -24,7 +25,8 @@ class OpenPublicSettingApi(generics.RetrieveAPIView):
def get_object(self): def get_object(self):
return { return {
"XPACK_ENABLED": settings.XPACK_ENABLED, "XPACK_ENABLED": settings.XPACK_ENABLED,
"INTERFACE": self.interface_setting "INTERFACE": self.interface_setting,
"COUNTRY_CALLING_CODES": COUNTRY_CALLING_CODES
} }

View File

@ -43,7 +43,7 @@ class OAuth2SettingSerializer(serializers.Serializer):
) )
AUTH_OAUTH2_ACCESS_TOKEN_METHOD = serializers.ChoiceField( AUTH_OAUTH2_ACCESS_TOKEN_METHOD = serializers.ChoiceField(
default='GET', label=_('Client authentication method'), default='GET', label=_('Client authentication method'),
choices=(('GET', 'GET'), ('POST', 'POST')) choices=(('GET', 'GET'), ('POST', 'POST-DATA'), ('POST_JSON', 'POST-JSON'))
) )
AUTH_OAUTH2_PROVIDER_USERINFO_ENDPOINT = serializers.CharField( AUTH_OAUTH2_PROVIDER_USERINFO_ENDPOINT = serializers.CharField(
required=True, max_length=1024, label=_('Provider userinfo endpoint') required=True, max_length=1024, label=_('Provider userinfo endpoint')

View File

@ -22,6 +22,10 @@ class CleaningSerializer(serializers.Serializer):
min_value=MIN_VALUE, max_value=9999, min_value=MIN_VALUE, max_value=9999,
label=_("Operate log keep days (day)"), label=_("Operate log keep days (day)"),
) )
PASSWORD_CHANGE_LOG_KEEP_DAYS = serializers.IntegerField(
min_value=MIN_VALUE, max_value=9999,
label=_("password change log keep days (day)"),
)
FTP_LOG_KEEP_DAYS = serializers.IntegerField( FTP_LOG_KEEP_DAYS = serializers.IntegerField(
min_value=MIN_VALUE, max_value=9999, min_value=MIN_VALUE, max_value=9999,
label=_("FTP log keep days (day)"), label=_("FTP log keep days (day)"),

View File

@ -109,6 +109,7 @@ class TicketSettingSerializer(serializers.Serializer):
PREFIX_TITLE = _('Ticket') PREFIX_TITLE = _('Ticket')
TICKETS_ENABLED = serializers.BooleanField(required=False, default=True, label=_("Enable tickets")) TICKETS_ENABLED = serializers.BooleanField(required=False, default=True, label=_("Enable tickets"))
TICKETS_DIRECT_APPROVE = serializers.BooleanField(required=False, default=False, label=_("No login approval"))
TICKET_AUTHORIZE_DEFAULT_TIME = serializers.IntegerField( TICKET_AUTHORIZE_DEFAULT_TIME = serializers.IntegerField(
min_value=1, max_value=999999, required=False, min_value=1, max_value=999999, required=False,
label=_("Ticket authorize default time") label=_("Ticket authorize default time")

View File

@ -11,6 +11,7 @@ __all__ = [
class PublicSettingSerializer(serializers.Serializer): class PublicSettingSerializer(serializers.Serializer):
XPACK_ENABLED = serializers.BooleanField() XPACK_ENABLED = serializers.BooleanField()
INTERFACE = serializers.DictField() INTERFACE = serializers.DictField()
COUNTRY_CALLING_CODES = serializers.ListField()
class PrivateSettingSerializer(PublicSettingSerializer): class PrivateSettingSerializer(PublicSettingSerializer):
@ -50,6 +51,7 @@ class PrivateSettingSerializer(PublicSettingSerializer):
ANNOUNCEMENT = serializers.DictField() ANNOUNCEMENT = serializers.DictField()
TICKETS_ENABLED = serializers.BooleanField() TICKETS_ENABLED = serializers.BooleanField()
TICKETS_DIRECT_APPROVE = serializers.BooleanField()
CONNECTION_TOKEN_REUSABLE = serializers.BooleanField() CONNECTION_TOKEN_REUSABLE = serializers.BooleanField()
CACHE_LOGIN_PASSWORD_ENABLED = serializers.BooleanField() CACHE_LOGIN_PASSWORD_ENABLED = serializers.BooleanField()
VAULT_ENABLED = serializers.BooleanField() VAULT_ENABLED = serializers.BooleanField()

View File

@ -14,9 +14,13 @@
</ul> </ul>
<b>{% trans "Synced User" %}:</b> <b>{% trans "Synced User" %}:</b>
<ul> <ul>
{% for user in users %} {% if users %}
<li>{{ user }}</li> {% for user in users %}
{% endfor %} <li>{{ user }}</li>
{% endfor %}
{% else %}
<li>{% trans 'No user synchronization required' %}</li>
{% endif %}
</ul> </ul>
{% if errors %} {% if errors %}
<b>{% trans 'Error' %}:</b> <b>{% trans 'Error' %}:</b>

View File

@ -12,7 +12,6 @@ router.register(r'chatai-prompts', api.ChatPromptViewSet, 'chatai-prompt')
urlpatterns = [ urlpatterns = [
path('mail/testing/', api.MailTestingAPI.as_view(), name='mail-testing'), path('mail/testing/', api.MailTestingAPI.as_view(), name='mail-testing'),
path('ldap/users/', api.LDAPUserListApi.as_view(), name='ldap-user-list'), path('ldap/users/', api.LDAPUserListApi.as_view(), name='ldap-user-list'),
path('ldap/users/import/', api.LDAPUserImportAPI.as_view(), name='ldap-user-import'),
path('wecom/testing/', api.WeComTestingAPI.as_view(), name='wecom-testing'), path('wecom/testing/', api.WeComTestingAPI.as_view(), name='wecom-testing'),
path('dingtalk/testing/', api.DingTalkTestingAPI.as_view(), name='dingtalk-testing'), path('dingtalk/testing/', api.DingTalkTestingAPI.as_view(), name='dingtalk-testing'),
path('feishu/testing/', api.FeiShuTestingAPI.as_view(), name='feishu-testing'), path('feishu/testing/', api.FeiShuTestingAPI.as_view(), name='feishu-testing'),

View File

@ -6,6 +6,7 @@ import asyncio
from channels.generic.websocket import AsyncJsonWebsocketConsumer from channels.generic.websocket import AsyncJsonWebsocketConsumer
from django.core.cache import cache from django.core.cache import cache
from django.conf import settings from django.conf import settings
from django.utils.translation import gettext_lazy as _
from common.db.utils import close_old_connections from common.db.utils import close_old_connections
from common.utils import get_logger from common.utils import get_logger
@ -13,9 +14,12 @@ from settings.serializers import (
LDAPTestConfigSerializer, LDAPTestConfigSerializer,
LDAPTestLoginSerializer LDAPTestLoginSerializer
) )
from orgs.models import Organization
from orgs.utils import current_org
from settings.tasks import sync_ldap_user from settings.tasks import sync_ldap_user
from settings.utils import ( from settings.utils import (
LDAPSyncUtil, LDAPTestUtil LDAPServerUtil, LDAPCacheUtil, LDAPImportUtil, LDAPSyncUtil,
LDAP_USE_CACHE_FLAGS, LDAPTestUtil
) )
from .tools import ( from .tools import (
verbose_ping, verbose_telnet, verbose_nmap, verbose_ping, verbose_telnet, verbose_nmap,
@ -27,9 +31,11 @@ logger = get_logger(__name__)
CACHE_KEY_LDAP_TEST_CONFIG_MSG = 'CACHE_KEY_LDAP_TEST_CONFIG_MSG' CACHE_KEY_LDAP_TEST_CONFIG_MSG = 'CACHE_KEY_LDAP_TEST_CONFIG_MSG'
CACHE_KEY_LDAP_TEST_LOGIN_MSG = 'CACHE_KEY_LDAP_TEST_LOGIN_MSG' CACHE_KEY_LDAP_TEST_LOGIN_MSG = 'CACHE_KEY_LDAP_TEST_LOGIN_MSG'
CACHE_KEY_LDAP_SYNC_USER_MSG = 'CACHE_KEY_LDAP_SYNC_USER_MSG' CACHE_KEY_LDAP_SYNC_USER_MSG = 'CACHE_KEY_LDAP_SYNC_USER_MSG'
CACHE_KEY_LDAP_IMPORT_USER_MSG = 'CACHE_KEY_LDAP_IMPORT_USER_MSG'
CACHE_KEY_LDAP_TEST_CONFIG_TASK_STATUS = 'CACHE_KEY_LDAP_TEST_CONFIG_TASK_STATUS' CACHE_KEY_LDAP_TEST_CONFIG_TASK_STATUS = 'CACHE_KEY_LDAP_TEST_CONFIG_TASK_STATUS'
CACHE_KEY_LDAP_TEST_LOGIN_TASK_STATUS = 'CACHE_KEY_LDAP_TEST_LOGIN_TASK_STATUS' CACHE_KEY_LDAP_TEST_LOGIN_TASK_STATUS = 'CACHE_KEY_LDAP_TEST_LOGIN_TASK_STATUS'
CACHE_KEY_LDAP_SYNC_USER_TASK_STATUS = 'CACHE_KEY_LDAP_SYNC_USER_TASK_STATUS' CACHE_KEY_LDAP_SYNC_USER_TASK_STATUS = 'CACHE_KEY_LDAP_SYNC_USER_TASK_STATUS'
CACHE_KEY_LDAP_IMPORT_USER_TASK_STATUS = 'CACHE_KEY_LDAP_IMPORT_USER_TASK_STATUS'
TASK_STATUS_IS_RUNNING = 'RUNNING' TASK_STATUS_IS_RUNNING = 'RUNNING'
TASK_STATUS_IS_OVER = 'OVER' TASK_STATUS_IS_OVER = 'OVER'
@ -117,6 +123,8 @@ class LdapWebsocket(AsyncJsonWebsocketConsumer):
ok, msg = cache.get(CACHE_KEY_LDAP_TEST_CONFIG_MSG) ok, msg = cache.get(CACHE_KEY_LDAP_TEST_CONFIG_MSG)
elif msg_type == 'sync_user': elif msg_type == 'sync_user':
ok, msg = cache.get(CACHE_KEY_LDAP_SYNC_USER_MSG) ok, msg = cache.get(CACHE_KEY_LDAP_SYNC_USER_MSG)
elif msg_type == 'import_user':
ok, msg = cache.get(CACHE_KEY_LDAP_IMPORT_USER_MSG)
else: else:
ok, msg = cache.get(CACHE_KEY_LDAP_TEST_LOGIN_MSG) ok, msg = cache.get(CACHE_KEY_LDAP_TEST_LOGIN_MSG)
await self.send_msg(ok, msg) await self.send_msg(ok, msg)
@ -165,8 +173,8 @@ class LdapWebsocket(AsyncJsonWebsocketConsumer):
cache.set(task_key, TASK_STATUS_IS_OVER, ttl) cache.set(task_key, TASK_STATUS_IS_OVER, ttl)
@staticmethod @staticmethod
def set_task_msg(task_key, ok, msg): def set_task_msg(task_key, ok, msg, ttl=120):
cache.set(task_key, (ok, msg), 120) cache.set(task_key, (ok, msg), ttl)
def run_testing_config(self, data): def run_testing_config(self, data):
while True: while True:
@ -207,3 +215,53 @@ class LdapWebsocket(AsyncJsonWebsocketConsumer):
ok = False if msg else True ok = False if msg else True
self.set_task_status_over(CACHE_KEY_LDAP_SYNC_USER_TASK_STATUS) self.set_task_status_over(CACHE_KEY_LDAP_SYNC_USER_TASK_STATUS)
self.set_task_msg(CACHE_KEY_LDAP_SYNC_USER_MSG, ok, msg) self.set_task_msg(CACHE_KEY_LDAP_SYNC_USER_MSG, ok, msg)
def run_import_user(self, data):
while True:
if self.task_is_over(CACHE_KEY_LDAP_IMPORT_USER_TASK_STATUS):
break
else:
ok, msg = self.import_user(data)
self.set_task_status_over(CACHE_KEY_LDAP_IMPORT_USER_TASK_STATUS, 3)
self.set_task_msg(CACHE_KEY_LDAP_IMPORT_USER_MSG, ok, msg, 3)
def import_user(self, data):
ok = False
org_ids = data.get('org_ids')
username_list = data.get('username_list', [])
cache_police = data.get('cache_police', True)
try:
users = self.get_ldap_users(username_list, cache_police)
if users is None:
msg = _('Get ldap users is None')
orgs = self.get_orgs(org_ids)
new_users, error_msg = LDAPImportUtil().perform_import(users, orgs)
if error_msg:
msg = error_msg
count = users if users is None else len(users)
orgs_name = ', '.join([str(org) for org in orgs])
ok = True
msg = _('Imported {} users successfully (Organization: {})').format(count, orgs_name)
except Exception as e:
msg = str(e)
return ok, msg
@staticmethod
def get_orgs(org_ids):
if org_ids:
orgs = list(Organization.objects.filter(id__in=org_ids))
else:
orgs = [current_org]
return orgs
@staticmethod
def get_ldap_users(username_list, cache_police):
if '*' in username_list:
users = LDAPServerUtil().search()
elif cache_police in LDAP_USE_CACHE_FLAGS:
users = LDAPCacheUtil().search(search_users=username_list)
else:
users = LDAPServerUtil().search(search_users=username_list)
return users

View File

@ -9,7 +9,7 @@ from django.conf import settings
from django.core.files.storage import default_storage from django.core.files.storage import default_storage
from django.http import HttpResponse from django.http import HttpResponse
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _, get_language
from rest_framework import viewsets from rest_framework import viewsets
from rest_framework.decorators import action from rest_framework.decorators import action
from rest_framework.request import Request from rest_framework.request import Request
@ -19,6 +19,8 @@ from rest_framework.serializers import ValidationError
from common.api import JMSBulkModelViewSet from common.api import JMSBulkModelViewSet
from common.serializers import FileSerializer from common.serializers import FileSerializer
from common.utils import is_uuid from common.utils import is_uuid
from common.utils.http import is_true
from common.utils.yml import yaml_load_with_i18n
from terminal import serializers from terminal import serializers
from terminal.models import AppletPublication, Applet from terminal.models import AppletPublication, Applet
@ -106,9 +108,66 @@ class AppletViewSet(DownloadUploadMixin, JMSBulkModelViewSet):
def get_object(self): def get_object(self):
pk = self.kwargs.get('pk') pk = self.kwargs.get('pk')
if not is_uuid(pk): if not is_uuid(pk):
return get_object_or_404(Applet, name=pk) obj = get_object_or_404(Applet, name=pk)
else: else:
return get_object_or_404(Applet, pk=pk) obj = get_object_or_404(Applet, pk=pk)
return self.trans_object(obj)
def get_queryset(self):
queryset = super().get_queryset()
queryset = self.trans_queryset(queryset)
return queryset
@staticmethod
def read_manifest_with_i18n(obj, lang='zh'):
path = os.path.join(obj.path, 'manifest.yml')
if os.path.exists(path):
with open(path, encoding='utf8') as f:
manifest = yaml_load_with_i18n(f, lang)
else:
manifest = {}
return manifest
def trans_queryset(self, queryset):
for obj in queryset:
self.trans_object(obj)
return queryset
@staticmethod
def readme(obj, lang=''):
lang = lang[:2]
readme_file = os.path.join(obj.path, f'README_{lang.upper()}.md')
if os.path.isfile(readme_file):
with open(readme_file, 'r') as f:
return f.read()
return ''
def trans_object(self, obj):
lang = get_language()
manifest = self.read_manifest_with_i18n(obj, lang)
obj.display_name = manifest.get('display_name', obj.display_name)
obj.comment = manifest.get('comment', obj.comment)
obj.readme = self.readme(obj, lang)
return obj
def is_record_found(self, obj, search):
combine_fields = ' '.join([getattr(obj, f, '') for f in self.search_fields])
return search in combine_fields
def filter_queryset(self, queryset):
search = self.request.query_params.get('search')
if search:
queryset = [i for i in queryset if self.is_record_found(i, search)]
for field in self.filterset_fields:
field_value = self.request.query_params.get(field)
if not field_value:
continue
if field in ['is_active', 'builtin']:
field_value = is_true(field_value)
queryset = [i for i in queryset if getattr(i, field, '') == field_value]
return queryset
def perform_destroy(self, instance): def perform_destroy(self, instance):
if not instance.name: if not instance.name:

View File

@ -42,7 +42,7 @@ class SmartEndpointViewMixin:
return endpoint return endpoint
def match_endpoint_by_label(self): def match_endpoint_by_label(self):
return Endpoint.match_by_instance_label(self.target_instance, self.target_protocol) return Endpoint.match_by_instance_label(self.target_instance, self.target_protocol, self.request)
def match_endpoint_by_target_ip(self): def match_endpoint_by_target_ip(self):
target_ip = self.request.GET.get('target_ip', '') # 支持target_ip参数用来方便测试 target_ip = self.request.GET.get('target_ip', '') # 支持target_ip参数用来方便测试

View File

@ -18,10 +18,11 @@ from rest_framework.response import Response
from audits.const import ActionChoices from audits.const import ActionChoices
from common.api import AsyncApiMixin from common.api import AsyncApiMixin
from common.const.http import GET from common.const.http import GET, POST
from common.drf.filters import BaseFilterSet from common.drf.filters import BaseFilterSet
from common.drf.filters import DatetimeRangeFilterBackend from common.drf.filters import DatetimeRangeFilterBackend
from common.drf.renders import PassthroughRenderer from common.drf.renders import PassthroughRenderer
from common.permissions import IsServiceAccount
from common.storage.replay import ReplayStorageHandler from common.storage.replay import ReplayStorageHandler
from common.utils import data_to_json, is_uuid, i18n_fmt from common.utils import data_to_json, is_uuid, i18n_fmt
from common.utils import get_logger, get_object_or_none from common.utils import get_logger, get_object_or_none
@ -33,6 +34,7 @@ from terminal import serializers
from terminal.const import TerminalType from terminal.const import TerminalType
from terminal.models import Session from terminal.models import Session
from terminal.permissions import IsSessionAssignee from terminal.permissions import IsSessionAssignee
from terminal.session_lifecycle import lifecycle_events_map, reasons_map
from terminal.utils import is_session_approver from terminal.utils import is_session_approver
from users.models import User from users.models import User
@ -79,6 +81,7 @@ class SessionViewSet(RecordViewLogMixin, OrgBulkModelViewSet):
serializer_classes = { serializer_classes = {
'default': serializers.SessionSerializer, 'default': serializers.SessionSerializer,
'display': serializers.SessionDisplaySerializer, 'display': serializers.SessionDisplaySerializer,
'lifecycle_log': serializers.SessionLifecycleLogSerializer,
} }
search_fields = [ search_fields = [
"user", "asset", "account", "remote_addr", "user", "asset", "account", "remote_addr",
@ -168,6 +171,23 @@ class SessionViewSet(RecordViewLogMixin, OrgBulkModelViewSet):
count = queryset.count() count = queryset.count()
return Response({'count': count}) return Response({'count': count})
@action(methods=[POST], detail=True, permission_classes=[IsServiceAccount], url_path='lifecycle_log',
url_name='lifecycle_log')
def lifecycle_log(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
validated_data = serializer.validated_data
event = validated_data.pop('event', None)
event_class = lifecycle_events_map.get(event, None)
if not event_class:
return Response({'msg': f'event_name {event} invalid'}, status=400)
session = self.get_object()
reason = validated_data.pop('reason', None)
reason = reasons_map.get(reason, reason)
event_obj = event_class(session, reason, **validated_data)
activity_log = event_obj.create_activity_log()
return Response({'msg': 'ok', 'id': activity_log.id})
def get_queryset(self): def get_queryset(self):
queryset = super().get_queryset() \ queryset = super().get_queryset() \
.prefetch_related('terminal') \ .prefetch_related('terminal') \

View File

@ -0,0 +1,9 @@
## Selenium Version
- Selenium == 4.4.0
- Chrome and ChromeDriver versions must match
- Driver [download address](https://chromedriver.chromium.org/downloads)
## ChangeLog
Refer to [ChangeLog](./ChangeLog) for some important updates.

View File

@ -0,0 +1,9 @@
## Selenium バージョン
- Selenium == 4.4.0
- Chrome と ChromeDriver のバージョンは一致している必要があります
- ドライバ [ダウンロードアドレス](https://chromedriver.chromium.org/downloads)
## 変更ログ
重要な更新については、[変更ログ](./ChangeLog) を参照してください

View File

@ -0,0 +1,4 @@
## DBeaver
- When connecting to a database application, it is necessary to download the driver. You can either install it offline
in advance or install the corresponding driver as prompted when connecting.

View File

@ -0,0 +1,3 @@
## DBeaver
- データベースに接続する際には、ドライバをダウンロードする必要があります。事前にオフラインでインストールするか、接続時に表示される指示に従って該当するドライバをインストールしてください。

View File

@ -2,10 +2,10 @@
import datetime import datetime
from django.db import transaction from django.db import transaction
from django.utils import timezone
from django.db.utils import OperationalError from django.db.utils import OperationalError
from common.utils.common import pretty_string from django.utils import timezone
from common.utils.common import pretty_string
from .base import CommandBase from .base import CommandBase
@ -19,9 +19,10 @@ class CommandStore(CommandBase):
""" """
保存命令到数据库 保存命令到数据库
""" """
cmd_input = pretty_string(command['input'])
self.model.objects.create( self.model.objects.create(
user=command["user"], asset=command["asset"], user=command["user"], asset=command["asset"],
account=command["account"], input=command["input"], account=command["account"], input=cmd_input,
output=command["output"], session=command["session"], output=command["output"], session=command["session"],
risk_level=command.get("risk_level", 0), org_id=command["org_id"], risk_level=command.get("risk_level", 0), org_id=command["org_id"],
timestamp=command["timestamp"] timestamp=command["timestamp"]

View File

@ -75,7 +75,20 @@ class Endpoint(JMSBaseModel):
return endpoint return endpoint
@classmethod @classmethod
def match_by_instance_label(cls, instance, protocol): def handle_endpoint_host(cls, endpoint, request=None):
if not endpoint.host and request:
# 动态添加 current request host
host_port = request.get_host()
# IPv6
if host_port.startswith('['):
host = host_port.split(']:')[0].rstrip(']') + ']'
else:
host = host_port.split(':')[0]
endpoint.host = host
return endpoint
@classmethod
def match_by_instance_label(cls, instance, protocol, request=None):
from assets.models import Asset from assets.models import Asset
from terminal.models import Session from terminal.models import Session
if isinstance(instance, Session): if isinstance(instance, Session):
@ -88,6 +101,7 @@ class Endpoint(JMSBaseModel):
endpoints = cls.objects.filter(name__in=list(values)).order_by('-date_updated') endpoints = cls.objects.filter(name__in=list(values)).order_by('-date_updated')
for endpoint in endpoints: for endpoint in endpoints:
if endpoint.is_valid_for(instance, protocol): if endpoint.is_valid_for(instance, protocol):
endpoint = cls.handle_endpoint_host(endpoint, request)
return endpoint return endpoint
@ -130,13 +144,5 @@ class EndpointRule(JMSBaseModel):
endpoint = endpoint_rule.endpoint endpoint = endpoint_rule.endpoint
else: else:
endpoint = Endpoint.get_or_create_default(request) endpoint = Endpoint.get_or_create_default(request)
if not endpoint.host and request: endpoint = Endpoint.handle_endpoint_host(endpoint, request)
# 动态添加 current request host
host_port = request.get_host()
# IPv6
if host_port.startswith('['):
host = host_port.split(']:')[0].rstrip(']') + ']'
else:
host = host_port.split(':')[0]
endpoint.host = host
return endpoint return endpoint

Some files were not shown because too many files have changed in this diff Show More