Merge pull request #12734 from jumpserver/master

v3.10.4 (branch-v3.10)
This commit is contained in:
Bryan 2024-02-29 16:39:55 +08:00 committed by GitHub
commit eedc2f1b41
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
122 changed files with 2611 additions and 1422 deletions

View File

@ -19,11 +19,11 @@ ARG BUILD_DEPENDENCIES=" \
ARG DEPENDENCIES=" \
freetds-dev \
libpq-dev \
libffi-dev \
libjpeg-dev \
libkrb5-dev \
libldap2-dev \
libpq-dev \
libsasl2-dev \
libssl-dev \
libxml2-dev \
@ -75,6 +75,7 @@ ENV LANG=zh_CN.UTF-8 \
ARG DEPENDENCIES=" \
libjpeg-dev \
libpq-dev \
libx11-dev \
freerdp2-dev \
libxmlsec1-openssl"

View File

@ -1,11 +1,12 @@
from django.db.models import Q
from rest_framework.generics import CreateAPIView
from accounts import serializers
from accounts.models import Account
from accounts.permissions import AccountTaskActionPermission
from accounts.tasks import (
remove_accounts_task, verify_accounts_connectivity_task, push_accounts_to_assets_task
)
from assets.exceptions import NotSupportedTemporarilyError
from authentication.permissions import UserConfirmation, ConfirmType
__all__ = [
@ -26,25 +27,35 @@ class AccountsTaskCreateAPI(CreateAPIView):
]
return super().get_permissions()
def perform_create(self, serializer):
data = serializer.validated_data
accounts = data.get('accounts', [])
params = data.get('params')
@staticmethod
def get_account_ids(data, action):
account_type = 'gather_accounts' if action == 'remove' else 'accounts'
accounts = data.get(account_type, [])
account_ids = [str(a.id) for a in accounts]
if data['action'] == 'push':
task = push_accounts_to_assets_task.delay(account_ids, params)
elif data['action'] == 'remove':
gather_accounts = data.get('gather_accounts', [])
gather_account_ids = [str(a.id) for a in gather_accounts]
task = remove_accounts_task.delay(gather_account_ids)
if action == 'remove':
return account_ids
assets = data.get('assets', [])
asset_ids = [str(a.id) for a in assets]
ids = Account.objects.filter(
Q(id__in=account_ids) | Q(asset_id__in=asset_ids)
).distinct().values_list('id', flat=True)
return [str(_id) for _id in ids]
def perform_create(self, serializer):
data = serializer.validated_data
action = data['action']
ids = self.get_account_ids(data, action)
if action == 'push':
task = push_accounts_to_assets_task.delay(ids, data.get('params'))
elif action == 'remove':
task = remove_accounts_task.delay(ids)
elif action == 'verify':
task = verify_accounts_connectivity_task.delay(ids)
else:
account = accounts[0]
asset = account.asset
if not asset.auto_config['ansible_enabled'] or \
not asset.auto_config['ping_enabled']:
raise NotSupportedTemporarilyError()
task = verify_accounts_connectivity_task.delay(account_ids)
raise ValueError(f"Invalid action: {action}")
data = getattr(serializer, '_data', {})
data["task"] = task.id

View File

@ -168,9 +168,8 @@ class AccountBackupHandler:
if not user.secret_key:
attachment_list = []
else:
password = user.secret_key.encode('utf8')
attachment = os.path.join(PATH, f'{plan_name}-{local_now_filename()}-{time.time()}.zip')
encrypt_and_compress_zip_file(attachment, password, files)
encrypt_and_compress_zip_file(attachment, user.secret_key, files)
attachment_list = [attachment, ]
AccountBackupExecutionTaskMsg(plan_name, user).publish(attachment_list)
print('邮件已发送至{}({})'.format(user, user.email))
@ -191,7 +190,6 @@ class AccountBackupHandler:
attachment = os.path.join(PATH, f'{plan_name}-{local_now_filename()}-{time.time()}.zip')
if password:
print('\033[32m>>> 使用加密密码对文件进行加密中\033[0m')
password = password.encode('utf8')
encrypt_and_compress_zip_file(attachment, password, files)
else:
zip_files(attachment, files)

View File

@ -7,6 +7,7 @@ type:
- all
method: change_secret
protocol: ssh
priority: 50
params:
- name: commands
type: list

View File

@ -39,3 +39,4 @@
login_host: "{{ jms_asset.address }}"
login_port: "{{ jms_asset.port }}"
login_database: "{{ jms_asset.spec_info.db_name }}"
mode: "{{ account.mode }}"

View File

@ -5,6 +5,7 @@ method: change_secret
category: host
type:
- windows
priority: 49
params:
- name: groups
type: str

View File

@ -4,6 +4,7 @@ from copy import deepcopy
from django.conf import settings
from django.utils import timezone
from django.utils.translation import gettext_lazy as _
from xlsxwriter import Workbook
from accounts.const import AutomationTypes, SecretType, SSHKeyStrategy, SecretStrategy
@ -118,6 +119,10 @@ class ChangeSecretManager(AccountBasePlaybookManager):
else:
new_secret = self.get_secret(secret_type)
if new_secret is None:
print(f'new_secret is None, account: {account}')
continue
if self.record_id is None:
recorder = ChangeSecretRecord(
asset=asset, account=account, execution=self.execution,
@ -183,17 +188,33 @@ class ChangeSecretManager(AccountBasePlaybookManager):
return False
return True
@staticmethod
def get_summary(recorders):
total, succeed, failed = 0, 0, 0
for recorder in recorders:
if recorder.status == 'success':
succeed += 1
else:
failed += 1
total += 1
summary = _('Success: %s, Failed: %s, Total: %s') % (succeed, failed, total)
return summary
def run(self, *args, **kwargs):
if self.secret_type and not self.check_secret():
return
super().run(*args, **kwargs)
recorders = list(self.name_recorder_mapper.values())
summary = self.get_summary(recorders)
print(summary, end='')
if self.record_id:
return
recorders = self.name_recorder_mapper.values()
recorders = list(recorders)
self.send_recorder_mail(recorders)
def send_recorder_mail(self, recorders):
self.send_recorder_mail(recorders, summary)
def send_recorder_mail(self, recorders, summary):
recipients = self.execution.recipients
if not recorders or not recipients:
return
@ -209,11 +230,10 @@ class ChangeSecretManager(AccountBasePlaybookManager):
for user in recipients:
attachments = []
if user.secret_key:
password = user.secret_key.encode('utf8')
attachment = os.path.join(path, f'{name}-{local_now_filename()}-{time.time()}.zip')
encrypt_and_compress_zip_file(attachment, password, [filename])
encrypt_and_compress_zip_file(attachment, user.secret_key, [filename])
attachments = [attachment]
ChangeSecretExecutionTaskMsg(name, user).publish(attachments)
ChangeSecretExecutionTaskMsg(name, user, summary).publish(attachments)
os.remove(filename)
@staticmethod

View File

@ -1,9 +1,10 @@
- hosts: demo
gather_facts: no
tasks:
- name: Gather posix account
- name: Gather windows account
ansible.builtin.win_shell: net user
register: result
ignore_errors: true
- name: Define info by set_fact
ansible.builtin.set_fact:

View File

@ -39,3 +39,4 @@
login_host: "{{ jms_asset.address }}"
login_port: "{{ jms_asset.port }}"
login_database: "{{ jms_asset.spec_info.db_name }}"
mode: "{{ account.mode }}"

View File

@ -5,6 +5,7 @@ method: push_account
category: host
type:
- windows
priority: 49
params:
- name: groups
type: str

View File

@ -6,6 +6,7 @@ type:
- windows
method: verify_account
protocol: rdp
priority: 1
i18n:
Windows rdp account verify:

View File

@ -7,6 +7,7 @@ type:
- all
method: verify_account
protocol: ssh
priority: 50
i18n:
SSH account verify:

View File

@ -51,6 +51,9 @@ class VerifyAccountManager(AccountBasePlaybookManager):
h['name'] += '(' + account.username + ')'
self.host_account_mapper[h['name']] = account
secret = account.secret
if secret is None:
print(f'account {account.name} secret is None')
continue
private_key_path = None
if account.secret_type == SecretType.SSH_KEY:
@ -62,7 +65,7 @@ class VerifyAccountManager(AccountBasePlaybookManager):
'name': account.name,
'username': account.username,
'secret_type': account.secret_type,
'secret': account.escape_jinja2_syntax(secret),
'secret': account.escape_jinja2_syntax(secret),
'private_key_path': private_key_path,
'become': account.get_ansible_become_auth(),
}

View File

@ -52,6 +52,7 @@ class AccountFilterSet(BaseFilterSet):
class GatheredAccountFilterSet(BaseFilterSet):
node_id = drf_filters.CharFilter(method='filter_nodes')
asset_id = drf_filters.CharFilter(field_name='asset_id', lookup_expr='exact')
asset_name = drf_filters.CharFilter(field_name='asset__name', lookup_expr='icontains')
@staticmethod
def filter_nodes(queryset, name, value):

View File

@ -54,20 +54,23 @@ class AccountBackupByObjStorageExecutionTaskMsg(object):
class ChangeSecretExecutionTaskMsg(object):
subject = _('Notification of implementation result of encryption change plan')
def __init__(self, name: str, user: User):
def __init__(self, name: str, user: User, summary):
self.name = name
self.user = user
self.summary = summary
@property
def message(self):
name = self.name
if self.user.secret_key:
return _('{} - The encryption change task has been completed. '
'See the attachment for details').format(name)
default_message = _('{} - The encryption change task has been completed. '
'See the attachment for details').format(name)
else:
return _("{} - The encryption change task has been completed: the encryption "
"password has not been set - please go to personal information -> "
"file encryption password to set the encryption password").format(name)
default_message = _("{} - The encryption change task has been completed: the encryption "
"password has not been set - please go to personal information -> "
"set encryption password in preferences").format(name)
return self.summary + '\n' + default_message
def publish(self, attachments=None):
send_mail_attachment_async(

View File

@ -58,7 +58,7 @@ class AccountCreateUpdateSerializerMixin(serializers.Serializer):
for data in initial_data:
if not data.get('asset') and not self.instance:
raise serializers.ValidationError({'asset': UniqueTogetherValidator.missing_message})
asset = data.get('asset') or self.instance.asset
asset = data.get('asset') or getattr(self.instance, 'asset', None)
self.from_template_if_need(data)
self.set_uniq_name_if_need(data, asset)
@ -455,12 +455,14 @@ class AccountHistorySerializer(serializers.ModelSerializer):
class AccountTaskSerializer(serializers.Serializer):
ACTION_CHOICES = (
('test', 'test'),
('verify', 'verify'),
('push', 'push'),
('remove', 'remove'),
)
action = serializers.ChoiceField(choices=ACTION_CHOICES, write_only=True)
assets = serializers.PrimaryKeyRelatedField(
queryset=Asset.objects, required=False, allow_empty=True, many=True
)
accounts = serializers.PrimaryKeyRelatedField(
queryset=Account.objects, required=False, allow_empty=True, many=True
)

View File

@ -63,7 +63,7 @@ def create_accounts_activities(account, action='create'):
def on_account_create_by_template(sender, instance, created=False, **kwargs):
if not created or instance.source != 'template':
return
push_accounts_if_need(accounts=(instance,))
push_accounts_if_need.delay(accounts=(instance,))
create_accounts_activities(instance, action='create')

View File

@ -41,21 +41,21 @@ class UserLoginReminderMsg(UserMessage):
class AssetLoginReminderMsg(UserMessage):
subject = _('Asset login reminder')
def __init__(self, user, asset: Asset, login_user: User, account_username):
def __init__(self, user, asset: Asset, login_user: User, account: Account, input_username):
self.asset = asset
self.login_user = login_user
self.account_username = account_username
self.account = account
self.input_username = input_username
super().__init__(user)
def get_html_msg(self) -> dict:
account = Account.objects.get(asset=self.asset, username=self.account_username)
context = {
'recipient': self.user,
'username': self.login_user.username,
'name': self.login_user.name,
'asset': str(self.asset),
'account': self.account_username,
'account_name': account.name,
'account': self.input_username,
'account_name': self.account.name,
}
message = render_to_string('acls/asset_login_reminder.html', context)

View File

@ -92,6 +92,7 @@ class AssetViewSet(SuggestionMixin, OrgBulkModelViewSet):
model = Asset
filterset_class = AssetFilterSet
search_fields = ("name", "address", "comment")
ordering = ('name',)
ordering_fields = ('name', 'address', 'connectivity', 'platform', 'date_updated', 'date_created')
serializer_classes = (
("default", serializers.AssetSerializer),

View File

@ -48,7 +48,7 @@ class AssetPermUserListApi(BaseAssetPermUserOrUserGroupListApi):
def get_queryset(self):
perms = self.get_asset_related_perms()
users = User.objects.filter(
users = User.get_queryset().filter(
Q(assetpermissions__in=perms) | Q(groups__assetpermissions__in=perms)
).distinct()
return users

View File

@ -1,2 +1,2 @@
from .endpoint import ExecutionManager
from .methods import platform_automation_methods, filter_platform_methods
from .methods import platform_automation_methods, filter_platform_methods, sorted_methods

View File

@ -68,6 +68,10 @@ def filter_platform_methods(category, tp_name, method=None, methods=None):
return methods
def sorted_methods(methods):
return sorted(methods, key=lambda x: x.get('priority', 10))
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
platform_automation_methods = get_platform_automation_methods(BASE_DIR)

View File

@ -7,6 +7,7 @@ type:
- windows
method: ping
protocol: rdp
priority: 1
i18n:
Ping by pyfreerdp:

View File

@ -7,6 +7,7 @@ type:
- all
method: ping
protocol: ssh
priority: 50
i18n:
Ping by paramiko:

View File

@ -90,7 +90,7 @@ class AllTypes(ChoicesMixin):
@classmethod
def set_automation_methods(cls, category, tp_name, constraints):
from assets.automations import filter_platform_methods
from assets.automations import filter_platform_methods, sorted_methods
automation = constraints.get('automation', {})
automation_methods = {}
platform_automation_methods = cls.get_automation_methods()
@ -101,6 +101,7 @@ class AllTypes(ChoicesMixin):
methods = filter_platform_methods(
category, tp_name, item_name, methods=platform_automation_methods
)
methods = sorted_methods(methods)
methods = [{'name': m['name'], 'id': m['id']} for m in methods]
automation_methods[item_name + '_methods'] = methods
automation.update(automation_methods)

View File

@ -12,6 +12,6 @@ class Migration(migrations.Migration):
operations = [
migrations.AlterModelOptions(
name='asset',
options={'ordering': ['name'], 'permissions': [('refresh_assethardwareinfo', 'Can refresh asset hardware info'), ('test_assetconnectivity', 'Can test asset connectivity'), ('match_asset', 'Can match asset'), ('change_assetnodes', 'Can change asset nodes')], 'verbose_name': 'Asset'},
options={'ordering': [], 'permissions': [('refresh_assethardwareinfo', 'Can refresh asset hardware info'), ('test_assetconnectivity', 'Can test asset connectivity'), ('match_asset', 'Can match asset'), ('change_assetnodes', 'Can change asset nodes')], 'verbose_name': 'Asset'},
),
]

View File

@ -348,7 +348,7 @@ class Asset(NodesRelationMixin, LabeledMixin, AbsConnectivity, JSONFilterMixin,
class Meta:
unique_together = [('org_id', 'name')]
verbose_name = _("Asset")
ordering = ["name", ]
ordering = []
permissions = [
('refresh_assethardwareinfo', _('Can refresh asset hardware info')),
('test_assetconnectivity', _('Can test asset connectivity')),

View File

@ -429,7 +429,7 @@ class NodeAssetsMixin(NodeAllAssetsMappingMixin):
@classmethod
@timeit
def get_nodes_all_assets(cls, *nodes):
def get_nodes_all_assets(cls, *nodes, distinct=True):
from .asset import Asset
node_ids = set()
descendant_node_query = Q()
@ -439,7 +439,10 @@ class NodeAssetsMixin(NodeAllAssetsMappingMixin):
if descendant_node_query:
_ids = Node.objects.order_by().filter(descendant_node_query).values_list('id', flat=True)
node_ids.update(_ids)
return Asset.objects.order_by().filter(nodes__id__in=node_ids).distinct()
assets = Asset.objects.order_by().filter(nodes__id__in=node_ids)
if distinct:
assets = assets.distinct()
return assets
def get_all_asset_ids(self):
asset_ids = self.get_all_asset_ids_by_node_key(org_id=self.org_id, node_key=self.key)

View File

@ -63,13 +63,13 @@ def on_asset_create(sender, instance=None, created=False, **kwargs):
return
logger.info("Asset create signal recv: {}".format(instance))
ensure_asset_has_node(assets=(instance,))
ensure_asset_has_node.delay(assets=(instance,))
# 获取资产硬件信息
auto_config = instance.auto_config
if auto_config.get('ping_enabled'):
logger.debug('Asset {} ping enabled, test connectivity'.format(instance.name))
test_assets_connectivity_handler(assets=(instance,))
test_assets_connectivity_handler.delay(assets=(instance,))
if auto_config.get('gather_facts_enabled'):
logger.debug('Asset {} gather facts enabled, gather facts'.format(instance.name))
gather_assets_facts_handler(assets=(instance,))

View File

@ -2,14 +2,16 @@
#
from operator import add, sub
from django.conf import settings
from django.db.models.signals import m2m_changed
from django.dispatch import receiver
from assets.models import Asset, Node
from common.const.signals import PRE_CLEAR, POST_ADD, PRE_REMOVE
from common.decorators import on_transaction_commit, merge_delay_run
from common.signals import django_ready
from common.utils import get_logger
from orgs.utils import tmp_to_org
from orgs.utils import tmp_to_org, tmp_to_root_org
from ..tasks import check_node_assets_amount_task
logger = get_logger(__file__)
@ -34,7 +36,7 @@ def on_node_asset_change(sender, action, instance, reverse, pk_set, **kwargs):
node_ids = [instance.id]
else:
node_ids = list(pk_set)
update_nodes_assets_amount(node_ids=node_ids)
update_nodes_assets_amount.delay(node_ids=node_ids)
@merge_delay_run(ttl=30)
@ -52,3 +54,18 @@ def update_nodes_assets_amount(node_ids=()):
node.assets_amount = node.get_assets_amount()
Node.objects.bulk_update(nodes, ['assets_amount'])
@receiver(django_ready)
def set_assets_size_to_setting(sender, **kwargs):
from assets.models import Asset
try:
with tmp_to_root_org():
amount = Asset.objects.order_by().count()
except:
amount = 0
if amount > 20000:
settings.ASSET_SIZE = 'large'
elif amount > 2000:
settings.ASSET_SIZE = 'medium'

View File

@ -44,18 +44,18 @@ def on_node_post_create(sender, instance, created, update_fields, **kwargs):
need_expire = False
if need_expire:
expire_node_assets_mapping(org_ids=(instance.org_id,))
expire_node_assets_mapping.delay(org_ids=(instance.org_id,))
@receiver(post_delete, sender=Node)
def on_node_post_delete(sender, instance, **kwargs):
expire_node_assets_mapping(org_ids=(instance.org_id,))
expire_node_assets_mapping.delay(org_ids=(instance.org_id,))
@receiver(m2m_changed, sender=Asset.nodes.through)
def on_node_asset_change(sender, instance, action='pre_remove', **kwargs):
if action.startswith('post'):
expire_node_assets_mapping(org_ids=(instance.org_id,))
expire_node_assets_mapping.delay(org_ids=(instance.org_id,))
@receiver(django_ready)

View File

@ -20,6 +20,7 @@ from common.const.http import GET, POST
from common.drf.filters import DatetimeRangeFilterBackend
from common.permissions import IsServiceAccount
from common.plugins.es import QuerySet as ESQuerySet
from common.sessions.cache import user_session_manager
from common.storage.ftp_file import FTPFileStorageHandler
from common.utils import is_uuid, get_logger, lazyproperty
from orgs.mixins.api import OrgReadonlyModelViewSet, OrgModelViewSet
@ -30,7 +31,7 @@ from terminal.models import default_storage
from users.models import User
from .backends import TYPE_ENGINE_MAPPING
from .const import ActivityChoices
from .filters import UserSessionFilterSet
from .filters import UserSessionFilterSet, OperateLogFilterSet
from .models import (
FTPLog, UserLoginLog, OperateLog, PasswordChangeLog,
ActivityLog, JobLog, UserSession
@ -204,10 +205,7 @@ class OperateLogViewSet(OrgReadonlyModelViewSet):
date_range_filter_fields = [
('datetime', ('date_from', 'date_to'))
]
filterset_fields = [
'user', 'action', 'resource_type', 'resource',
'remote_addr'
]
filterset_class = OperateLogFilterSet
search_fields = ['resource', 'user']
ordering = ['-datetime']
@ -289,8 +287,7 @@ class UserSessionViewSet(CommonApiMixin, viewsets.ModelViewSet):
return Response(status=status.HTTP_200_OK)
keys = queryset.values_list('key', flat=True)
session_store_cls = import_module(settings.SESSION_ENGINE).SessionStore
for key in keys:
session_store_cls(key).delete()
user_session_manager.decrement_or_remove(key)
queryset.delete()
return Response(status=status.HTTP_200_OK)

View File

@ -1,12 +1,13 @@
from django.core.cache import cache
from django.apps import apps
from django.utils import translation
from django_filters import rest_framework as drf_filters
from rest_framework import filters
from rest_framework.compat import coreapi, coreschema
from common.drf.filters import BaseFilterSet
from notifications.ws import WS_SESSION_KEY
from common.sessions.cache import user_session_manager
from orgs.utils import current_org
from .models import UserSession
from .models import UserSession, OperateLog
__all__ = ['CurrentOrgMembersFilter']
@ -41,15 +42,32 @@ class UserSessionFilterSet(BaseFilterSet):
@staticmethod
def filter_is_active(queryset, name, is_active):
redis_client = cache.client.get_client()
members = redis_client.smembers(WS_SESSION_KEY)
members = [member.decode('utf-8') for member in members]
keys = user_session_manager.get_active_keys()
if is_active:
queryset = queryset.filter(key__in=members)
queryset = queryset.filter(key__in=keys)
else:
queryset = queryset.exclude(key__in=members)
queryset = queryset.exclude(key__in=keys)
return queryset
class Meta:
model = UserSession
fields = ['id', 'ip', 'city', 'type']
class OperateLogFilterSet(BaseFilterSet):
resource_type = drf_filters.CharFilter(method='filter_resource_type')
@staticmethod
def filter_resource_type(queryset, name, resource_type):
current_lang = translation.get_language()
with translation.override(current_lang):
mapper = {str(m._meta.verbose_name): m._meta.verbose_name_raw for m in apps.get_models()}
tp = mapper.get(resource_type)
queryset = queryset.filter(resource_type=tp)
return queryset
class Meta:
model = OperateLog
fields = [
'user', 'action', 'resource', 'remote_addr'
]

View File

@ -4,15 +4,15 @@ from datetime import timedelta
from importlib import import_module
from django.conf import settings
from django.core.cache import caches, cache
from django.core.cache import caches
from django.db import models
from django.db.models import Q
from django.utils import timezone
from django.utils.translation import gettext, gettext_lazy as _
from common.db.encoder import ModelJSONFieldEncoder
from common.sessions.cache import user_session_manager
from common.utils import lazyproperty, i18n_trans
from notifications.ws import WS_SESSION_KEY
from ops.models import JobExecution
from orgs.mixins.models import OrgModelMixin, Organization
from orgs.utils import current_org
@ -278,8 +278,7 @@ class UserSession(models.Model):
@property
def is_active(self):
redis_client = cache.client.get_client()
return redis_client.sismember(WS_SESSION_KEY, self.key)
return user_session_manager.check_active(self.key)
@property
def date_expired(self):

View File

@ -23,7 +23,7 @@ class JobLogSerializer(JobExecutionSerializer):
class Meta:
model = models.JobLog
read_only_fields = [
"id", "material", "time_cost", 'date_start',
"id", "material", 'job_type', "time_cost", 'date_start',
'date_finished', 'date_created',
'is_finished', 'is_success',
'task_id', 'creator_name'

View File

@ -19,7 +19,7 @@ from ops.celery.decorator import (
from ops.models import CeleryTaskExecution
from terminal.models import Session, Command
from terminal.backends import server_replay_storage
from .models import UserLoginLog, OperateLog, FTPLog, ActivityLog
from .models import UserLoginLog, OperateLog, FTPLog, ActivityLog, PasswordChangeLog
logger = get_logger(__name__)
@ -38,6 +38,14 @@ def clean_operation_log_period():
OperateLog.objects.filter(datetime__lt=expired_day).delete()
def clean_password_change_log_period():
now = timezone.now()
days = get_log_keep_day('PASSWORD_CHANGE_LOG_KEEP_DAYS')
expired_day = now - datetime.timedelta(days=days)
PasswordChangeLog.objects.filter(datetime__lt=expired_day).delete()
logger.info("Clean password change log done")
def clean_activity_log_period():
now = timezone.now()
days = get_log_keep_day('ACTIVITY_LOG_KEEP_DAYS')
@ -109,6 +117,7 @@ def clean_audits_log_period():
clean_activity_log_period()
clean_celery_tasks_period()
clean_expired_session_period()
clean_password_change_log_period()
@shared_task(verbose_name=_('Upload FTP file to external storage'))

View File

@ -205,7 +205,7 @@ class RDPFileClientProtocolURLMixin:
return data
def get_smart_endpoint(self, protocol, asset=None):
endpoint = Endpoint.match_by_instance_label(asset, protocol)
endpoint = Endpoint.match_by_instance_label(asset, protocol, self.request)
if not endpoint:
target_ip = asset.get_target_ip() if asset else ''
endpoint = EndpointRule.match_endpoint(
@ -443,7 +443,7 @@ class ConnectionTokenViewSet(ExtraActionApiMixin, RootOrgViewMixin, JMSModelView
self._record_operate_log(acl, asset)
for reviewer in reviewers:
AssetLoginReminderMsg(
reviewer, asset, user, self.input_username
reviewer, asset, user, account, self.input_username
).publish_async()
def create(self, request, *args, **kwargs):

View File

@ -10,6 +10,7 @@ from rest_framework import authentication, exceptions
from common.auth import signature
from common.decorators import merge_delay_run
from common.utils import get_object_or_none, get_request_ip_or_data, contains_ip
from users.models import User
from ..models import AccessKey, PrivateToken
@ -19,22 +20,23 @@ def date_more_than(d, seconds):
@merge_delay_run(ttl=60)
def update_token_last_used(tokens=()):
for token in tokens:
token.date_last_used = timezone.now()
token.save(update_fields=['date_last_used'])
access_keys_ids = [token.id for token in tokens if isinstance(token, AccessKey)]
private_token_keys = [token.key for token in tokens if isinstance(token, PrivateToken)]
if len(access_keys_ids) > 0:
AccessKey.objects.filter(id__in=access_keys_ids).update(date_last_used=timezone.now())
if len(private_token_keys) > 0:
PrivateToken.objects.filter(key__in=private_token_keys).update(date_last_used=timezone.now())
@merge_delay_run(ttl=60)
def update_user_last_used(users=()):
for user in users:
user.date_api_key_last_used = timezone.now()
user.save(update_fields=['date_api_key_last_used'])
User.objects.filter(id__in=users).update(date_api_key_last_used=timezone.now())
def after_authenticate_update_date(user, token=None):
update_user_last_used(users=(user,))
update_user_last_used.delay(users=(user.id,))
if token:
update_token_last_used(tokens=(token,))
update_token_last_used.delay(tokens=(token,))
class AccessTokenAuthentication(authentication.BaseAuthentication):

View File

@ -98,16 +98,19 @@ class OAuth2Backend(JMSModelBackend):
access_token_url = '{url}{separator}{query}'.format(
url=settings.AUTH_OAUTH2_ACCESS_TOKEN_ENDPOINT, separator=separator, query=urlencode(query_dict)
)
# token_method -> get, post(post_data), post_json
token_method = settings.AUTH_OAUTH2_ACCESS_TOKEN_METHOD.lower()
requests_func = getattr(requests, token_method, requests.get)
logger.debug(log_prompt.format('Call the access token endpoint[method: %s]' % token_method))
headers = {
'Accept': 'application/json'
}
if token_method == 'post':
access_token_response = requests_func(access_token_url, headers=headers, data=query_dict)
if token_method.startswith('post'):
body_key = 'json' if token_method.endswith('json') else 'data'
access_token_response = requests.post(
access_token_url, headers=headers, **{body_key: query_dict}
)
else:
access_token_response = requests_func(access_token_url, headers=headers)
access_token_response = requests.get(access_token_url, headers=headers)
try:
access_token_response.raise_for_status()
access_token_response_data = access_token_response.json()

View File

@ -18,7 +18,7 @@ class EncryptedField(forms.CharField):
class UserLoginForm(forms.Form):
days_auto_login = int(settings.SESSION_COOKIE_AGE / 3600 / 24)
disable_days_auto_login = settings.SESSION_EXPIRE_AT_BROWSER_CLOSE_FORCE \
disable_days_auto_login = settings.SESSION_EXPIRE_AT_BROWSER_CLOSE \
or days_auto_login < 1
username = forms.CharField(

View File

@ -142,23 +142,7 @@ class SessionCookieMiddleware(MiddlewareMixin):
return response
response.set_cookie(key, value)
@staticmethod
def set_cookie_session_expire(request, response):
if not request.session.get('auth_session_expiration_required'):
return
value = 'age'
if settings.SESSION_EXPIRE_AT_BROWSER_CLOSE_FORCE or \
not request.session.get('auto_login', False):
value = 'close'
age = request.session.get_expiry_age()
expire_timestamp = request.session.get_expiry_date().timestamp()
response.set_cookie('jms_session_expire_timestamp', expire_timestamp)
response.set_cookie('jms_session_expire', value, max_age=age)
request.session.pop('auth_session_expiration_required', None)
def process_response(self, request, response: HttpResponse):
self.set_cookie_session_prefix(request, response)
self.set_cookie_public_key(request, response)
self.set_cookie_session_expire(request, response)
return response

View File

@ -37,9 +37,6 @@ def on_user_auth_login_success(sender, user, request, **kwargs):
UserSession.objects.filter(key=session_key).delete()
cache.set(lock_key, request.session.session_key, None)
# 标记登录,设置 cookie前端可以控制刷新, Middleware 会拦截这个生成 cookie
request.session['auth_session_expiration_required'] = 1
@receiver(cas_user_authenticated)
def on_cas_user_login_success(sender, request, user, **kwargs):

View File

@ -407,6 +407,15 @@
$('#password-hidden').val(passwordEncrypted); //返回给密码输入input
$('#login-form').submit(); //post提交
}
function checkHealth() {
let url = "{% url 'health' %}";
requestApi({
url: url,
method: "GET",
flash_message: false,
})
}
setInterval(checkHealth, 30 * 1000);
</script>
</html>

View File

@ -70,11 +70,12 @@ class DingTalkQRMixin(DingTalkBaseMixin, View):
self.request.session[DINGTALK_STATE_SESSION_KEY] = state
params = {
'appid': settings.DINGTALK_APPKEY,
'client_id': settings.DINGTALK_APPKEY,
'response_type': 'code',
'scope': 'snsapi_login',
'scope': 'openid',
'state': state,
'redirect_uri': redirect_uri,
'prompt': 'consent'
}
url = URL.QR_CONNECT + '?' + urlencode(params)
return url

View File

@ -19,3 +19,17 @@ class Status(models.TextChoices):
failed = 'failed', _("Failed")
error = 'error', _("Error")
canceled = 'canceled', _("Canceled")
COUNTRY_CALLING_CODES = [
{'name': 'China(中国)', 'value': '+86'},
{'name': 'HongKong(中国香港)', 'value': '+852'},
{'name': 'Macao(中国澳门)', 'value': '+853'},
{'name': 'Taiwan(中国台湾)', 'value': '+886'},
{'name': 'America(America)', 'value': '+1'}, {'name': 'Russia(Россия)', 'value': '+7'},
{'name': 'France(français)', 'value': '+33'},
{'name': 'Britain(Britain)', 'value': '+44'},
{'name': 'Germany(Deutschland)', 'value': '+49'},
{'name': 'Japan(日本)', 'value': '+81'}, {'name': 'Korea(한국)', 'value': '+82'},
{'name': 'India(भारत)', 'value': '+91'}
]

View File

@ -362,11 +362,15 @@ class RelatedManager:
if name is None or val is None:
continue
if custom_attr_filter:
custom_filter_q = None
spec_attr_filter = getattr(to_model, "get_{}_filter_attr_q".format(name), None)
if spec_attr_filter:
custom_filter_q = spec_attr_filter(val, match)
elif custom_attr_filter:
custom_filter_q = custom_attr_filter(name, val, match)
if custom_filter_q:
filters.append(custom_filter_q)
continue
if custom_filter_q:
filters.append(custom_filter_q)
continue
if match == 'ip_in':
q = cls.get_ip_in_q(name, val)
@ -464,11 +468,15 @@ class JSONManyToManyDescriptor:
rule_value = rule.get('value', '')
rule_match = rule.get('match', 'exact')
if custom_attr_filter:
q = custom_attr_filter(rule['name'], rule_value, rule_match)
if q:
custom_q &= q
continue
custom_filter_q = None
spec_attr_filter = getattr(to_model, "get_filter_{}_attr_q".format(rule['name']), None)
if spec_attr_filter:
custom_filter_q = spec_attr_filter(rule_value, rule_match)
elif custom_attr_filter:
custom_filter_q = custom_attr_filter(rule['name'], rule_value, rule_match)
if custom_filter_q:
custom_q &= custom_filter_q
continue
if rule_match == 'in':
res &= value in rule_value or '*' in rule_value
@ -517,7 +525,6 @@ class JSONManyToManyDescriptor:
res &= rule_value.issubset(value)
else:
res &= bool(value & rule_value)
else:
logging.error("unknown match: {}".format(rule['match']))
res &= False

View File

@ -3,6 +3,7 @@
import asyncio
import functools
import inspect
import os
import threading
import time
from concurrent.futures import ThreadPoolExecutor
@ -101,7 +102,11 @@ def run_debouncer_func(cache_key, org, ttl, func, *args, **kwargs):
first_run_time = current
if current - first_run_time > ttl:
_loop_debouncer_func_args_cache.pop(cache_key, None)
_loop_debouncer_func_task_time_cache.pop(cache_key, None)
executor.submit(run_func_partial, *args, **kwargs)
logger.debug('pid {} executor submit run {}'.format(
os.getpid(), func.__name__, ))
return
loop = _loop_thread.get_loop()
@ -133,13 +138,26 @@ class Debouncer(object):
return await self.loop.run_in_executor(self.executor, func)
ignore_err_exceptions = (
"(3101, 'Plugin instructed the server to rollback the current transaction.')",
)
def _run_func_with_org(key, org, func, *args, **kwargs):
from orgs.utils import set_current_org
try:
set_current_org(org)
func(*args, **kwargs)
with transaction.atomic():
set_current_org(org)
func(*args, **kwargs)
except Exception as e:
logger.error('delay run error: %s' % e)
msg = str(e)
log_func = logger.error
if msg in ignore_err_exceptions:
log_func = logger.info
pid = os.getpid()
thread_name = threading.current_thread()
log_func('pid {} thread {} delay run {} error: {}'.format(
pid, thread_name, func.__name__, msg))
_loop_debouncer_func_task_cache.pop(key, None)
_loop_debouncer_func_args_cache.pop(key, None)
_loop_debouncer_func_task_time_cache.pop(key, None)
@ -181,6 +199,32 @@ def merge_delay_run(ttl=5, key=None):
:return:
"""
def delay(func, *args, **kwargs):
from orgs.utils import get_current_org
suffix_key_func = key if key else default_suffix_key
org = get_current_org()
func_name = f'{func.__module__}_{func.__name__}'
key_suffix = suffix_key_func(*args, **kwargs)
cache_key = f'MERGE_DELAY_RUN_{func_name}_{key_suffix}'
cache_kwargs = _loop_debouncer_func_args_cache.get(cache_key, {})
for k, v in kwargs.items():
if not isinstance(v, (tuple, list, set)):
raise ValueError('func kwargs value must be list or tuple: %s %s' % (func.__name__, v))
v = set(v)
if k not in cache_kwargs:
cache_kwargs[k] = v
else:
cache_kwargs[k] = cache_kwargs[k].union(v)
_loop_debouncer_func_args_cache[cache_key] = cache_kwargs
run_debouncer_func(cache_key, org, ttl, func, *args, **cache_kwargs)
def apply(func, sync=False, *args, **kwargs):
if sync:
return func(*args, **kwargs)
else:
return delay(func, *args, **kwargs)
def inner(func):
sigs = inspect.signature(func)
if len(sigs.parameters) != 1:
@ -188,27 +232,12 @@ def merge_delay_run(ttl=5, key=None):
param = list(sigs.parameters.values())[0]
if not isinstance(param.default, tuple):
raise ValueError('func default must be tuple: %s' % param.default)
suffix_key_func = key if key else default_suffix_key
func.delay = functools.partial(delay, func)
func.apply = functools.partial(apply, func)
@functools.wraps(func)
def wrapper(*args, **kwargs):
from orgs.utils import get_current_org
org = get_current_org()
func_name = f'{func.__module__}_{func.__name__}'
key_suffix = suffix_key_func(*args, **kwargs)
cache_key = f'MERGE_DELAY_RUN_{func_name}_{key_suffix}'
cache_kwargs = _loop_debouncer_func_args_cache.get(cache_key, {})
for k, v in kwargs.items():
if not isinstance(v, (tuple, list, set)):
raise ValueError('func kwargs value must be list or tuple: %s %s' % (func.__name__, v))
v = set(v)
if k not in cache_kwargs:
cache_kwargs[k] = v
else:
cache_kwargs[k] = cache_kwargs[k].union(v)
_loop_debouncer_func_args_cache[cache_key] = cache_kwargs
run_debouncer_func(cache_key, org, ttl, func, *args, **cache_kwargs)
return func(*args, **kwargs)
return wrapper

View File

@ -6,7 +6,7 @@ import logging
from django.core.cache import cache
from django.core.exceptions import ImproperlyConfigured
from django.db.models import Q, Count
from django.db.models import Q
from django_filters import rest_framework as drf_filters
from rest_framework import filters
from rest_framework.compat import coreapi, coreschema
@ -180,36 +180,30 @@ class LabelFilterBackend(filters.BaseFilterBackend):
]
@staticmethod
def filter_resources(resources, labels_id):
def parse_label_ids(labels_id):
from labels.models import Label
label_ids = [i.strip() for i in labels_id.split(',')]
cleaned = []
args = []
for label_id in label_ids:
kwargs = {}
if ':' in label_id:
k, v = label_id.split(':', 1)
kwargs['label__name'] = k.strip()
kwargs['name'] = k.strip()
if v != '*':
kwargs['label__value'] = v.strip()
kwargs['value'] = v.strip()
args.append(kwargs)
else:
kwargs['label_id'] = label_id
args.append(kwargs)
cleaned.append(label_id)
if len(args) == 1:
resources = resources.filter(**args[0])
return resources
q = Q()
for kwarg in args:
q |= Q(**kwarg)
resources = resources.filter(q) \
.values('res_id') \
.order_by('res_id') \
.annotate(count=Count('res_id', distinct=True)) \
.values('res_id', 'count') \
.filter(count=len(args))
return resources
if len(args) != 0:
q = Q()
for kwarg in args:
q |= Q(**kwarg)
ids = Label.objects.filter(q).values_list('id', flat=True)
cleaned.extend(list(ids))
return cleaned
def filter_queryset(self, request, queryset, view):
labels_id = request.query_params.get('labels')
@ -230,7 +224,8 @@ class LabelFilterBackend(filters.BaseFilterBackend):
resources = labeled_resource_cls.objects.filter(
res_type__app_label=app_label, res_type__model=model_name,
)
resources = self.filter_resources(resources, labels_id)
label_ids = self.parse_label_ids(labels_id)
resources = model.filter_resources_by_labels(resources, label_ids)
res_ids = resources.values_list('res_id', flat=True)
queryset = queryset.filter(id__in=set(res_ids))
return queryset

View File

@ -87,7 +87,7 @@ class BaseFileRenderer(BaseRenderer):
if value is None:
return '-'
pk = str(value.get('id', '') or value.get('pk', ''))
name = value.get('name') or value.get('display_name', '')
name = value.get('display_name', '') or value.get('name', '')
return '{}({})'.format(name, pk)
@staticmethod

View File

@ -28,9 +28,10 @@ class ErrorCode:
class URL:
QR_CONNECT = 'https://oapi.dingtalk.com/connect/qrconnect'
QR_CONNECT = 'https://login.dingtalk.com/oauth2/auth'
OAUTH_CONNECT = 'https://oapi.dingtalk.com/connect/oauth2/sns_authorize'
GET_USER_INFO_BY_CODE = 'https://oapi.dingtalk.com/sns/getuserinfo_bycode'
GET_USER_ACCESSTOKEN = 'https://api.dingtalk.com/v1.0/oauth2/userAccessToken'
GET_USER_INFO = 'https://api.dingtalk.com/v1.0/contact/users/me'
GET_TOKEN = 'https://oapi.dingtalk.com/gettoken'
SEND_MESSAGE_BY_TEMPLATE = 'https://oapi.dingtalk.com/topapi/message/corpconversation/sendbytemplate'
SEND_MESSAGE = 'https://oapi.dingtalk.com/topapi/message/corpconversation/asyncsend_v2'
@ -72,8 +73,9 @@ class DingTalkRequests(BaseRequest):
def get(self, url, params=None,
with_token=False, with_sign=False,
check_errcode_is_0=True,
**kwargs):
**kwargs) -> dict:
pass
get = as_request(get)
def post(self, url, json=None, params=None,
@ -81,6 +83,7 @@ class DingTalkRequests(BaseRequest):
check_errcode_is_0=True,
**kwargs) -> dict:
pass
post = as_request(post)
def _add_sign(self, kwargs: dict):
@ -123,17 +126,22 @@ class DingTalk:
)
def get_userinfo_bycode(self, code):
# https://developers.dingtalk.com/document/app/obtain-the-user-information-based-on-the-sns-temporary-authorization?spm=ding_open_doc.document.0.0.3a256573y8Y7yg#topic-1995619
body = {
"tmp_auth_code": code
'clientId': self._appid,
'clientSecret': self._appsecret,
'code': code,
'grantType': 'authorization_code'
}
data = self._request.post(URL.GET_USER_ACCESSTOKEN, json=body, check_errcode_is_0=False)
token = data['accessToken']
data = self._request.post(URL.GET_USER_INFO_BY_CODE, json=body, with_sign=True)
return data['user_info']
user = self._request.get(URL.GET_USER_INFO,
headers={'x-acs-dingtalk-access-token': token}, check_errcode_is_0=False)
return user
def get_user_id_by_code(self, code):
user_info = self.get_userinfo_bycode(code)
unionid = user_info['unionid']
unionid = user_info['unionId']
userid = self.get_userid_by_unionid(unionid)
return userid, None

View File

View File

@ -0,0 +1,56 @@
import re
from django.contrib.sessions.backends.cache import (
SessionStore as DjangoSessionStore
)
from django.core.cache import cache
from jumpserver.utils import get_current_request
class SessionStore(DjangoSessionStore):
ignore_urls = [
r'^/api/v1/users/profile/'
]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.ignore_pattern = re.compile('|'.join(self.ignore_urls))
def save(self, *args, **kwargs):
request = get_current_request()
if request is None or not self.ignore_pattern.match(request.path):
super().save(*args, **kwargs)
class RedisUserSessionManager:
JMS_SESSION_KEY = 'jms_session_key'
def __init__(self):
self.client = cache.client.get_client()
def add_or_increment(self, session_key):
self.client.hincrby(self.JMS_SESSION_KEY, session_key, 1)
def decrement_or_remove(self, session_key):
new_count = self.client.hincrby(self.JMS_SESSION_KEY, session_key, -1)
if new_count <= 0:
self.client.hdel(self.JMS_SESSION_KEY, session_key)
def check_active(self, session_key):
count = self.client.hget(self.JMS_SESSION_KEY, session_key)
count = 0 if count is None else int(count.decode('utf-8'))
return count > 0
def get_active_keys(self):
session_keys = []
for k, v in self.client.hgetall(self.JMS_SESSION_KEY).items():
count = int(v.decode('utf-8'))
if count <= 0:
continue
key = k.decode('utf-8')
session_keys.append(key)
return session_keys
user_session_manager = RedisUserSessionManager()

View File

@ -69,7 +69,7 @@ def digest_sql_query():
for query in queries:
sql = query['sql']
print(" # {}: {}".format(query['time'], sql[:1000]))
print(" # {}: {}".format(query['time'], sql[:1000]))
if len(queries) < 3:
continue
print("- Table: {}".format(table_name))

View File

@ -21,6 +21,8 @@ def encrypt_and_compress_zip_file(filename, secret_password, encrypted_filenames
with pyzipper.AESZipFile(
filename, 'w', compression=pyzipper.ZIP_LZMA, encryption=pyzipper.WZ_AES
) as zf:
if secret_password and isinstance(secret_password, str):
secret_password = secret_password.encode('utf8')
zf.setpassword(secret_password)
for encrypted_filename in encrypted_filenames:
with open(encrypted_filename, 'rb') as f:

View File

@ -547,7 +547,6 @@ class Config(dict):
'REFERER_CHECK_ENABLED': False,
'SESSION_ENGINE': 'cache',
'SESSION_SAVE_EVERY_REQUEST': True,
'SESSION_EXPIRE_AT_BROWSER_CLOSE_FORCE': False,
'SERVER_REPLAY_STORAGE': {},
'SECURITY_DATA_CRYPTO_ALGO': None,
'GMSSL_ENABLED': False,
@ -564,8 +563,10 @@ class Config(dict):
'FTP_LOG_KEEP_DAYS': 180,
'CLOUD_SYNC_TASK_EXECUTION_KEEP_DAYS': 180,
'JOB_EXECUTION_KEEP_DAYS': 180,
'PASSWORD_CHANGE_LOG_KEEP_DAYS': 999,
'TICKETS_ENABLED': True,
'TICKETS_DIRECT_APPROVE': False,
# 废弃的
'DEFAULT_ORG_SHOW_ALL_USERS': True,
@ -606,7 +607,9 @@ class Config(dict):
'GPT_MODEL': 'gpt-3.5-turbo',
'VIRTUAL_APP_ENABLED': False,
'FILE_UPLOAD_SIZE_LIMIT_MB': 200
'FILE_UPLOAD_SIZE_LIMIT_MB': 200,
'TICKET_APPLY_ASSET_SCOPE': 'all'
}
old_config_map = {
@ -701,7 +704,8 @@ class Config(dict):
def compatible_redis(self):
redis_config = {
'REDIS_PASSWORD': quote(str(self.REDIS_PASSWORD)),
'REDIS_PASSWORD': str(self.REDIS_PASSWORD),
'REDIS_PASSWORD_QUOTE': quote(str(self.REDIS_PASSWORD)),
}
for key, value in redis_config.items():
self[key] = value

View File

@ -66,11 +66,6 @@ class RequestMiddleware:
def __call__(self, request):
set_current_request(request)
response = self.get_response(request)
is_request_api = request.path.startswith('/api')
if not settings.SESSION_EXPIRE_AT_BROWSER_CLOSE and \
not is_request_api:
age = request.session.get_expiry_age()
request.session.set_expiry(age)
return response

View File

@ -3,6 +3,7 @@
path_perms_map = {
'xpack': '*',
'settings': '*',
'img': '*',
'replay': 'default',
'applets': 'terminal.view_applet',
'virtual_apps': 'terminal.view_virtualapp',

View File

@ -234,11 +234,9 @@ CSRF_COOKIE_NAME = '{}csrftoken'.format(SESSION_COOKIE_NAME_PREFIX)
SESSION_COOKIE_NAME = '{}sessionid'.format(SESSION_COOKIE_NAME_PREFIX)
SESSION_COOKIE_AGE = CONFIG.SESSION_COOKIE_AGE
SESSION_EXPIRE_AT_BROWSER_CLOSE = True
# 自定义的配置SESSION_EXPIRE_AT_BROWSER_CLOSE 始终为 True, 下面这个来控制是否强制关闭后过期 cookie
SESSION_EXPIRE_AT_BROWSER_CLOSE_FORCE = CONFIG.SESSION_EXPIRE_AT_BROWSER_CLOSE_FORCE
SESSION_SAVE_EVERY_REQUEST = CONFIG.SESSION_SAVE_EVERY_REQUEST
SESSION_ENGINE = "django.contrib.sessions.backends.{}".format(CONFIG.SESSION_ENGINE)
SESSION_EXPIRE_AT_BROWSER_CLOSE = CONFIG.SESSION_EXPIRE_AT_BROWSER_CLOSE
SESSION_ENGINE = "common.sessions.{}".format(CONFIG.SESSION_ENGINE)
MESSAGE_STORAGE = 'django.contrib.messages.storage.cookie.CookieStorage'
# Database
@ -408,7 +406,7 @@ if REDIS_SENTINEL_SERVICE_NAME and REDIS_SENTINELS:
else:
REDIS_LOCATION_NO_DB = '%(protocol)s://:%(password)s@%(host)s:%(port)s/{}' % {
'protocol': REDIS_PROTOCOL,
'password': CONFIG.REDIS_PASSWORD,
'password': CONFIG.REDIS_PASSWORD_QUOTE,
'host': CONFIG.REDIS_HOST,
'port': CONFIG.REDIS_PORT,
}

View File

@ -122,11 +122,11 @@ WS_LISTEN_PORT = CONFIG.WS_LISTEN_PORT
LOGIN_LOG_KEEP_DAYS = CONFIG.LOGIN_LOG_KEEP_DAYS
TASK_LOG_KEEP_DAYS = CONFIG.TASK_LOG_KEEP_DAYS
OPERATE_LOG_KEEP_DAYS = CONFIG.OPERATE_LOG_KEEP_DAYS
PASSWORD_CHANGE_LOG_KEEP_DAYS = CONFIG.PASSWORD_CHANGE_LOG_KEEP_DAYS
ACTIVITY_LOG_KEEP_DAYS = CONFIG.ACTIVITY_LOG_KEEP_DAYS
FTP_LOG_KEEP_DAYS = CONFIG.FTP_LOG_KEEP_DAYS
CLOUD_SYNC_TASK_EXECUTION_KEEP_DAYS = CONFIG.CLOUD_SYNC_TASK_EXECUTION_KEEP_DAYS
JOB_EXECUTION_KEEP_DAYS = CONFIG.JOB_EXECUTION_KEEP_DAYS
ORG_CHANGE_TO_URL = CONFIG.ORG_CHANGE_TO_URL
WINDOWS_SKIP_ALL_MANUAL_PASSWORD = CONFIG.WINDOWS_SKIP_ALL_MANUAL_PASSWORD
@ -137,6 +137,7 @@ CHANGE_AUTH_PLAN_SECURE_MODE_ENABLED = CONFIG.CHANGE_AUTH_PLAN_SECURE_MODE_ENABL
DATETIME_DISPLAY_FORMAT = '%Y-%m-%d %H:%M:%S'
TICKETS_ENABLED = CONFIG.TICKETS_ENABLED
TICKETS_DIRECT_APPROVE = CONFIG.TICKETS_DIRECT_APPROVE
REFERER_CHECK_ENABLED = CONFIG.REFERER_CHECK_ENABLED
CONNECTION_TOKEN_ENABLED = CONFIG.CONNECTION_TOKEN_ENABLED
@ -214,6 +215,9 @@ PERM_TREE_REGEN_INTERVAL = CONFIG.PERM_TREE_REGEN_INTERVAL
MAGNUS_ORACLE_PORTS = CONFIG.MAGNUS_ORACLE_PORTS
LIMIT_SUPER_PRIV = CONFIG.LIMIT_SUPER_PRIV
# Asset account may be too many
ASSET_SIZE = 'small'
# Chat AI
CHAT_AI_ENABLED = CONFIG.CHAT_AI_ENABLED
GPT_API_KEY = CONFIG.GPT_API_KEY
@ -224,3 +228,5 @@ GPT_MODEL = CONFIG.GPT_MODEL
VIRTUAL_APP_ENABLED = CONFIG.VIRTUAL_APP_ENABLED
FILE_UPLOAD_SIZE_LIMIT_MB = CONFIG.FILE_UPLOAD_SIZE_LIMIT_MB
TICKET_APPLY_ASSET_SCOPE = CONFIG.TICKET_APPLY_ASSET_SCOPE

View File

@ -82,7 +82,6 @@ BOOTSTRAP3 = {
# Django channels support websocket
REDIS_LAYERS_HOST = {
'db': CONFIG.REDIS_DB_WS,
'password': CONFIG.REDIS_PASSWORD or None,
}
REDIS_LAYERS_SSL_PARAMS = {}
@ -97,6 +96,7 @@ if REDIS_USE_SSL:
if REDIS_SENTINEL_SERVICE_NAME and REDIS_SENTINELS:
REDIS_LAYERS_HOST['sentinels'] = REDIS_SENTINELS
REDIS_LAYERS_HOST['password'] = CONFIG.REDIS_PASSWORD or None
REDIS_LAYERS_HOST['master_name'] = REDIS_SENTINEL_SERVICE_NAME
REDIS_LAYERS_HOST['sentinel_kwargs'] = {
'password': REDIS_SENTINEL_PASSWORD,
@ -111,7 +111,7 @@ else:
# More info see: https://github.com/django/channels_redis/issues/334
# REDIS_LAYERS_HOST['address'] = (CONFIG.REDIS_HOST, CONFIG.REDIS_PORT)
REDIS_LAYERS_ADDRESS = '{protocol}://:{password}@{host}:{port}/{db}'.format(
protocol=REDIS_PROTOCOL, password=CONFIG.REDIS_PASSWORD,
protocol=REDIS_PROTOCOL, password=CONFIG.REDIS_PASSWORD_QUOTE,
host=CONFIG.REDIS_HOST, port=CONFIG.REDIS_PORT, db=CONFIG.REDIS_DB_WS
)
REDIS_LAYERS_HOST['address'] = REDIS_LAYERS_ADDRESS
@ -153,7 +153,7 @@ if REDIS_SENTINEL_SERVICE_NAME and REDIS_SENTINELS:
else:
CELERY_BROKER_URL = CELERY_BROKER_URL_FORMAT % {
'protocol': REDIS_PROTOCOL,
'password': CONFIG.REDIS_PASSWORD,
'password': CONFIG.REDIS_PASSWORD_QUOTE,
'host': CONFIG.REDIS_HOST,
'port': CONFIG.REDIS_PORT,
'db': CONFIG.REDIS_DB_CELERY,
@ -187,6 +187,7 @@ ANSIBLE_LOG_DIR = os.path.join(PROJECT_DIR, 'data', 'ansible')
REDIS_HOST = CONFIG.REDIS_HOST
REDIS_PORT = CONFIG.REDIS_PORT
REDIS_PASSWORD = CONFIG.REDIS_PASSWORD
REDIS_PASSWORD_QUOTE = CONFIG.REDIS_PASSWORD_QUOTE
DJANGO_REDIS_SCAN_ITERSIZE = 1000

View File

@ -1,6 +1,6 @@
from django.contrib.contenttypes.fields import GenericRelation
from django.db import models
from django.db.models import OneToOneField
from django.db.models import OneToOneField, Count
from common.utils import lazyproperty
from .models import LabeledResource
@ -36,3 +36,37 @@ class LabeledMixin(models.Model):
@res_labels.setter
def res_labels(self, value):
self.real.labels.set(value, bulk=False)
@classmethod
def filter_resources_by_labels(cls, resources, label_ids):
return cls._get_filter_res_by_labels_m2m_all(resources, label_ids)
@classmethod
def _get_filter_res_by_labels_m2m_in(cls, resources, label_ids):
return resources.filter(label_id__in=label_ids)
@classmethod
def _get_filter_res_by_labels_m2m_all(cls, resources, label_ids):
if len(label_ids) == 1:
return cls._get_filter_res_by_labels_m2m_in(resources, label_ids)
resources = resources.filter(label_id__in=label_ids) \
.values('res_id') \
.order_by('res_id') \
.annotate(count=Count('res_id', distinct=True)) \
.values('res_id', 'count') \
.filter(count=len(label_ids))
return resources
@classmethod
def get_labels_filter_attr_q(cls, value, match):
resources = LabeledResource.objects.all()
if not value:
return None
if match != 'm2m_all':
resources = cls._get_filter_res_by_labels_m2m_in(resources, value)
else:
resources = cls._get_filter_res_by_labels_m2m_all(resources, value)
res_ids = set(resources.values_list('res_id', flat=True))
return models.Q(id__in=res_ids)

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7879f4eeb499e920ad6c4bfdb0b1f334936ca344c275be056f12fcf7485f2bf6
size 170948
oid sha256:d04781f4f0b0de3ac5f707febb222e239553d6103bca0cec41ab2fd5ab044571
size 173799

File diff suppressed because it is too large Load Diff

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:19d3a111cc245f9a9d36b860fd95447df916ad66c918bef672bacdad6bc77a8f
size 140119
oid sha256:e66a6fa05d25f1c502f95001b5ff0d0a310affd32eac939fd7b840845028074f
size 142298

File diff suppressed because it is too large Load Diff

View File

@ -1,28 +1,32 @@
import json
import time
from threading import Thread
from channels.generic.websocket import JsonWebsocketConsumer
from django.core.cache import cache
from django.conf import settings
from common.db.utils import safe_db_connection
from common.sessions.cache import user_session_manager
from common.utils import get_logger
from .signal_handlers import new_site_msg_chan
from .site_msg import SiteMessageUtil
logger = get_logger(__name__)
WS_SESSION_KEY = 'ws_session_key'
class SiteMsgWebsocket(JsonWebsocketConsumer):
sub = None
refresh_every_seconds = 10
@property
def session(self):
return self.scope['session']
def connect(self):
user = self.scope["user"]
if user.is_authenticated:
self.accept()
session = self.scope['session']
redis_client = cache.client.get_client()
redis_client.sadd(WS_SESSION_KEY, session.session_key)
user_session_manager.add_or_increment(self.session.session_key)
self.sub = self.watch_recv_new_site_msg()
else:
self.close()
@ -66,6 +70,32 @@ class SiteMsgWebsocket(JsonWebsocketConsumer):
if not self.sub:
return
self.sub.unsubscribe()
session = self.scope['session']
redis_client = cache.client.get_client()
redis_client.srem(WS_SESSION_KEY, session.session_key)
user_session_manager.decrement_or_remove(self.session.session_key)
if self.should_delete_session():
thread = Thread(target=self.delay_delete_session)
thread.start()
def should_delete_session(self):
return (self.session.modified or settings.SESSION_SAVE_EVERY_REQUEST) and \
not self.session.is_empty() and \
self.session.get_expire_at_browser_close() and \
not user_session_manager.check_active(self.session.session_key)
def delay_delete_session(self):
timeout = 6
check_interval = 0.5
start_time = time.time()
while time.time() - start_time < timeout:
time.sleep(check_interval)
if user_session_manager.check_active(self.session.session_key):
return
self.delete_session()
def delete_session(self):
try:
self.session.delete()
except Exception as e:
logger.info(f'delete session error: {e}')

View File

@ -1,3 +1,4 @@
import os
from collections import defaultdict
from functools import reduce
@ -29,6 +30,8 @@ class DefaultCallback:
)
self.status = 'running'
self.finished = False
self.local_pid = 0
self.private_data_dir = None
@property
def host_results(self):
@ -45,6 +48,9 @@ class DefaultCallback:
event = data.get('event', None)
if not event:
return
pid = data.get('pid', None)
if pid:
self.write_pid(pid)
event_data = data.get('event_data', {})
host = event_data.get('remote_addr', '')
task = event_data.get('task', '')
@ -152,3 +158,11 @@ class DefaultCallback:
def status_handler(self, data, **kwargs):
status = data.get('status', '')
self.status = self.STATUS_MAPPER.get(status, 'unknown')
rc = kwargs.get('runner_config', None)
self.private_data_dir = rc.private_data_dir if rc else '/tmp/'
def write_pid(self, pid):
pid_filepath = os.path.join(self.private_data_dir, 'local.pid')
with open(pid_filepath, 'w') as f:
f.write(str(pid))

View File

@ -2,6 +2,7 @@
#
import os
import re
from collections import defaultdict
from celery.result import AsyncResult
from django.shortcuts import get_object_or_404
@ -166,16 +167,58 @@ class CeleryTaskViewSet(
i.next_exec_time = now + next_run_at
return queryset
def generate_summary_state(self, execution_qs):
model = self.get_queryset().model
executions = execution_qs.order_by('-date_published').values('name', 'state')
summary_state_dict = defaultdict(
lambda: {
'states': [], 'state': 'green',
'summary': {'total': 0, 'success': 0}
}
)
for execution in executions:
name = execution['name']
state = execution['state']
summary = summary_state_dict[name]['summary']
summary['total'] += 1
summary['success'] += 1 if state == 'SUCCESS' else 0
states = summary_state_dict[name].get('states')
if states is not None and len(states) >= 5:
color = model.compute_state_color(states)
summary_state_dict[name]['state'] = color
summary_state_dict[name].pop('states', None)
elif isinstance(states, list):
states.append(state)
return summary_state_dict
def loading_summary_state(self, queryset):
if isinstance(queryset, list):
names = [i.name for i in queryset]
execution_qs = CeleryTaskExecution.objects.filter(name__in=names)
else:
execution_qs = CeleryTaskExecution.objects.all()
summary_state_dict = self.generate_summary_state(execution_qs)
for i in queryset:
i.summary = summary_state_dict.get(i.name, {}).get('summary', {})
i.state = summary_state_dict.get(i.name, {}).get('state', 'green')
return queryset
def list(self, request, *args, **kwargs):
queryset = self.filter_queryset(self.get_queryset())
page = self.paginate_queryset(queryset)
if page is not None:
page = self.generate_execute_time(page)
page = self.loading_summary_state(page)
serializer = self.get_serializer(page, many=True)
return self.get_paginated_response(serializer.data)
queryset = self.generate_execute_time(queryset)
queryset = self.loading_summary_state(queryset)
serializer = self.get_serializer(queryset, many=True)
return Response(serializer.data)

View File

@ -1,9 +1,11 @@
import json
import os
from celery.result import AsyncResult
from django.conf import settings
from django.db import transaction
from django.db.models import Count
from django.http import Http404
from django.shortcuts import get_object_or_404
from django.utils._os import safe_join
from django.utils.translation import gettext_lazy as _
@ -14,9 +16,10 @@ from rest_framework.views import APIView
from assets.models import Asset
from common.const.http import POST
from common.permissions import IsValidUser
from ops.celery import app
from ops.const import Types
from ops.models import Job, JobExecution
from ops.serializers.job import JobSerializer, JobExecutionSerializer, FileSerializer
from ops.serializers.job import JobSerializer, JobExecutionSerializer, FileSerializer, JobTaskStopSerializer
__all__ = [
'JobViewSet', 'JobExecutionViewSet', 'JobRunVariableHelpAPIView',
@ -187,6 +190,33 @@ class JobExecutionViewSet(OrgBulkModelViewSet):
queryset = queryset.filter(creator=self.request.user)
return queryset
@action(methods=[POST], detail=False, serializer_class=JobTaskStopSerializer, permission_classes=[IsValidUser, ],
url_path='stop')
def stop(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
if not serializer.is_valid():
return Response({'error': serializer.errors}, status=400)
task_id = serializer.validated_data['task_id']
try:
instance = get_object_or_404(JobExecution, task_id=task_id, creator=request.user)
except Http404:
return Response(
{'error': _('The task is being created and cannot be interrupted. Please try again later.')},
status=400
)
task = AsyncResult(task_id, app=app)
inspect = app.control.inspect()
for worker in inspect.registered().keys():
if task_id not in [at['id'] for at in inspect.active().get(worker, [])]:
# 在队列中未执行使用revoke执行
task.revoke(terminate=True)
instance.set_error('Job stop by "revoke task {}"'.format(task_id))
return Response({'task_id': task_id}, status=200)
instance.stop()
return Response({'task_id': task_id}, status=200)
class JobAssetDetail(APIView):
rbac_perms = {

View File

@ -15,6 +15,9 @@ class CeleryTask(models.Model):
name = models.CharField(max_length=1024, verbose_name=_('Name'))
date_last_publish = models.DateTimeField(null=True, verbose_name=_("Date last publish"))
__summary = None
__state = None
@property
def meta(self):
task = app.tasks.get(self.name, None)
@ -25,23 +28,43 @@ class CeleryTask(models.Model):
@property
def summary(self):
if self.__summary is not None:
return self.__summary
executions = CeleryTaskExecution.objects.filter(name=self.name)
total = executions.count()
success = executions.filter(state='SUCCESS').count()
return {'total': total, 'success': success}
@summary.setter
def summary(self, value):
self.__summary = value
@staticmethod
def compute_state_color(states: list, default_count=5):
color = 'green'
states = states[:default_count]
if not states:
return color
if states[0] == 'FAILURE':
color = 'red'
elif 'FAILURE' in states:
color = 'yellow'
return color
@property
def state(self):
last_five_executions = CeleryTaskExecution.objects.filter(name=self.name).order_by('-date_published')[:5]
if self.__state is not None:
return self.__state
last_five_executions = CeleryTaskExecution.objects.filter(
name=self.name
).order_by('-date_published').values('state')[:5]
states = [i['state'] for i in last_five_executions]
color = self.compute_state_color(states)
return color
if len(last_five_executions) > 0:
if last_five_executions[0].state == 'FAILURE':
return "red"
for execution in last_five_executions:
if execution.state == 'FAILURE':
return "yellow"
return "green"
@state.setter
def state(self, value):
self.__state = value
class Meta:
verbose_name = _("Celery Task")

View File

@ -67,6 +67,7 @@ class JMSPermedInventory(JMSInventory):
'postgresql': ['postgresql'],
'sqlserver': ['sqlserver'],
'ssh': ['shell', 'python', 'win_shell', 'raw'],
'winrm': ['win_shell', 'shell'],
}
if self.module not in protocol_supported_modules_mapping.get(protocol.name, []):
@ -553,6 +554,15 @@ class JobExecution(JMSOrgBaseModel):
finally:
ssh_tunnel.local_gateway_clean(runner)
def stop(self):
with open(os.path.join(self.private_dir, 'local.pid')) as f:
try:
pid = f.read()
os.kill(int(pid), 9)
except Exception as e:
print(e)
self.set_error('Job stop by "kill -9 {}"'.format(pid))
class Meta:
verbose_name = _("Job Execution")
ordering = ['-date_created']

View File

@ -57,6 +57,13 @@ class FileSerializer(serializers.Serializer):
ref_name = "JobFileSerializer"
class JobTaskStopSerializer(serializers.Serializer):
task_id = serializers.CharField(max_length=128)
class Meta:
ref_name = "JobTaskStopSerializer"
class JobExecutionSerializer(BulkOrgResourceModelSerializer):
creator = ReadableHiddenField(default=serializers.CurrentUserDefault())
job_type = serializers.ReadOnlyField(label=_("Job type"))

View File

@ -173,6 +173,9 @@ class Organization(OrgRoleMixin, JMSBaseModel):
def is_default(self):
return str(self.id) == self.DEFAULT_ID
def is_system(self):
return str(self.id) == self.SYSTEM_ID
@property
def internal(self):
return str(self.id) in self.INTERNAL_IDS

View File

@ -87,7 +87,8 @@ class OrgResourceStatisticsRefreshUtil:
if not cache_field_name:
return
org = getattr(instance, 'org', None)
cls.refresh_org_fields(((org, cache_field_name),))
cache_field_name = tuple(cache_field_name)
cls.refresh_org_fields.delay(org_fields=((org, cache_field_name),))
@receiver(post_save)

View File

@ -6,6 +6,7 @@ from functools import wraps
from inspect import signature
from werkzeug.local import LocalProxy
from django.conf import settings
from common.local import thread_local
from .models import Organization
@ -14,7 +15,6 @@ from .models import Organization
def get_org_from_request(request):
# query中优先级最高
oid = request.GET.get("oid")
# 其次header
if not oid:
oid = request.META.get("HTTP_X_JMS_ORG")
@ -24,14 +24,33 @@ def get_org_from_request(request):
# 其次session
if not oid:
oid = request.session.get("oid")
if oid and oid.lower() == 'default':
return Organization.default()
if oid and oid.lower() == 'root':
return Organization.root()
if oid and oid.lower() == 'system':
return Organization.system()
org = Organization.get_instance(oid)
if org and org.internal:
# 内置组织直接返回
return org
if not settings.XPACK_ENABLED:
# 社区版用户只能使用默认组织
return Organization.default()
if not org and request.user.is_authenticated:
# 企业版用户优先从自己有权限的组织中获取
org = request.user.orgs.first()
if not org:
org = Organization.default()
if not oid:
oid = Organization.DEFAULT_ID
if oid.lower() == "default":
oid = Organization.DEFAULT_ID
elif oid.lower() == "root":
oid = Organization.ROOT_ID
org = Organization.get_instance(oid, default=Organization.default())
return org

View File

@ -1,9 +1,11 @@
import abc
from django.conf import settings
from rest_framework.generics import ListAPIView, RetrieveAPIView
from assets.api.asset.asset import AssetFilterSet
from assets.models import Asset, Node
from common.api.mixin import ExtraFilterFieldsMixin
from common.utils import get_logger, lazyproperty, is_uuid
from orgs.utils import tmp_to_root_org
from perms import serializers
@ -37,8 +39,8 @@ class UserPermedAssetRetrieveApi(SelfOrPKUserMixin, RetrieveAPIView):
return asset
class BaseUserPermedAssetsApi(SelfOrPKUserMixin, ListAPIView):
ordering = ('name',)
class BaseUserPermedAssetsApi(SelfOrPKUserMixin, ExtraFilterFieldsMixin, ListAPIView):
ordering = []
search_fields = ('name', 'address', 'comment')
ordering_fields = ("name", "address")
filterset_class = AssetFilterSet
@ -47,6 +49,8 @@ class BaseUserPermedAssetsApi(SelfOrPKUserMixin, ListAPIView):
def get_queryset(self):
if getattr(self, 'swagger_fake_view', False):
return Asset.objects.none()
if settings.ASSET_SIZE == 'small':
self.ordering = ['name']
assets = self.get_assets()
assets = self.serializer_class.setup_eager_loading(assets)
return assets

View File

@ -8,9 +8,9 @@ from rest_framework import serializers
from accounts.models import Account
from assets.const import Category, AllTypes
from assets.models import Node, Asset, Platform
from assets.serializers.asset.common import AssetLabelSerializer, AssetProtocolsPermsSerializer
from common.serializers.fields import ObjectRelatedField, LabeledChoiceField
from assets.serializers.asset.common import AssetProtocolsPermsSerializer
from common.serializers import ResourceLabelsMixin
from common.serializers.fields import ObjectRelatedField, LabeledChoiceField
from orgs.mixins.serializers import OrgResourceModelSerializerMixin
from perms.serializers.permission import ActionChoicesField

View File

@ -13,7 +13,7 @@ class AssetPermissionUtil(object):
""" 资产授权相关的方法工具 """
@timeit
def get_permissions_for_user(self, user, with_group=True, flat=False):
def get_permissions_for_user(self, user, with_group=True, flat=False, with_expired=False):
""" 获取用户的授权规则 """
perm_ids = set()
# user
@ -25,7 +25,7 @@ class AssetPermissionUtil(object):
groups = user.groups.all()
group_perm_ids = self.get_permissions_for_user_groups(groups, flat=True)
perm_ids.update(group_perm_ids)
perms = self.get_permissions(ids=perm_ids)
perms = self.get_permissions(ids=perm_ids, with_expired=with_expired)
if flat:
return perms.values_list('id', flat=True)
return perms
@ -102,6 +102,8 @@ class AssetPermissionUtil(object):
return model.objects.filter(id__in=ids)
@staticmethod
def get_permissions(ids):
perms = AssetPermission.objects.filter(id__in=ids).valid().order_by('-date_expired')
return perms
def get_permissions(ids, with_expired=False):
perms = AssetPermission.objects.filter(id__in=ids)
if not with_expired:
perms = perms.valid()
return perms.order_by('-date_expired')

View File

@ -7,10 +7,10 @@ from django.db.models import Q
from rest_framework.utils.encoders import JSONEncoder
from assets.const import AllTypes
from assets.models import FavoriteAsset, Asset
from assets.models import FavoriteAsset, Asset, Node
from common.utils.common import timeit, get_logger
from orgs.utils import current_org, tmp_to_root_org
from perms.models import PermNode, UserAssetGrantedTreeNodeRelation
from perms.models import PermNode, UserAssetGrantedTreeNodeRelation, AssetPermission
from .permission import AssetPermissionUtil
__all__ = ['AssetPermissionPermAssetUtil', 'UserPermAssetUtil', 'UserPermNodeUtil']
@ -21,38 +21,37 @@ logger = get_logger(__name__)
class AssetPermissionPermAssetUtil:
def __init__(self, perm_ids):
self.perm_ids = perm_ids
self.perm_ids = set(perm_ids)
def get_all_assets(self):
node_assets = self.get_perm_nodes_assets()
direct_assets = self.get_direct_assets()
# 比原来的查到所有 asset id 再搜索块很多,因为当资产量大的时候,搜索会很慢
return (node_assets | direct_assets).distinct()
return (node_assets | direct_assets).order_by().distinct()
@timeit
def get_perm_nodes_assets(self, flat=False):
""" 获取所有授权节点下的资产 """
from assets.models import Node
from ..models import AssetPermission
def get_perm_nodes(self):
""" 获取所有授权节点 """
nodes_ids = AssetPermission.objects \
.filter(id__in=self.perm_ids) \
.values_list('nodes', flat=True)
nodes_ids = set(nodes_ids)
nodes = Node.objects.filter(id__in=nodes_ids).only('id', 'key')
assets = PermNode.get_nodes_all_assets(*nodes)
if flat:
return set(assets.values_list('id', flat=True))
return nodes
@timeit
def get_perm_nodes_assets(self):
""" 获取所有授权节点下的资产 """
nodes = self.get_perm_nodes()
assets = PermNode.get_nodes_all_assets(*nodes, distinct=False)
return assets
@timeit
def get_direct_assets(self, flat=False):
def get_direct_assets(self):
""" 获取直接授权的资产 """
from ..models import AssetPermission
asset_ids = AssetPermission.objects \
.filter(id__in=self.perm_ids) \
.values_list('assets', flat=True)
assets = Asset.objects.filter(id__in=asset_ids).distinct()
if flat:
return set(assets.values_list('id', flat=True))
asset_ids = AssetPermission.assets.through.objects \
.filter(assetpermission_id__in=self.perm_ids) \
.values_list('asset_id', flat=True)
assets = Asset.objects.filter(id__in=asset_ids)
return assets

View File

@ -72,7 +72,7 @@ class UserPermTreeRefreshUtil(_UserPermTreeCacheMixin):
@timeit
def refresh_if_need(self, force=False):
built_just_now = cache.get(self.cache_key_time)
built_just_now = False if settings.ASSET_SIZE == 'small' else cache.get(self.cache_key_time)
if built_just_now:
logger.info('Refresh user perm tree just now, pass: {}'.format(built_just_now))
return
@ -80,12 +80,18 @@ class UserPermTreeRefreshUtil(_UserPermTreeCacheMixin):
if not to_refresh_orgs:
logger.info('Not have to refresh orgs')
return
logger.info("Delay refresh user orgs: {} {}".format(self.user, [o.name for o in to_refresh_orgs]))
refresh_user_orgs_perm_tree(user_orgs=((self.user, tuple(to_refresh_orgs)),))
refresh_user_favorite_assets(users=(self.user,))
sync = True if settings.ASSET_SIZE == 'small' else False
refresh_user_orgs_perm_tree.apply(sync=sync, user_orgs=((self.user, tuple(to_refresh_orgs)),))
refresh_user_favorite_assets.apply(sync=sync, users=(self.user,))
@timeit
def refresh_tree_manual(self):
"""
用来手动 debug
:return:
"""
built_just_now = cache.get(self.cache_key_time)
if built_just_now:
logger.info('Refresh just now, pass: {}'.format(built_just_now))
@ -105,8 +111,9 @@ class UserPermTreeRefreshUtil(_UserPermTreeCacheMixin):
return
self._clean_user_perm_tree_for_legacy_org()
ttl = settings.PERM_TREE_REGEN_INTERVAL
cache.set(self.cache_key_time, int(time.time()), ttl)
if settings.ASSET_SIZE != 'small':
ttl = settings.PERM_TREE_REGEN_INTERVAL
cache.set(self.cache_key_time, int(time.time()), ttl)
lock = UserGrantedTreeRebuildLock(self.user.id)
got = lock.acquire(blocking=False)
@ -187,13 +194,20 @@ class UserPermTreeExpireUtil(_UserPermTreeCacheMixin):
@on_transaction_commit
def expire_perm_tree_for_users_orgs(self, user_ids, org_ids):
user_ids = list(user_ids)
org_ids = [str(oid) for oid in org_ids]
with self.client.pipeline() as p:
for uid in user_ids:
cache_key = self.get_cache_key(uid)
p.srem(cache_key, *org_ids)
p.execute()
logger.info('Expire perm tree for users: [{}], orgs: [{}]'.format(user_ids, org_ids))
users_display = ','.join([str(i) for i in user_ids[:3]])
if len(user_ids) > 3:
users_display += '...'
orgs_display = ','.join([str(i) for i in org_ids[:3]])
if len(org_ids) > 3:
orgs_display += '...'
logger.info('Expire perm tree for users: [{}], orgs: [{}]'.format(users_display, orgs_display))
def expire_perm_tree_for_all_user(self):
keys = self.client.keys(self.cache_key_all_user)

View File

@ -1,28 +1,16 @@
# -*- coding: utf-8 -*-
#
import threading
from django.conf import settings
from django.utils.translation import gettext_lazy as _
from rest_framework import generics
from rest_framework.generics import CreateAPIView
from rest_framework.views import Response, APIView
from rest_framework.views import Response
from common.api import AsyncApiMixin
from common.utils import get_logger
from orgs.models import Organization
from orgs.utils import current_org
from users.models import User
from ..models import Setting
from ..serializers import (
LDAPTestConfigSerializer, LDAPUserSerializer,
LDAPTestLoginSerializer
)
from ..tasks import sync_ldap_user
from ..serializers import LDAPUserSerializer
from ..utils import (
LDAPServerUtil, LDAPCacheUtil, LDAPImportUtil, LDAPSyncUtil,
LDAP_USE_CACHE_FLAGS, LDAPTestUtil
LDAPServerUtil, LDAPCacheUtil,
LDAP_USE_CACHE_FLAGS
)
logger = get_logger(__file__)
@ -100,49 +88,3 @@ class LDAPUserListApi(generics.ListAPIView):
else:
data = {'msg': _('Users are not synchronized, please click the user synchronization button')}
return Response(data=data, status=400)
class LDAPUserImportAPI(APIView):
perm_model = Setting
rbac_perms = {
'POST': 'settings.change_auth'
}
def get_orgs(self):
org_ids = self.request.data.get('org_ids')
if org_ids:
orgs = list(Organization.objects.filter(id__in=org_ids))
else:
orgs = [current_org]
return orgs
def get_ldap_users(self):
username_list = self.request.data.get('username_list', [])
cache_police = self.request.query_params.get('cache_police', True)
if '*' in username_list:
users = LDAPServerUtil().search()
elif cache_police in LDAP_USE_CACHE_FLAGS:
users = LDAPCacheUtil().search(search_users=username_list)
else:
users = LDAPServerUtil().search(search_users=username_list)
return users
def post(self, request):
try:
users = self.get_ldap_users()
except Exception as e:
return Response({'error': str(e)}, status=400)
if users is None:
return Response({'msg': _('Get ldap users is None')}, status=400)
orgs = self.get_orgs()
new_users, errors = LDAPImportUtil().perform_import(users, orgs)
if errors:
return Response({'errors': errors}, status=400)
count = users if users is None else len(users)
orgs_name = ', '.join([str(org) for org in orgs])
return Response({
'msg': _('Imported {} users successfully (Organization: {})').format(count, orgs_name)
})

View File

@ -3,6 +3,7 @@ from rest_framework import generics
from rest_framework.permissions import AllowAny
from authentication.permissions import IsValidUserOrConnectionToken
from common.const.choices import COUNTRY_CALLING_CODES
from common.utils import get_logger, lazyproperty
from common.utils.timezone import local_now
from .. import serializers
@ -24,7 +25,8 @@ class OpenPublicSettingApi(generics.RetrieveAPIView):
def get_object(self):
return {
"XPACK_ENABLED": settings.XPACK_ENABLED,
"INTERFACE": self.interface_setting
"INTERFACE": self.interface_setting,
"COUNTRY_CALLING_CODES": COUNTRY_CALLING_CODES
}

View File

@ -43,7 +43,7 @@ class OAuth2SettingSerializer(serializers.Serializer):
)
AUTH_OAUTH2_ACCESS_TOKEN_METHOD = serializers.ChoiceField(
default='GET', label=_('Client authentication method'),
choices=(('GET', 'GET'), ('POST', 'POST'))
choices=(('GET', 'GET'), ('POST', 'POST-DATA'), ('POST_JSON', 'POST-JSON'))
)
AUTH_OAUTH2_PROVIDER_USERINFO_ENDPOINT = serializers.CharField(
required=True, max_length=1024, label=_('Provider userinfo endpoint')

View File

@ -22,6 +22,10 @@ class CleaningSerializer(serializers.Serializer):
min_value=MIN_VALUE, max_value=9999,
label=_("Operate log keep days (day)"),
)
PASSWORD_CHANGE_LOG_KEEP_DAYS = serializers.IntegerField(
min_value=MIN_VALUE, max_value=9999,
label=_("password change log keep days (day)"),
)
FTP_LOG_KEEP_DAYS = serializers.IntegerField(
min_value=MIN_VALUE, max_value=9999,
label=_("FTP log keep days (day)"),

View File

@ -109,6 +109,7 @@ class TicketSettingSerializer(serializers.Serializer):
PREFIX_TITLE = _('Ticket')
TICKETS_ENABLED = serializers.BooleanField(required=False, default=True, label=_("Enable tickets"))
TICKETS_DIRECT_APPROVE = serializers.BooleanField(required=False, default=False, label=_("No login approval"))
TICKET_AUTHORIZE_DEFAULT_TIME = serializers.IntegerField(
min_value=1, max_value=999999, required=False,
label=_("Ticket authorize default time")

View File

@ -11,6 +11,7 @@ __all__ = [
class PublicSettingSerializer(serializers.Serializer):
XPACK_ENABLED = serializers.BooleanField()
INTERFACE = serializers.DictField()
COUNTRY_CALLING_CODES = serializers.ListField()
class PrivateSettingSerializer(PublicSettingSerializer):
@ -50,6 +51,7 @@ class PrivateSettingSerializer(PublicSettingSerializer):
ANNOUNCEMENT = serializers.DictField()
TICKETS_ENABLED = serializers.BooleanField()
TICKETS_DIRECT_APPROVE = serializers.BooleanField()
CONNECTION_TOKEN_REUSABLE = serializers.BooleanField()
CACHE_LOGIN_PASSWORD_ENABLED = serializers.BooleanField()
VAULT_ENABLED = serializers.BooleanField()

View File

@ -14,9 +14,13 @@
</ul>
<b>{% trans "Synced User" %}:</b>
<ul>
{% for user in users %}
<li>{{ user }}</li>
{% endfor %}
{% if users %}
{% for user in users %}
<li>{{ user }}</li>
{% endfor %}
{% else %}
<li>{% trans 'No user synchronization required' %}</li>
{% endif %}
</ul>
{% if errors %}
<b>{% trans 'Error' %}:</b>

View File

@ -12,7 +12,6 @@ router.register(r'chatai-prompts', api.ChatPromptViewSet, 'chatai-prompt')
urlpatterns = [
path('mail/testing/', api.MailTestingAPI.as_view(), name='mail-testing'),
path('ldap/users/', api.LDAPUserListApi.as_view(), name='ldap-user-list'),
path('ldap/users/import/', api.LDAPUserImportAPI.as_view(), name='ldap-user-import'),
path('wecom/testing/', api.WeComTestingAPI.as_view(), name='wecom-testing'),
path('dingtalk/testing/', api.DingTalkTestingAPI.as_view(), name='dingtalk-testing'),
path('feishu/testing/', api.FeiShuTestingAPI.as_view(), name='feishu-testing'),

View File

@ -6,6 +6,7 @@ import asyncio
from channels.generic.websocket import AsyncJsonWebsocketConsumer
from django.core.cache import cache
from django.conf import settings
from django.utils.translation import gettext_lazy as _
from common.db.utils import close_old_connections
from common.utils import get_logger
@ -13,9 +14,12 @@ from settings.serializers import (
LDAPTestConfigSerializer,
LDAPTestLoginSerializer
)
from orgs.models import Organization
from orgs.utils import current_org
from settings.tasks import sync_ldap_user
from settings.utils import (
LDAPSyncUtil, LDAPTestUtil
LDAPServerUtil, LDAPCacheUtil, LDAPImportUtil, LDAPSyncUtil,
LDAP_USE_CACHE_FLAGS, LDAPTestUtil
)
from .tools import (
verbose_ping, verbose_telnet, verbose_nmap,
@ -27,9 +31,11 @@ logger = get_logger(__name__)
CACHE_KEY_LDAP_TEST_CONFIG_MSG = 'CACHE_KEY_LDAP_TEST_CONFIG_MSG'
CACHE_KEY_LDAP_TEST_LOGIN_MSG = 'CACHE_KEY_LDAP_TEST_LOGIN_MSG'
CACHE_KEY_LDAP_SYNC_USER_MSG = 'CACHE_KEY_LDAP_SYNC_USER_MSG'
CACHE_KEY_LDAP_IMPORT_USER_MSG = 'CACHE_KEY_LDAP_IMPORT_USER_MSG'
CACHE_KEY_LDAP_TEST_CONFIG_TASK_STATUS = 'CACHE_KEY_LDAP_TEST_CONFIG_TASK_STATUS'
CACHE_KEY_LDAP_TEST_LOGIN_TASK_STATUS = 'CACHE_KEY_LDAP_TEST_LOGIN_TASK_STATUS'
CACHE_KEY_LDAP_SYNC_USER_TASK_STATUS = 'CACHE_KEY_LDAP_SYNC_USER_TASK_STATUS'
CACHE_KEY_LDAP_IMPORT_USER_TASK_STATUS = 'CACHE_KEY_LDAP_IMPORT_USER_TASK_STATUS'
TASK_STATUS_IS_RUNNING = 'RUNNING'
TASK_STATUS_IS_OVER = 'OVER'
@ -117,6 +123,8 @@ class LdapWebsocket(AsyncJsonWebsocketConsumer):
ok, msg = cache.get(CACHE_KEY_LDAP_TEST_CONFIG_MSG)
elif msg_type == 'sync_user':
ok, msg = cache.get(CACHE_KEY_LDAP_SYNC_USER_MSG)
elif msg_type == 'import_user':
ok, msg = cache.get(CACHE_KEY_LDAP_IMPORT_USER_MSG)
else:
ok, msg = cache.get(CACHE_KEY_LDAP_TEST_LOGIN_MSG)
await self.send_msg(ok, msg)
@ -165,8 +173,8 @@ class LdapWebsocket(AsyncJsonWebsocketConsumer):
cache.set(task_key, TASK_STATUS_IS_OVER, ttl)
@staticmethod
def set_task_msg(task_key, ok, msg):
cache.set(task_key, (ok, msg), 120)
def set_task_msg(task_key, ok, msg, ttl=120):
cache.set(task_key, (ok, msg), ttl)
def run_testing_config(self, data):
while True:
@ -207,3 +215,53 @@ class LdapWebsocket(AsyncJsonWebsocketConsumer):
ok = False if msg else True
self.set_task_status_over(CACHE_KEY_LDAP_SYNC_USER_TASK_STATUS)
self.set_task_msg(CACHE_KEY_LDAP_SYNC_USER_MSG, ok, msg)
def run_import_user(self, data):
while True:
if self.task_is_over(CACHE_KEY_LDAP_IMPORT_USER_TASK_STATUS):
break
else:
ok, msg = self.import_user(data)
self.set_task_status_over(CACHE_KEY_LDAP_IMPORT_USER_TASK_STATUS, 3)
self.set_task_msg(CACHE_KEY_LDAP_IMPORT_USER_MSG, ok, msg, 3)
def import_user(self, data):
ok = False
org_ids = data.get('org_ids')
username_list = data.get('username_list', [])
cache_police = data.get('cache_police', True)
try:
users = self.get_ldap_users(username_list, cache_police)
if users is None:
msg = _('Get ldap users is None')
orgs = self.get_orgs(org_ids)
new_users, error_msg = LDAPImportUtil().perform_import(users, orgs)
if error_msg:
msg = error_msg
count = users if users is None else len(users)
orgs_name = ', '.join([str(org) for org in orgs])
ok = True
msg = _('Imported {} users successfully (Organization: {})').format(count, orgs_name)
except Exception as e:
msg = str(e)
return ok, msg
@staticmethod
def get_orgs(org_ids):
if org_ids:
orgs = list(Organization.objects.filter(id__in=org_ids))
else:
orgs = [current_org]
return orgs
@staticmethod
def get_ldap_users(username_list, cache_police):
if '*' in username_list:
users = LDAPServerUtil().search()
elif cache_police in LDAP_USE_CACHE_FLAGS:
users = LDAPCacheUtil().search(search_users=username_list)
else:
users = LDAPServerUtil().search(search_users=username_list)
return users

View File

@ -9,7 +9,7 @@ from django.conf import settings
from django.core.files.storage import default_storage
from django.http import HttpResponse
from django.shortcuts import get_object_or_404
from django.utils.translation import gettext as _
from django.utils.translation import gettext as _, get_language
from rest_framework import viewsets
from rest_framework.decorators import action
from rest_framework.request import Request
@ -19,6 +19,8 @@ from rest_framework.serializers import ValidationError
from common.api import JMSBulkModelViewSet
from common.serializers import FileSerializer
from common.utils import is_uuid
from common.utils.http import is_true
from common.utils.yml import yaml_load_with_i18n
from terminal import serializers
from terminal.models import AppletPublication, Applet
@ -106,9 +108,66 @@ class AppletViewSet(DownloadUploadMixin, JMSBulkModelViewSet):
def get_object(self):
pk = self.kwargs.get('pk')
if not is_uuid(pk):
return get_object_or_404(Applet, name=pk)
obj = get_object_or_404(Applet, name=pk)
else:
return get_object_or_404(Applet, pk=pk)
obj = get_object_or_404(Applet, pk=pk)
return self.trans_object(obj)
def get_queryset(self):
queryset = super().get_queryset()
queryset = self.trans_queryset(queryset)
return queryset
@staticmethod
def read_manifest_with_i18n(obj, lang='zh'):
path = os.path.join(obj.path, 'manifest.yml')
if os.path.exists(path):
with open(path, encoding='utf8') as f:
manifest = yaml_load_with_i18n(f, lang)
else:
manifest = {}
return manifest
def trans_queryset(self, queryset):
for obj in queryset:
self.trans_object(obj)
return queryset
@staticmethod
def readme(obj, lang=''):
lang = lang[:2]
readme_file = os.path.join(obj.path, f'README_{lang.upper()}.md')
if os.path.isfile(readme_file):
with open(readme_file, 'r') as f:
return f.read()
return ''
def trans_object(self, obj):
lang = get_language()
manifest = self.read_manifest_with_i18n(obj, lang)
obj.display_name = manifest.get('display_name', obj.display_name)
obj.comment = manifest.get('comment', obj.comment)
obj.readme = self.readme(obj, lang)
return obj
def is_record_found(self, obj, search):
combine_fields = ' '.join([getattr(obj, f, '') for f in self.search_fields])
return search in combine_fields
def filter_queryset(self, queryset):
search = self.request.query_params.get('search')
if search:
queryset = [i for i in queryset if self.is_record_found(i, search)]
for field in self.filterset_fields:
field_value = self.request.query_params.get(field)
if not field_value:
continue
if field in ['is_active', 'builtin']:
field_value = is_true(field_value)
queryset = [i for i in queryset if getattr(i, field, '') == field_value]
return queryset
def perform_destroy(self, instance):
if not instance.name:

View File

@ -42,7 +42,7 @@ class SmartEndpointViewMixin:
return endpoint
def match_endpoint_by_label(self):
return Endpoint.match_by_instance_label(self.target_instance, self.target_protocol)
return Endpoint.match_by_instance_label(self.target_instance, self.target_protocol, self.request)
def match_endpoint_by_target_ip(self):
target_ip = self.request.GET.get('target_ip', '') # 支持target_ip参数用来方便测试

View File

@ -18,10 +18,11 @@ from rest_framework.response import Response
from audits.const import ActionChoices
from common.api import AsyncApiMixin
from common.const.http import GET
from common.const.http import GET, POST
from common.drf.filters import BaseFilterSet
from common.drf.filters import DatetimeRangeFilterBackend
from common.drf.renders import PassthroughRenderer
from common.permissions import IsServiceAccount
from common.storage.replay import ReplayStorageHandler
from common.utils import data_to_json, is_uuid, i18n_fmt
from common.utils import get_logger, get_object_or_none
@ -33,6 +34,7 @@ from terminal import serializers
from terminal.const import TerminalType
from terminal.models import Session
from terminal.permissions import IsSessionAssignee
from terminal.session_lifecycle import lifecycle_events_map, reasons_map
from terminal.utils import is_session_approver
from users.models import User
@ -79,6 +81,7 @@ class SessionViewSet(RecordViewLogMixin, OrgBulkModelViewSet):
serializer_classes = {
'default': serializers.SessionSerializer,
'display': serializers.SessionDisplaySerializer,
'lifecycle_log': serializers.SessionLifecycleLogSerializer,
}
search_fields = [
"user", "asset", "account", "remote_addr",
@ -168,6 +171,23 @@ class SessionViewSet(RecordViewLogMixin, OrgBulkModelViewSet):
count = queryset.count()
return Response({'count': count})
@action(methods=[POST], detail=True, permission_classes=[IsServiceAccount], url_path='lifecycle_log',
url_name='lifecycle_log')
def lifecycle_log(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
validated_data = serializer.validated_data
event = validated_data.pop('event', None)
event_class = lifecycle_events_map.get(event, None)
if not event_class:
return Response({'msg': f'event_name {event} invalid'}, status=400)
session = self.get_object()
reason = validated_data.pop('reason', None)
reason = reasons_map.get(reason, reason)
event_obj = event_class(session, reason, **validated_data)
activity_log = event_obj.create_activity_log()
return Response({'msg': 'ok', 'id': activity_log.id})
def get_queryset(self):
queryset = super().get_queryset() \
.prefetch_related('terminal') \

View File

@ -0,0 +1,9 @@
## Selenium Version
- Selenium == 4.4.0
- Chrome and ChromeDriver versions must match
- Driver [download address](https://chromedriver.chromium.org/downloads)
## ChangeLog
Refer to [ChangeLog](./ChangeLog) for some important updates.

View File

@ -0,0 +1,9 @@
## Selenium バージョン
- Selenium == 4.4.0
- Chrome と ChromeDriver のバージョンは一致している必要があります
- ドライバ [ダウンロードアドレス](https://chromedriver.chromium.org/downloads)
## 変更ログ
重要な更新については、[変更ログ](./ChangeLog) を参照してください

View File

@ -0,0 +1,4 @@
## DBeaver
- When connecting to a database application, it is necessary to download the driver. You can either install it offline
in advance or install the corresponding driver as prompted when connecting.

View File

@ -0,0 +1,3 @@
## DBeaver
- データベースに接続する際には、ドライバをダウンロードする必要があります。事前にオフラインでインストールするか、接続時に表示される指示に従って該当するドライバをインストールしてください。

View File

@ -2,10 +2,10 @@
import datetime
from django.db import transaction
from django.utils import timezone
from django.db.utils import OperationalError
from common.utils.common import pretty_string
from django.utils import timezone
from common.utils.common import pretty_string
from .base import CommandBase
@ -19,9 +19,10 @@ class CommandStore(CommandBase):
"""
保存命令到数据库
"""
cmd_input = pretty_string(command['input'])
self.model.objects.create(
user=command["user"], asset=command["asset"],
account=command["account"], input=command["input"],
account=command["account"], input=cmd_input,
output=command["output"], session=command["session"],
risk_level=command.get("risk_level", 0), org_id=command["org_id"],
timestamp=command["timestamp"]

View File

@ -75,7 +75,20 @@ class Endpoint(JMSBaseModel):
return endpoint
@classmethod
def match_by_instance_label(cls, instance, protocol):
def handle_endpoint_host(cls, endpoint, request=None):
if not endpoint.host and request:
# 动态添加 current request host
host_port = request.get_host()
# IPv6
if host_port.startswith('['):
host = host_port.split(']:')[0].rstrip(']') + ']'
else:
host = host_port.split(':')[0]
endpoint.host = host
return endpoint
@classmethod
def match_by_instance_label(cls, instance, protocol, request=None):
from assets.models import Asset
from terminal.models import Session
if isinstance(instance, Session):
@ -88,6 +101,7 @@ class Endpoint(JMSBaseModel):
endpoints = cls.objects.filter(name__in=list(values)).order_by('-date_updated')
for endpoint in endpoints:
if endpoint.is_valid_for(instance, protocol):
endpoint = cls.handle_endpoint_host(endpoint, request)
return endpoint
@ -130,13 +144,5 @@ class EndpointRule(JMSBaseModel):
endpoint = endpoint_rule.endpoint
else:
endpoint = Endpoint.get_or_create_default(request)
if not endpoint.host and request:
# 动态添加 current request host
host_port = request.get_host()
# IPv6
if host_port.startswith('['):
host = host_port.split(']:')[0].rstrip(']') + ']'
else:
host = host_port.split(':')[0]
endpoint.host = host
endpoint = Endpoint.handle_endpoint_host(endpoint, request)
return endpoint

Some files were not shown because too many files have changed in this diff Show More