Compare commits

..

2 Commits

268 changed files with 8927 additions and 28883 deletions

View File

@@ -1,26 +0,0 @@
{
"dry_run": false,
"min_account_age_days": 3,
"max_urls_for_spam": 1,
"min_body_len_for_links": 40,
"spam_words": [
"call now",
"zadzwoń",
"zadzwoń teraz",
"kontakt",
"telefon",
"telefone",
"contato",
"suporte",
"infolinii",
"click here",
"buy now",
"subscribe",
"visit"
],
"bracket_max": 6,
"special_char_density_threshold": 0.12,
"phone_regex": "\\+?\\d[\\d\\-\\s\\(\\)\\.]{6,}\\d",
"labels_for_spam": ["spam"],
"labels_for_review": ["needs-triage"]
}

View File

@@ -31,6 +31,8 @@ jobs:
- name: Set up QEMU - name: Set up QEMU
uses: docker/setup-qemu-action@v3 uses: docker/setup-qemu-action@v3
with:
image: tonistiigi/binfmt:qemu-v7.0.0-28
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3 uses: docker/setup-buildx-action@v3

View File

@@ -1,46 +0,0 @@
name: Build and Push Python Base Image
on:
workflow_dispatch:
inputs:
tag:
description: 'Tag to build'
required: true
default: '3.11-slim-bullseye-v1'
type: string
jobs:
build-and-push:
runs-on: ubuntu-22.04
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
ref: ${{ github.event.pull_request.head.ref }}
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
with:
image: tonistiigi/binfmt:qemu-v7.0.0-28
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to DockerHub
uses: docker/login-action@v2
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Extract repository name
id: repo
run: echo "REPO=$(basename ${{ github.repository }})" >> $GITHUB_ENV
- name: Build and push multi-arch image
uses: docker/build-push-action@v6
with:
platforms: linux/amd64,linux/arm64
push: true
file: Dockerfile-python
tags: jumpserver/core-base:python-${{ inputs.tag }}

View File

@@ -1,123 +0,0 @@
name: Cleanup PR Branches
on:
schedule:
# 每天凌晨2点运行
- cron: '0 2 * * *'
workflow_dispatch:
# 允许手动触发
inputs:
dry_run:
description: 'Dry run mode (default: true)'
required: false
default: 'true'
type: boolean
jobs:
cleanup-branches:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0 # 获取所有分支和提交历史
- name: Setup Git
run: |
git config --global user.name "GitHub Actions"
git config --global user.email "actions@github.com"
- name: Get dry run setting
id: dry-run
run: |
if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then
echo "dry_run=${{ github.event.inputs.dry_run }}" >> $GITHUB_OUTPUT
else
echo "dry_run=false" >> $GITHUB_OUTPUT
fi
- name: Cleanup branches
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
DRY_RUN: ${{ steps.dry-run.outputs.dry_run }}
run: |
echo "Starting branch cleanup..."
echo "Dry run mode: $DRY_RUN"
# 获取所有本地分支
git fetch --all --prune
# 获取以 pr 或 repr 开头的分支
branches=$(git branch -r | grep -E 'origin/(pr|repr)' | sed 's/origin\///' | grep -v 'HEAD')
echo "Found branches matching pattern:"
echo "$branches"
deleted_count=0
skipped_count=0
for branch in $branches; do
echo ""
echo "Processing branch: $branch"
# 检查分支是否有未合并的PR
pr_info=$(gh pr list --head "$branch" --state open --json number,title,state 2>/dev/null)
if [ $? -eq 0 ] && [ "$pr_info" != "[]" ]; then
echo " ⚠️ Branch has open PR(s), skipping deletion"
echo " PR info: $pr_info"
skipped_count=$((skipped_count + 1))
continue
fi
# 检查分支是否有已合并的PR可选如果PR已合并也可以删除
merged_pr_info=$(gh pr list --head "$branch" --state merged --json number,title,state 2>/dev/null)
if [ $? -eq 0 ] && [ "$merged_pr_info" != "[]" ]; then
echo " ✅ Branch has merged PR(s), safe to delete"
echo " Merged PR info: $merged_pr_info"
else
echo " No PRs found for this branch"
fi
# 执行删除操作
if [ "$DRY_RUN" = "true" ]; then
echo " 🔍 [DRY RUN] Would delete branch: $branch"
deleted_count=$((deleted_count + 1))
else
echo " 🗑️ Deleting branch: $branch"
# 删除远程分支
if git push origin --delete "$branch" 2>/dev/null; then
echo " ✅ Successfully deleted remote branch: $branch"
deleted_count=$((deleted_count + 1))
else
echo " ❌ Failed to delete remote branch: $branch"
fi
fi
done
echo ""
echo "=== Cleanup Summary ==="
echo "Branches processed: $(echo "$branches" | wc -l)"
echo "Branches deleted: $deleted_count"
echo "Branches skipped: $skipped_count"
if [ "$DRY_RUN" = "true" ]; then
echo ""
echo "🔍 This was a DRY RUN - no branches were actually deleted"
echo "To perform actual deletion, run this workflow manually with dry_run=false"
fi
- name: Create summary
if: always()
run: |
echo "## Branch Cleanup Summary" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo "**Workflow:** ${{ github.workflow }}" >> $GITHUB_STEP_SUMMARY
echo "**Run ID:** ${{ github.run_id }}" >> $GITHUB_STEP_SUMMARY
echo "**Dry Run:** ${{ steps.dry-run.outputs.dry_run }}" >> $GITHUB_STEP_SUMMARY
echo "**Triggered by:** ${{ github.event_name }}" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo "Check the logs above for detailed information about processed branches." >> $GITHUB_STEP_SUMMARY

View File

@@ -1,33 +1,10 @@
on: on: [push, pull_request, release]
push:
pull_request:
types: [opened, synchronize, closed]
release:
types: [created]
name: JumpServer repos generic handler name: JumpServer repos generic handler
jobs: jobs:
handle_pull_request: generic_handler:
if: github.event_name == 'pull_request' name: Run generic handler
runs-on: ubuntu-latest
steps:
- uses: jumpserver/action-generic-handler@master
env:
GITHUB_TOKEN: ${{ secrets.PRIVATE_TOKEN }}
I18N_TOKEN: ${{ secrets.I18N_TOKEN }}
handle_push:
if: github.event_name == 'push'
runs-on: ubuntu-latest
steps:
- uses: jumpserver/action-generic-handler@master
env:
GITHUB_TOKEN: ${{ secrets.PRIVATE_TOKEN }}
I18N_TOKEN: ${{ secrets.I18N_TOKEN }}
handle_release:
if: github.event_name == 'release'
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: jumpserver/action-generic-handler@master - uses: jumpserver/action-generic-handler@master

View File

@@ -1,9 +1,11 @@
name: 🔀 Sync mirror to Gitee name: 🔀 Sync mirror to Gitee
on: on:
schedule: push:
# 每天凌晨3点运行 branches:
- cron: '0 3 * * *' - master
- dev
create:
jobs: jobs:
mirror: mirror:
@@ -12,6 +14,7 @@ jobs:
steps: steps:
- name: mirror - name: mirror
continue-on-error: true continue-on-error: true
if: github.event_name == 'push' || (github.event_name == 'create' && github.event.ref_type == 'tag')
uses: wearerequired/git-mirror-action@v1 uses: wearerequired/git-mirror-action@v1
env: env:
SSH_PRIVATE_KEY: ${{ secrets.GITEE_SSH_PRIVATE_KEY }} SSH_PRIVATE_KEY: ${{ secrets.GITEE_SSH_PRIVATE_KEY }}

View File

@@ -1,4 +1,4 @@
FROM jumpserver/core-base:20251128_025056 AS stage-build FROM jumpserver/core-base:20250827_025554 AS stage-build
ARG VERSION ARG VERSION
@@ -19,7 +19,7 @@ RUN set -ex \
&& python manage.py compilemessages && python manage.py compilemessages
FROM python:3.11-slim-trixie FROM python:3.11-slim-bullseye
ENV LANG=en_US.UTF-8 \ ENV LANG=en_US.UTF-8 \
PATH=/opt/py3/bin:$PATH PATH=/opt/py3/bin:$PATH
@@ -39,7 +39,7 @@ ARG TOOLS=" \
ARG APT_MIRROR=http://deb.debian.org ARG APT_MIRROR=http://deb.debian.org
RUN set -ex \ RUN set -ex \
&& sed -i "s@http://.*.debian.org@${APT_MIRROR}@g" /etc/apt/sources.list.d/debian.sources \ && sed -i "s@http://.*.debian.org@${APT_MIRROR}@g" /etc/apt/sources.list \
&& ln -sf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime \ && ln -sf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime \
&& apt-get update > /dev/null \ && apt-get update > /dev/null \
&& apt-get -y install --no-install-recommends ${DEPENDENCIES} \ && apt-get -y install --no-install-recommends ${DEPENDENCIES} \

View File

@@ -1,5 +1,6 @@
FROM python:3.11.14-slim-trixie FROM python:3.11-slim-bullseye
ARG TARGETARCH ARG TARGETARCH
COPY --from=ghcr.io/astral-sh/uv:0.6.14 /uv /uvx /usr/local/bin/
# Install APT dependencies # Install APT dependencies
ARG DEPENDENCIES=" \ ARG DEPENDENCIES=" \
ca-certificates \ ca-certificates \
@@ -21,13 +22,13 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked,id=core \
set -ex \ set -ex \
&& rm -f /etc/apt/apt.conf.d/docker-clean \ && rm -f /etc/apt/apt.conf.d/docker-clean \
&& echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' > /etc/apt/apt.conf.d/keep-cache \ && echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' > /etc/apt/apt.conf.d/keep-cache \
&& sed -i "s@http://.*.debian.org@${APT_MIRROR}@g" /etc/apt/sources.list.d/debian.sources \ && sed -i "s@http://.*.debian.org@${APT_MIRROR}@g" /etc/apt/sources.list \
&& apt-get update > /dev/null \ && apt-get update > /dev/null \
&& apt-get -y install --no-install-recommends ${DEPENDENCIES} \ && apt-get -y install --no-install-recommends ${DEPENDENCIES} \
&& echo "no" | dpkg-reconfigure dash && echo "no" | dpkg-reconfigure dash
# Install bin tools # Install bin tools
ARG CHECK_VERSION=v1.0.5 ARG CHECK_VERSION=v1.0.4
RUN set -ex \ RUN set -ex \
&& wget https://github.com/jumpserver-dev/healthcheck/releases/download/${CHECK_VERSION}/check-${CHECK_VERSION}-linux-${TARGETARCH}.tar.gz \ && wget https://github.com/jumpserver-dev/healthcheck/releases/download/${CHECK_VERSION}/check-${CHECK_VERSION}-linux-${TARGETARCH}.tar.gz \
&& tar -xf check-${CHECK_VERSION}-linux-${TARGETARCH}.tar.gz \ && tar -xf check-${CHECK_VERSION}-linux-${TARGETARCH}.tar.gz \
@@ -40,10 +41,12 @@ RUN set -ex \
WORKDIR /opt/jumpserver WORKDIR /opt/jumpserver
ARG PIP_MIRROR=https://pypi.org/simple ARG PIP_MIRROR=https://pypi.org/simple
ENV POETRY_PYPI_MIRROR_URL=${PIP_MIRROR}
ENV ANSIBLE_COLLECTIONS_PATHS=/opt/py3/lib/python3.11/site-packages/ansible_collections ENV ANSIBLE_COLLECTIONS_PATHS=/opt/py3/lib/python3.11/site-packages/ansible_collections
ENV LANG=en_US.UTF-8 \ ENV LANG=en_US.UTF-8 \
PATH=/opt/py3/bin:$PATH PATH=/opt/py3/bin:$PATH
ENV SETUPTOOLS_SCM_PRETEND_VERSION=3.4.5
ENV UV_LINK_MODE=copy
RUN --mount=type=cache,target=/root/.cache \ RUN --mount=type=cache,target=/root/.cache \
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \ --mount=type=bind,source=pyproject.toml,target=pyproject.toml \
@@ -51,7 +54,6 @@ RUN --mount=type=cache,target=/root/.cache \
--mount=type=bind,source=requirements/collections.yml,target=collections.yml \ --mount=type=bind,source=requirements/collections.yml,target=collections.yml \
--mount=type=bind,source=requirements/static_files.sh,target=utils/static_files.sh \ --mount=type=bind,source=requirements/static_files.sh,target=utils/static_files.sh \
set -ex \ set -ex \
&& pip install uv -i${PIP_MIRROR} \
&& uv venv \ && uv venv \
&& uv pip install -i${PIP_MIRROR} -r pyproject.toml \ && uv pip install -i${PIP_MIRROR} -r pyproject.toml \
&& ln -sf $(pwd)/.venv /opt/py3 \ && ln -sf $(pwd)/.venv /opt/py3 \

View File

@@ -13,7 +13,7 @@ ARG TOOLS=" \
nmap \ nmap \
telnet \ telnet \
vim \ vim \
postgresql-client \ postgresql-client-13 \
wget \ wget \
poppler-utils" poppler-utils"

View File

@@ -77,8 +77,7 @@ JumpServer consists of multiple key components, which collectively form the func
| [Luna](https://github.com/jumpserver/luna) | <a href="https://github.com/jumpserver/luna/releases"><img alt="Luna release" src="https://img.shields.io/github/release/jumpserver/luna.svg" /></a> | JumpServer Web Terminal | | [Luna](https://github.com/jumpserver/luna) | <a href="https://github.com/jumpserver/luna/releases"><img alt="Luna release" src="https://img.shields.io/github/release/jumpserver/luna.svg" /></a> | JumpServer Web Terminal |
| [KoKo](https://github.com/jumpserver/koko) | <a href="https://github.com/jumpserver/koko/releases"><img alt="Koko release" src="https://img.shields.io/github/release/jumpserver/koko.svg" /></a> | JumpServer Character Protocol Connector | | [KoKo](https://github.com/jumpserver/koko) | <a href="https://github.com/jumpserver/koko/releases"><img alt="Koko release" src="https://img.shields.io/github/release/jumpserver/koko.svg" /></a> | JumpServer Character Protocol Connector |
| [Lion](https://github.com/jumpserver/lion) | <a href="https://github.com/jumpserver/lion/releases"><img alt="Lion release" src="https://img.shields.io/github/release/jumpserver/lion.svg" /></a> | JumpServer Graphical Protocol Connector | | [Lion](https://github.com/jumpserver/lion) | <a href="https://github.com/jumpserver/lion/releases"><img alt="Lion release" src="https://img.shields.io/github/release/jumpserver/lion.svg" /></a> | JumpServer Graphical Protocol Connector |
| [Chen](https://github.com/jumpserver/chen) | <a href="https://github.com/jumpserver/chen/releases"><img alt="Chen release" src="https://img.shields.io/github/release/jumpserver/chen.svg" /> | JumpServer Web DB | [Chen](https://github.com/jumpserver/chen) | <a href="https://github.com/jumpserver/chen/releases"><img alt="Chen release" src="https://img.shields.io/github/release/jumpserver/chen.svg" /> | JumpServer Web DB |
| [Client](https://github.com/jumpserver/clients) | <a href="https://github.com/jumpserver/clients/releases"><img alt="Clients release" src="https://img.shields.io/github/release/jumpserver/clients.svg" /> | JumpServer Client |
| [Tinker](https://github.com/jumpserver/tinker) | <img alt="Tinker" src="https://img.shields.io/badge/release-private-red" /> | JumpServer Remote Application Connector (Windows) | | [Tinker](https://github.com/jumpserver/tinker) | <img alt="Tinker" src="https://img.shields.io/badge/release-private-red" /> | JumpServer Remote Application Connector (Windows) |
| [Panda](https://github.com/jumpserver/Panda) | <img alt="Panda" src="https://img.shields.io/badge/release-private-red" /> | JumpServer EE Remote Application Connector (Linux) | | [Panda](https://github.com/jumpserver/Panda) | <img alt="Panda" src="https://img.shields.io/badge/release-private-red" /> | JumpServer EE Remote Application Connector (Linux) |
| [Razor](https://github.com/jumpserver/razor) | <img alt="Chen" src="https://img.shields.io/badge/release-private-red" /> | JumpServer EE RDP Proxy Connector | | [Razor](https://github.com/jumpserver/razor) | <img alt="Chen" src="https://img.shields.io/badge/release-private-red" /> | JumpServer EE RDP Proxy Connector |

View File

@@ -1,19 +1,16 @@
from django.conf import settings
from django.db import transaction from django.db import transaction
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from rest_framework import serializers as drf_serializers
from rest_framework.decorators import action from rest_framework.decorators import action
from rest_framework.generics import ListAPIView, CreateAPIView from rest_framework.generics import ListAPIView, CreateAPIView
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.status import HTTP_200_OK, HTTP_400_BAD_REQUEST from rest_framework.status import HTTP_200_OK
from accounts import serializers from accounts import serializers
from accounts.const import ChangeSecretRecordStatusChoice from accounts.const import ChangeSecretRecordStatusChoice
from accounts.filters import AccountFilterSet, NodeFilterBackend from accounts.filters import AccountFilterSet, NodeFilterBackend
from accounts.mixins import AccountRecordViewLogMixin from accounts.mixins import AccountRecordViewLogMixin
from accounts.models import Account, ChangeSecretRecord from accounts.models import Account, ChangeSecretRecord
from assets.const.gpt import create_or_update_chatx_resources
from assets.models import Asset, Node from assets.models import Asset, Node
from authentication.permissions import UserConfirmation, ConfirmType from authentication.permissions import UserConfirmation, ConfirmType
from common.api.mixin import ExtraFilterFieldsMixin from common.api.mixin import ExtraFilterFieldsMixin
@@ -21,7 +18,6 @@ from common.drf.filters import AttrRulesFilterBackend
from common.permissions import IsValidUser from common.permissions import IsValidUser
from common.utils import lazyproperty, get_logger from common.utils import lazyproperty, get_logger
from orgs.mixins.api import OrgBulkModelViewSet from orgs.mixins.api import OrgBulkModelViewSet
from orgs.utils import tmp_to_root_org
from rbac.permissions import RBACPermission from rbac.permissions import RBACPermission
logger = get_logger(__file__) logger = get_logger(__file__)
@@ -47,7 +43,6 @@ class AccountViewSet(OrgBulkModelViewSet):
'clear_secret': 'accounts.change_account', 'clear_secret': 'accounts.change_account',
'move_to_assets': 'accounts.delete_account', 'move_to_assets': 'accounts.delete_account',
'copy_to_assets': 'accounts.add_account', 'copy_to_assets': 'accounts.add_account',
'chat': 'accounts.view_account',
} }
export_as_zip = True export_as_zip = True
@@ -157,17 +152,10 @@ class AccountViewSet(OrgBulkModelViewSet):
def copy_to_assets(self, request, *args, **kwargs): def copy_to_assets(self, request, *args, **kwargs):
return self._copy_or_move_to_assets(request, move=False) return self._copy_or_move_to_assets(request, move=False)
@action(methods=['get'], detail=False, url_path='chat')
def chat(self, request, *args, **kwargs):
with tmp_to_root_org():
__, account = create_or_update_chatx_resources()
serializer = self.get_serializer(account)
return Response(serializer.data)
class AccountSecretsViewSet(AccountRecordViewLogMixin, AccountViewSet): class AccountSecretsViewSet(AccountRecordViewLogMixin, AccountViewSet):
""" """
因为可能要导出所有账号,所以单独建立了一个 viewset 因为可能要导出所有账号所以单独建立了一个 viewset
""" """
serializer_classes = { serializer_classes = {
'default': serializers.AccountSecretSerializer, 'default': serializers.AccountSecretSerializer,
@@ -186,66 +174,12 @@ class AssetAccountBulkCreateApi(CreateAPIView):
'POST': 'accounts.add_account', 'POST': 'accounts.add_account',
} }
@staticmethod
def get_all_assets(base_payload: dict):
nodes = base_payload.pop('nodes', [])
asset_ids = base_payload.pop('assets', [])
nodes = Node.objects.filter(id__in=nodes).only('id', 'key')
node_asset_ids = Node.get_nodes_all_assets(*nodes).values_list('id', flat=True)
asset_ids = set(asset_ids + list(node_asset_ids))
return Asset.objects.filter(id__in=asset_ids)
def create(self, request, *args, **kwargs): def create(self, request, *args, **kwargs):
if hasattr(request.data, "copy"): serializer = self.get_serializer(data=request.data)
base_payload = request.data.copy() serializer.is_valid(raise_exception=True)
else: data = serializer.create(serializer.validated_data)
base_payload = dict(request.data) serializer = serializers.AssetAccountBulkSerializerResultSerializer(data, many=True)
return Response(data=serializer.data, status=HTTP_200_OK)
templates = base_payload.pop("template", None)
assets = self.get_all_assets(base_payload)
if not assets.exists():
error = _("No valid assets found for account creation.")
return Response(
data={
"detail": error,
"code": "no_valid_assets"
},
status=HTTP_400_BAD_REQUEST
)
result = []
errors = []
def handle_one(_payload):
try:
ser = self.get_serializer(data=_payload)
ser.is_valid(raise_exception=True)
data = ser.bulk_create(ser.validated_data, assets)
if isinstance(data, (list, tuple)):
result.extend(data)
else:
result.append(data)
except drf_serializers.ValidationError as e:
errors.extend(list(e.detail))
except Exception as e:
errors.extend([str(e)])
if not templates:
handle_one(base_payload)
else:
if not isinstance(templates, (list, tuple)):
templates = [templates]
for tpl in templates:
payload = dict(base_payload)
payload["template"] = tpl
handle_one(payload)
if errors:
raise drf_serializers.ValidationError(errors)
out_ser = serializers.AssetAccountBulkSerializerResultSerializer(result, many=True)
return Response(data=out_ser.data, status=HTTP_200_OK)
class AccountHistoriesSecretAPI(ExtraFilterFieldsMixin, AccountRecordViewLogMixin, ListAPIView): class AccountHistoriesSecretAPI(ExtraFilterFieldsMixin, AccountRecordViewLogMixin, ListAPIView):

View File

@@ -25,8 +25,7 @@ class IntegrationApplicationViewSet(OrgBulkModelViewSet):
} }
rbac_perms = { rbac_perms = {
'get_once_secret': 'accounts.change_integrationapplication', 'get_once_secret': 'accounts.change_integrationapplication',
'get_account_secret': 'accounts.view_integrationapplication', 'get_account_secret': 'accounts.view_integrationapplication'
'get_sdks_info': 'accounts.view_integrationapplication'
} }
def read_file(self, path): def read_file(self, path):
@@ -37,6 +36,7 @@ class IntegrationApplicationViewSet(OrgBulkModelViewSet):
@action( @action(
['GET'], detail=False, url_path='sdks', ['GET'], detail=False, url_path='sdks',
permission_classes=[IsValidUser]
) )
def get_sdks_info(self, request, *args, **kwargs): def get_sdks_info(self, request, *args, **kwargs):
code_suffix_mapper = { code_suffix_mapper = {
@@ -81,7 +81,4 @@ class IntegrationApplicationViewSet(OrgBulkModelViewSet):
remote_addr=get_request_ip(request), service=service.name, service_id=service.id, remote_addr=get_request_ip(request), service=service.name, service_id=service.id,
account=f'{account.name}({account.username})', asset=f'{asset.name}({asset.address})', account=f'{account.name}({account.username})', asset=f'{asset.name}({asset.address})',
) )
return Response(data={'id': request.user.id, 'secret': account.secret})
# 根据配置决定是否返回密码
secret = account.secret if settings.SECURITY_ACCOUNT_SECRET_READ else None
return Response(data={'id': request.user.id, 'secret': secret})

View File

@@ -1,5 +1,3 @@
from django.conf import settings
from django.utils.translation import gettext_lazy as _
from django_filters import rest_framework as drf_filters from django_filters import rest_framework as drf_filters
from rest_framework import status from rest_framework import status
from rest_framework.decorators import action from rest_framework.decorators import action

View File

@@ -104,7 +104,7 @@ class AutomationExecutionViewSet(
mixins.CreateModelMixin, mixins.ListModelMixin, mixins.CreateModelMixin, mixins.ListModelMixin,
mixins.RetrieveModelMixin, viewsets.GenericViewSet mixins.RetrieveModelMixin, viewsets.GenericViewSet
): ):
search_fields = ('id', 'trigger', 'automation__name') search_fields = ('trigger', 'automation__name')
filterset_fields = ('trigger', 'automation_id', 'automation__name') filterset_fields = ('trigger', 'automation_id', 'automation__name')
filterset_class = AutomationExecutionFilterSet filterset_class = AutomationExecutionFilterSet
serializer_class = serializers.AutomationExecutionSerializer serializer_class = serializers.AutomationExecutionSerializer

View File

@@ -235,8 +235,8 @@ class AccountBackupHandler:
except Exception as e: except Exception as e:
error = str(e) error = str(e)
print(f'\033[31m>>> {error}\033[0m') print(f'\033[31m>>> {error}\033[0m')
self.manager.status = Status.error self.execution.status = Status.error
self.manager.summary['error'] = error self.execution.summary['error'] = error
def backup_by_obj_storage(self): def backup_by_obj_storage(self):
object_id = self.execution.snapshot.get('id') object_id = self.execution.snapshot.get('id')

View File

@@ -105,6 +105,10 @@ class BaseChangeSecretPushManager(AccountBasePlaybookManager):
h['account']['mode'] = 'sysdba' if account.privileged else None h['account']['mode'] = 'sysdba' if account.privileged else None
return h return h
def add_extra_params(self, host, **kwargs):
host['ssh_params'] = {}
return host
def host_callback(self, host, asset=None, account=None, automation=None, path_dir=None, **kwargs): def host_callback(self, host, asset=None, account=None, automation=None, path_dir=None, **kwargs):
host = super().host_callback( host = super().host_callback(
host, asset=asset, account=account, automation=automation, host, asset=asset, account=account, automation=automation,
@@ -113,18 +117,7 @@ class BaseChangeSecretPushManager(AccountBasePlaybookManager):
if host.get('error'): if host.get('error'):
return host return host
inventory_hosts = [] host = self.add_extra_params(host, automation=automation)
if asset.type == HostTypes.WINDOWS:
if self.secret_type == SecretType.SSH_KEY:
host['error'] = _("Windows does not support SSH key authentication")
return host
new_secret = self.get_secret(account)
if '>' in new_secret or '^' in new_secret:
host['error'] = _("Windows password cannot contain special characters like > ^")
return host
host['ssh_params'] = {}
accounts = self.get_accounts(account) accounts = self.get_accounts(account)
existing_ids = set(map(str, accounts.values_list('id', flat=True))) existing_ids = set(map(str, accounts.values_list('id', flat=True)))
missing_ids = set(map(str, self.account_ids)) - existing_ids missing_ids = set(map(str, self.account_ids)) - existing_ids
@@ -140,6 +133,11 @@ class BaseChangeSecretPushManager(AccountBasePlaybookManager):
if asset.type == HostTypes.WINDOWS: if asset.type == HostTypes.WINDOWS:
accounts = accounts.filter(secret_type=SecretType.PASSWORD) accounts = accounts.filter(secret_type=SecretType.PASSWORD)
inventory_hosts = []
if asset.type == HostTypes.WINDOWS and self.secret_type == SecretType.SSH_KEY:
print(f'Windows {asset} does not support ssh key push')
return inventory_hosts
for account in accounts: for account in accounts:
h = deepcopy(host) h = deepcopy(host)
h['name'] += '(' + account.username + ')' # To distinguish different accounts h['name'] += '(' + account.username + ')' # To distinguish different accounts

View File

@@ -5,14 +5,12 @@
tasks: tasks:
- name: Test SQLServer connection - name: Test SQLServer connection
mssql_script: community.general.mssql_script:
login_user: "{{ jms_account.username }}" login_user: "{{ jms_account.username }}"
login_password: "{{ jms_account.secret }}" login_password: "{{ jms_account.secret }}"
login_host: "{{ jms_asset.address }}" login_host: "{{ jms_asset.address }}"
login_port: "{{ jms_asset.port }}" login_port: "{{ jms_asset.port }}"
name: '{{ jms_asset.spec_info.db_name }}' name: '{{ jms_asset.spec_info.db_name }}'
encryption: "{{ jms_asset.encryption | default(None) }}"
tds_version: "{{ jms_asset.tds_version | default(None) }}"
script: | script: |
SELECT @@version SELECT @@version
register: db_info register: db_info
@@ -25,53 +23,45 @@
var: info var: info
- name: Check whether SQLServer User exist - name: Check whether SQLServer User exist
mssql_script: community.general.mssql_script:
login_user: "{{ jms_account.username }}" login_user: "{{ jms_account.username }}"
login_password: "{{ jms_account.secret }}" login_password: "{{ jms_account.secret }}"
login_host: "{{ jms_asset.address }}" login_host: "{{ jms_asset.address }}"
login_port: "{{ jms_asset.port }}" login_port: "{{ jms_asset.port }}"
name: '{{ jms_asset.spec_info.db_name }}' name: '{{ jms_asset.spec_info.db_name }}'
encryption: "{{ jms_asset.encryption | default(None) }}"
tds_version: "{{ jms_asset.tds_version | default(None) }}"
script: "SELECT 1 from sys.sql_logins WHERE name='{{ account.username }}';" script: "SELECT 1 from sys.sql_logins WHERE name='{{ account.username }}';"
when: db_info is succeeded when: db_info is succeeded
register: user_exist register: user_exist
- name: Change SQLServer password - name: Change SQLServer password
mssql_script: community.general.mssql_script:
login_user: "{{ jms_account.username }}" login_user: "{{ jms_account.username }}"
login_password: "{{ jms_account.secret }}" login_password: "{{ jms_account.secret }}"
login_host: "{{ jms_asset.address }}" login_host: "{{ jms_asset.address }}"
login_port: "{{ jms_asset.port }}" login_port: "{{ jms_asset.port }}"
name: '{{ jms_asset.spec_info.db_name }}' name: '{{ jms_asset.spec_info.db_name }}'
encryption: "{{ jms_asset.encryption | default(None) }}"
tds_version: "{{ jms_asset.tds_version | default(None) }}"
script: "ALTER LOGIN {{ account.username }} WITH PASSWORD = '{{ account.secret }}', DEFAULT_DATABASE = {{ jms_asset.spec_info.db_name }}; select @@version" script: "ALTER LOGIN {{ account.username }} WITH PASSWORD = '{{ account.secret }}', DEFAULT_DATABASE = {{ jms_asset.spec_info.db_name }}; select @@version"
ignore_errors: true ignore_errors: true
when: user_exist.query_results[0] | length != 0 when: user_exist.query_results[0] | length != 0
- name: Add SQLServer user - name: Add SQLServer user
mssql_script: community.general.mssql_script:
login_user: "{{ jms_account.username }}" login_user: "{{ jms_account.username }}"
login_password: "{{ jms_account.secret }}" login_password: "{{ jms_account.secret }}"
login_host: "{{ jms_asset.address }}" login_host: "{{ jms_asset.address }}"
login_port: "{{ jms_asset.port }}" login_port: "{{ jms_asset.port }}"
name: '{{ jms_asset.spec_info.db_name }}' name: '{{ jms_asset.spec_info.db_name }}'
encryption: "{{ jms_asset.encryption | default(None) }}"
tds_version: "{{ jms_asset.tds_version | default(None) }}"
script: "CREATE LOGIN {{ account.username }} WITH PASSWORD = '{{ account.secret }}', DEFAULT_DATABASE = {{ jms_asset.spec_info.db_name }}; CREATE USER {{ account.username }} FOR LOGIN {{ account.username }}; select @@version" script: "CREATE LOGIN {{ account.username }} WITH PASSWORD = '{{ account.secret }}', DEFAULT_DATABASE = {{ jms_asset.spec_info.db_name }}; CREATE USER {{ account.username }} FOR LOGIN {{ account.username }}; select @@version"
ignore_errors: true ignore_errors: true
when: user_exist.query_results[0] | length == 0 when: user_exist.query_results[0] | length == 0
- name: Verify password - name: Verify password
mssql_script: community.general.mssql_script:
login_user: "{{ account.username }}" login_user: "{{ account.username }}"
login_password: "{{ account.secret }}" login_password: "{{ account.secret }}"
login_host: "{{ jms_asset.address }}" login_host: "{{ jms_asset.address }}"
login_port: "{{ jms_asset.port }}" login_port: "{{ jms_asset.port }}"
name: '{{ jms_asset.spec_info.db_name }}' name: '{{ jms_asset.spec_info.db_name }}'
encryption: "{{ jms_asset.encryption | default(None) }}"
tds_version: "{{ jms_asset.tds_version | default(None) }}"
script: | script: |
SELECT @@version SELECT @@version
when: check_conn_after_change when: check_conn_after_change

View File

@@ -18,7 +18,6 @@
uid: "{{ params.uid | int if params.uid | length > 0 else omit }}" uid: "{{ params.uid | int if params.uid | length > 0 else omit }}"
shell: "{{ params.shell if params.shell | length > 0 else omit }}" shell: "{{ params.shell if params.shell | length > 0 else omit }}"
home: "{{ params.home if params.home | length > 0 else '/home/' + account.username }}" home: "{{ params.home if params.home | length > 0 else '/home/' + account.username }}"
group: "{{ params.group if params.group | length > 0 else omit }}"
groups: "{{ params.groups if params.groups | length > 0 else omit }}" groups: "{{ params.groups if params.groups | length > 0 else omit }}"
append: "{{ true if params.groups | length > 0 else false }}" append: "{{ true if params.groups | length > 0 else false }}"
expires: -1 expires: -1

View File

@@ -28,12 +28,6 @@ params:
default: '' default: ''
help_text: "{{ 'Params home help text' | trans }}" help_text: "{{ 'Params home help text' | trans }}"
- name: group
type: str
label: "{{ 'Params group label' | trans }}"
default: ''
help_text: "{{ 'Params group help text' | trans }}"
- name: groups - name: groups
type: str type: str
label: "{{ 'Params groups label' | trans }}" label: "{{ 'Params groups label' | trans }}"
@@ -67,11 +61,6 @@ i18n:
ja: 'デフォルトのホームディレクトリ /home/{アカウントユーザ名}' ja: 'デフォルトのホームディレクトリ /home/{アカウントユーザ名}'
en: 'Default home directory /home/{account username}' en: 'Default home directory /home/{account username}'
Params group help text:
zh: '请输入用户组(名字或数字),只能输入一个(需填写已存在的用户组)'
ja: 'ユーザー グループ (名前または番号) を入力してください。入力できるのは 1 つだけです (既存のユーザー グループを入力する必要があります)'
en: 'Please enter a user group (name or number), only one can be entered (must fill in an existing user group)'
Params groups help text: Params groups help text:
zh: '请输入用户组,多个用户组使用逗号分隔(需填写已存在的用户组)' zh: '请输入用户组,多个用户组使用逗号分隔(需填写已存在的用户组)'
ja: 'グループを入力してください。複数のグループはコンマで区切ってください(既存のグループを入力してください)' ja: 'グループを入力してください。複数のグループはコンマで区切ってください(既存のグループを入力してください)'
@@ -97,11 +86,6 @@ i18n:
ja: 'グループ' ja: 'グループ'
en: 'Groups' en: 'Groups'
Params group label:
zh: '主组'
ja: '主组'
en: 'Main group'
Params uid label: Params uid label:
zh: '用户ID' zh: '用户ID'
ja: 'ユーザーID' ja: 'ユーザーID'

View File

@@ -18,7 +18,6 @@
uid: "{{ params.uid | int if params.uid | length > 0 else omit }}" uid: "{{ params.uid | int if params.uid | length > 0 else omit }}"
shell: "{{ params.shell if params.shell | length > 0 else omit }}" shell: "{{ params.shell if params.shell | length > 0 else omit }}"
home: "{{ params.home if params.home | length > 0 else '/home/' + account.username }}" home: "{{ params.home if params.home | length > 0 else '/home/' + account.username }}"
group: "{{ params.group if params.group | length > 0 else omit }}"
groups: "{{ params.groups if params.groups | length > 0 else omit }}" groups: "{{ params.groups if params.groups | length > 0 else omit }}"
append: "{{ true if params.groups | length > 0 else false }}" append: "{{ true if params.groups | length > 0 else false }}"
expires: -1 expires: -1

View File

@@ -30,12 +30,6 @@ params:
default: '' default: ''
help_text: "{{ 'Params home help text' | trans }}" help_text: "{{ 'Params home help text' | trans }}"
- name: group
type: str
label: "{{ 'Params group label' | trans }}"
default: ''
help_text: "{{ 'Params group help text' | trans }}"
- name: groups - name: groups
type: str type: str
label: "{{ 'Params groups label' | trans }}" label: "{{ 'Params groups label' | trans }}"
@@ -69,11 +63,6 @@ i18n:
ja: 'デフォルトのホームディレクトリ /home/{アカウントユーザ名}' ja: 'デフォルトのホームディレクトリ /home/{アカウントユーザ名}'
en: 'Default home directory /home/{account username}' en: 'Default home directory /home/{account username}'
Params group help text:
zh: '请输入用户组(名字或数字),只能输入一个(需填写已存在的用户组)'
ja: 'ユーザー グループ (名前または番号) を入力してください。入力できるのは 1 つだけです (既存のユーザー グループを入力する必要があります)'
en: 'Please enter a user group (name or number), only one can be entered (must fill in an existing user group)'
Params groups help text: Params groups help text:
zh: '请输入用户组,多个用户组使用逗号分隔(需填写已存在的用户组)' zh: '请输入用户组,多个用户组使用逗号分隔(需填写已存在的用户组)'
ja: 'グループを入力してください。複数のグループはコンマで区切ってください(既存のグループを入力してください)' ja: 'グループを入力してください。複数のグループはコンマで区切ってください(既存のグループを入力してください)'
@@ -99,11 +88,6 @@ i18n:
ja: 'グループ' ja: 'グループ'
en: 'Groups' en: 'Groups'
Params group label:
zh: '主组'
ja: '主组'
en: 'Main group'
Params uid label: Params uid label:
zh: '用户ID' zh: '用户ID'
ja: 'ユーザーID' ja: 'ユーザーID'

View File

@@ -5,11 +5,14 @@ from django.conf import settings
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from xlsxwriter import Workbook from xlsxwriter import Workbook
from assets.automations.methods import platform_automation_methods as asset_methods
from assets.const import AutomationTypes as AssetAutomationTypes
from accounts.automations.methods import platform_automation_methods as account_methods
from accounts.const import ( from accounts.const import (
AutomationTypes, SecretStrategy, ChangeSecretRecordStatusChoice AutomationTypes, SecretStrategy, ChangeSecretRecordStatusChoice
) )
from accounts.models import ChangeSecretRecord from accounts.models import ChangeSecretRecord
from accounts.notifications import ChangeSecretExecutionTaskMsg from accounts.notifications import ChangeSecretExecutionTaskMsg, ChangeSecretReportMsg
from accounts.serializers import ChangeSecretRecordBackUpSerializer from accounts.serializers import ChangeSecretRecordBackUpSerializer
from common.utils import get_logger from common.utils import get_logger
from common.utils.file import encrypt_and_compress_zip_file from common.utils.file import encrypt_and_compress_zip_file
@@ -22,6 +25,22 @@ logger = get_logger(__name__)
class ChangeSecretManager(BaseChangeSecretPushManager): class ChangeSecretManager(BaseChangeSecretPushManager):
ansible_account_prefer = '' ansible_account_prefer = ''
def get_method_id_meta_mapper(self):
return {
method["id"]: method for method in self.platform_automation_methods
}
@property
def platform_automation_methods(self):
return asset_methods + account_methods
def add_extra_params(self, host, **kwargs):
host = super().add_extra_params(host, **kwargs)
automation = kwargs.get('automation')
for extra_type in [AssetAutomationTypes.ping, AutomationTypes.verify_account]:
host[f"{extra_type}_params"] = self.get_params(automation, extra_type)
return host
@classmethod @classmethod
def method_type(cls): def method_type(cls):
return AutomationTypes.change_secret return AutomationTypes.change_secret
@@ -94,6 +113,10 @@ class ChangeSecretManager(BaseChangeSecretPushManager):
if not recipients: if not recipients:
return return
context = self.get_report_context()
for user in recipients:
ChangeSecretReportMsg(user, context).publish()
if not records: if not records:
return return

View File

@@ -0,0 +1,36 @@
- hosts: website
gather_facts: no
vars:
ansible_python_interpreter: "{{ local_python_interpreter }}"
tasks:
- name: Test privileged account
website_ping:
login_host: "{{ jms_asset.address }}"
login_user: "{{ jms_account.username }}"
login_password: "{{ jms_account.secret }}"
steps: "{{ ping_params.steps }}"
load_state: "{{ ping_params.load_state }}"
- name: "Change {{ account.username }} password"
website_user:
login_host: "{{ jms_asset.address }}"
login_user: "{{ jms_account.username }}"
login_password: "{{ jms_account.secret }}"
steps: "{{ params.steps }}"
load_state: "{{ params.load_state }}"
name: "{{ account.username }}"
password: "{{ account.secret }}"
ignore_errors: true
register: change_secret_result
- name: "Verify {{ account.username }} password"
website_ping:
login_host: "{{ jms_asset.address }}"
login_user: "{{ account.username }}"
login_password: "{{ account.secret }}"
steps: "{{ verify_account_params.steps }}"
load_state: "{{ verify_account_params.load_state }}"
when:
- check_conn_after_change or change_secret_result.failed | default(false)
delegate_to: localhost

View File

@@ -0,0 +1,51 @@
id: change_account_website
name: "{{ 'Website account change secret' | trans }}"
category: web
type:
- website
method: change_secret
priority: 50
params:
- name: load_state
type: choice
label: "{{ 'Load state' | trans }}"
choices:
- [ networkidle, "{{ 'Network idle' | trans }}" ]
- [ domcontentloaded, "{{ 'Dom content loaded' | trans }}" ]
- [ load, "{{ 'Load completed' | trans }}" ]
default: 'load'
- name: steps
type: list
default: [ ]
label: "{{ 'Steps' | trans }}"
help_text: "{{ 'Params step help text' | trans }}"
i18n:
Website account change secret:
zh: 使用 Playwright 模拟浏览器变更账号密码
ja: Playwright を使用してブラウザをシミュレートし、アカウントのパスワードを変更します
en: Use Playwright to simulate a browser for account password change.
Load state:
zh: 加载状态检测
en: Load state detection
ja: ロード状態の検出
Steps:
zh: 步骤
en: Steps
ja: 手順
Network idle:
zh: 网络空闲
en: Network idle
ja: ネットワークが空いた状態
Dom content loaded:
zh: 文档内容加载完成
en: Dom content loaded
ja: ドキュメントの内容がロードされた状態
Load completed:
zh: 全部加载完成
en: All load completed
ja: すべてのロードが完了した状態
Params step help text:
zh: 根据配置决定任务执行步骤
ja: 設定に基づいてタスクの実行ステップを決定する
en: Determine task execution steps based on configuration

View File

@@ -5,14 +5,12 @@
tasks: tasks:
- name: Test SQLServer connection - name: Test SQLServer connection
mssql_script: community.general.mssql_script:
login_user: "{{ jms_account.username }}" login_user: "{{ jms_account.username }}"
login_password: "{{ jms_account.secret }}" login_password: "{{ jms_account.secret }}"
login_host: "{{ jms_asset.address }}" login_host: "{{ jms_asset.address }}"
login_port: "{{ jms_asset.port }}" login_port: "{{ jms_asset.port }}"
name: '{{ jms_asset.spec_info.db_name }}' name: '{{ jms_asset.spec_info.db_name }}'
encryption: "{{ jms_asset.encryption | default(None) }}"
tds_version: "{{ jms_asset.tds_version | default(None) }}"
script: | script: |
SELECT SELECT
l.name, l.name,

View File

@@ -5,14 +5,12 @@
tasks: tasks:
- name: Test SQLServer connection - name: Test SQLServer connection
mssql_script: community.general.mssql_script:
login_user: "{{ jms_account.username }}" login_user: "{{ jms_account.username }}"
login_password: "{{ jms_account.secret }}" login_password: "{{ jms_account.secret }}"
login_host: "{{ jms_asset.address }}" login_host: "{{ jms_asset.address }}"
login_port: "{{ jms_asset.port }}" login_port: "{{ jms_asset.port }}"
name: '{{ jms_asset.spec_info.db_name }}' name: '{{ jms_asset.spec_info.db_name }}'
encryption: "{{ jms_asset.encryption | default(None) }}"
tds_version: "{{ jms_asset.tds_version | default(None) }}"
script: | script: |
SELECT @@version SELECT @@version
register: db_info register: db_info
@@ -25,55 +23,47 @@
var: info var: info
- name: Check whether SQLServer User exist - name: Check whether SQLServer User exist
mssql_script: community.general.mssql_script:
login_user: "{{ jms_account.username }}" login_user: "{{ jms_account.username }}"
login_password: "{{ jms_account.secret }}" login_password: "{{ jms_account.secret }}"
login_host: "{{ jms_asset.address }}" login_host: "{{ jms_asset.address }}"
login_port: "{{ jms_asset.port }}" login_port: "{{ jms_asset.port }}"
name: '{{ jms_asset.spec_info.db_name }}' name: '{{ jms_asset.spec_info.db_name }}'
encryption: "{{ jms_asset.encryption | default(None) }}"
tds_version: "{{ jms_asset.tds_version | default(None) }}"
script: "SELECT 1 from sys.sql_logins WHERE name='{{ account.username }}';" script: "SELECT 1 from sys.sql_logins WHERE name='{{ account.username }}';"
when: db_info is succeeded when: db_info is succeeded
register: user_exist register: user_exist
- name: Change SQLServer password - name: Change SQLServer password
mssql_script: community.general.mssql_script:
login_user: "{{ jms_account.username }}" login_user: "{{ jms_account.username }}"
login_password: "{{ jms_account.secret }}" login_password: "{{ jms_account.secret }}"
login_host: "{{ jms_asset.address }}" login_host: "{{ jms_asset.address }}"
login_port: "{{ jms_asset.port }}" login_port: "{{ jms_asset.port }}"
name: '{{ jms_asset.spec_info.db_name }}' name: '{{ jms_asset.spec_info.db_name }}'
encryption: "{{ jms_asset.encryption | default(None) }}"
tds_version: "{{ jms_asset.tds_version | default(None) }}"
script: "ALTER LOGIN {{ account.username }} WITH PASSWORD = '{{ account.secret }}', DEFAULT_DATABASE = {{ jms_asset.spec_info.db_name }}; select @@version" script: "ALTER LOGIN {{ account.username }} WITH PASSWORD = '{{ account.secret }}', DEFAULT_DATABASE = {{ jms_asset.spec_info.db_name }}; select @@version"
ignore_errors: true ignore_errors: true
when: user_exist.query_results[0] | length != 0 when: user_exist.query_results[0] | length != 0
register: change_info register: change_info
- name: Add SQLServer user - name: Add SQLServer user
mssql_script: community.general.mssql_script:
login_user: "{{ jms_account.username }}" login_user: "{{ jms_account.username }}"
login_password: "{{ jms_account.secret }}" login_password: "{{ jms_account.secret }}"
login_host: "{{ jms_asset.address }}" login_host: "{{ jms_asset.address }}"
login_port: "{{ jms_asset.port }}" login_port: "{{ jms_asset.port }}"
name: '{{ jms_asset.spec_info.db_name }}' name: '{{ jms_asset.spec_info.db_name }}'
encryption: "{{ jms_asset.encryption | default(None) }}"
tds_version: "{{ jms_asset.tds_version | default(None) }}"
script: "CREATE LOGIN [{{ account.username }}] WITH PASSWORD = '{{ account.secret }}'; CREATE USER [{{ account.username }}] FOR LOGIN [{{ account.username }}]; select @@version" script: "CREATE LOGIN [{{ account.username }}] WITH PASSWORD = '{{ account.secret }}'; CREATE USER [{{ account.username }}] FOR LOGIN [{{ account.username }}]; select @@version"
ignore_errors: true ignore_errors: true
when: user_exist.query_results[0] | length == 0 when: user_exist.query_results[0] | length == 0
register: change_info register: change_info
- name: Verify password - name: Verify password
mssql_script: community.general.mssql_script:
login_user: "{{ account.username }}" login_user: "{{ account.username }}"
login_password: "{{ account.secret }}" login_password: "{{ account.secret }}"
login_host: "{{ jms_asset.address }}" login_host: "{{ jms_asset.address }}"
login_port: "{{ jms_asset.port }}" login_port: "{{ jms_asset.port }}"
name: '{{ jms_asset.spec_info.db_name }}' name: '{{ jms_asset.spec_info.db_name }}'
encryption: "{{ jms_asset.encryption | default(None) }}"
tds_version: "{{ jms_asset.tds_version | default(None) }}"
script: | script: |
SELECT @@version SELECT @@version
when: check_conn_after_change when: check_conn_after_change

View File

@@ -18,7 +18,6 @@
uid: "{{ params.uid | int if params.uid | length > 0 else omit }}" uid: "{{ params.uid | int if params.uid | length > 0 else omit }}"
shell: "{{ params.shell if params.shell | length > 0 else omit }}" shell: "{{ params.shell if params.shell | length > 0 else omit }}"
home: "{{ params.home if params.home | length > 0 else '/home/' + account.username }}" home: "{{ params.home if params.home | length > 0 else '/home/' + account.username }}"
group: "{{ params.group if params.group | length > 0 else omit }}"
groups: "{{ params.groups if params.groups | length > 0 else omit }}" groups: "{{ params.groups if params.groups | length > 0 else omit }}"
append: "{{ true if params.groups | length > 0 else false }}" append: "{{ true if params.groups | length > 0 else false }}"
expires: -1 expires: -1

View File

@@ -28,12 +28,6 @@ params:
default: '' default: ''
help_text: "{{ 'Params home help text' | trans }}" help_text: "{{ 'Params home help text' | trans }}"
- name: group
type: str
label: "{{ 'Params group label' | trans }}"
default: ''
help_text: "{{ 'Params group help text' | trans }}"
- name: groups - name: groups
type: str type: str
label: "{{ 'Params groups label' | trans }}" label: "{{ 'Params groups label' | trans }}"
@@ -67,11 +61,6 @@ i18n:
ja: 'デフォルトのホームディレクトリ /home/{アカウントユーザ名}' ja: 'デフォルトのホームディレクトリ /home/{アカウントユーザ名}'
en: 'Default home directory /home/{account username}' en: 'Default home directory /home/{account username}'
Params group help text:
zh: '请输入用户组(名字或数字),只能输入一个(需填写已存在的用户组)'
ja: 'ユーザー グループ (名前または番号) を入力してください。入力できるのは 1 つだけです (既存のユーザー グループを入力する必要があります)'
en: 'Please enter a user group (name or number), only one can be entered (must fill in an existing user group)'
Params groups help text: Params groups help text:
zh: '请输入用户组,多个用户组使用逗号分隔(需填写已存在的用户组)' zh: '请输入用户组,多个用户组使用逗号分隔(需填写已存在的用户组)'
ja: 'グループを入力してください。複数のグループはコンマで区切ってください(既存のグループを入力してください)' ja: 'グループを入力してください。複数のグループはコンマで区切ってください(既存のグループを入力してください)'
@@ -97,11 +86,6 @@ i18n:
ja: 'グループ' ja: 'グループ'
en: 'Groups' en: 'Groups'
Params group label:
zh: '主组'
ja: '主组'
en: 'Main group'
Params uid label: Params uid label:
zh: '用户ID' zh: '用户ID'
ja: 'ユーザーID' ja: 'ユーザーID'

View File

@@ -18,7 +18,6 @@
uid: "{{ params.uid | int if params.uid | length > 0 else omit }}" uid: "{{ params.uid | int if params.uid | length > 0 else omit }}"
shell: "{{ params.shell if params.shell | length > 0 else omit }}" shell: "{{ params.shell if params.shell | length > 0 else omit }}"
home: "{{ params.home if params.home | length > 0 else '/home/' + account.username }}" home: "{{ params.home if params.home | length > 0 else '/home/' + account.username }}"
group: "{{ params.group if params.group | length > 0 else omit }}"
groups: "{{ params.groups if params.groups | length > 0 else omit }}" groups: "{{ params.groups if params.groups | length > 0 else omit }}"
append: "{{ true if params.groups | length > 0 else false }}" append: "{{ true if params.groups | length > 0 else false }}"
expires: -1 expires: -1

View File

@@ -30,12 +30,6 @@ params:
default: '' default: ''
help_text: "{{ 'Params home help text' | trans }}" help_text: "{{ 'Params home help text' | trans }}"
- name: group
type: str
label: "{{ 'Params group label' | trans }}"
default: ''
help_text: "{{ 'Params group help text' | trans }}"
- name: groups - name: groups
type: str type: str
label: "{{ 'Params groups label' | trans }}" label: "{{ 'Params groups label' | trans }}"
@@ -69,11 +63,6 @@ i18n:
ja: 'デフォルトのホームディレクトリ /home/{アカウントユーザ名}' ja: 'デフォルトのホームディレクトリ /home/{アカウントユーザ名}'
en: 'Default home directory /home/{account username}' en: 'Default home directory /home/{account username}'
Params group help text:
zh: '请输入用户组(名字或数字),只能输入一个(需填写已存在的用户组)'
ja: 'ユーザー グループ (名前または番号) を入力してください。入力できるのは 1 つだけです (既存のユーザー グループを入力する必要があります)'
en: 'Please enter a user group (name or number), only one can be entered (must fill in an existing user group)'
Params groups help text: Params groups help text:
zh: '请输入用户组,多个用户组使用逗号分隔(需填写已存在的用户组)' zh: '请输入用户组,多个用户组使用逗号分隔(需填写已存在的用户组)'
ja: 'グループを入力してください。複数のグループはコンマで区切ってください(既存のグループを入力してください)' ja: 'グループを入力してください。複数のグループはコンマで区切ってください(既存のグループを入力してください)'
@@ -95,14 +84,9 @@ i18n:
en: 'Home' en: 'Home'
Params groups label: Params groups label:
zh: '附加组' zh: '用户组'
ja: '追加グループ' ja: 'グループ'
en: 'Additional Group' en: 'Groups'
Params group label:
zh: '主组'
ja: '主组'
en: 'Main group'
Params uid label: Params uid label:
zh: '用户ID' zh: '用户ID'

View File

@@ -5,13 +5,11 @@
tasks: tasks:
- name: "Remove account" - name: "Remove account"
mssql_script: community.general.mssql_script:
login_user: "{{ jms_account.username }}" login_user: "{{ jms_account.username }}"
login_password: "{{ jms_account.secret }}" login_password: "{{ jms_account.secret }}"
login_host: "{{ jms_asset.address }}" login_host: "{{ jms_asset.address }}"
login_port: "{{ jms_asset.port }}" login_port: "{{ jms_asset.port }}"
name: "{{ jms_asset.spec_info.db_name }}" name: "{{ jms_asset.spec_info.db_name }}"
encryption: "{{ jms_asset.encryption | default(None) }}"
tds_version: "{{ jms_asset.tds_version | default(None) }}"
script: "DROP LOGIN {{ account.username }}; select @@version" script: "DROP LOGIN {{ account.username }}; select @@version"

View File

@@ -5,13 +5,11 @@
tasks: tasks:
- name: Verify account - name: Verify account
mssql_script: community.general.mssql_script:
login_user: "{{ account.username }}" login_user: "{{ account.username }}"
login_password: "{{ account.secret }}" login_password: "{{ account.secret }}"
login_host: "{{ jms_asset.address }}" login_host: "{{ jms_asset.address }}"
login_port: "{{ jms_asset.port }}" login_port: "{{ jms_asset.port }}"
name: '{{ jms_asset.spec_info.db_name }}' name: '{{ jms_asset.spec_info.db_name }}'
encryption: "{{ jms_asset.encryption | default(None) }}"
tds_version: "{{ jms_asset.tds_version | default(None) }}"
script: | script: |
SELECT @@version SELECT @@version

View File

@@ -0,0 +1,13 @@
- hosts: website
gather_facts: no
vars:
ansible_python_interpreter: "{{ local_python_interpreter }}"
tasks:
- name: Verify account
website_ping:
login_host: "{{ jms_asset.address }}"
login_user: "{{ account.username }}"
login_password: "{{ account.secret }}"
steps: "{{ params.steps }}"
load_state: "{{ params.load_state }}"

View File

@@ -0,0 +1,50 @@
id: verify_account_website
name: "{{ 'Website account verify' | trans }}"
category: web
type:
- website
method: verify_account
priority: 50
params:
- name: load_state
type: choice
label: "{{ 'Load state' | trans }}"
choices:
- [ networkidle, "{{ 'Network idle' | trans }}" ]
- [ domcontentloaded, "{{ 'Dom content loaded' | trans }}" ]
- [ load, "{{ 'Load completed' | trans }}" ]
default: 'load'
- name: steps
type: list
label: "{{ 'Steps' | trans }}"
help_text: "{{ 'Params step help text' | trans }}"
default: []
i18n:
Website account verify:
zh: 使用 Playwright 模拟浏览器验证账号
ja: Playwright を使用してブラウザをシミュレートし、アカウントの検証を行います
en: Use Playwright to simulate a browser for account verification.
Load state:
zh: 加载状态检测
en: Load state detection
ja: ロード状態の検出
Steps:
zh: 步骤
en: Steps
ja: 手順
Network idle:
zh: 网络空闲
en: Network idle
ja: ネットワークが空いた状態
Dom content loaded:
zh: 文档内容加载完成
en: Dom content loaded
ja: ドキュメントの内容がロードされた状態
Load completed:
zh: 全部加载完成
en: All load completed
ja: すべてのロードが完了した状態
Params step help text:
zh: 配置步骤,根据配置决定任务执行步骤
ja: パラメータを設定し、設定に基づいてタスクの実行手順を決定します
en: Configure steps, and determine the task execution steps based on the configuration.

View File

@@ -234,7 +234,7 @@ class AutomationExecutionFilterSet(DaysExecutionFilterMixin, BaseFilterSet):
class Meta: class Meta:
model = AutomationExecution model = AutomationExecution
fields = ["id", "days", 'trigger', 'automation__name'] fields = ["days", 'trigger', 'automation__name']
class PushAccountRecordFilterSet(SecretRecordMixin, UUIDFilterMixin, BaseFilterSet): class PushAccountRecordFilterSet(SecretRecordMixin, UUIDFilterMixin, BaseFilterSet):

View File

@@ -81,9 +81,7 @@ class VaultModelMixin(models.Model):
def mark_secret_save_to_vault(self): def mark_secret_save_to_vault(self):
self._secret = self._secret_save_to_vault_mark self._secret = self._secret_save_to_vault_mark
self.skip_history_when_saving = True self.skip_history_when_saving = True
# Avoid calling overridden `save()` on concrete models (e.g. AccountTemplate) self.save()
# which may mutate `secret/_secret` again and cause post_save recursion.
super(VaultModelMixin, self).save(update_fields=['_secret'])
@property @property
def secret_has_save_to_vault(self): def secret_has_save_to_vault(self):

View File

@@ -14,7 +14,7 @@ from accounts.models import Account, AccountTemplate, GatheredAccount
from accounts.tasks import push_accounts_to_assets_task from accounts.tasks import push_accounts_to_assets_task
from assets.const import Category, AllTypes from assets.const import Category, AllTypes
from assets.models import Asset from assets.models import Asset
from common.serializers import SecretReadableMixin, SecretReadableCheckMixin, CommonBulkModelSerializer from common.serializers import SecretReadableMixin
from common.serializers.fields import ObjectRelatedField, LabeledChoiceField from common.serializers.fields import ObjectRelatedField, LabeledChoiceField
from common.utils import get_logger from common.utils import get_logger
from .base import BaseAccountSerializer, AuthValidateMixin from .base import BaseAccountSerializer, AuthValidateMixin
@@ -253,8 +253,6 @@ class AccountSerializer(AccountCreateUpdateSerializerMixin, BaseAccountSerialize
'source_id': {'required': False, 'allow_null': True}, 'source_id': {'required': False, 'allow_null': True},
} }
fields_unimport_template = ['params'] fields_unimport_template = ['params']
# 手动判断唯一性校验
validators = []
@classmethod @classmethod
def setup_eager_loading(cls, queryset): def setup_eager_loading(cls, queryset):
@@ -265,21 +263,6 @@ class AccountSerializer(AccountCreateUpdateSerializerMixin, BaseAccountSerialize
) )
return queryset return queryset
def validate(self, attrs):
instance = getattr(self, "instance", None)
if instance:
return super().validate(attrs)
field_errors = {}
for _fields in Account._meta.unique_together:
lookup = {field: attrs.get(field) for field in _fields}
if Account.objects.filter(**lookup).exists():
verbose_names = ', '.join([str(Account._meta.get_field(f).verbose_name) for f in _fields])
msg_template = _('Account already exists. Field(s): {fields} must be unique.')
field_errors[_fields[0]] = msg_template.format(fields=verbose_names)
raise serializers.ValidationError(field_errors)
return attrs
class AccountDetailSerializer(AccountSerializer): class AccountDetailSerializer(AccountSerializer):
has_secret = serializers.BooleanField(label=_("Has secret"), read_only=True) has_secret = serializers.BooleanField(label=_("Has secret"), read_only=True)
@@ -292,26 +275,26 @@ class AccountDetailSerializer(AccountSerializer):
class AssetAccountBulkSerializerResultSerializer(serializers.Serializer): class AssetAccountBulkSerializerResultSerializer(serializers.Serializer):
asset = serializers.CharField(read_only=True, label=_('Asset')) asset = serializers.CharField(read_only=True, label=_('Asset'))
account = serializers.CharField(read_only=True, label=_('Account'))
state = serializers.CharField(read_only=True, label=_('State')) state = serializers.CharField(read_only=True, label=_('State'))
error = serializers.CharField(read_only=True, label=_('Error')) error = serializers.CharField(read_only=True, label=_('Error'))
changed = serializers.BooleanField(read_only=True, label=_('Changed')) changed = serializers.BooleanField(read_only=True, label=_('Changed'))
class AssetAccountBulkSerializer( class AssetAccountBulkSerializer(
AccountCreateUpdateSerializerMixin, AuthValidateMixin, CommonBulkModelSerializer AccountCreateUpdateSerializerMixin, AuthValidateMixin, serializers.ModelSerializer
): ):
su_from_username = serializers.CharField( su_from_username = serializers.CharField(
max_length=128, required=False, write_only=True, allow_null=True, label=_("Su from"), max_length=128, required=False, write_only=True, allow_null=True, label=_("Su from"),
allow_blank=True, allow_blank=True,
) )
assets = serializers.PrimaryKeyRelatedField(queryset=Asset.objects, many=True, label=_('Assets'))
class Meta: class Meta:
model = Account model = Account
fields = [ fields = [
'name', 'username', 'secret', 'secret_type', 'secret_reset', 'name', 'username', 'secret', 'secret_type', 'passphrase',
'passphrase', 'privileged', 'is_active', 'comment', 'template', 'privileged', 'is_active', 'comment', 'template',
'on_invalid', 'push_now', 'params', 'on_invalid', 'push_now', 'params', 'assets',
'su_from_username', 'source', 'source_id', 'su_from_username', 'source', 'source_id',
] ]
extra_kwargs = { extra_kwargs = {
@@ -393,7 +376,8 @@ class AssetAccountBulkSerializer(
handler = self._handle_err_create handler = self._handle_err_create
return handler return handler
def perform_bulk_create(self, vd, assets): def perform_bulk_create(self, vd):
assets = vd.pop('assets')
on_invalid = vd.pop('on_invalid', 'skip') on_invalid = vd.pop('on_invalid', 'skip')
secret_type = vd.get('secret_type', 'password') secret_type = vd.get('secret_type', 'password')
@@ -401,7 +385,8 @@ class AssetAccountBulkSerializer(
vd['name'] = vd.get('username') vd['name'] = vd.get('username')
create_handler = self.get_create_handler(on_invalid) create_handler = self.get_create_handler(on_invalid)
secret_type_supports = Asset.get_secret_type_assets(assets, secret_type) asset_ids = [asset.id for asset in assets]
secret_type_supports = Asset.get_secret_type_assets(asset_ids, secret_type)
_results = {} _results = {}
for asset in assets: for asset in assets:
@@ -409,7 +394,6 @@ class AssetAccountBulkSerializer(
_results[asset] = { _results[asset] = {
'error': _('Asset does not support this secret type: %s') % secret_type, 'error': _('Asset does not support this secret type: %s') % secret_type,
'state': 'error', 'state': 'error',
'account': vd['name'],
} }
continue continue
@@ -419,13 +403,13 @@ class AssetAccountBulkSerializer(
self.clean_auth_fields(vd) self.clean_auth_fields(vd)
instance, changed, state = self.perform_create(vd, create_handler) instance, changed, state = self.perform_create(vd, create_handler)
_results[asset] = { _results[asset] = {
'changed': changed, 'instance': instance.id, 'state': state, 'account': vd['name'] 'changed': changed, 'instance': instance.id, 'state': state
} }
except serializers.ValidationError as e: except serializers.ValidationError as e:
_results[asset] = {'error': e.detail[0], 'state': 'error', 'account': vd['name']} _results[asset] = {'error': e.detail[0], 'state': 'error'}
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
_results[asset] = {'error': str(e), 'state': 'error', 'account': vd['name']} _results[asset] = {'error': str(e), 'state': 'error'}
results = [{'asset': asset, **result} for asset, result in _results.items()] results = [{'asset': asset, **result} for asset, result in _results.items()]
state_score = {'created': 3, 'updated': 2, 'skipped': 1, 'error': 0} state_score = {'created': 3, 'updated': 2, 'skipped': 1, 'error': 0}
@@ -442,8 +426,7 @@ class AssetAccountBulkSerializer(
errors.append({ errors.append({
'error': _('Account has exist'), 'error': _('Account has exist'),
'state': 'error', 'state': 'error',
'asset': str(result['asset']), 'asset': str(result['asset'])
'account': result.get('account'),
}) })
if errors: if errors:
raise serializers.ValidationError(errors) raise serializers.ValidationError(errors)
@@ -462,23 +445,17 @@ class AssetAccountBulkSerializer(
account_ids = [str(_id) for _id in accounts.values_list('id', flat=True)] account_ids = [str(_id) for _id in accounts.values_list('id', flat=True)]
push_accounts_to_assets_task.delay(account_ids, params) push_accounts_to_assets_task.delay(account_ids, params)
def bulk_create(self, validated_data, assets): def create(self, validated_data):
if not assets:
raise serializers.ValidationError(
{'assets': _('At least one asset or node must be specified')},
{'nodes': _('At least one asset or node must be specified')}
)
params = validated_data.pop('params', None) params = validated_data.pop('params', None)
push_now = validated_data.pop('push_now', False) push_now = validated_data.pop('push_now', False)
results = self.perform_bulk_create(validated_data, assets) results = self.perform_bulk_create(validated_data)
self.push_accounts_if_need(results, push_now, params) self.push_accounts_if_need(results, push_now, params)
for res in results: for res in results:
res['asset'] = str(res['asset']) res['asset'] = str(res['asset'])
return results return results
class AccountSecretSerializer(SecretReadableCheckMixin, SecretReadableMixin, AccountSerializer): class AccountSecretSerializer(SecretReadableMixin, AccountSerializer):
spec_info = serializers.DictField(label=_('Spec info'), read_only=True) spec_info = serializers.DictField(label=_('Spec info'), read_only=True)
class Meta(AccountSerializer.Meta): class Meta(AccountSerializer.Meta):
@@ -491,10 +468,9 @@ class AccountSecretSerializer(SecretReadableCheckMixin, SecretReadableMixin, Acc
exclude_backup_fields = [ exclude_backup_fields = [
'passphrase', 'push_now', 'params', 'spec_info' 'passphrase', 'push_now', 'params', 'spec_info'
] ]
secret_fields = ['secret']
class AccountHistorySerializer(SecretReadableCheckMixin, serializers.ModelSerializer): class AccountHistorySerializer(serializers.ModelSerializer):
secret_type = LabeledChoiceField(choices=SecretType.choices, label=_('Secret type')) secret_type = LabeledChoiceField(choices=SecretType.choices, label=_('Secret type'))
secret = serializers.CharField(label=_('Secret'), read_only=True) secret = serializers.CharField(label=_('Secret'), read_only=True)
id = serializers.IntegerField(label=_('ID'), source='history_id', read_only=True) id = serializers.IntegerField(label=_('ID'), source='history_id', read_only=True)
@@ -510,7 +486,6 @@ class AccountHistorySerializer(SecretReadableCheckMixin, serializers.ModelSerial
'history_user': {'label': _('User')}, 'history_user': {'label': _('User')},
'history_date': {'label': _('Date')}, 'history_date': {'label': _('Date')},
} }
secret_fields = ['secret']
class AccountTaskSerializer(serializers.Serializer): class AccountTaskSerializer(serializers.Serializer):

View File

@@ -2,7 +2,7 @@ from django.utils.translation import gettext_lazy as _
from rest_framework import serializers from rest_framework import serializers
from accounts.models import AccountTemplate from accounts.models import AccountTemplate
from common.serializers import SecretReadableMixin, SecretReadableCheckMixin from common.serializers import SecretReadableMixin
from common.serializers.fields import ObjectRelatedField from common.serializers.fields import ObjectRelatedField
from .base import BaseAccountSerializer from .base import BaseAccountSerializer
@@ -62,11 +62,10 @@ class AccountDetailTemplateSerializer(AccountTemplateSerializer):
fields = AccountTemplateSerializer.Meta.fields + ['spec_info'] fields = AccountTemplateSerializer.Meta.fields + ['spec_info']
class AccountTemplateSecretSerializer(SecretReadableCheckMixin, SecretReadableMixin, AccountDetailTemplateSerializer): class AccountTemplateSecretSerializer(SecretReadableMixin, AccountDetailTemplateSerializer):
class Meta(AccountDetailTemplateSerializer.Meta): class Meta(AccountDetailTemplateSerializer.Meta):
fields = AccountDetailTemplateSerializer.Meta.fields fields = AccountDetailTemplateSerializer.Meta.fields
extra_kwargs = { extra_kwargs = {
**AccountDetailTemplateSerializer.Meta.extra_kwargs, **AccountDetailTemplateSerializer.Meta.extra_kwargs,
'secret': {'write_only': False}, 'secret': {'write_only': False},
} }
secret_fields = ['secret']

View File

@@ -79,7 +79,7 @@ class VaultSignalHandler(object):
else: else:
vault_client.update(instance) vault_client.update(instance)
except Exception as e: except Exception as e:
logger.exception('Vault save failed: %s', e) logger.error('Vault save failed: {}'.format(e))
raise VaultException() raise VaultException()
@staticmethod @staticmethod
@@ -87,7 +87,7 @@ class VaultSignalHandler(object):
try: try:
vault_client.delete(instance) vault_client.delete(instance)
except Exception as e: except Exception as e:
logger.exception('Vault delete failed: %s', e) logger.error('Vault delete failed: {}'.format(e))
raise VaultException() raise VaultException()

View File

@@ -6,7 +6,7 @@ from django.utils.translation import gettext_lazy as _
from accounts.backends import vault_client from accounts.backends import vault_client
from accounts.const import VaultTypeChoices from accounts.const import VaultTypeChoices
from accounts.models import AccountTemplate, Account from accounts.models import Account, AccountTemplate
from common.utils import get_logger from common.utils import get_logger
from orgs.utils import tmp_to_root_org from orgs.utils import tmp_to_root_org

View File

@@ -0,0 +1,36 @@
{% load i18n %}
<h3>{% trans 'Task name' %}: {{ name }}</h3>
<h3>{% trans 'Task execution id' %}: {{ execution_id }}</h3>
<p>{% trans 'Respectful' %} {{ recipient }}</p>
<p>{% trans 'Hello! The following is the failure of changing the password of your assets or pushing the account. Please check and handle it in time.' %}</p>
<table style="width: 100%; border-collapse: collapse; max-width: 100%; text-align: left; margin-top: 20px;">
<caption></caption>
<thead>
<tr style="background-color: #f2f2f2;">
<th style="border: 1px solid #ddd; padding: 10px;">{% trans 'Asset' %}</th>
<th style="border: 1px solid #ddd; padding: 10px;">{% trans 'Account' %}</th>
<th style="border: 1px solid #ddd; padding: 10px;">{% trans 'Error' %}</th>
</tr>
</thead>
<tbody>
{% for asset_name, account_username, error in asset_account_errors %}
<tr>
<td style="border: 1px solid #ddd; padding: 10px;">{{ asset_name }}</td>
<td style="border: 1px solid #ddd; padding: 10px;">{{ account_username }}</td>
<td style="border: 1px solid #ddd; padding: 10px;">
<div style="
max-width: 90%;
white-space: nowrap;
overflow: hidden;
text-overflow: ellipsis;
display: block;"
title="{{ error }}"
>
{{ error }}
</div>
</td>
</tr>
{% endfor %}
</tbody>
</table>

View File

@@ -3,4 +3,3 @@ from .connect_method import *
from .login_acl import * from .login_acl import *
from .login_asset_acl import * from .login_asset_acl import *
from .login_asset_check import * from .login_asset_check import *
from .data_masking import *

View File

@@ -1,20 +0,0 @@
from orgs.mixins.api import OrgBulkModelViewSet
from .common import ACLUserFilterMixin
from ..models import DataMaskingRule
from .. import serializers
__all__ = ['DataMaskingRuleViewSet']
class DataMaskingRuleFilter(ACLUserFilterMixin):
class Meta:
model = DataMaskingRule
fields = ('name',)
class DataMaskingRuleViewSet(OrgBulkModelViewSet):
model = DataMaskingRule
filterset_class = DataMaskingRuleFilter
search_fields = ('name',)
serializer_class = serializers.DataMaskingRuleSerializer

View File

@@ -8,7 +8,7 @@ __all__ = ['LoginAssetACLViewSet']
class LoginAssetACLFilter(ACLUserAssetFilterMixin): class LoginAssetACLFilter(ACLUserAssetFilterMixin):
class Meta: class Meta:
model = models.LoginAssetACL model = models.LoginAssetACL
fields = ['name', 'action'] fields = ['name', ]
class LoginAssetACLViewSet(OrgBulkModelViewSet): class LoginAssetACLViewSet(OrgBulkModelViewSet):

View File

@@ -1,45 +0,0 @@
# Generated by Django 4.1.13 on 2025-10-07 16:16
import common.db.fields
from django.conf import settings
import django.core.validators
from django.db import migrations, models
import uuid
class Migration(migrations.Migration):
dependencies = [
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
('acls', '0002_auto_20210926_1047'),
]
operations = [
migrations.CreateModel(
name='DataMaskingRule',
fields=[
('created_by', models.CharField(blank=True, max_length=128, null=True, verbose_name='Created by')),
('updated_by', models.CharField(blank=True, max_length=128, null=True, verbose_name='Updated by')),
('date_created', models.DateTimeField(auto_now_add=True, null=True, verbose_name='Date created')),
('date_updated', models.DateTimeField(auto_now=True, verbose_name='Date updated')),
('comment', models.TextField(blank=True, default='', verbose_name='Comment')),
('id', models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False)),
('org_id', models.CharField(blank=True, db_index=True, default='', max_length=36, verbose_name='Organization')),
('priority', models.IntegerField(default=50, help_text='1-100, the lower the value will be match first', validators=[django.core.validators.MinValueValidator(1), django.core.validators.MaxValueValidator(100)], verbose_name='Priority')),
('action', models.CharField(default='reject', max_length=64, verbose_name='Action')),
('is_active', models.BooleanField(default=True, verbose_name='Active')),
('users', common.db.fields.JSONManyToManyField(default=dict, to='users.User', verbose_name='Users')),
('assets', common.db.fields.JSONManyToManyField(default=dict, to='assets.Asset', verbose_name='Assets')),
('accounts', models.JSONField(default=list, verbose_name='Accounts')),
('name', models.CharField(max_length=128, verbose_name='Name')),
('fields_pattern', models.CharField(default='password', max_length=128, verbose_name='Fields pattern')),
('masking_method', models.CharField(choices=[('fixed_char', 'Fixed Character Replacement'), ('hide_middle', 'Hide Middle Characters'), ('keep_prefix', 'Keep Prefix Only'), ('keep_suffix', 'Keep Suffix Only')], default='fixed_char', max_length=32, verbose_name='Masking Method')),
('mask_pattern', models.CharField(blank=True, default='######', max_length=128, null=True, verbose_name='Mask Pattern')),
('reviewers', models.ManyToManyField(blank=True, to=settings.AUTH_USER_MODEL, verbose_name='Reviewers')),
],
options={
'verbose_name': 'Data Masking Rule',
'unique_together': {('org_id', 'name')},
},
),
]

View File

@@ -2,4 +2,3 @@ from .command_acl import *
from .connect_method import * from .connect_method import *
from .login_acl import * from .login_acl import *
from .login_asset_acl import * from .login_asset_acl import *
from .data_masking import *

View File

@@ -1,42 +0,0 @@
from django.db import models
from acls.models import UserAssetAccountBaseACL
from common.utils import get_logger
from django.utils.translation import gettext_lazy as _
logger = get_logger(__file__)
__all__ = ['MaskingMethod', 'DataMaskingRule']
class MaskingMethod(models.TextChoices):
fixed_char = "fixed_char", _("Fixed Character Replacement") # 固定字符替换
hide_middle = "hide_middle", _("Hide Middle Characters") # 隐藏中间几位
keep_prefix = "keep_prefix", _("Keep Prefix Only") # 只保留前缀
keep_suffix = "keep_suffix", _("Keep Suffix Only") # 只保留后缀
class DataMaskingRule(UserAssetAccountBaseACL):
name = models.CharField(max_length=128, verbose_name=_("Name"))
fields_pattern = models.CharField(max_length=128, default='password', verbose_name=_("Fields pattern"))
masking_method = models.CharField(
max_length=32,
choices=MaskingMethod.choices,
default=MaskingMethod.fixed_char,
verbose_name=_("Masking Method"),
)
mask_pattern = models.CharField(
max_length=128,
verbose_name=_("Mask Pattern"),
default="######",
blank=True,
null=True,
)
def __str__(self):
return self.name
class Meta:
unique_together = [('org_id', 'name')]
verbose_name = _("Data Masking Rule")

View File

@@ -1,52 +1,30 @@
from django.utils import timezone from django.template.loader import render_to_string
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from accounts.models import Account from accounts.models import Account
from acls.models import LoginACL, LoginAssetACL
from assets.models import Asset from assets.models import Asset
from audits.models import UserLoginLog from audits.models import UserLoginLog
from common.views.template import custom_render_to_string
from notifications.notifications import UserMessage from notifications.notifications import UserMessage
from users.models import User from users.models import User
class UserLoginReminderMsg(UserMessage): class UserLoginReminderMsg(UserMessage):
subject = _('User login reminder') subject = _('User login reminder')
template_name = 'acls/user_login_reminder.html'
contexts = [
{"name": "city", "label": _('Login city'), "default": "Shanghai"},
{"name": "username", "label": _('User'), "default": "john"},
{"name": "ip", "label": "IP", "default": "192.168.1.1"},
{"name": "recipient_name", "label": _("Recipient name"), "default": "John"},
{"name": "recipient_username", "label": _("Recipient username"), "default": "john"},
{"name": "user_agent", "label": _('User agent'), "default": "Mozilla/5.0"},
{"name": "acl_name", "label": _('ACL name'), "default": "login acl"},
{"name": "login_from", "label": _('Login from'), "default": "web"},
{"name": "time", "label": _('Login time'), "default": "2025-01-01 12:00:00"},
]
def __init__(self, user, user_log: UserLoginLog, acl: LoginACL): def __init__(self, user, user_log: UserLoginLog):
self.user_log = user_log self.user_log = user_log
self.acl_name = str(acl)
self.login_from = user_log.get_type_display()
now = timezone.localtime(user_log.datetime)
self.time = now.strftime('%Y-%m-%d %H:%M:%S')
super().__init__(user) super().__init__(user)
def get_html_msg(self) -> dict: def get_html_msg(self) -> dict:
user_log = self.user_log user_log = self.user_log
context = { context = {
'ip': user_log.ip, 'ip': user_log.ip,
'time': self.time,
'city': user_log.city, 'city': user_log.city,
'acl_name': self.acl_name,
'login_from': self.login_from,
'username': user_log.username, 'username': user_log.username,
'recipient_name': self.user.name, 'recipient': self.user,
'recipient_username': self.user.username,
'user_agent': user_log.user_agent, 'user_agent': user_log.user_agent,
} }
message = custom_render_to_string(self.template_name, context) message = render_to_string('acls/user_login_reminder.html', context)
return { return {
'subject': str(self.subject), 'subject': str(self.subject),
@@ -62,55 +40,24 @@ class UserLoginReminderMsg(UserMessage):
class AssetLoginReminderMsg(UserMessage): class AssetLoginReminderMsg(UserMessage):
subject = _('User login alert for asset') subject = _('User login alert for asset')
template_name = 'acls/asset_login_reminder.html'
contexts = [
{"name": "city", "label": _('Login city'), "default": "Shanghai"},
{"name": "username", "label": _('User'), "default": "john"},
{"name": "name", "label": _('Name'), "default": "John"},
{"name": "asset", "label": _('Asset'), "default": "dev server"},
{"name": "recipient_name", "label": _('Recipient name'), "default": "John"},
{"name": "recipient_username", "label": _('Recipient username'), "default": "john"},
{"name": "account", "label": _('Account Input username'), "default": "root"},
{"name": "account_name", "label": _('Account name'), "default": "root"},
{"name": "acl_name", "label": _('ACL name'), "default": "login acl"},
{"name": "ip", "label": "IP", "default": "192.168.1.1"},
{"name": "login_from", "label": _('Login from'), "default": "web"},
{"name": "time", "label": _('Login time'), "default": "2025-01-01 12:00:00"}
]
def __init__( def __init__(self, user, asset: Asset, login_user: User, account: Account, input_username):
self, user, asset: Asset, login_user: User,
account: Account, acl: LoginAssetACL,
ip, input_username, login_from
):
self.ip = ip
self.asset = asset self.asset = asset
self.login_user = login_user self.login_user = login_user
self.account = account self.account = account
self.acl_name = str(acl)
self.login_from = login_from
self.login_user = login_user
self.input_username = input_username self.input_username = input_username
now = timezone.localtime(timezone.now())
self.time = now.strftime('%Y-%m-%d %H:%M:%S')
super().__init__(user) super().__init__(user)
def get_html_msg(self) -> dict: def get_html_msg(self) -> dict:
context = { context = {
'ip': self.ip, 'recipient': self.user,
'time': self.time,
'login_from': self.login_from,
'recipient_name': self.user.name,
'recipient_username': self.user.username,
'username': self.login_user.username, 'username': self.login_user.username,
'name': self.login_user.name, 'name': self.login_user.name,
'asset': str(self.asset), 'asset': str(self.asset),
'account': self.input_username, 'account': self.input_username,
'account_name': self.account.name, 'account_name': self.account.name,
'acl_name': self.acl_name,
} }
message = custom_render_to_string(self.template_name, context) message = render_to_string('acls/asset_login_reminder.html', context)
return { return {
'subject': str(self.subject), 'subject': str(self.subject),

View File

@@ -3,4 +3,3 @@ from .connect_method import *
from .login_acl import * from .login_acl import *
from .login_asset_acl import * from .login_asset_acl import *
from .login_asset_check import * from .login_asset_check import *
from .data_masking import *

View File

@@ -90,7 +90,7 @@ class BaseACLSerializer(ActionAclSerializer, serializers.Serializer):
fields_small = fields_mini + [ fields_small = fields_mini + [
"is_active", "priority", "action", "is_active", "priority", "action",
"date_created", "date_updated", "date_created", "date_updated",
"comment", "created_by" "comment", "created_by", "org_id",
] ]
fields_m2m = ["reviewers", ] fields_m2m = ["reviewers", ]
fields = fields_small + fields_m2m fields = fields_small + fields_m2m
@@ -100,20 +100,6 @@ class BaseACLSerializer(ActionAclSerializer, serializers.Serializer):
'reviewers': {'label': _('Recipients')}, 'reviewers': {'label': _('Recipients')},
} }
class BaseUserACLSerializer(BaseACLSerializer):
users = JSONManyToManyField(label=_('User'))
class Meta(BaseACLSerializer.Meta):
fields = BaseACLSerializer.Meta.fields + ['users']
class BaseUserAssetAccountACLSerializer(BaseUserACLSerializer):
assets = JSONManyToManyField(label=_('Asset'))
accounts = serializers.ListField(label=_('Account'))
class Meta(BaseUserACLSerializer.Meta):
fields = BaseUserACLSerializer.Meta.fields + ['assets', 'accounts', 'org_id']
def validate_reviewers(self, reviewers): def validate_reviewers(self, reviewers):
action = self.initial_data.get('action') action = self.initial_data.get('action')
if not action and self.instance: if not action and self.instance:
@@ -133,3 +119,18 @@ class BaseUserAssetAccountACLSerializer(BaseUserACLSerializer):
) )
raise serializers.ValidationError(error) raise serializers.ValidationError(error)
return valid_reviewers return valid_reviewers
class BaseUserACLSerializer(BaseACLSerializer):
users = JSONManyToManyField(label=_('User'))
class Meta(BaseACLSerializer.Meta):
fields = BaseACLSerializer.Meta.fields + ['users']
class BaseUserAssetAccountACLSerializer(BaseUserACLSerializer):
assets = JSONManyToManyField(label=_('Asset'))
accounts = serializers.ListField(label=_('Account'))
class Meta(BaseUserACLSerializer.Meta):
fields = BaseUserACLSerializer.Meta.fields + ['assets', 'accounts']

View File

@@ -1,19 +0,0 @@
from django.utils.translation import gettext_lazy as _
from acls.models import MaskingMethod, DataMaskingRule
from common.serializers.fields import LabeledChoiceField
from common.serializers.mixin import CommonBulkModelSerializer
from orgs.mixins.serializers import BulkOrgResourceModelSerializer
from .base import BaseUserAssetAccountACLSerializer as BaseSerializer
__all__ = ['DataMaskingRuleSerializer']
class DataMaskingRuleSerializer(BaseSerializer, BulkOrgResourceModelSerializer):
masking_method = LabeledChoiceField(
choices=MaskingMethod.choices, default=MaskingMethod.fixed_char, label=_('Masking Method')
)
class Meta(BaseSerializer.Meta):
model = DataMaskingRule
fields = BaseSerializer.Meta.fields + ['fields_pattern', 'masking_method', 'mask_pattern']

View File

@@ -17,7 +17,7 @@ class LoginACLSerializer(BaseUserACLSerializer, CommonBulkModelSerializer):
class Meta(BaseUserACLSerializer.Meta): class Meta(BaseUserACLSerializer.Meta):
model = LoginACL model = LoginACL
fields = list((set(BaseUserACLSerializer.Meta.fields) | {'rules'})) fields = list((set(BaseUserACLSerializer.Meta.fields) | {'rules'}) - {'org_id'})
action_choices_exclude = [ action_choices_exclude = [
ActionChoices.warning, ActionChoices.warning,
ActionChoices.notify_and_warn, ActionChoices.notify_and_warn,

View File

@@ -1,17 +1,13 @@
{% load i18n %} {% load i18n %}
<h3>{% trans 'Dear' %}: {{ recipient_name }}[{{ recipient_username }}]</h3> <h3>{% trans 'Dear' %}: {{ recipient.name }}[{{ recipient.username }}]</h3>
<hr> <hr>
<p>{% trans 'We would like to inform you that a user has recently logged into the following asset:' %}<p> <p>{% trans 'We would like to inform you that a user has recently logged into the following asset:' %}<p>
<p><strong>{% trans 'Asset details' %}:</strong></p> <p><strong>{% trans 'Asset details' %}:</strong></p>
<ul> <ul>
<li><strong>{% trans 'User' %}:</strong> [{{ name }}({{ username }})]</li> <li><strong>{% trans 'User' %}:</strong> [{{ name }}({{ username }})]</li>
<li><strong>IP:</strong> [{{ ip }}]</li>
<li><strong>{% trans 'Assets' %}:</strong> [{{ asset }}]</li> <li><strong>{% trans 'Assets' %}:</strong> [{{ asset }}]</li>
<li><strong>{% trans 'Account' %}:</strong> [{{ account_name }}({{ account }})]</li> <li><strong>{% trans 'Account' %}:</strong> [{{ account_name }}({{ account }})]</li>
<li><strong>{% trans 'Login asset acl' %}:</strong> [{{ acl_name }}]</li>
<li><strong>{% trans 'Login from' %}:</strong> [{{ login_from }}]</li>
<li><strong>{% trans 'Time' %}:</strong> [{{ time }}]</li>
</ul> </ul>
<hr> <hr>

View File

@@ -1,6 +1,6 @@
{% load i18n %} {% load i18n %}
<h3>{% trans 'Dear' %}: {{ recipient_name }}[{{ recipient_username }}]</h3> <h3>{% trans 'Dear' %}: {{ recipient.name }}[{{ recipient.username }}]</h3>
<hr> <hr>
<p>{% trans 'We would like to inform you that a user has recently logged:' %}<p> <p>{% trans 'We would like to inform you that a user has recently logged:' %}<p>
<p><strong>{% trans 'User details' %}:</strong></p> <p><strong>{% trans 'User details' %}:</strong></p>
@@ -8,10 +8,7 @@
<li><strong>{% trans 'User' %}:</strong> [{{ username }}]</li> <li><strong>{% trans 'User' %}:</strong> [{{ username }}]</li>
<li><strong>IP:</strong> [{{ ip }}]</li> <li><strong>IP:</strong> [{{ ip }}]</li>
<li><strong>{% trans 'Login city' %}:</strong> [{{ city }}]</li> <li><strong>{% trans 'Login city' %}:</strong> [{{ city }}]</li>
<li><strong>{% trans 'Login from' %}:</strong> [{{ login_from }}]</li>
<li><strong>{% trans 'User agent' %}:</strong> [{{ user_agent }}]</li> <li><strong>{% trans 'User agent' %}:</strong> [{{ user_agent }}]</li>
<li><strong>{% trans 'Login acl' %}:</strong> [{{ acl_name }}]</li>
<li><strong>{% trans 'Time' %}:</strong> [{{ time }}]</li>
</ul> </ul>
<hr> <hr>

View File

@@ -11,7 +11,6 @@ router.register(r'login-asset-acls', api.LoginAssetACLViewSet, 'login-asset-acl'
router.register(r'command-filter-acls', api.CommandFilterACLViewSet, 'command-filter-acl') router.register(r'command-filter-acls', api.CommandFilterACLViewSet, 'command-filter-acl')
router.register(r'command-groups', api.CommandGroupViewSet, 'command-group') router.register(r'command-groups', api.CommandGroupViewSet, 'command-group')
router.register(r'connect-method-acls', api.ConnectMethodACLViewSet, 'connect-method-acl') router.register(r'connect-method-acls', api.ConnectMethodACLViewSet, 'connect-method-acl')
router.register(r'data-masking-rules', api.DataMaskingRuleViewSet, 'data-masking-rule')
urlpatterns = [ urlpatterns = [
path('login-asset/check/', api.LoginAssetCheckAPI.as_view(), name='login-asset-check'), path('login-asset/check/', api.LoginAssetCheckAPI.as_view(), name='login-asset-check'),

View File

@@ -1,7 +1,8 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
from collections import defaultdict
from django.conf import settings from django.conf import settings
from django.db import transaction
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from django_filters import rest_framework as drf_filters from django_filters import rest_framework as drf_filters
@@ -112,7 +113,7 @@ class BaseAssetViewSet(OrgBulkModelViewSet):
("accounts", AccountSerializer), ("accounts", AccountSerializer),
) )
rbac_perms = ( rbac_perms = (
("match", "assets.view_asset"), ("match", "assets.match_asset"),
("platform", "assets.view_platform"), ("platform", "assets.view_platform"),
("gateways", "assets.view_gateway"), ("gateways", "assets.view_gateway"),
("accounts", "assets.view_account"), ("accounts", "assets.view_account"),
@@ -180,17 +181,32 @@ class AssetViewSet(SuggestionMixin, BaseAssetViewSet):
def sync_platform_protocols(self, request, *args, **kwargs): def sync_platform_protocols(self, request, *args, **kwargs):
platform_id = request.data.get('platform_id') platform_id = request.data.get('platform_id')
platform = get_object_or_404(Platform, pk=platform_id) platform = get_object_or_404(Platform, pk=platform_id)
asset_ids = list(platform.assets.values_list('id', flat=True)) assets = platform.assets.all()
platform_protocols = list(platform.protocols.values('name', 'port'))
with transaction.atomic(): platform_protocols = {
if asset_ids: p['name']: p['port']
Protocol.objects.filter(asset_id__in=asset_ids).delete() for p in platform.protocols.values('name', 'port')
if asset_ids and platform_protocols: }
asset_protocols_map = defaultdict(set)
protocols = assets.prefetch_related('protocols').values_list(
'id', 'protocols__name'
)
for asset_id, protocol in protocols:
asset_id = str(asset_id)
asset_protocols_map[asset_id].add(protocol)
objs = [] objs = []
for aid in asset_ids: for asset_id, protocols in asset_protocols_map.items():
for p in platform_protocols: protocol_names = set(platform_protocols) - protocols
objs.append(Protocol(name=p['name'], port=p['port'], asset_id=aid)) if not protocol_names:
continue
for name in protocol_names:
objs.append(
Protocol(
name=name,
port=platform_protocols[name],
asset_id=asset_id,
)
)
Protocol.objects.bulk_create(objs) Protocol.objects.bulk_create(objs)
return Response(status=status.HTTP_200_OK) return Response(status=status.HTTP_200_OK)

View File

@@ -16,6 +16,7 @@ class CategoryViewSet(ListModelMixin, JMSGenericViewSet):
'types': TypeSerializer, 'types': TypeSerializer,
} }
permission_classes = (IsValidUser,) permission_classes = (IsValidUser,)
default_limit = None
def get_queryset(self): def get_queryset(self):
return AllTypes.categories() return AllTypes.categories()

View File

@@ -14,7 +14,7 @@ class FavoriteAssetViewSet(BulkModelViewSet):
serializer_class = FavoriteAssetSerializer serializer_class = FavoriteAssetSerializer
permission_classes = (IsValidUser,) permission_classes = (IsValidUser,)
filterset_fields = ['asset'] filterset_fields = ['asset']
page_no_limit = True default_limit = None
def dispatch(self, request, *args, **kwargs): def dispatch(self, request, *args, **kwargs):
with tmp_to_root_org(): with tmp_to_root_org():

View File

@@ -43,7 +43,7 @@ class NodeViewSet(SuggestionMixin, OrgBulkModelViewSet):
search_fields = ('full_value',) search_fields = ('full_value',)
serializer_class = serializers.NodeSerializer serializer_class = serializers.NodeSerializer
rbac_perms = { rbac_perms = {
'match': 'assets.view_node', 'match': 'assets.match_node',
'check_assets_amount_task': 'assets.change_node' 'check_assets_amount_task': 'assets.change_node'
} }

View File

@@ -43,7 +43,7 @@ class AssetPlatformViewSet(JMSModelViewSet):
'ops_methods': 'assets.view_platform', 'ops_methods': 'assets.view_platform',
'filter_nodes_assets': 'assets.view_platform', 'filter_nodes_assets': 'assets.view_platform',
} }
page_no_limit = True default_limit = None
def get_queryset(self): def get_queryset(self):
# 因为没有走分页逻辑,所以需要这里 prefetch # 因为没有走分页逻辑,所以需要这里 prefetch
@@ -112,10 +112,8 @@ class PlatformProtocolViewSet(JMSModelViewSet):
class PlatformAutomationMethodsApi(generics.ListAPIView): class PlatformAutomationMethodsApi(generics.ListAPIView):
permission_classes = (IsValidUser,)
queryset = PlatformAutomation.objects.none() queryset = PlatformAutomation.objects.none()
rbac_perms = {
'list': 'assets.view_platform'
}
@staticmethod @staticmethod
def automation_methods(): def automation_methods():

View File

@@ -1,8 +1,8 @@
from rest_framework.generics import ListAPIView from rest_framework.generics import ListAPIView
from assets import serializers from assets import serializers
from assets.const import Protocol
from common.permissions import IsValidUser from common.permissions import IsValidUser
from assets.models import Protocol
__all__ = ['ProtocolListApi'] __all__ = ['ProtocolListApi']

View File

@@ -126,7 +126,7 @@ class BaseManager:
self.execution.save() self.execution.save()
def print_summary(self): def print_summary(self):
content = "\nSummary: \n" content = "\nSummery: \n"
for k, v in self.summary.items(): for k, v in self.summary.items():
content += f"\t - {k}: {v}\n" content += f"\t - {k}: {v}\n"
content += "\t - Using: {}s\n".format(self.duration) content += "\t - Using: {}s\n".format(self.duration)
@@ -201,14 +201,17 @@ class PlaybookPrepareMixin:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
# example: {'gather_fact_windows': {'id': 'gather_fact_windows', 'name': '', 'method': 'gather_fact', ...} } # example: {'gather_fact_windows': {'id': 'gather_fact_windows', 'name': '', 'method': 'gather_fact', ...} }
self.method_id_meta_mapper = { self.method_id_meta_mapper = self.get_method_id_meta_mapper()
# 根据执行方式就行分组, 不同资产的改密、推送等操作可能会使用不同的执行方式
# 然后根据执行方式分组, 再根据 bulk_size 分组, 生成不同的 playbook
self.playbooks = []
def get_method_id_meta_mapper(self):
return {
method["id"]: method method["id"]: method
for method in self.platform_automation_methods for method in self.platform_automation_methods
if method["method"] == self.__class__.method_type() if method["method"] == self.__class__.method_type()
} }
# 根据执行方式就行分组, 不同资产的改密、推送等操作可能会使用不同的执行方式
# 然后根据执行方式分组, 再根据 bulk_size 分组, 生成不同的 playbook
self.playbooks = []
@classmethod @classmethod
def method_type(cls): def method_type(cls):

View File

@@ -6,13 +6,11 @@
tasks: tasks:
- name: Test SQLServer connection - name: Test SQLServer connection
mssql_script: community.general.mssql_script:
login_user: "{{ jms_account.username }}" login_user: "{{ jms_account.username }}"
login_password: "{{ jms_account.secret }}" login_password: "{{ jms_account.secret }}"
login_host: "{{ jms_asset.address }}" login_host: "{{ jms_asset.address }}"
login_port: "{{ jms_asset.port }}" login_port: "{{ jms_asset.port }}"
name: '{{ jms_asset.spec_info.db_name }}' name: '{{ jms_asset.spec_info.db_name }}'
encryption: "{{ jms_asset.encryption | default(None) }}"
tds_version: "{{ jms_asset.tds_version | default(None) }}"
script: | script: |
SELECT @@version SELECT @@version

View File

@@ -0,0 +1,13 @@
- hosts: website
gather_facts: no
vars:
ansible_python_interpreter: "{{ local_python_interpreter }}"
tasks:
- name: Test Website connection
website_ping:
login_host: "{{ jms_asset.address }}"
login_user: "{{ jms_account.username }}"
login_password: "{{ jms_account.secret }}"
steps: "{{ params.steps }}"
load_state: "{{ params.load_state }}"

View File

@@ -0,0 +1,50 @@
id: website_ping
name: "{{ 'Website ping' | trans }}"
method: ping
category:
- web
type:
- website
params:
- name: load_state
type: choice
label: "{{ 'Load state' | trans }}"
choices:
- [ networkidle, "{{ 'Network idle' | trans }}" ]
- [ domcontentloaded, "{{ 'Dom content loaded' | trans }}" ]
- [ load, "{{ 'Load completed' | trans }}" ]
default: 'load'
- name: steps
type: list
default: []
label: "{{ 'Steps' | trans }}"
help_text: "{{ 'Params step help text' | trans }}"
i18n:
Website ping:
zh: 使用 Playwright 模拟浏览器测试可连接性
en: Use Playwright to simulate a browser for connectivity testing
ja: Playwright を使用してブラウザをシミュレートし、接続性テストを実行する
Load state:
zh: 加载状态检测
en: Load state detection
ja: ロード状態の検出
Steps:
zh: 步骤
en: Steps
ja: 手順
Network idle:
zh: 网络空闲
en: Network idle
ja: ネットワークが空いた状態
Dom content loaded:
zh: 文档内容加载完成
en: Dom content loaded
ja: ドキュメントの内容がロードされた状態
Load completed:
zh: 全部加载完成
en: All load completed
ja: すべてのロードが完了した状態
Params step help text:
zh: 配置步骤,根据配置决定任务执行步骤
ja: パラメータを設定し、設定に基づいてタスクの実行手順を決定します
en: Configure steps, and determine the task execution steps based on the configuration.

View File

@@ -1,6 +1,5 @@
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from orgs.models import Organization
from .base import BaseType from .base import BaseType
@@ -53,41 +52,3 @@ class GPTTypes(BaseType):
return [ return [
cls.CHATGPT, cls.CHATGPT,
] ]
CHATX_NAME = 'ChatX'
def create_or_update_chatx_resources(chatx_name=CHATX_NAME, org_id=Organization.SYSTEM_ID):
from django.apps import apps
platform_model = apps.get_model('assets', 'Platform')
asset_model = apps.get_model('assets', 'Asset')
account_model = apps.get_model('accounts', 'Account')
platform, __ = platform_model.objects.get_or_create(
name=chatx_name,
defaults={
'internal': True,
'type': chatx_name,
'category': 'ai',
}
)
asset, __ = asset_model.objects.get_or_create(
address=chatx_name,
defaults={
'name': chatx_name,
'platform': platform,
'org_id': org_id
}
)
account, __ = account_model.objects.get_or_create(
username=chatx_name,
defaults={
'name': chatx_name,
'asset': asset,
'org_id': org_id
}
)
return asset, account

View File

@@ -250,12 +250,6 @@ class Protocol(ChoicesMixin, models.TextChoices):
'default': False, 'default': False,
'label': _('Auth username') 'label': _('Auth username')
}, },
'enable_cluster_mode': {
'type': 'bool',
'default': False,
'label': _('Enable cluster mode'),
'help_text': _('Enable if this Redis instance is part of a cluster')
},
} }
}, },
} }
@@ -268,14 +262,6 @@ class Protocol(ChoicesMixin, models.TextChoices):
'port_from_addr': True, 'port_from_addr': True,
'required': True, 'required': True,
'secret_types': ['token'], 'secret_types': ['token'],
'setting': {
'namespace': {
'type': 'str',
'required': False,
'default': '',
'label': _('Namespace')
}
}
}, },
cls.http: { cls.http: {
'port': 80, 'port': 80,

View File

@@ -20,13 +20,17 @@ class WebTypes(BaseType):
def _get_automation_constrains(cls) -> dict: def _get_automation_constrains(cls) -> dict:
constrains = { constrains = {
'*': { '*': {
'ansible_enabled': False, 'ansible_enabled': True,
'ping_enabled': False, 'ansible_config': {
'ansible_connection': 'local',
},
'ping_enabled': True,
'gather_facts_enabled': False, 'gather_facts_enabled': False,
'verify_account_enabled': False, 'verify_account_enabled': True,
'change_secret_enabled': False, 'change_secret_enabled': True,
'push_account_enabled': False, 'push_account_enabled': False,
'gather_accounts_enabled': False, 'gather_accounts_enabled': False,
'remove_account_enabled': False,
} }
} }
return constrains return constrains

View File

@@ -408,7 +408,8 @@ class Asset(NodesRelationMixin, LabeledMixin, AbsConnectivity, JSONFilterMixin,
return tree_node return tree_node
@staticmethod @staticmethod
def get_secret_type_assets(assets, secret_type): def get_secret_type_assets(asset_ids, secret_type):
assets = Asset.objects.filter(id__in=asset_ids)
asset_protocol = assets.prefetch_related('protocols').values_list('id', 'protocols__name') asset_protocol = assets.prefetch_related('protocols').values_list('id', 'protocols__name')
protocol_secret_types_map = const.Protocol.protocol_secret_types() protocol_secret_types_map = const.Protocol.protocol_secret_types()
asset_secret_types_mapp = defaultdict(set) asset_secret_types_mapp = defaultdict(set)

View File

@@ -28,8 +28,7 @@ class MyAsset(JMSBaseModel):
@staticmethod @staticmethod
def set_asset_custom_value(assets, user): def set_asset_custom_value(assets, user):
asset_ids = [asset.id for asset in assets] my_assets = MyAsset.objects.filter(asset__in=assets, user=user).all()
my_assets = MyAsset.objects.filter(asset_id__in=asset_ids, user=user).all()
customs = {my_asset.asset.id: my_asset.custom_to_dict() for my_asset in my_assets} customs = {my_asset.asset.id: my_asset.custom_to_dict() for my_asset in my_assets}
for asset in assets: for asset in assets:
custom = customs.get(asset.id) custom = customs.get(asset.id)

View File

@@ -59,10 +59,7 @@ class DatabaseSerializer(AssetSerializer):
if not platform: if not platform:
return return
if platform.type in [ if platform.type in ['mysql', 'mariadb']:
'mysql', 'mariadb', 'oracle', 'sqlserver',
'db2', 'dameng', 'clickhouse', 'redis'
]:
db_field.required = False db_field.required = False
db_field.allow_blank = True db_field.allow_blank = True
db_field.allow_null = True db_field.allow_null = True

View File

@@ -26,13 +26,4 @@ class WebSerializer(AssetSerializer):
'submit_selector': { 'submit_selector': {
'default': 'id=login_button', 'default': 'id=login_button',
}, },
'script': {
'default': [],
} }
}
def to_internal_value(self, data):
data = data.copy()
if data.get('script') in ("", None):
data.pop('script', None)
return super().to_internal_value(data)

View File

@@ -84,7 +84,6 @@ class PlatformAutomationSerializer(serializers.ModelSerializer):
class PlatformProtocolSerializer(serializers.ModelSerializer): class PlatformProtocolSerializer(serializers.ModelSerializer):
setting = MethodSerializer(required=False, label=_("Setting")) setting = MethodSerializer(required=False, label=_("Setting"))
port_from_addr = serializers.BooleanField(label=_("Port from addr"), read_only=True) port_from_addr = serializers.BooleanField(label=_("Port from addr"), read_only=True)
port = serializers.IntegerField(label=_("Port"), required=False, min_value=0, max_value=65535)
class Meta: class Meta:
model = PlatformProtocol model = PlatformProtocol

View File

@@ -43,7 +43,7 @@ from .serializers import (
OperateLogSerializer, OperateLogActionDetailSerializer, OperateLogSerializer, OperateLogActionDetailSerializer,
PasswordChangeLogSerializer, ActivityUnionLogSerializer, PasswordChangeLogSerializer, ActivityUnionLogSerializer,
FileSerializer, UserSessionSerializer, JobsAuditSerializer, FileSerializer, UserSessionSerializer, JobsAuditSerializer,
ServiceAccessLogSerializer, OperateLogFullSerializer ServiceAccessLogSerializer
) )
from .utils import construct_userlogin_usernames, record_operate_log_and_activity_log from .utils import construct_userlogin_usernames, record_operate_log_and_activity_log
@@ -256,9 +256,7 @@ class OperateLogViewSet(OrgReadonlyModelViewSet):
def get_serializer_class(self): def get_serializer_class(self):
if self.is_action_detail: if self.is_action_detail:
return OperateLogActionDetailSerializer return OperateLogActionDetailSerializer
elif self.request.query_params.get('format'): return super().get_serializer_class()
return OperateLogFullSerializer
return OperateLogSerializer
def get_queryset(self): def get_queryset(self):
current_org_id = str(current_org.id) current_org_id = str(current_org.id)

View File

@@ -23,8 +23,6 @@ logger = get_logger(__name__)
class OperatorLogHandler(metaclass=Singleton): class OperatorLogHandler(metaclass=Singleton):
CACHE_KEY = 'OPERATOR_LOG_CACHE_KEY' CACHE_KEY = 'OPERATOR_LOG_CACHE_KEY'
SYSTEM_OBJECTS = frozenset({"Role"})
PREFER_CURRENT_ELSE_USER = frozenset({"SSOToken"})
def __init__(self): def __init__(self):
self.log_client = self.get_storage_client() self.log_client = self.get_storage_client()
@@ -144,21 +142,13 @@ class OperatorLogHandler(metaclass=Singleton):
after = self.__data_processing(after) after = self.__data_processing(after)
return before, after return before, after
def get_org_id(self, user, object_name): @staticmethod
if object_name in self.SYSTEM_OBJECTS: def get_org_id(object_name):
return Organization.SYSTEM_ID system_obj = ('Role',)
org_id = get_current_org_id()
current = get_current_org_id() if object_name in system_obj:
current_id = str(current) if current else None org_id = Organization.SYSTEM_ID
return org_id
if object_name in self.PREFER_CURRENT_ELSE_USER:
if current_id and current_id != Organization.DEFAULT_ID:
return current_id
org = user.orgs.distinct().first()
return str(org.id) if org else Organization.DEFAULT_ID
return current_id or Organization.DEFAULT_ID
def create_or_update_operate_log( def create_or_update_operate_log(
self, action, resource_type, resource=None, resource_display=None, self, action, resource_type, resource=None, resource_display=None,
@@ -178,7 +168,7 @@ class OperatorLogHandler(metaclass=Singleton):
# 前后都没变化,没必要生成日志,除非手动强制保存 # 前后都没变化,没必要生成日志,除非手动强制保存
return return
org_id = self.get_org_id(user, object_name) org_id = self.get_org_id(object_name)
data = { data = {
'id': log_id, "user": str(user), 'action': action, 'id': log_id, "user": str(user), 'action': action,
'resource_type': str(resource_type), 'org_id': org_id, 'resource_type': str(resource_type), 'org_id': org_id,

View File

@@ -127,21 +127,6 @@ class OperateLogSerializer(BulkOrgResourceModelSerializer):
return i18n_trans(instance.resource) return i18n_trans(instance.resource)
class DiffFieldSerializer(serializers.JSONField):
def to_file_representation(self, value):
row = getattr(self, '_row') or {}
attrs = {'diff': value, 'resource_type': row.get('resource_type')}
instance = type('OperateLog', (), attrs)
return OperateLogStore.convert_diff_friendly(instance)
class OperateLogFullSerializer(OperateLogSerializer):
diff = DiffFieldSerializer(label=_("Diff"))
class Meta(OperateLogSerializer.Meta):
fields = OperateLogSerializer.Meta.fields + ['diff']
class PasswordChangeLogSerializer(serializers.ModelSerializer): class PasswordChangeLogSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = models.PasswordChangeLog model = models.PasswordChangeLog

View File

@@ -116,7 +116,7 @@ def send_login_info_to_reviewers(instance: UserLoginLog | str, auth_acl_id):
reviewers = acl.reviewers.all() reviewers = acl.reviewers.all()
for reviewer in reviewers: for reviewer in reviewers:
UserLoginReminderMsg(reviewer, instance, acl).publish_async() UserLoginReminderMsg(reviewer, instance).publish_async()
@receiver(post_auth_success) @receiver(post_auth_success)

View File

@@ -47,21 +47,20 @@ def on_m2m_changed(sender, action, instance, reverse, model, pk_set, **kwargs):
objs = model.objects.filter(pk__in=pk_set) objs = model.objects.filter(pk__in=pk_set)
objs_display = [str(o) for o in objs] objs_display = [str(o) for o in objs]
action = M2M_ACTION[action] action = M2M_ACTION[action]
changed_field = current_instance.get(field_name, {}) changed_field = current_instance.get(field_name, [])
changed_value = changed_field.get('value', [])
after, before, before_value = None, None, None after, before, before_value = None, None, None
if action == ActionChoices.create: if action == ActionChoices.create:
before_value = list(set(changed_value) - set(objs_display)) before_value = list(set(changed_field) - set(objs_display))
elif action == ActionChoices.delete: elif action == ActionChoices.delete:
before_value = list(set(changed_value).symmetric_difference(set(objs_display))) before_value = list(
set(changed_field).symmetric_difference(set(objs_display))
)
if changed_field: if changed_field:
after = {field_name: changed_field} after = {field_name: changed_field}
if before_value: if before_value:
before_change_field = changed_field.copy() before = {field_name: before_value}
before_change_field['value'] = before_value
before = {field_name: before_change_field}
if sorted(str(before)) == sorted(str(after)): if sorted(str(before)) == sorted(str(after)):
return return

View File

@@ -16,4 +16,3 @@ from .sso import *
from .temp_token import * from .temp_token import *
from .token import * from .token import *
from .face import * from .face import *
from .access_token import *

View File

@@ -1,47 +0,0 @@
from django.shortcuts import get_object_or_404
from django.utils.translation import gettext as _
from rest_framework.decorators import action
from rest_framework.response import Response
from rest_framework.status import HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND
from oauth2_provider.models import get_access_token_model
from common.api import JMSModelViewSet
from rbac.permissions import RBACPermission
from ..serializers import AccessTokenSerializer
AccessToken = get_access_token_model()
class AccessTokenViewSet(JMSModelViewSet):
"""
OAuth2 Access Token 管理视图集
用户只能查看和撤销自己的 access token
"""
serializer_class = AccessTokenSerializer
permission_classes = [RBACPermission]
http_method_names = ['get', 'options', 'delete']
rbac_perms = {
'revoke': 'oauth2_provider.delete_accesstoken',
}
def get_queryset(self):
"""只返回当前用户的 access token按创建时间倒序"""
return AccessToken.objects.filter(user=self.request.user).order_by('-created')
@action(methods=['DELETE'], detail=True, url_path='revoke')
def revoke(self, request, *args, **kwargs):
"""
撤销 access token 及其关联的 refresh token
如果 token 不存在或不属于当前用户,返回 404
"""
token = get_object_or_404(
AccessToken.objects.filter(user=request.user),
id=kwargs['pk']
)
# 优先撤销 refresh token会自动撤销关联的 access token
token_to_revoke = token.refresh_token if token.refresh_token else token
token_to_revoke.revoke()
return Response(status=HTTP_204_NO_CONTENT)

View File

@@ -69,8 +69,6 @@ class RDPFileClientProtocolURLMixin:
'autoreconnection enabled:i': '1', 'autoreconnection enabled:i': '1',
'bookmarktype:i': '3', 'bookmarktype:i': '3',
'use redirection server name:i': '0', 'use redirection server name:i': '0',
'bitmapcachepersistenable:i': '0',
'bitmapcachesize:i': '1500',
} }
# copy from # copy from
@@ -78,6 +76,7 @@ class RDPFileClientProtocolURLMixin:
rdp_low_speed_broadband_option = { rdp_low_speed_broadband_option = {
"connection type:i": 2, "connection type:i": 2,
"disable wallpaper:i": 1, "disable wallpaper:i": 1,
"bitmapcachepersistenable:i": 1,
"disable full window drag:i": 1, "disable full window drag:i": 1,
"disable menu anims:i": 1, "disable menu anims:i": 1,
"allow font smoothing:i": 0, "allow font smoothing:i": 0,
@@ -88,6 +87,7 @@ class RDPFileClientProtocolURLMixin:
rdp_high_speed_broadband_option = { rdp_high_speed_broadband_option = {
"connection type:i": 4, "connection type:i": 4,
"disable wallpaper:i": 0, "disable wallpaper:i": 0,
"bitmapcachepersistenable:i": 1,
"disable full window drag:i": 1, "disable full window drag:i": 1,
"disable menu anims:i": 0, "disable menu anims:i": 0,
"allow font smoothing:i": 0, "allow font smoothing:i": 0,
@@ -219,18 +219,8 @@ class RDPFileClientProtocolURLMixin:
} }
}) })
else: else:
if connect_method_dict['type'] == 'virtual_app':
endpoint_protocol = 'vnc'
token_protocol = 'vnc'
data.update({
'protocol': 'vnc',
})
else:
endpoint_protocol = connect_method_dict['endpoint_protocol']
token_protocol = token.protocol
endpoint = self.get_smart_endpoint( endpoint = self.get_smart_endpoint(
protocol=endpoint_protocol, protocol=connect_method_dict['endpoint_protocol'],
asset=asset asset=asset
) )
data.update({ data.update({
@@ -246,7 +236,7 @@ class RDPFileClientProtocolURLMixin:
}, },
'endpoint': { 'endpoint': {
'host': endpoint.host, 'host': endpoint.host,
'port': endpoint.get_port(token.asset, token_protocol), 'port': endpoint.get_port(token.asset, token.protocol),
} }
}) })
return data return data
@@ -372,7 +362,6 @@ class ConnectionTokenViewSet(AuthFaceMixin, ExtraActionApiMixin, RootOrgViewMixi
self.validate_serializer(serializer) self.validate_serializer(serializer)
return super().perform_create(serializer) return super().perform_create(serializer)
def _insert_connect_options(self, data, user): def _insert_connect_options(self, data, user):
connect_options = data.pop('connect_options', {}) connect_options = data.pop('connect_options', {})
default_name_opts = { default_name_opts = {
@@ -386,7 +375,7 @@ class ConnectionTokenViewSet(AuthFaceMixin, ExtraActionApiMixin, RootOrgViewMixi
for name in default_name_opts.keys(): for name in default_name_opts.keys():
value = preferences.get(name, default_name_opts[name]) value = preferences.get(name, default_name_opts[name])
connect_options[name] = value connect_options[name] = value
connect_options['lang'] = getattr(user, 'lang') or settings.LANGUAGE_CODE connect_options['lang'] = getattr(user, 'lang', settings.LANGUAGE_CODE)
data['connect_options'] = connect_options data['connect_options'] = connect_options
@staticmethod @staticmethod
@@ -442,7 +431,7 @@ class ConnectionTokenViewSet(AuthFaceMixin, ExtraActionApiMixin, RootOrgViewMixi
if account.username != AliasAccount.INPUT: if account.username != AliasAccount.INPUT:
data['input_username'] = '' data['input_username'] = ''
ticket = self._validate_acl(user, asset, account, connect_method, protocol) ticket = self._validate_acl(user, asset, account, connect_method)
if ticket: if ticket:
data['from_ticket'] = ticket data['from_ticket'] = ticket
@@ -481,7 +470,7 @@ class ConnectionTokenViewSet(AuthFaceMixin, ExtraActionApiMixin, RootOrgViewMixi
after=after, object_name=object_name after=after, object_name=object_name
) )
def _validate_acl(self, user, asset, account, connect_method, protocol): def _validate_acl(self, user, asset, account, connect_method):
from acls.models import LoginAssetACL from acls.models import LoginAssetACL
kwargs = {'user': user, 'asset': asset, 'account': account} kwargs = {'user': user, 'asset': asset, 'account': account}
if account.username == AliasAccount.INPUT: if account.username == AliasAccount.INPUT:
@@ -534,15 +523,9 @@ class ConnectionTokenViewSet(AuthFaceMixin, ExtraActionApiMixin, RootOrgViewMixi
return return
self._record_operate_log(acl, asset) self._record_operate_log(acl, asset)
os = get_request_os(self.request) if self.request else 'windows'
method = ConnectMethodUtil.get_connect_method(
connect_method, protocol=protocol, os=os
)
login_from = method['label'] if method else connect_method
for reviewer in reviewers: for reviewer in reviewers:
AssetLoginReminderMsg( AssetLoginReminderMsg(
reviewer, asset, user, account, acl, reviewer, asset, user, account, self.input_username
ip, self.input_username, login_from
).publish_async() ).publish_async()
def create_face_verify(self, response): def create_face_verify(self, response):
@@ -575,9 +558,7 @@ class SuperConnectionTokenViewSet(ConnectionTokenViewSet):
rbac_perms = { rbac_perms = {
'create': 'authentication.add_superconnectiontoken', 'create': 'authentication.add_superconnectiontoken',
'renewal': 'authentication.add_superconnectiontoken', 'renewal': 'authentication.add_superconnectiontoken',
'list': 'authentication.view_superconnectiontoken',
'check': 'authentication.view_superconnectiontoken', 'check': 'authentication.view_superconnectiontoken',
'retrieve': 'authentication.view_superconnectiontoken',
'get_secret_detail': 'authentication.view_superconnectiontokensecret', 'get_secret_detail': 'authentication.view_superconnectiontokensecret',
'get_applet_info': 'authentication.view_superconnectiontoken', 'get_applet_info': 'authentication.view_superconnectiontoken',
'release_applet_account': 'authentication.view_superconnectiontoken', 'release_applet_account': 'authentication.view_superconnectiontoken',
@@ -585,12 +566,7 @@ class SuperConnectionTokenViewSet(ConnectionTokenViewSet):
} }
def get_queryset(self): def get_queryset(self):
return ConnectionToken.objects.none() return ConnectionToken.objects.all()
def get_object(self):
pk = self.kwargs.get(self.lookup_field)
token = get_object_or_404(ConnectionToken, pk=pk)
return token
def get_user(self, serializer): def get_user(self, serializer):
return serializer.validated_data.get('user') return serializer.validated_data.get('user')

View File

@@ -67,9 +67,8 @@ class UserResetPasswordSendCodeApi(CreateAPIView):
code = random_string(settings.SMS_CODE_LENGTH, lower=False, upper=False) code = random_string(settings.SMS_CODE_LENGTH, lower=False, upper=False)
subject = '%s: %s' % (get_login_title(), _('Forgot password')) subject = '%s: %s' % (get_login_title(), _('Forgot password'))
tip = _('The validity period of the verification code is {} minute').format(settings.VERIFY_CODE_TTL // 60)
context = { context = {
'user': user, 'title': subject, 'code': code, 'tip': tip, 'user': user, 'title': subject, 'code': code,
} }
message = render_to_string('authentication/_msg_reset_password_code.html', context) message = render_to_string('authentication/_msg_reset_password_code.html', context)
content = {'subject': subject, 'message': message} content = {'subject': subject, 'message': message}

View File

@@ -1,5 +1,6 @@
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from django.contrib.auth.backends import ModelBackend from django.contrib.auth.backends import ModelBackend
from django.views import View
from common.utils import get_logger from common.utils import get_logger
from users.models import User from users.models import User
@@ -24,10 +25,7 @@ class JMSBaseAuthBackend:
""" """
# 三方用户认证完成后,在后续的 get_user 获取逻辑中,也应该需要检查用户是否有效 # 三方用户认证完成后,在后续的 get_user 获取逻辑中,也应该需要检查用户是否有效
is_valid = getattr(user, 'is_valid', None) is_valid = getattr(user, 'is_valid', None)
if not is_valid: return is_valid or is_valid is None
logger.info("User %s is not valid", getattr(user, "username", "<unknown>"))
return False
return True
# allow user to authenticate # allow user to authenticate
def username_allow_authenticate(self, username): def username_allow_authenticate(self, username):
@@ -65,3 +63,11 @@ class JMSBaseAuthBackend:
class JMSModelBackend(JMSBaseAuthBackend, ModelBackend): class JMSModelBackend(JMSBaseAuthBackend, ModelBackend):
def user_can_authenticate(self, user): def user_can_authenticate(self, user):
return True return True
class BaseAuthCallbackClientView(View):
http_method_names = ['get']
def get(self, request):
from authentication.views.utils import redirect_to_guard_view
return redirect_to_guard_view(query_string='next=client')

View File

@@ -1,22 +1,51 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
import threading
from django.conf import settings from django.conf import settings
from django.contrib.auth import get_user_model
from django_cas_ng.backends import CASBackend as _CASBackend from django_cas_ng.backends import CASBackend as _CASBackend
from common.utils import get_logger from common.utils import get_logger
from ..base import JMSBaseAuthBackend from ..base import JMSBaseAuthBackend
__all__ = ['CASBackend'] __all__ = ['CASBackend', 'CASUserDoesNotExist']
logger = get_logger(__name__) logger = get_logger(__name__)
class CASUserDoesNotExist(Exception):
"""Exception raised when a CAS user does not exist."""
pass
class CASBackend(JMSBaseAuthBackend, _CASBackend): class CASBackend(JMSBaseAuthBackend, _CASBackend):
@staticmethod @staticmethod
def is_enabled(): def is_enabled():
return settings.AUTH_CAS return settings.AUTH_CAS
def authenticate(self, request, ticket, service): def authenticate(self, request, ticket, service):
# 这里做个hack ,让父类始终走CAS_CREATE_USER=True的逻辑然后调用 authentication/mixins.py 中的 custom_get_or_create 方法 UserModel = get_user_model()
settings.CAS_CREATE_USER = True manager = UserModel._default_manager
return super().authenticate(request, ticket, service) original_get_by_natural_key = manager.get_by_natural_key
thread_local = threading.local()
thread_local.thread_id = threading.get_ident()
logger.debug(f"CASBackend.authenticate: thread_id={thread_local.thread_id}")
def get_by_natural_key(self, username):
logger.debug(f"CASBackend.get_by_natural_key: thread_id={threading.get_ident()}, username={username}")
if threading.get_ident() != thread_local.thread_id:
return original_get_by_natural_key(username)
try:
user = original_get_by_natural_key(username)
except UserModel.DoesNotExist:
raise CASUserDoesNotExist(username)
return user
try:
manager.get_by_natural_key = get_by_natural_key.__get__(manager, type(manager))
user = super().authenticate(request, ticket=ticket, service=service)
finally:
manager.get_by_natural_key = original_get_by_natural_key
return user

View File

@@ -3,10 +3,11 @@
import django_cas_ng.views import django_cas_ng.views
from django.urls import path from django.urls import path
from .views import CASLoginView from .views import CASLoginView, CASCallbackClientView
urlpatterns = [ urlpatterns = [
path('login/', CASLoginView.as_view(), name='cas-login'), path('login/', CASLoginView.as_view(), name='cas-login'),
path('logout/', django_cas_ng.views.LogoutView.as_view(), name='cas-logout'), path('logout/', django_cas_ng.views.LogoutView.as_view(), name='cas-logout'),
path('callback/', django_cas_ng.views.CallbackView.as_view(), name='cas-proxy-callback') path('callback/', django_cas_ng.views.CallbackView.as_view(), name='cas-proxy-callback'),
path('login/client', CASCallbackClientView.as_view(), name='cas-proxy-callback-client'),
] ]

View File

@@ -3,20 +3,31 @@ from django.http import HttpResponseRedirect
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from django_cas_ng.views import LoginView from django_cas_ng.views import LoginView
from authentication.views.mixins import FlashMessageMixin from authentication.backends.base import BaseAuthCallbackClientView
from common.utils import FlashMessageUtil
from .backends import CASUserDoesNotExist
__all__ = ['LoginView'] __all__ = ['LoginView']
class CASLoginView(LoginView, FlashMessageMixin): class CASLoginView(LoginView):
def get(self, request): def get(self, request):
try: try:
resp = super().get(request) resp = super().get(request)
except PermissionDenied:
resp = HttpResponseRedirect('/')
error_message = getattr(request, 'error_message', '')
if error_message:
response = self.get_failed_response('/', title=_('CAS Error'), msg=error_message)
return response
else:
return resp return resp
except PermissionDenied:
return HttpResponseRedirect('/')
except CASUserDoesNotExist as e:
message_data = {
'title': _('User does not exist: {}').format(e),
'error': _(
'CAS login was successful, but no corresponding local user was found in the system, and automatic '
'user creation is disabled in the CAS authentication configuration. Login failed.'),
'interval': 10,
'redirect_url': '/',
}
return FlashMessageUtil.gen_and_redirect_to(message_data)
class CASCallbackClientView(BaseAuthCallbackClientView):
pass

View File

@@ -69,8 +69,6 @@ class AccessTokenAuthentication(authentication.BaseAuthentication):
msg = _('Invalid token header. Sign string should not contain invalid characters.') msg = _('Invalid token header. Sign string should not contain invalid characters.')
raise exceptions.AuthenticationFailed(msg) raise exceptions.AuthenticationFailed(msg)
user, header = self.authenticate_credentials(token) user, header = self.authenticate_credentials(token)
if not user:
return None
after_authenticate_update_date(user) after_authenticate_update_date(user)
return user, header return user, header
@@ -79,6 +77,10 @@ class AccessTokenAuthentication(authentication.BaseAuthentication):
model = get_user_model() model = get_user_model()
user_id = cache.get(token) user_id = cache.get(token)
user = get_object_or_none(model, id=user_id) user = get_object_or_none(model, id=user_id)
if not user:
msg = _('Invalid token or cache refreshed.')
raise exceptions.AuthenticationFailed(msg)
return user, None return user, None
def authenticate_header(self, request): def authenticate_header(self, request):
@@ -108,7 +110,7 @@ class SessionAuthentication(authentication.SessionAuthentication):
user = getattr(request._request, 'user', None) user = getattr(request._request, 'user', None)
# Unauthenticated, CSRF validation not required # Unauthenticated, CSRF validation not required
if not user or not user.is_active or not user.is_valid: if not user or not user.is_active:
return None return None
try: try:
@@ -134,7 +136,7 @@ class SignatureAuthentication(signature.SignatureAuthentication):
# example implementation: # example implementation:
try: try:
key = AccessKey.objects.get(id=key_id) key = AccessKey.objects.get(id=key_id)
if not key.is_valid: if not key.is_active:
return None, None return None, None
user, secret = key.user, str(key.secret) user, secret = key.user, str(key.secret)
after_authenticate_update_date(user, key) after_authenticate_update_date(user, key)

View File

@@ -7,5 +7,6 @@ from . import views
urlpatterns = [ urlpatterns = [
path('login/', views.OAuth2AuthRequestView.as_view(), name='login'), path('login/', views.OAuth2AuthRequestView.as_view(), name='login'),
path('callback/', views.OAuth2AuthCallbackView.as_view(), name='login-callback'), path('callback/', views.OAuth2AuthCallbackView.as_view(), name='login-callback'),
path('callback/client/', views.OAuth2AuthCallbackClientView.as_view(), name='login-callback-client'),
path('logout/', views.OAuth2EndSessionView.as_view(), name='logout') path('logout/', views.OAuth2EndSessionView.as_view(), name='logout')
] ]

View File

@@ -3,37 +3,29 @@ from django.contrib import auth
from django.http import HttpResponseRedirect from django.http import HttpResponseRedirect
from django.urls import reverse from django.urls import reverse
from django.utils.http import urlencode from django.utils.http import urlencode
from django.utils.translation import gettext_lazy as _
from django.views import View from django.views import View
from authentication.decorators import pre_save_next_to_session, redirect_to_pre_save_next_after_auth from authentication.backends.base import BaseAuthCallbackClientView
from authentication.mixins import authenticate from authentication.mixins import authenticate
from authentication.utils import build_absolute_uri from authentication.utils import build_absolute_uri
from authentication.views.mixins import FlashMessageMixin from authentication.views.mixins import FlashMessageMixin
from common.utils import get_logger, safe_next_url from common.utils import get_logger
logger = get_logger(__file__) logger = get_logger(__file__)
class OAuth2AuthRequestView(View): class OAuth2AuthRequestView(View):
@pre_save_next_to_session()
def get(self, request): def get(self, request):
log_prompt = "Process OAuth2 GET requests: {}" log_prompt = "Process OAuth2 GET requests: {}"
logger.debug(log_prompt.format('Start')) logger.debug(log_prompt.format('Start'))
request_params = request.GET.dict()
request_params.pop('next', None)
query = urlencode(request_params)
redirect_uri = build_absolute_uri(
request, path=reverse(settings.AUTH_OAUTH2_AUTH_LOGIN_CALLBACK_URL_NAME)
)
redirect_uri = f"{redirect_uri}?{query}"
query_dict = { query_dict = {
'client_id': settings.AUTH_OAUTH2_CLIENT_ID, 'response_type': 'code', 'client_id': settings.AUTH_OAUTH2_CLIENT_ID, 'response_type': 'code',
'scope': settings.AUTH_OAUTH2_SCOPE, 'scope': settings.AUTH_OAUTH2_SCOPE,
'redirect_uri': redirect_uri 'redirect_uri': build_absolute_uri(
request, path=reverse(settings.AUTH_OAUTH2_AUTH_LOGIN_CALLBACK_URL_NAME)
)
} }
if '?' in settings.AUTH_OAUTH2_PROVIDER_AUTHORIZATION_ENDPOINT: if '?' in settings.AUTH_OAUTH2_PROVIDER_AUTHORIZATION_ENDPOINT:
@@ -52,7 +44,6 @@ class OAuth2AuthRequestView(View):
class OAuth2AuthCallbackView(View, FlashMessageMixin): class OAuth2AuthCallbackView(View, FlashMessageMixin):
http_method_names = ['get', ] http_method_names = ['get', ]
@redirect_to_pre_save_next_after_auth
def get(self, request): def get(self, request):
""" Processes GET requests. """ """ Processes GET requests. """
log_prompt = "Process GET requests [OAuth2AuthCallbackView]: {}" log_prompt = "Process GET requests [OAuth2AuthCallbackView]: {}"
@@ -67,17 +58,19 @@ class OAuth2AuthCallbackView(View, FlashMessageMixin):
logger.debug(log_prompt.format('Login: {}'.format(user))) logger.debug(log_prompt.format('Login: {}'.format(user)))
auth.login(self.request, user) auth.login(self.request, user)
logger.debug(log_prompt.format('Redirect')) logger.debug(log_prompt.format('Redirect'))
return HttpResponseRedirect(settings.AUTH_OAUTH2_AUTHENTICATION_REDIRECT_URI) return HttpResponseRedirect(
else: settings.AUTH_OAUTH2_AUTHENTICATION_REDIRECT_URI
if getattr(request, 'error_message', ''): )
response = self.get_failed_response('/', title=_('OAuth2 Error'), msg=request.error_message)
return response
logger.debug(log_prompt.format('Redirect')) logger.debug(log_prompt.format('Redirect'))
redirect_url = settings.AUTH_OAUTH2_PROVIDER_END_SESSION_ENDPOINT or '/' redirect_url = settings.AUTH_OAUTH2_PROVIDER_END_SESSION_ENDPOINT or '/'
return HttpResponseRedirect(redirect_url) return HttpResponseRedirect(redirect_url)
class OAuth2AuthCallbackClientView(BaseAuthCallbackClientView):
pass
class OAuth2EndSessionView(View): class OAuth2EndSessionView(View):
http_method_names = ['get', 'post', ] http_method_names = ['get', 'post', ]

View File

@@ -1,20 +0,0 @@
from django.db.models.signals import post_delete
from django.dispatch import receiver
from django.core.cache import cache
from django.conf import settings
from oauth2_provider.models import get_application_model
from .utils import clear_oauth2_authorization_server_view_cache
__all__ = ['on_oauth2_provider_application_deleted']
Application = get_application_model()
@receiver(post_delete, sender=Application)
def on_oauth2_provider_application_deleted(sender, instance, **kwargs):
if instance.name == settings.OAUTH2_PROVIDER_JUMPSERVER_CLIENT_NAME:
clear_oauth2_authorization_server_view_cache()

View File

@@ -1,14 +0,0 @@
# -*- coding: utf-8 -*-
#
from django.urls import path
from oauth2_provider import views as op_views
from . import views
urlpatterns = [
path("authorize/", op_views.AuthorizationView.as_view(), name="authorize"),
path("token/", op_views.TokenView.as_view(), name="token"),
path("revoke/", op_views.RevokeTokenView.as_view(), name="revoke-token"),
path(".well-known/oauth-authorization-server", views.OAuthAuthorizationServerView.as_view(), name="oauth-authorization-server"),
]

View File

@@ -1,31 +0,0 @@
from django.conf import settings
from django.core.cache import cache
from oauth2_provider.models import get_application_model
from common.utils import get_logger
logger = get_logger(__name__)
def get_or_create_jumpserver_client_application():
"""Auto get or create OAuth2 JumpServer Client application."""
Application = get_application_model()
application, created = Application.objects.get_or_create(
name=settings.OAUTH2_PROVIDER_JUMPSERVER_CLIENT_NAME,
defaults={
'client_type': Application.CLIENT_PUBLIC,
'authorization_grant_type': Application.GRANT_AUTHORIZATION_CODE,
'redirect_uris': settings.OAUTH2_PROVIDER_CLIENT_REDIRECT_URI,
'skip_authorization': True,
}
)
return application
CACHE_OAUTH_SERVER_VIEW_KEY_PREFIX = 'oauth2_provider_metadata'
def clear_oauth2_authorization_server_view_cache():
logger.info("Clearing OAuth2 Authorization Server Metadata view cache")
cache_key = f'views.decorators.cache.cache_page.{CACHE_OAUTH_SERVER_VIEW_KEY_PREFIX}.GET*'
cache.delete_pattern(cache_key)

View File

@@ -1,77 +0,0 @@
from django.views.generic import View
from django.http import JsonResponse
from django.utils.decorators import method_decorator
from django.views.decorators.cache import cache_page
from django.views.decorators.csrf import csrf_exempt
from django.conf import settings
from django.urls import reverse
from oauth2_provider.settings import oauth2_settings
from typing import List, Dict, Any
from .utils import get_or_create_jumpserver_client_application, CACHE_OAUTH_SERVER_VIEW_KEY_PREFIX
@method_decorator(csrf_exempt, name='dispatch')
@method_decorator(cache_page(timeout=60 * 60 * 24, key_prefix=CACHE_OAUTH_SERVER_VIEW_KEY_PREFIX), name='dispatch')
class OAuthAuthorizationServerView(View):
"""
OAuth 2.0 Authorization Server Metadata Endpoint
RFC 8414: https://datatracker.ietf.org/doc/html/rfc8414
This endpoint provides machine-readable information about the
OAuth 2.0 authorization server's configuration.
"""
def get_base_url(self, request) -> str:
scheme = 'https' if request.is_secure() else 'http'
host = request.get_host()
return f"{scheme}://{host}"
def get_supported_scopes(self) -> List[str]:
scopes_config = oauth2_settings.SCOPES
if isinstance(scopes_config, dict):
return list(scopes_config.keys())
return []
def get_metadata(self, request) -> Dict[str, Any]:
base_url = self.get_base_url(request)
application = get_or_create_jumpserver_client_application()
metadata = {
"issuer": base_url,
"client_id": application.client_id if application else "Not found any application.",
"authorization_endpoint": base_url + reverse('authentication:oauth2-provider:authorize'),
"token_endpoint": base_url + reverse('authentication:oauth2-provider:token'),
"revocation_endpoint": base_url + reverse('authentication:oauth2-provider:revoke-token'),
"response_types_supported": ["code"],
"grant_types_supported": ["authorization_code", "refresh_token"],
"scopes_supported": self.get_supported_scopes(),
"token_endpoint_auth_methods_supported": ["none"],
"revocation_endpoint_auth_methods_supported": ["none"],
"code_challenge_methods_supported": ["S256"],
"response_modes_supported": ["query"],
}
if hasattr(oauth2_settings, 'ACCESS_TOKEN_EXPIRE_SECONDS'):
metadata["token_expires_in"] = oauth2_settings.ACCESS_TOKEN_EXPIRE_SECONDS
if hasattr(oauth2_settings, 'REFRESH_TOKEN_EXPIRE_SECONDS'):
if oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS:
metadata["refresh_token_expires_in"] = oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS
return metadata
def get(self, request, *args, **kwargs):
metadata = self.get_metadata(request)
response = JsonResponse(metadata)
self.add_cors_headers(response)
return response
def options(self, request, *args, **kwargs):
response = JsonResponse({})
self.add_cors_headers(response)
return response
@staticmethod
def add_cors_headers(response):
response['Access-Control-Allow-Origin'] = '*'
response['Access-Control-Allow-Methods'] = 'GET, OPTIONS'
response['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
response['Access-Control-Max-Age'] = '3600'

View File

@@ -15,5 +15,6 @@ from . import views
urlpatterns = [ urlpatterns = [
path('login/', views.OIDCAuthRequestView.as_view(), name='login'), path('login/', views.OIDCAuthRequestView.as_view(), name='login'),
path('callback/', views.OIDCAuthCallbackView.as_view(), name='login-callback'), path('callback/', views.OIDCAuthCallbackView.as_view(), name='login-callback'),
path('callback/client/', views.OIDCAuthCallbackClientView.as_view(), name='login-callback-client'),
path('logout/', views.OIDCEndSessionView.as_view(), name='logout'), path('logout/', views.OIDCEndSessionView.as_view(), name='logout'),
] ]

View File

@@ -25,11 +25,11 @@ from django.utils.http import urlencode
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from django.views.generic import View from django.views.generic import View
from authentication.decorators import pre_save_next_to_session, redirect_to_pre_save_next_after_auth
from authentication.utils import build_absolute_uri_for_oidc from authentication.utils import build_absolute_uri_for_oidc
from authentication.views.mixins import FlashMessageMixin from authentication.views.mixins import FlashMessageMixin
from common.utils import safe_next_url from common.utils import safe_next_url
from .utils import get_logger from .utils import get_logger
from ..base import BaseAuthCallbackClientView
logger = get_logger(__file__) logger = get_logger(__file__)
@@ -58,7 +58,6 @@ class OIDCAuthRequestView(View):
b = base64.urlsafe_b64encode(h) b = base64.urlsafe_b64encode(h)
return b.decode('ascii')[:-1] return b.decode('ascii')[:-1]
@pre_save_next_to_session()
def get(self, request): def get(self, request):
""" Processes GET requests. """ """ Processes GET requests. """
@@ -67,9 +66,8 @@ class OIDCAuthRequestView(View):
# Defines common parameters used to bootstrap the authentication request. # Defines common parameters used to bootstrap the authentication request.
logger.debug(log_prompt.format('Construct request params')) logger.debug(log_prompt.format('Construct request params'))
request_params = request.GET.dict() authentication_request_params = request.GET.dict()
request_params.pop('next', None) authentication_request_params.update({
request_params.update({
'scope': settings.AUTH_OPENID_SCOPES, 'scope': settings.AUTH_OPENID_SCOPES,
'response_type': 'code', 'response_type': 'code',
'client_id': settings.AUTH_OPENID_CLIENT_ID, 'client_id': settings.AUTH_OPENID_CLIENT_ID,
@@ -82,7 +80,7 @@ class OIDCAuthRequestView(View):
code_verifier = self.gen_code_verifier() code_verifier = self.gen_code_verifier()
code_challenge_method = settings.AUTH_OPENID_CODE_CHALLENGE_METHOD or 'S256' code_challenge_method = settings.AUTH_OPENID_CODE_CHALLENGE_METHOD or 'S256'
code_challenge = self.gen_code_challenge(code_verifier, code_challenge_method) code_challenge = self.gen_code_challenge(code_verifier, code_challenge_method)
request_params.update({ authentication_request_params.update({
'code_challenge_method': code_challenge_method, 'code_challenge_method': code_challenge_method,
'code_challenge': code_challenge 'code_challenge': code_challenge
}) })
@@ -93,7 +91,7 @@ class OIDCAuthRequestView(View):
if settings.AUTH_OPENID_USE_STATE: if settings.AUTH_OPENID_USE_STATE:
logger.debug(log_prompt.format('Use state')) logger.debug(log_prompt.format('Use state'))
state = get_random_string(settings.AUTH_OPENID_STATE_LENGTH) state = get_random_string(settings.AUTH_OPENID_STATE_LENGTH)
request_params.update({'state': state}) authentication_request_params.update({'state': state})
request.session['oidc_auth_state'] = state request.session['oidc_auth_state'] = state
# Nonces should be used too! In that case the generated nonce is stored both in the # Nonces should be used too! In that case the generated nonce is stored both in the
@@ -101,12 +99,17 @@ class OIDCAuthRequestView(View):
if settings.AUTH_OPENID_USE_NONCE: if settings.AUTH_OPENID_USE_NONCE:
logger.debug(log_prompt.format('Use nonce')) logger.debug(log_prompt.format('Use nonce'))
nonce = get_random_string(settings.AUTH_OPENID_NONCE_LENGTH) nonce = get_random_string(settings.AUTH_OPENID_NONCE_LENGTH)
request_params.update({'nonce': nonce, }) authentication_request_params.update({'nonce': nonce, })
request.session['oidc_auth_nonce'] = nonce request.session['oidc_auth_nonce'] = nonce
# Stores the "next" URL in the session if applicable.
logger.debug(log_prompt.format('Stores next url in the session'))
next_url = request.GET.get('next')
request.session['oidc_auth_next_url'] = safe_next_url(next_url, request=request)
# Redirects the user to authorization endpoint. # Redirects the user to authorization endpoint.
logger.debug(log_prompt.format('Construct redirect url')) logger.debug(log_prompt.format('Construct redirect url'))
query = urlencode(request_params) query = urlencode(authentication_request_params)
redirect_url = '{url}?{query}'.format( redirect_url = '{url}?{query}'.format(
url=settings.AUTH_OPENID_PROVIDER_AUTHORIZATION_ENDPOINT, query=query) url=settings.AUTH_OPENID_PROVIDER_AUTHORIZATION_ENDPOINT, query=query)
@@ -126,14 +129,11 @@ class OIDCAuthCallbackView(View, FlashMessageMixin):
http_method_names = ['get', ] http_method_names = ['get', ]
@redirect_to_pre_save_next_after_auth
def get(self, request): def get(self, request):
""" Processes GET requests. """ """ Processes GET requests. """
log_prompt = "Process GET requests [OIDCAuthCallbackView]: {}" log_prompt = "Process GET requests [OIDCAuthCallbackView]: {}"
logger.debug(log_prompt.format('Start')) logger.debug(log_prompt.format('Start'))
callback_params = request.GET callback_params = request.GET
error_title = _("OpenID Error")
# Retrieve the state value that was previously generated. No state means that we cannot # Retrieve the state value that was previously generated. No state means that we cannot
# authenticate the user (so a failure should be returned). # authenticate the user (so a failure should be returned).
@@ -166,14 +166,16 @@ class OIDCAuthCallbackView(View, FlashMessageMixin):
raise SuspiciousOperation('Invalid OpenID Connect callback state value') raise SuspiciousOperation('Invalid OpenID Connect callback state value')
# Authenticates the end-user. # Authenticates the end-user.
next_url = request.session.get('oidc_auth_next_url', None)
code_verifier = request.session.get('oidc_auth_code_verifier', None) code_verifier = request.session.get('oidc_auth_code_verifier', None)
logger.debug(log_prompt.format('Process authenticate')) logger.debug(log_prompt.format('Process authenticate'))
try: try:
user = auth.authenticate(nonce=nonce, request=request, code_verifier=code_verifier) user = auth.authenticate(nonce=nonce, request=request, code_verifier=code_verifier)
except IntegrityError as e: except IntegrityError as e:
title = _("OpenID Error")
msg = _('Please check if a user with the same username or email already exists') msg = _('Please check if a user with the same username or email already exists')
logger.error(e, exc_info=True) logger.error(e, exc_info=True)
response = self.get_failed_response('/', error_title, msg) response = self.get_failed_response('/', title, msg)
return response return response
if user: if user:
logger.debug(log_prompt.format('Login: {}'.format(user))) logger.debug(log_prompt.format('Login: {}'.format(user)))
@@ -189,7 +191,10 @@ class OIDCAuthCallbackView(View, FlashMessageMixin):
callback_params.get('session_state', None) callback_params.get('session_state', None)
logger.debug(log_prompt.format('Redirect')) logger.debug(log_prompt.format('Redirect'))
return HttpResponseRedirect(settings.AUTH_OPENID_AUTHENTICATION_REDIRECT_URI) return HttpResponseRedirect(
next_url or settings.AUTH_OPENID_AUTHENTICATION_REDIRECT_URI
)
if 'error' in callback_params: if 'error' in callback_params:
logger.debug( logger.debug(
log_prompt.format('Error in callback params: {}'.format(callback_params['error'])) log_prompt.format('Error in callback params: {}'.format(callback_params['error']))
@@ -200,12 +205,13 @@ class OIDCAuthCallbackView(View, FlashMessageMixin):
# OpenID Connect Provider authenticate endpoint. # OpenID Connect Provider authenticate endpoint.
logger.debug(log_prompt.format('Logout')) logger.debug(log_prompt.format('Logout'))
auth.logout(request) auth.logout(request)
redirect_url = settings.AUTH_OPENID_AUTHENTICATION_FAILURE_REDIRECT_URI
if not user and getattr(request, 'error_message', ''):
response = self.get_failed_response(redirect_url, title=error_title, msg=request.error_message)
return response
logger.debug(log_prompt.format('Redirect')) logger.debug(log_prompt.format('Redirect'))
return HttpResponseRedirect(redirect_url) return HttpResponseRedirect(settings.AUTH_OPENID_AUTHENTICATION_FAILURE_REDIRECT_URI)
class OIDCAuthCallbackClientView(BaseAuthCallbackClientView):
pass
class OIDCEndSessionView(View): class OIDCEndSessionView(View):

View File

@@ -8,5 +8,6 @@ urlpatterns = [
path('login/', views.Saml2AuthRequestView.as_view(), name='saml2-login'), path('login/', views.Saml2AuthRequestView.as_view(), name='saml2-login'),
path('logout/', views.Saml2EndSessionView.as_view(), name='saml2-logout'), path('logout/', views.Saml2EndSessionView.as_view(), name='saml2-logout'),
path('callback/', views.Saml2AuthCallbackView.as_view(), name='saml2-callback'), path('callback/', views.Saml2AuthCallbackView.as_view(), name='saml2-callback'),
path('callback/client/', views.Saml2AuthCallbackClientView.as_view(), name='saml2-callback-client'),
path('metadata/', views.Saml2AuthMetadataView.as_view(), name='saml2-metadata'), path('metadata/', views.Saml2AuthMetadataView.as_view(), name='saml2-metadata'),
] ]

View File

@@ -17,8 +17,9 @@ from onelogin.saml2.idp_metadata_parser import (
) )
from authentication.views.mixins import FlashMessageMixin from authentication.views.mixins import FlashMessageMixin
from common.utils import get_logger, safe_next_url from common.utils import get_logger
from .settings import JmsSaml2Settings from .settings import JmsSaml2Settings
from ..base import BaseAuthCallbackClientView
logger = get_logger(__file__) logger = get_logger(__file__)
@@ -207,16 +208,13 @@ class Saml2AuthRequestView(View, PrepareRequestMixin):
log_prompt = "Process SAML GET requests: {}" log_prompt = "Process SAML GET requests: {}"
logger.debug(log_prompt.format('Start')) logger.debug(log_prompt.format('Start'))
request_params = request.GET.dict()
try: try:
saml_instance = self.init_saml_auth(request) saml_instance = self.init_saml_auth(request)
except OneLogin_Saml2_Error as error: except OneLogin_Saml2_Error as error:
logger.error(log_prompt.format('Init saml auth error: %s' % error)) logger.error(log_prompt.format('Init saml auth error: %s' % error))
return HttpResponse(error, status=412) return HttpResponse(error, status=412)
next_url = request_params.get('next') or settings.AUTH_SAML2_PROVIDER_AUTHORIZATION_ENDPOINT next_url = settings.AUTH_SAML2_PROVIDER_AUTHORIZATION_ENDPOINT
next_url = safe_next_url(next_url, request=request)
url = saml_instance.login(return_to=next_url) url = saml_instance.login(return_to=next_url)
logger.debug(log_prompt.format('Redirect login url')) logger.debug(log_prompt.format('Redirect login url'))
return HttpResponseRedirect(url) return HttpResponseRedirect(url)
@@ -254,7 +252,6 @@ class Saml2AuthCallbackView(View, PrepareRequestMixin, FlashMessageMixin):
def post(self, request): def post(self, request):
log_prompt = "Process SAML2 POST requests: {}" log_prompt = "Process SAML2 POST requests: {}"
post_data = request.POST post_data = request.POST
error_title = _("SAML2 Error")
try: try:
saml_instance = self.init_saml_auth(request) saml_instance = self.init_saml_auth(request)
@@ -282,24 +279,20 @@ class Saml2AuthCallbackView(View, PrepareRequestMixin, FlashMessageMixin):
try: try:
user = auth.authenticate(request=request, saml_user_data=saml_user_data) user = auth.authenticate(request=request, saml_user_data=saml_user_data)
except IntegrityError as e: except IntegrityError as e:
title = _("SAML2 Error")
msg = _('Please check if a user with the same username or email already exists') msg = _('Please check if a user with the same username or email already exists')
logger.error(e, exc_info=True) logger.error(e, exc_info=True)
response = self.get_failed_response('/', error_title, msg) response = self.get_failed_response('/', title, msg)
return response return response
if user and user.is_valid: if user and user.is_valid:
logger.debug(log_prompt.format('Login: {}'.format(user))) logger.debug(log_prompt.format('Login: {}'.format(user)))
auth.login(self.request, user) auth.login(self.request, user)
if not user and getattr(request, 'error_message', ''):
response = self.get_failed_response('/', title=error_title, msg=request.error_message)
return response
logger.debug(log_prompt.format('Redirect')) logger.debug(log_prompt.format('Redirect'))
relay_state = post_data.get('RelayState') redir = post_data.get('RelayState')
if not relay_state or len(relay_state) == 0: if not redir or len(redir) == 0:
relay_state = "/" redir = "/"
next_url = saml_instance.redirect_to(relay_state) next_url = saml_instance.redirect_to(redir)
next_url = safe_next_url(next_url, request=request)
return HttpResponseRedirect(next_url) return HttpResponseRedirect(next_url)
@csrf_exempt @csrf_exempt
@@ -307,6 +300,10 @@ class Saml2AuthCallbackView(View, PrepareRequestMixin, FlashMessageMixin):
return super().dispatch(*args, **kwargs) return super().dispatch(*args, **kwargs)
class Saml2AuthCallbackClientView(BaseAuthCallbackClientView):
pass
class Saml2AuthMetadataView(View, PrepareRequestMixin): class Saml2AuthMetadataView(View, PrepareRequestMixin):
def get(self, request): def get(self, request):

View File

@@ -1,9 +1,6 @@
from django.db.models import TextChoices from django.db.models import TextChoices
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
USER_LOGIN_GUARD_VIEW_REDIRECT_FIELD = 'next'
RSA_PRIVATE_KEY = 'rsa_private_key' RSA_PRIVATE_KEY = 'rsa_private_key'
RSA_PUBLIC_KEY = 'rsa_public_key' RSA_PUBLIC_KEY = 'rsa_public_key'

View File

@@ -1,193 +0,0 @@
"""
This module provides decorators to handle redirect URLs during the authentication flow:
1. pre_save_next_to_session: Captures and stores the intended next URL before redirecting to auth provider
2. redirect_to_pre_save_next_after_auth: Redirects to the stored next URL after successful authentication
3. post_save_next_to_session: Copies the stored next URL to session['next'] after view execution
"""
from urllib.parse import urlparse
from django.http import HttpResponseRedirect
from django.urls import reverse
from django.utils.translation import gettext_lazy as _
from functools import wraps
from common.utils import get_logger, safe_next_url
from .const import USER_LOGIN_GUARD_VIEW_REDIRECT_FIELD
logger = get_logger(__file__)
__all__ = [
'pre_save_next_to_session', 'redirect_to_pre_save_next_after_auth',
'post_save_next_to_session_if_guard_redirect'
]
# Session key for storing the redirect URL after authentication
AUTH_SESSION_NEXT_URL_KEY = 'auth_next_url'
def pre_save_next_to_session(get_next_url=None):
"""
Decorator to capture and store the 'next' parameter into session BEFORE view execution.
This decorator is applied to the authentication request view to preserve the user's
intended destination URL before redirecting to the authentication provider.
Args:
get_next_url: Optional callable that extracts the next URL from request.
Default: lambda req: req.GET.get('next')
Usage:
# Use default (request.GET.get('next'))
@pre_save_next_to_session()
def get(self, request):
pass
# Custom extraction from POST data
@pre_save_next_to_session(get_next_url=lambda req: req.POST.get('next'))
def post(self, request):
pass
# Custom extraction from both GET and POST
@pre_save_next_to_session(
get_next_url=lambda req: req.GET.get('next') or req.POST.get('next')
)
def get(self, request):
pass
Example flow:
User accesses: /auth/oauth2/?next=/dashboard/
↓ (decorator saves '/dashboard/' to session)
Redirected to OAuth2 provider for authentication
"""
# Default function to extract next URL from request.GET
if get_next_url is None:
get_next_url = lambda req: req.GET.get('next')
def decorator(view_func):
@wraps(view_func)
def wrapper(self, request, *args, **kwargs):
next_url = get_next_url(request)
if next_url:
request.session[AUTH_SESSION_NEXT_URL_KEY] = next_url
logger.debug(f"[Auth] Saved next_url to session: {next_url}")
return view_func(self, request, *args, **kwargs)
return wrapper
return decorator
def redirect_to_pre_save_next_after_auth(view_func):
"""
Decorator to redirect to the previously saved 'next' URL after successful authentication.
This decorator is applied to the authentication callback view. After the user successfully
authenticates, if a 'next' URL was previously saved in the session (by pre_save_next_to_session),
the user will be redirected to that URL instead of the default redirect location.
Conditions for redirect:
- User must be authenticated (request.user.is_authenticated)
- Session must contain the saved next URL (AUTH_SESSION_NEXT_URL_KEY)
- The next URL must not be '/' (avoid unnecessary redirects)
- The next URL must pass security validation (safe_next_url)
If any condition fails, returns the original view response.
Usage:
@redirect_to_pre_save_next_after_auth
def get(self, request):
# Process authentication callback
if user_authenticated:
auth.login(request, user)
return HttpResponseRedirect(default_url)
Example flow:
User redirected back from OAuth2 provider: /auth/oauth2/callback/?code=xxx
↓ (view processes authentication, user becomes authenticated)
Decorator checks session for saved next URL
↓ (finds '/dashboard/' in session)
Redirects to: /dashboard/
↓ (clears saved URL from session)
"""
@wraps(view_func)
def wrapper(self, request, *args, **kwargs):
# Execute the original view method first
response = view_func(self, request, *args, **kwargs)
# Check if user has been authenticated
if request.user and request.user.is_authenticated:
# Check if session contains a saved next URL
saved_next_url = request.session.get(AUTH_SESSION_NEXT_URL_KEY)
if saved_next_url and saved_next_url != '/':
# Validate the URL for security
safe_url = safe_next_url(saved_next_url, request=request)
if safe_url:
# Clear the saved URL from session (one-time use)
request.session.pop(AUTH_SESSION_NEXT_URL_KEY, None)
logger.debug(f"[Auth] Redirecting authenticated user to saved next_url: {safe_url}")
return HttpResponseRedirect(safe_url)
# Return the original response if no redirect conditions are met
return response
return wrapper
def post_save_next_to_session_if_guard_redirect(view_func):
"""
Decorator to copy AUTH_SESSION_NEXT_URL_KEY to session['next'] after view execution,
but only if redirecting to login-guard view.
This decorator is applied AFTER view execution. It copies the value from
AUTH_SESSION_NEXT_URL_KEY (internal storage) to 'next' (standard session key)
for use by downstream code.
Only sets the 'next' session key when the response is a redirect to guard-view
(i.e., response with redirect status code and location path matching login-guard view URL).
Usage:
@post_save_next_to_session_if_guard_redirect
def get(self, request):
# Process the request and return response
if some_condition:
return self.redirect_to_guard_view() # Decorator will copy next to session
return HttpResponseRedirect(url) # Decorator won't copy if not to guard-view
Example flow:
View executes and returns redirect to guard view
↓ (response is redirect with 'login-guard' in Location)
Decorator checks if response is redirect to guard-view and session has saved next URL
↓ (copies AUTH_SESSION_NEXT_URL_KEY to session['next'])
User is redirected to guard-view with 'next' available in session
"""
@wraps(view_func)
def wrapper(self, request, *args, **kwargs):
# Execute the original view method
response = view_func(self, request, *args, **kwargs)
# Check if response is a redirect to guard view
# Redirect responses typically have status codes 301, 302, 303, 307, 308
is_guard_redirect = False
if hasattr(response, 'status_code') and response.status_code in (301, 302, 303, 307, 308):
# Check if the redirect location is to guard view
location = response.get('Location', '')
if location:
# Extract path from location URL (handle both absolute and relative URLs)
parsed_url = urlparse(location)
path = parsed_url.path
# Check if path matches guard view URL pattern
guard_view_url = reverse('authentication:login-guard')
if path == guard_view_url:
is_guard_redirect = True
# Only set 'next' if response is a redirect to guard view
if is_guard_redirect:
# Copy AUTH_SESSION_NEXT_URL_KEY to 'next' if it exists
saved_next_url = request.session.get(AUTH_SESSION_NEXT_URL_KEY)
if saved_next_url:
# 这里 'next' 是 UserLoginGuardView.redirect_field_name
request.session[USER_LOGIN_GUARD_VIEW_REDIRECT_FIELD] = saved_next_url
logger.debug(f"[Auth] Copied {AUTH_SESSION_NEXT_URL_KEY} to 'next' in session: {saved_next_url}")
return response
return wrapper

View File

@@ -114,12 +114,12 @@ class BlockMFAError(AuthFailedNeedLogMixin, AuthFailedError):
super().__init__(username=username, request=request, ip=ip) super().__init__(username=username, request=request, ip=ip)
class BlockLoginError(AuthFailedNeedLogMixin, AuthFailedNeedBlockMixin, AuthFailedError): class BlockLoginError(AuthFailedNeedBlockMixin, AuthFailedError):
error = 'block_login' error = 'block_login'
def __init__(self, username, ip, request): def __init__(self, username, ip):
self.msg = const.block_user_login_msg.format(settings.SECURITY_LOGIN_LIMIT_TIME) self.msg = const.block_user_login_msg.format(settings.SECURITY_LOGIN_LIMIT_TIME)
super().__init__(username=username, ip=ip, request=request) super().__init__(username=username, ip=ip)
class SessionEmptyError(AuthFailedError): class SessionEmptyError(AuthFailedError):

Some files were not shown because too many files have changed in this diff Show More