Compare commits

..

2 Commits
v4.7 ... v4.3.0

Author SHA1 Message Date
Bai
0d16879d26 perf: Sticky 2024-10-24 14:29:38 +08:00
feng
4b981fd93c fix: Error subpub_msg log 2024-10-17 15:17:12 +08:00
294 changed files with 6539 additions and 23523 deletions

View File

@@ -1,24 +0,0 @@
name: Publish Release to Discord
on:
release:
types: [published]
jobs:
send_discord_notification:
runs-on: ubuntu-latest
if: startsWith(github.event.release.tag_name, 'v4.')
steps:
- name: Send release notification to Discord
env:
WEBHOOK_URL: ${{ secrets.DISCORD_CHANGELOG_WEBHOOK }}
run: |
# 获取标签名称和 release body
TAG_NAME="${{ github.event.release.tag_name }}"
RELEASE_BODY="${{ github.event.release.body }}"
# 使用 jq 构建 JSON 数据,以确保安全传递
JSON_PAYLOAD=$(jq -n --arg tag "# JumpServer $TAG_NAME Released! 🚀" --arg body "$RELEASE_BODY" '{content: "\($tag)\n\($body)"}')
# 使用 curl 发送 JSON 数据
curl -X POST -H "Content-Type: application/json" -d "$JSON_PAYLOAD" "$WEBHOOK_URL"

View File

@@ -1,24 +0,0 @@
name: Auto update docs changelog
on:
release:
types: [published]
jobs:
update_docs_changelog:
runs-on: ubuntu-latest
if: startsWith(github.event.release.tag_name, 'v4.')
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Update docs changelog
env:
TAG_NAME: ${{ github.event.release.tag_name }}
DOCS_TOKEN: ${{ secrets.DOCS_TOKEN }}
run: |
git config --global user.name 'BaiJiangJie'
git config --global user.email 'jiangjie.bai@fit2cloud.com'
git clone https://$DOCS_TOKEN@github.com/jumpservice/documentation.git
cd documentation/utils
bash update_changelog.sh

View File

@@ -1,28 +0,0 @@
name: LLM Code Review
permissions:
contents: read
pull-requests: write
on:
pull_request:
types: [opened, reopened, synchronize]
jobs:
llm-code-review:
runs-on: ubuntu-latest
steps:
- uses: fit2cloud/LLM-CodeReview-Action@main
env:
GITHUB_TOKEN: ${{ secrets.FIT2CLOUDRD_LLM_CODE_REVIEW_TOKEN }}
OPENAI_API_KEY: ${{ secrets.ALIYUN_LLM_API_KEY }}
LANGUAGE: English
OPENAI_API_ENDPOINT: https://dashscope.aliyuncs.com/compatible-mode/v1
MODEL: qwen2-1.5b-instruct
PROMPT: "Please check the following code differences for any irregularities, potential issues, or optimization suggestions, and provide your answers in English."
top_p: 1
temperature: 1
# max_tokens: 10000
MAX_PATCH_LENGTH: 10000
IGNORE_PATTERNS: "/node_modules,*.md,/dist,/.github"
FILE_PATTERNS: "*.java,*.go,*.py,*.vue,*.ts,*.js,*.css,*.scss,*.html"

View File

@@ -1,40 +0,0 @@
name: Translate README
on:
workflow_dispatch:
inputs:
target_langs:
description: "Target Languages"
required: false
default: "zh-hans,zh-hant,ja,pt-br"
gen_dir_path:
description: "Generate Dir Name"
required: false
default: "readmes/"
push_branch:
description: "Push Branch"
required: false
default: "pr@dev@translate_readme"
prompt:
description: "AI Translate Prompt"
required: false
default: ""
gpt_mode:
description: "GPT Mode"
required: false
default: "gpt-4o-mini"
jobs:
build:
runs-on: ubuntu-latest
steps:
- name: Auto Translate
uses: jumpserver-dev/action-translate-readme@main
env:
GITHUB_TOKEN: ${{ secrets.PRIVATE_TOKEN }}
OPENAI_API_KEY: ${{ secrets.GPT_API_TOKEN }}
GPT_MODE: ${{ github.event.inputs.gpt_mode }}
TARGET_LANGUAGES: ${{ github.event.inputs.target_langs }}
PUSH_BRANCH: ${{ github.event.inputs.push_branch }}
GEN_DIR_PATH: ${{ github.event.inputs.gen_dir_path }}
PROMPT: ${{ github.event.inputs.prompt }}

1
.gitignore vendored
View File

@@ -45,4 +45,3 @@ test.py
.history/ .history/
.test/ .test/
*.mo *.mo
apps.iml

View File

@@ -1,4 +1,4 @@
FROM jumpserver/core-base:20241210_070105 AS stage-build FROM jumpserver/core-base:20240924_031841 AS stage-build
ARG VERSION ARG VERSION
@@ -28,7 +28,6 @@ ARG DEPENDENCIES=" \
libx11-dev" libx11-dev"
ARG TOOLS=" \ ARG TOOLS=" \
cron \
ca-certificates \ ca-certificates \
default-libmysqlclient-dev \ default-libmysqlclient-dev \
openssh-client \ openssh-client \
@@ -36,20 +35,19 @@ ARG TOOLS=" \
bubblewrap" bubblewrap"
ARG APT_MIRROR=http://deb.debian.org ARG APT_MIRROR=http://deb.debian.org
RUN set -ex \ RUN set -ex \
&& rm -f /etc/apt/apt.conf.d/docker-clean \
&& sed -i "s@http://.*.debian.org@${APT_MIRROR}@g" /etc/apt/sources.list \ && sed -i "s@http://.*.debian.org@${APT_MIRROR}@g" /etc/apt/sources.list \
&& ln -sf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime \ && ln -sf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime \
&& apt-get update > /dev/null \ && apt-get update > /dev/null \
&& apt-get -y install --no-install-recommends ${DEPENDENCIES} \ && apt-get -y install --no-install-recommends ${DEPENDENCIES} \
&& apt-get -y install --no-install-recommends ${TOOLS} \ && apt-get -y install --no-install-recommends ${TOOLS} \
&& apt-get clean \
&& mkdir -p /root/.ssh/ \ && mkdir -p /root/.ssh/ \
&& echo "Host *\n\tStrictHostKeyChecking no\n\tUserKnownHostsFile /dev/null\n\tCiphers +aes128-cbc\n\tKexAlgorithms +diffie-hellman-group1-sha1\n\tHostKeyAlgorithms +ssh-rsa" > /root/.ssh/config \ && echo "Host *\n\tStrictHostKeyChecking no\n\tUserKnownHostsFile /dev/null\n\tCiphers +aes128-cbc\n\tKexAlgorithms +diffie-hellman-group1-sha1\n\tHostKeyAlgorithms +ssh-rsa" > /root/.ssh/config \
&& echo "no" | dpkg-reconfigure dash \ && echo "no" | dpkg-reconfigure dash \
&& apt-get clean all \ && sed -i "s@# export @export @g" ~/.bashrc \
&& rm -rf /var/lib/apt/lists/* \ && sed -i "s@# alias @alias @g" ~/.bashrc
&& echo "0 3 * * * root find /tmp -type f -mtime +1 -size +1M -exec rm -f {} \; && date > /tmp/clean.log" > /etc/cron.d/cleanup_tmp \
&& chmod 0644 /etc/cron.d/cleanup_tmp
COPY --from=stage-build /opt /opt COPY --from=stage-build /opt /opt
COPY --from=stage-build /usr/local/bin /usr/local/bin COPY --from=stage-build /usr/local/bin /usr/local/bin

View File

@@ -15,8 +15,8 @@ ARG DEPENDENCIES=" \
libldap2-dev \ libldap2-dev \
libsasl2-dev" libsasl2-dev"
ARG APT_MIRROR=http://deb.debian.org
ARG APT_MIRROR=http://deb.debian.org
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked,id=core \ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked,id=core \
--mount=type=cache,target=/var/lib/apt,sharing=locked,id=core \ --mount=type=cache,target=/var/lib/apt,sharing=locked,id=core \
set -ex \ set -ex \
@@ -27,8 +27,9 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked,id=core \
&& apt-get -y install --no-install-recommends ${DEPENDENCIES} \ && apt-get -y install --no-install-recommends ${DEPENDENCIES} \
&& echo "no" | dpkg-reconfigure dash && echo "no" | dpkg-reconfigure dash
# Install bin tools # Install bin tools
ARG CHECK_VERSION=v1.0.4 ARG CHECK_VERSION=v1.0.3
RUN set -ex \ RUN set -ex \
&& wget https://github.com/jumpserver-dev/healthcheck/releases/download/${CHECK_VERSION}/check-${CHECK_VERSION}-linux-${TARGETARCH}.tar.gz \ && wget https://github.com/jumpserver-dev/healthcheck/releases/download/${CHECK_VERSION}/check-${CHECK_VERSION}-linux-${TARGETARCH}.tar.gz \
&& tar -xf check-${CHECK_VERSION}-linux-${TARGETARCH}.tar.gz \ && tar -xf check-${CHECK_VERSION}-linux-${TARGETARCH}.tar.gz \
@@ -37,24 +38,23 @@ RUN set -ex \
&& chmod 755 /usr/local/bin/check \ && chmod 755 /usr/local/bin/check \
&& rm -f check-${CHECK_VERSION}-linux-${TARGETARCH}.tar.gz && rm -f check-${CHECK_VERSION}-linux-${TARGETARCH}.tar.gz
# Install Python dependencies # Install Python dependencies
WORKDIR /opt/jumpserver WORKDIR /opt/jumpserver
ARG PIP_MIRROR=https://pypi.org/simple ARG PIP_MIRROR=https://pypi.org/simple
ENV POETRY_PYPI_MIRROR_URL=${PIP_MIRROR}
ENV ANSIBLE_COLLECTIONS_PATHS=/opt/py3/lib/python3.11/site-packages/ansible_collections ENV ANSIBLE_COLLECTIONS_PATHS=/opt/py3/lib/python3.11/site-packages/ansible_collections
RUN --mount=type=cache,target=/root/.cache \ RUN --mount=type=cache,target=/root/.cache,sharing=locked,id=core \
--mount=type=bind,source=poetry.lock,target=poetry.lock \ --mount=type=bind,source=poetry.lock,target=poetry.lock \
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \ --mount=type=bind,source=pyproject.toml,target=pyproject.toml \
--mount=type=bind,source=utils/clean_site_packages.sh,target=clean_site_packages.sh \ --mount=type=bind,source=utils/clean_site_packages.sh,target=clean_site_packages.sh \
--mount=type=bind,source=requirements/collections.yml,target=collections.yml \ --mount=type=bind,source=requirements/collections.yml,target=collections.yml \
set -ex \ set -ex \
&& python3 -m venv /opt/py3 \ && python3 -m venv /opt/py3 \
&& pip install poetry poetry-plugin-pypi-mirror -i ${PIP_MIRROR} \ && pip install poetry -i ${PIP_MIRROR} \
&& . /opt/py3/bin/activate \
&& poetry config virtualenvs.create false \ && poetry config virtualenvs.create false \
&& poetry install --no-cache --only main \ && . /opt/py3/bin/activate \
&& poetry install --only main \
&& ansible-galaxy collection install -r collections.yml --force --ignore-certs \ && ansible-galaxy collection install -r collections.yml --force --ignore-certs \
&& bash clean_site_packages.sh \ && bash clean_site_packages.sh
&& poetry cache clear pypi --all

View File

@@ -15,20 +15,21 @@ ARG TOOLS=" \
vim \ vim \
wget" wget"
ARG APT_MIRROR=http://deb.debian.org
RUN set -ex \ RUN set -ex \
&& rm -f /etc/apt/apt.conf.d/docker-clean \
&& echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' > /etc/apt/apt.conf.d/keep-cache \
&& sed -i "s@http://.*.debian.org@${APT_MIRROR}@g" /etc/apt/sources.list \
&& apt-get update \ && apt-get update \
&& apt-get -y install --no-install-recommends ${TOOLS} \ && apt-get -y install --no-install-recommends ${TOOLS} \
&& apt-get clean all \ && echo "no" | dpkg-reconfigure dash
&& rm -rf /var/lib/apt/lists/*
WORKDIR /opt/jumpserver WORKDIR /opt/jumpserver
ARG PIP_MIRROR=https://pypi.org/simple ARG PIP_MIRROR=https://pypi.org/simple
ENV POETRY_PYPI_MIRROR_URL=${PIP_MIRROR}
COPY poetry.lock pyproject.toml ./ COPY poetry.lock pyproject.toml ./
RUN set -ex \ RUN set -ex \
&& . /opt/py3/bin/activate \ && . /opt/py3/bin/activate \
&& pip install poetry poetry-plugin-pypi-mirror -i ${PIP_MIRROR} \ && pip install poetry -i ${PIP_MIRROR} \
&& poetry install --only xpack \ && poetry install --only xpack
&& poetry cache clear pypi --all

View File

@@ -10,8 +10,7 @@
[![][github-release-shield]][github-release-link] [![][github-release-shield]][github-release-link]
[![][github-stars-shield]][github-stars-link] [![][github-stars-shield]][github-stars-link]
[English](/README.md) · [中文(简体)](/readmes/README.zh-hans.md) · [中文(繁體)](/readmes/README.zh-hant.md) · [日本語](/readmes/README.ja.md) · [Português (Brasil)](/readmes/README.pt-br.md) **English** · [简体中文](./README.zh-CN.md)
</div> </div>
<br/> <br/>
@@ -69,13 +68,10 @@ JumpServer consists of multiple key components, which collectively form the func
| [KoKo](https://github.com/jumpserver/koko) | <a href="https://github.com/jumpserver/koko/releases"><img alt="Koko release" src="https://img.shields.io/github/release/jumpserver/koko.svg" /></a> | JumpServer Character Protocol Connector | | [KoKo](https://github.com/jumpserver/koko) | <a href="https://github.com/jumpserver/koko/releases"><img alt="Koko release" src="https://img.shields.io/github/release/jumpserver/koko.svg" /></a> | JumpServer Character Protocol Connector |
| [Lion](https://github.com/jumpserver/lion) | <a href="https://github.com/jumpserver/lion/releases"><img alt="Lion release" src="https://img.shields.io/github/release/jumpserver/lion.svg" /></a> | JumpServer Graphical Protocol Connector | | [Lion](https://github.com/jumpserver/lion) | <a href="https://github.com/jumpserver/lion/releases"><img alt="Lion release" src="https://img.shields.io/github/release/jumpserver/lion.svg" /></a> | JumpServer Graphical Protocol Connector |
| [Chen](https://github.com/jumpserver/chen) | <a href="https://github.com/jumpserver/chen/releases"><img alt="Chen release" src="https://img.shields.io/github/release/jumpserver/chen.svg" /> | JumpServer Web DB | | [Chen](https://github.com/jumpserver/chen) | <a href="https://github.com/jumpserver/chen/releases"><img alt="Chen release" src="https://img.shields.io/github/release/jumpserver/chen.svg" /> | JumpServer Web DB |
| [Razor](https://github.com/jumpserver/razor) | <img alt="Chen" src="https://img.shields.io/badge/release-private-red" /> | JumpServer EE RDP Proxy Connector | | [Razor](https://github.com/jumpserver/razor) | <img alt="Chen" src="https://img.shields.io/badge/release-private-red" /> | JumpServer EE RDP Proxy Connector |
| [Tinker](https://github.com/jumpserver/tinker) | <img alt="Tinker" src="https://img.shields.io/badge/release-private-red" /> | JumpServer EE Remote Application Connector (Windows) | | [Tinker](https://github.com/jumpserver/tinker) | <img alt="Tinker" src="https://img.shields.io/badge/release-private-red" /> | JumpServer EE Remote Application Connector (Windows) |
| [Panda](https://github.com/jumpserver/Panda) | <img alt="Panda" src="https://img.shields.io/badge/release-private-red" /> | JumpServer EE Remote Application Connector (Linux) | | [Panda](https://github.com/jumpserver/Panda) | <img alt="Panda" src="https://img.shields.io/badge/release-private-red" /> | JumpServer EE Remote Application Connector (Linux) |
| [Magnus](https://github.com/jumpserver/magnus) | <img alt="Magnus" src="https://img.shields.io/badge/release-private-red" /> | JumpServer EE Database Proxy Connector | | [Magnus](https://github.com/jumpserver/magnus) | <img alt="Magnus" src="https://img.shields.io/badge/release-private-red" /> | JumpServer EE Database Proxy Connector |
| [Nec](https://github.com/jumpserver/nec) | <img alt="Nec" src="https://img.shields.io/badge/release-private-red" /> | JumpServer EE VNC Proxy Connector |
| [Facelive](https://github.com/jumpserver/facelive) | <img alt="Facelive" src="https://img.shields.io/badge/release-private-red" /> | JumpServer EE Facial Recognition |
## Contributing ## Contributing
@@ -89,7 +85,7 @@ JumpServer is a mission critical product. Please refer to the Basic Security Rec
## License ## License
Copyright (c) 2014-2025 FIT2CLOUD, All rights reserved. Copyright (c) 2014-2024 飞致云 FIT2CLOUD, All rights reserved.
Licensed under The GNU General Public License version 3 (GPLv3) (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at Licensed under The GNU General Public License version 3 (GPLv3) (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

View File

@@ -109,7 +109,7 @@ JumpServer是一款安全产品请参考 [基本安全建议](https://docs.ju
## License & Copyright ## License & Copyright
Copyright (c) 2014-2024 飞致云, All rights reserved. Copyright (c) 2014-2024 飞致云 FIT2CLOUD, All rights reserved.
Licensed under The GNU General Public License version 3 (GPLv3) (the "License"); you may not use this file except in Licensed under The GNU General Public License version 3 (GPLv3) (the "License"); you may not use this file except in
compliance with the License. You may obtain a copy of the License at compliance with the License. You may obtain a copy of the License at

View File

@@ -30,6 +30,6 @@
login_user: "{{ account.username }}" login_user: "{{ account.username }}"
login_password: "{{ account.secret }}" login_password: "{{ account.secret }}"
login_secret_type: "{{ account.secret_type }}" login_secret_type: "{{ account.secret_type }}"
gateway_args: "{{ jms_gateway | default({}) }}" gateway_args: "{{ jms_gateway | default(None) }}"
when: account.secret_type == "password" when: account.secret_type == "password"
delegate_to: localhost delegate_to: localhost

View File

@@ -160,10 +160,6 @@ class ChangeSecretManager(AccountBasePlaybookManager):
ChangeSecretRecord.objects.bulk_create(records) ChangeSecretRecord.objects.bulk_create(records)
return inventory_hosts return inventory_hosts
@staticmethod
def require_update_version(account, recorder):
return account.secret != recorder.new_secret
def on_host_success(self, host, result): def on_host_success(self, host, result):
recorder = self.name_recorder_mapper.get(host) recorder = self.name_recorder_mapper.get(host)
if not recorder: if not recorder:
@@ -175,8 +171,6 @@ class ChangeSecretManager(AccountBasePlaybookManager):
if not account: if not account:
print("Account not found, deleted ?") print("Account not found, deleted ?")
return return
version_update_required = self.require_update_version(account, recorder)
account.secret = recorder.new_secret account.secret = recorder.new_secret
account.date_updated = timezone.now() account.date_updated = timezone.now()
@@ -186,10 +180,7 @@ class ChangeSecretManager(AccountBasePlaybookManager):
while retry_count < max_retries: while retry_count < max_retries:
try: try:
recorder.save() recorder.save()
account_update_fields = ['secret', 'date_updated'] account.save(update_fields=['secret', 'version', 'date_updated'])
if version_update_required:
account_update_fields.append('version')
account.save(update_fields=account_update_fields)
break break
except Exception as e: except Exception as e:
retry_count += 1 retry_count += 1

View File

@@ -30,6 +30,6 @@
login_user: "{{ account.username }}" login_user: "{{ account.username }}"
login_password: "{{ account.secret }}" login_password: "{{ account.secret }}"
login_secret_type: "{{ account.secret_type }}" login_secret_type: "{{ account.secret_type }}"
gateway_args: "{{ jms_gateway | default({}) }}" gateway_args: "{{ jms_gateway | default(None) }}"
when: account.secret_type == "password" when: account.secret_type == "password"
delegate_to: localhost delegate_to: localhost

View File

@@ -8,11 +8,6 @@ logger = get_logger(__name__)
class PushAccountManager(ChangeSecretManager, AccountBasePlaybookManager): class PushAccountManager(ChangeSecretManager, AccountBasePlaybookManager):
@staticmethod
def require_update_version(account, recorder):
account.skip_history_when_saving = True
return False
@classmethod @classmethod
def method_type(cls): def method_type(cls):
return AutomationTypes.push_account return AutomationTypes.push_account

View File

@@ -1,18 +1,19 @@
from importlib import import_module from importlib import import_module
from django.utils.functional import LazyObject, empty from django.utils.functional import LazyObject
from common.utils import get_logger from common.utils import get_logger
from ..const import VaultTypeChoices from ..const import VaultTypeChoices
__all__ = ['vault_client', 'get_vault_client', 'refresh_vault_client'] __all__ = ['vault_client', 'get_vault_client']
logger = get_logger(__file__) logger = get_logger(__file__)
def get_vault_client(raise_exception=False, **kwargs): def get_vault_client(raise_exception=False, **kwargs):
tp = kwargs.get('VAULT_BACKEND') if kwargs.get('VAULT_ENABLED') else VaultTypeChoices.local enabled = kwargs.get('VAULT_ENABLED')
tp = 'hcp' if enabled else 'local'
try: try:
module_path = f'apps.accounts.backends.{tp}.main' module_path = f'apps.accounts.backends.{tp}.main'
client = import_module(module_path).Vault(**kwargs) client = import_module(module_path).Vault(**kwargs)
@@ -38,7 +39,3 @@ class VaultClient(LazyObject):
""" 为了安全, 页面修改配置, 重启服务后才会重新初始化 vault_client """ """ 为了安全, 页面修改配置, 重启服务后才会重新初始化 vault_client """
vault_client = VaultClient() vault_client = VaultClient()
def refresh_vault_client():
vault_client._wrapped = empty

View File

@@ -1 +0,0 @@
from .main import *

View File

@@ -1,16 +0,0 @@
from .service import AmazonSecretsManagerClient
from ..base.vault import BaseVault
from ..utils.mixins import GeneralVaultMixin
from ...const import VaultTypeChoices
class Vault(GeneralVaultMixin, BaseVault):
type = VaultTypeChoices.aws
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.client = AmazonSecretsManagerClient(
region_name=kwargs.get('VAULT_AWS_REGION_NAME'),
access_key_id=kwargs.get('VAULT_AWS_ACCESS_KEY_ID'),
secret_key=kwargs.get('VAULT_AWS_ACCESS_SECRET_KEY'),
)

View File

@@ -1,56 +0,0 @@
import boto3
from common.utils import get_logger, random_string
logger = get_logger(__name__)
__all__ = ['AmazonSecretsManagerClient']
class AmazonSecretsManagerClient(object):
def __init__(self, region_name, access_key_id, secret_key):
self.client = boto3.client(
'secretsmanager', region_name=region_name,
aws_access_key_id=access_key_id, aws_secret_access_key=secret_key,
)
self.empty_secret = '#{empty}#'
def is_active(self):
try:
secret_id = f'jumpserver/test-{random_string(12)}'
self.create(secret_id, 'secret')
self.get(secret_id)
self.update(secret_id, 'secret')
self.delete(secret_id)
except Exception as e:
return False, f'Vault is not reachable: {e}'
else:
return True, ''
def get(self, name, version=''):
params = {'SecretId': name}
if version:
params['VersionStage'] = version
try:
secret = self.client.get_secret_value(**params)['SecretString']
return secret if secret != self.empty_secret else ''
except Exception: # noqa
return ''
def create(self, name, secret):
self.client.create_secret(Name=name, SecretString=secret or self.empty_secret)
def update(self, name, secret):
self.client.update_secret(SecretId=name, SecretString=secret or self.empty_secret)
def delete(self, name):
self.client.delete_secret(SecretId=name)
def update_metadata(self, name, metadata: dict):
tags = [{'Key': k, 'Value': v} for k, v in metadata.items()]
try:
self.client.tag_resource(SecretId=name, Tags=tags)
except Exception as e:
logger.error(f'update_metadata: {name} {str(e)}')

View File

@@ -1 +0,0 @@
from .main import *

View File

@@ -1,33 +0,0 @@
from ..base.entries import BaseEntry
class AzureBaseEntry(BaseEntry):
@property
def full_path(self):
return self.path_spec
class AccountEntry(AzureBaseEntry):
@property
def path_spec(self):
# 长度 0-127
account_id = str(self.instance.id)[:18]
path = f'assets-{self.instance.asset_id}-accounts-{account_id}'
return path
class AccountTemplateEntry(AzureBaseEntry):
@property
def path_spec(self):
path = f'account-templates-{self.instance.id}'
return path
class HistoricalAccountEntry(AzureBaseEntry):
@property
def path_spec(self):
path = f'accounts-{self.instance.instance.id}-histories-{self.instance.history_id}'
return path

View File

@@ -1,17 +0,0 @@
from .service import AZUREVaultClient
from ..base.vault import BaseVault
from ..utils.mixins import GeneralVaultMixin
from ...const import VaultTypeChoices
class Vault(GeneralVaultMixin, BaseVault):
type = VaultTypeChoices.azure
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.client = AZUREVaultClient(
vault_url=kwargs.get('VAULT_AZURE_HOST'),
tenant_id=kwargs.get('VAULT_AZURE_TENANT_ID'),
client_id=kwargs.get('VAULT_AZURE_CLIENT_ID'),
client_secret=kwargs.get('VAULT_AZURE_CLIENT_SECRET')
)

View File

@@ -1,58 +0,0 @@
# -*- coding: utf-8 -*-
#
from azure.core.exceptions import ResourceNotFoundError, ClientAuthenticationError
from azure.identity import ClientSecretCredential
from azure.keyvault.secrets import SecretClient
from common.utils import get_logger
logger = get_logger(__name__)
__all__ = ['AZUREVaultClient']
class AZUREVaultClient(object):
def __init__(self, vault_url, tenant_id, client_id, client_secret):
authentication_endpoint = 'https://login.microsoftonline.com/' \
if ('azure.net' in vault_url) else 'https://login.chinacloudapi.cn/'
credentials = ClientSecretCredential(
client_id=client_id, client_secret=client_secret, tenant_id=tenant_id, authority=authentication_endpoint
)
self.client = SecretClient(vault_url=vault_url, credential=credentials)
def is_active(self):
try:
self.client.set_secret('jumpserver', '666')
except (ResourceNotFoundError, ClientAuthenticationError) as e:
logger.error(str(e))
return False, f'Vault is not reachable: {e}'
else:
return True, ''
def get(self, name, version=None):
try:
secret = self.client.get_secret(name, version)
return secret.value
except (ResourceNotFoundError, ClientAuthenticationError) as e:
return ''
def create(self, name, secret):
if not secret:
secret = ''
self.client.set_secret(name, secret)
def update(self, name, secret):
if not secret:
secret = ''
self.client.set_secret(name, secret)
def delete(self, name):
self.client.begin_delete_secret(name)
def update_metadata(self, name, metadata: dict):
try:
self.client.update_secret_properties(name, tags=metadata)
except (ResourceNotFoundError, ClientAuthenticationError) as e:
logger.error(f'update_metadata: {name} {str(e)}')

View File

@@ -0,0 +1,74 @@
from abc import ABC, abstractmethod
from django.forms.models import model_to_dict
__all__ = ['BaseVault']
class BaseVault(ABC):
def __init__(self, *args, **kwargs):
self.enabled = kwargs.get('VAULT_ENABLED')
def get(self, instance):
""" 返回 secret 值 """
return self._get(instance)
def create(self, instance):
if not instance.secret_has_save_to_vault:
self._create(instance)
self._clean_db_secret(instance)
self.save_metadata(instance)
if instance.is_sync_metadata:
self.save_metadata(instance)
def update(self, instance):
if not instance.secret_has_save_to_vault:
self._update(instance)
self._clean_db_secret(instance)
self.save_metadata(instance)
if instance.is_sync_metadata:
self.save_metadata(instance)
def delete(self, instance):
self._delete(instance)
def save_metadata(self, instance):
metadata = model_to_dict(instance, fields=[
'name', 'username', 'secret_type',
'connectivity', 'su_from', 'privileged'
])
metadata = {k: str(v)[:500] for k, v in metadata.items() if v}
return self._save_metadata(instance, metadata)
# -------- abstractmethod -------- #
@abstractmethod
def _get(self, instance):
raise NotImplementedError
@abstractmethod
def _create(self, instance):
raise NotImplementedError
@abstractmethod
def _update(self, instance):
raise NotImplementedError
@abstractmethod
def _delete(self, instance):
raise NotImplementedError
@abstractmethod
def _clean_db_secret(self, instance):
raise NotImplementedError
@abstractmethod
def _save_metadata(self, instance, metadata):
raise NotImplementedError
@abstractmethod
def is_active(self, *args, **kwargs) -> (bool, str):
raise NotImplementedError

View File

@@ -1,109 +0,0 @@
import importlib
import inspect
from abc import ABC, abstractmethod
from django.forms.models import model_to_dict
from .entries import BaseEntry
from ...const import VaultTypeChoices
class BaseVault(ABC):
def __init__(self, *args, **kwargs):
self.enabled = kwargs.get('VAULT_ENABLED')
self._entry_classes = {}
self._load_entries()
def _load_entries_import_module(self, module_name):
module = importlib.import_module(module_name)
for name, obj in inspect.getmembers(module, inspect.isclass):
self._entry_classes.setdefault(name, obj)
def _load_entries(self):
if self.type == VaultTypeChoices.local:
return
module_name = f'accounts.backends.{self.type}.entries'
if importlib.util.find_spec(module_name): # noqa
self._load_entries_import_module(module_name)
base_module = 'accounts.backends.base.entries'
self._load_entries_import_module(base_module)
@property
@abstractmethod
def type(self):
raise NotImplementedError
def get(self, instance):
""" 返回 secret 值 """
return self._get(self.build_entry(instance))
def create(self, instance):
if not instance.secret_has_save_to_vault:
entry = self.build_entry(instance)
self._create(entry)
self._clean_db_secret(instance)
self.save_metadata(entry)
def update(self, instance):
entry = self.build_entry(instance)
if not instance.secret_has_save_to_vault:
self._update(entry)
self._clean_db_secret(instance)
self.save_metadata(entry)
if instance.is_sync_metadata:
self.save_metadata(entry)
def delete(self, instance):
entry = self.build_entry(instance)
self._delete(entry)
def save_metadata(self, entry):
metadata = model_to_dict(entry.instance, fields=[
'name', 'username', 'secret_type',
'connectivity', 'su_from', 'privileged'
])
metadata = {k: str(v)[:500] for k, v in metadata.items() if v}
return self._save_metadata(entry, metadata)
def build_entry(self, instance):
if self.type == VaultTypeChoices.local:
return BaseEntry(instance)
entry_class_name = f'{instance.__class__.__name__}Entry'
entry_class = self._entry_classes.get(entry_class_name)
if not entry_class:
raise Exception(f'Entry class {entry_class_name} is not found')
return entry_class(instance)
def _clean_db_secret(self, instance):
instance.is_sync_metadata = False
instance.mark_secret_save_to_vault()
# -------- abstractmethod -------- #
@abstractmethod
def _get(self, instance):
raise NotImplementedError
@abstractmethod
def _create(self, entry):
raise NotImplementedError
@abstractmethod
def _update(self, entry):
raise NotImplementedError
@abstractmethod
def _delete(self, entry):
raise NotImplementedError
@abstractmethod
def _save_metadata(self, instance, metadata):
raise NotImplementedError
@abstractmethod
def is_active(self, *args, **kwargs) -> (bool, str):
raise NotImplementedError

View File

@@ -1,18 +1,19 @@
import sys
from abc import ABC from abc import ABC
from common.db.utils import Encryptor from common.db.utils import Encryptor
from common.utils import lazyproperty from common.utils import lazyproperty
current_module = sys.modules[__name__]
__all__ = ['build_entry']
class BaseEntry(ABC): class BaseEntry(ABC):
def __init__(self, instance): def __init__(self, instance):
self.instance = instance self.instance = instance
@property
def path_base(self):
path = f'orgs/{self.instance.org_id}'
return path
@lazyproperty @lazyproperty
def full_path(self): def full_path(self):
path_base = self.path_base path_base = self.path_base
@@ -20,24 +21,32 @@ class BaseEntry(ABC):
path = f'{path_base}/{path_spec}' path = f'{path_base}/{path_spec}'
return path return path
@property
def path_base(self):
path = f'orgs/{self.instance.org_id}'
return path
@property @property
def path_spec(self): def path_spec(self):
raise NotImplementedError raise NotImplementedError
def get_encrypt_secret(self): def to_internal_data(self):
secret = getattr(self.instance, '_secret', None) secret = getattr(self.instance, '_secret', None)
if secret is not None: if secret is not None:
secret = Encryptor(secret).encrypt() secret = Encryptor(secret).encrypt()
return secret data = {'secret': secret}
return data
@staticmethod @staticmethod
def get_decrypt_secret(secret): def to_external_data(data):
secret = data.pop('secret', None)
if secret is not None: if secret is not None:
secret = Encryptor(secret).decrypt() secret = Encryptor(secret).decrypt()
return secret return secret
class AccountEntry(BaseEntry): class AccountEntry(BaseEntry):
@property @property
def path_spec(self): def path_spec(self):
path = f'assets/{self.instance.asset_id}/accounts/{self.instance.id}' path = f'assets/{self.instance.asset_id}/accounts/{self.instance.id}'
@@ -45,6 +54,7 @@ class AccountEntry(BaseEntry):
class AccountTemplateEntry(BaseEntry): class AccountTemplateEntry(BaseEntry):
@property @property
def path_spec(self): def path_spec(self):
path = f'account-templates/{self.instance.id}' path = f'account-templates/{self.instance.id}'
@@ -52,12 +62,23 @@ class AccountTemplateEntry(BaseEntry):
class HistoricalAccountEntry(BaseEntry): class HistoricalAccountEntry(BaseEntry):
@property @property
def path_base(self): def path_base(self):
path = f'accounts/{self.instance.instance.id}' account = self.instance.instance
path = f'accounts/{account.id}/'
return path return path
@property @property
def path_spec(self): def path_spec(self):
path = f'histories/{self.instance.history_id}' path = f'histories/{self.instance.history_id}'
return path return path
def build_entry(instance) -> BaseEntry:
class_name = instance.__class__.__name__
entry_class_name = f'{class_name}Entry'
entry_class = getattr(current_module, entry_class_name, None)
if not entry_class:
raise Exception(f'Entry class {entry_class_name} is not found')
return entry_class(instance)

View File

@@ -1,18 +1,14 @@
from common.db.utils import get_logger from common.db.utils import get_logger
from .entries import build_entry
from .service import VaultKVClient from .service import VaultKVClient
from ..base.vault import BaseVault from ..base import BaseVault
from ...const import VaultTypeChoices
logger = get_logger(__name__)
__all__ = ['Vault'] __all__ = ['Vault']
logger = get_logger(__name__)
class Vault(BaseVault): class Vault(BaseVault):
type = VaultTypeChoices.hcp
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.client = VaultKVClient( self.client = VaultKVClient(
@@ -24,25 +20,34 @@ class Vault(BaseVault):
def is_active(self): def is_active(self):
return self.client.is_active() return self.client.is_active()
def _get(self, entry): def _get(self, instance):
entry = build_entry(instance)
# TODO: get data 是不是层数太多了 # TODO: get data 是不是层数太多了
data = self.client.get(path=entry.full_path).get('data', {}) data = self.client.get(path=entry.full_path).get('data', {})
data = entry.get_decrypt_secret(data.get('secret')) data = entry.to_external_data(data)
return data return data
def _create(self, entry): def _create(self, instance):
data = {'secret': entry.get_encrypt_secret()} entry = build_entry(instance)
data = entry.to_internal_data()
self.client.create(path=entry.full_path, data=data) self.client.create(path=entry.full_path, data=data)
def _update(self, entry): def _update(self, instance):
data = {'secret': entry.get_encrypt_secret()} entry = build_entry(instance)
data = entry.to_internal_data()
self.client.patch(path=entry.full_path, data=data) self.client.patch(path=entry.full_path, data=data)
def _delete(self, entry): def _delete(self, instance):
entry = build_entry(instance)
self.client.delete(path=entry.full_path) self.client.delete(path=entry.full_path)
def _save_metadata(self, entry, metadata): def _clean_db_secret(self, instance):
instance.is_sync_metadata = False
instance.mark_secret_save_to_vault()
def _save_metadata(self, instance, metadata):
try: try:
entry = build_entry(instance)
self.client.update_metadata(path=entry.full_path, metadata=metadata) self.client.update_metadata(path=entry.full_path, metadata=metadata)
except Exception as e: except Exception as e:
logger.error(f'save metadata error: {e}') logger.error(f'save metadata error: {e}')

View File

@@ -1,6 +1,5 @@
from common.utils import get_logger from common.utils import get_logger
from ..base.vault import BaseVault from ..base import BaseVault
from ...const import VaultTypeChoices
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -8,28 +7,27 @@ __all__ = ['Vault']
class Vault(BaseVault): class Vault(BaseVault):
type = VaultTypeChoices.local
def is_active(self): def is_active(self):
return True, '' return True, ''
def _get(self, entry): def _get(self, instance):
secret = getattr(entry.instance, '_secret', None) secret = getattr(instance, '_secret', None)
return secret return secret
def _create(self, entry): def _create(self, instance):
""" Ignore """ """ Ignore """
pass pass
def _update(self, entry): def _update(self, instance):
""" Ignore """ """ Ignore """
pass pass
def _delete(self, entry): def _delete(self, instance):
""" Ignore """ """ Ignore """
pass pass
def _save_metadata(self, entry, metadata): def _save_metadata(self, instance, metadata):
""" Ignore """ """ Ignore """
pass pass

View File

@@ -1,32 +0,0 @@
from common.utils import get_logger
logger = get_logger(__name__)
class GeneralVaultMixin(object):
client = None
def is_active(self):
return self.client.is_active()
def _get(self, entry):
secret = self.client.get(name=entry.full_path)
return entry.get_decrypt_secret(secret)
def _create(self, entry):
secret = entry.get_encrypt_secret()
self.client.create(name=entry.full_path, secret=secret)
def _update(self, entry):
secret = entry.get_encrypt_secret()
self.client.update(name=entry.full_path, secret=secret)
def _delete(self, entry):
self.client.delete(name=entry.full_path)
def _save_metadata(self, entry, metadata):
try:
self.client.update_metadata(name=entry.full_path, metadata=metadata)
except Exception as e:
logger.error(f'save metadata error: {e}')

View File

@@ -49,9 +49,9 @@ class SecretStrategy(models.TextChoices):
class SSHKeyStrategy(models.TextChoices): class SSHKeyStrategy(models.TextChoices):
set_jms = 'set_jms', _('Replace (Replace only keys pushed by JumpServer) ')
set = 'set', _('Empty and append SSH KEY')
add = 'add', _('Append SSH KEY') add = 'add', _('Append SSH KEY')
set = 'set', _('Empty and append SSH KEY')
set_jms = 'set_jms', _('Replace (Replace only keys pushed by JumpServer) ')
class TriggerChoice(models.TextChoices, TreeChoices): class TriggerChoice(models.TextChoices, TreeChoices):

View File

@@ -7,5 +7,3 @@ __all__ = ['VaultTypeChoices']
class VaultTypeChoices(models.TextChoices): class VaultTypeChoices(models.TextChoices):
local = 'local', _('Database') local = 'local', _('Database')
hcp = 'hcp', _('HCP Vault') hcp = 'hcp', _('HCP Vault')
azure = 'azure', _('Azure Key Vault')
aws = 'aws', _('Amazon Secrets Manager')

View File

@@ -1,8 +0,0 @@
from common.exceptions import JMSException
from django.utils.translation import gettext_lazy as _
class VaultException(JMSException):
default_detail = _(
'Vault operation failed. Please retry or check your account information on Vault.'
)

View File

@@ -50,7 +50,7 @@ class Migration(migrations.Migration):
('secret', common.db.fields.EncryptTextField(blank=True, null=True, verbose_name='Secret')), ('secret', common.db.fields.EncryptTextField(blank=True, null=True, verbose_name='Secret')),
('secret_strategy', models.CharField(choices=[('specific', 'Specific secret'), ('random', 'Random generate')], default='specific', max_length=16, verbose_name='Secret strategy')), ('secret_strategy', models.CharField(choices=[('specific', 'Specific secret'), ('random', 'Random generate')], default='specific', max_length=16, verbose_name='Secret strategy')),
('password_rules', models.JSONField(default=dict, verbose_name='Password rules')), ('password_rules', models.JSONField(default=dict, verbose_name='Password rules')),
('ssh_key_change_strategy', models.CharField(choices=[('set_jms', 'Replace (Replace only keys pushed by JumpServer) '), ('set', 'Empty and append SSH KEY'), ('add', 'Append SSH KEY')], default='set_jms', max_length=16, verbose_name='SSH key change strategy')), ('ssh_key_change_strategy', models.CharField(choices=[('add', 'Append SSH KEY'), ('set', 'Empty and append SSH KEY'), ('set_jms', 'Replace (Replace only keys pushed by JumpServer) ')], default='add', max_length=16, verbose_name='SSH key change strategy')),
], ],
options={ options={
'verbose_name': 'Change secret automation', 'verbose_name': 'Change secret automation',
@@ -76,7 +76,7 @@ class Migration(migrations.Migration):
('secret', common.db.fields.EncryptTextField(blank=True, null=True, verbose_name='Secret')), ('secret', common.db.fields.EncryptTextField(blank=True, null=True, verbose_name='Secret')),
('secret_strategy', models.CharField(choices=[('specific', 'Specific secret'), ('random', 'Random generate')], default='specific', max_length=16, verbose_name='Secret strategy')), ('secret_strategy', models.CharField(choices=[('specific', 'Specific secret'), ('random', 'Random generate')], default='specific', max_length=16, verbose_name='Secret strategy')),
('password_rules', models.JSONField(default=dict, verbose_name='Password rules')), ('password_rules', models.JSONField(default=dict, verbose_name='Password rules')),
('ssh_key_change_strategy', models.CharField(choices=[('set_jms', 'Replace (Replace only keys pushed by JumpServer) '), ('set', 'Empty and append SSH KEY'), ('add', 'Append SSH KEY')], default='set_jms', max_length=16, verbose_name='SSH key change strategy')), ('ssh_key_change_strategy', models.CharField(choices=[('add', 'Append SSH KEY'), ('set', 'Empty and append SSH KEY'), ('set_jms', 'Replace (Replace only keys pushed by JumpServer) ')], default='add', max_length=16, verbose_name='SSH key change strategy')),
('triggers', models.JSONField(default=list, max_length=16, verbose_name='Triggers')), ('triggers', models.JSONField(default=list, max_length=16, verbose_name='Triggers')),
('username', models.CharField(max_length=128, verbose_name='Username')), ('username', models.CharField(max_length=128, verbose_name='Username')),
('action', models.CharField(max_length=16, verbose_name='Action')), ('action', models.CharField(max_length=16, verbose_name='Action')),

View File

@@ -14,17 +14,13 @@ from common.db import fields
from common.db.encoder import ModelJSONFieldEncoder from common.db.encoder import ModelJSONFieldEncoder
from common.utils import get_logger, lazyproperty from common.utils import get_logger, lazyproperty
from ops.mixin import PeriodTaskModelMixin from ops.mixin import PeriodTaskModelMixin
from orgs.mixins.models import OrgModelMixin, JMSOrgBaseModel, OrgManager from orgs.mixins.models import OrgModelMixin, JMSOrgBaseModel
__all__ = ['AccountBackupAutomation', 'AccountBackupExecution'] __all__ = ['AccountBackupAutomation', 'AccountBackupExecution']
logger = get_logger(__file__) logger = get_logger(__file__)
class BaseBackupAutomationManager(OrgManager):
pass
class AccountBackupAutomation(PeriodTaskModelMixin, JMSOrgBaseModel): class AccountBackupAutomation(PeriodTaskModelMixin, JMSOrgBaseModel):
types = models.JSONField(default=list) types = models.JSONField(default=list)
backup_type = models.CharField(max_length=128, choices=AccountBackupType.choices, backup_type = models.CharField(max_length=128, choices=AccountBackupType.choices,
@@ -51,8 +47,6 @@ class AccountBackupAutomation(PeriodTaskModelMixin, JMSOrgBaseModel):
max_length=4096, blank=True, null=True, verbose_name=_('Zip encrypt password') max_length=4096, blank=True, null=True, verbose_name=_('Zip encrypt password')
) )
objects = BaseBackupAutomationManager.from_queryset(models.QuerySet)()
def __str__(self): def __str__(self):
return f'{self.name}({self.org_id})' return f'{self.name}({self.org_id})'

View File

@@ -51,7 +51,7 @@ class AutomationExecution(AssetAutomationExecution):
class ChangeSecretMixin(SecretWithRandomMixin): class ChangeSecretMixin(SecretWithRandomMixin):
ssh_key_change_strategy = models.CharField( ssh_key_change_strategy = models.CharField(
choices=SSHKeyStrategy.choices, max_length=16, choices=SSHKeyStrategy.choices, max_length=16,
default=SSHKeyStrategy.set_jms, verbose_name=_('SSH key change strategy') default=SSHKeyStrategy.add, verbose_name=_('SSH key change strategy')
) )
get_all_assets: callable # get all assets get_all_assets: callable # get all assets

View File

@@ -2,7 +2,7 @@ from django.conf import settings
from django.db import models from django.db import models
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from accounts.const import AutomationTypes, SecretType from accounts.const import AutomationTypes
from accounts.models import Account from accounts.models import Account
from .base import AccountBaseAutomation from .base import AccountBaseAutomation
from .change_secret import ChangeSecretMixin from .change_secret import ChangeSecretMixin
@@ -23,8 +23,7 @@ class PushAccountAutomation(ChangeSecretMixin, AccountBaseAutomation):
create_usernames = set(usernames) - set(account_usernames) create_usernames = set(usernames) - set(account_usernames)
create_account_objs = [ create_account_objs = [
Account( Account(
name=f"{username}-{secret_type}" if secret_type != SecretType.PASSWORD else username, name=f'{username}-{secret_type}', username=username,
username=username,
secret_type=secret_type, asset=asset, secret_type=secret_type, asset=asset,
) )
for username in create_usernames for username in create_usernames

View File

@@ -80,7 +80,6 @@ class VaultModelMixin(models.Model):
def mark_secret_save_to_vault(self): def mark_secret_save_to_vault(self):
self._secret = self._secret_save_to_vault_mark self._secret = self._secret_save_to_vault_mark
self.skip_history_when_saving = True
self.save() self.save()
@property @property

View File

@@ -385,7 +385,7 @@ class AssetAccountBulkSerializer(
_results = {} _results = {}
for asset in assets: for asset in assets:
if asset not in secret_type_supports and asset.category != Category.CUSTOM: if asset not in secret_type_supports:
_results[asset] = { _results[asset] = {
'error': _('Asset does not support this secret type: %s') % secret_type, 'error': _('Asset does not support this secret type: %s') % secret_type,
'state': 'error', 'state': 'error',

View File

@@ -63,26 +63,6 @@ class ChangeSecretAutomationSerializer(AuthValidateMixin, BaseAutomationSerializ
)}, )},
}} }}
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.set_ssh_key_change_strategy_choices()
def set_ssh_key_change_strategy_choices(self):
ssh_key_change_strategy = self.fields.get("ssh_key_change_strategy")
if not ssh_key_change_strategy:
return
ssh_key_change_strategy._choices.pop(SSHKeyStrategy.add, None)
def to_representation(self, instance):
data = super().to_representation(instance)
ssh_strategy_value = data.get('ssh_key_change_strategy', {}).get('value')
if ssh_strategy_value == SSHKeyStrategy.add:
data['ssh_key_change_strategy'] = {
'label': SSHKeyStrategy.set_jms.label,
'value': SSHKeyStrategy.set_jms.value
}
return data
@property @property
def model_type(self): def model_type(self):
return AutomationTypes.change_secret return AutomationTypes.change_secret

View File

@@ -3,18 +3,14 @@ from collections import defaultdict
from django.db.models.signals import post_delete from django.db.models.signals import post_delete
from django.db.models.signals import pre_save, post_save from django.db.models.signals import pre_save, post_save
from django.dispatch import receiver from django.dispatch import receiver
from django.utils.functional import LazyObject
from django.utils.translation import gettext_noop from django.utils.translation import gettext_noop
from accounts.backends import vault_client, refresh_vault_client from accounts.backends import vault_client
from accounts.const import Source from accounts.const import Source
from audits.const import ActivityChoices from audits.const import ActivityChoices
from audits.signal_handlers import create_activities from audits.signal_handlers import create_activities
from common.decorators import merge_delay_run from common.decorators import merge_delay_run
from common.signals import django_ready
from common.utils import get_logger, i18n_fmt from common.utils import get_logger, i18n_fmt
from common.utils.connection import RedisPubSub
from .exceptions import VaultException
from .models import Account, AccountTemplate from .models import Account, AccountTemplate
from .tasks.push_account import push_accounts_to_assets_task from .tasks.push_account import push_accounts_to_assets_task
@@ -23,9 +19,6 @@ logger = get_logger(__name__)
@receiver(pre_save, sender=Account) @receiver(pre_save, sender=Account)
def on_account_pre_save(sender, instance, **kwargs): def on_account_pre_save(sender, instance, **kwargs):
if getattr(instance, 'skip_history_when_saving', False):
return
if instance.version == 0: if instance.version == 0:
instance.version = 1 instance.version = 1
else: else:
@@ -69,7 +62,7 @@ def create_accounts_activities(account, action='create'):
@receiver(post_save, sender=Account) @receiver(post_save, sender=Account)
def on_account_create_by_template(sender, instance, created=False, **kwargs): def on_account_create_by_template(sender, instance, created=False, **kwargs):
if not created: if not created or instance.source != Source.TEMPLATE:
return return
push_accounts_if_need.delay(accounts=(instance,)) push_accounts_if_need.delay(accounts=(instance,))
create_accounts_activities(instance, action='create') create_accounts_activities(instance, action='create')
@@ -85,39 +78,16 @@ class VaultSignalHandler(object):
@staticmethod @staticmethod
def save_to_vault(sender, instance, created, **kwargs): def save_to_vault(sender, instance, created, **kwargs):
try: if created:
if created: vault_client.create(instance)
vault_client.create(instance) else:
else: vault_client.update(instance)
vault_client.update(instance)
except Exception as e:
logger.error('Vault save failed: {}'.format(e))
raise VaultException()
@staticmethod @staticmethod
def delete_to_vault(sender, instance, **kwargs): def delete_to_vault(sender, instance, **kwargs):
try: vault_client.delete(instance)
vault_client.delete(instance)
except Exception as e:
logger.error('Vault delete failed: {}'.format(e))
raise VaultException()
for model in (Account, AccountTemplate, Account.history.model): for model in (Account, AccountTemplate, Account.history.model):
post_save.connect(VaultSignalHandler.save_to_vault, sender=model) post_save.connect(VaultSignalHandler.save_to_vault, sender=model)
post_delete.connect(VaultSignalHandler.delete_to_vault, sender=model) post_delete.connect(VaultSignalHandler.delete_to_vault, sender=model)
class VaultPubSub(LazyObject):
def _setup(self):
self._wrapped = RedisPubSub('refresh_vault')
vault_pub_sub = VaultPubSub()
@receiver(django_ready)
def subscribe_vault_change(sender, **kwargs):
logger.debug("Start subscribe vault change")
vault_pub_sub.subscribe(lambda name: refresh_vault_client())

View File

@@ -5,7 +5,6 @@ from celery import shared_task
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from accounts.backends import vault_client from accounts.backends import vault_client
from accounts.const import VaultTypeChoices
from accounts.models import Account, AccountTemplate from accounts.models import Account, AccountTemplate
from common.utils import get_logger from common.utils import get_logger
from orgs.utils import tmp_to_root_org from orgs.utils import tmp_to_root_org
@@ -40,9 +39,6 @@ def sync_secret_to_vault():
# 这里不能判断 settings.VAULT_ENABLED, 必须判断当前 vault_client 的类型 # 这里不能判断 settings.VAULT_ENABLED, 必须判断当前 vault_client 的类型
print('\033[35m>>> 当前 Vault 功能未开启, 不需要同步') print('\033[35m>>> 当前 Vault 功能未开启, 不需要同步')
return return
if VaultTypeChoices.local == vault_client.type:
print('\033[31m>>> 当前第三方 Vault 客户端初始化失败,数据存储在本地数据库')
return
failed, skipped, succeeded = 0, 0, 0 failed, skipped, succeeded = 0, 0, 0
to_sync_models = [Account, AccountTemplate, Account.history.model] to_sync_models = [Account, AccountTemplate, Account.history.model]
@@ -52,8 +48,7 @@ def sync_secret_to_vault():
for model in to_sync_models: for model in to_sync_models:
instances += list(model.objects.all()) instances += list(model.objects.all())
max_workers = 1 if VaultTypeChoices.azure == vault_client.type else 10 with ThreadPoolExecutor(max_workers=10) as executor:
with ThreadPoolExecutor(max_workers=max_workers) as executor:
tasks = [executor.submit(sync_instance, instance) for instance in instances] tasks = [executor.submit(sync_instance, instance) for instance in instances]
for future in as_completed(tasks): for future in as_completed(tasks):

View File

@@ -9,5 +9,3 @@ class ActionChoices(models.TextChoices):
warning = 'warning', _('Warn') warning = 'warning', _('Warn')
notice = 'notice', _('Notify') notice = 'notice', _('Notify')
notify_and_warn = 'notify_and_warn', _('Notify and warn') notify_and_warn = 'notify_and_warn', _('Notify and warn')
face_verify = 'face_verify', _('Face Verify')
face_online = 'face_online', _('Face Online')

View File

@@ -70,13 +70,6 @@ class ActionAclSerializer(serializers.Serializer):
return return
if not settings.XPACK_LICENSE_IS_VALID: if not settings.XPACK_LICENSE_IS_VALID:
field_action._choices.pop(ActionChoices.review, None) field_action._choices.pop(ActionChoices.review, None)
if not (
settings.XPACK_LICENSE_IS_VALID and
settings.XPACK_LICENSE_EDITION_ULTIMATE and
settings.FACE_RECOGNITION_ENABLED
):
field_action._choices.pop(ActionChoices.face_verify, None)
field_action._choices.pop(ActionChoices.face_online, None)
for choice in self.Meta.action_choices_exclude: for choice in self.Meta.action_choices_exclude:
field_action._choices.pop(choice, None) field_action._choices.pop(choice, None)

View File

@@ -32,9 +32,7 @@ class CommandFilterACLSerializer(BaseSerializer, BulkOrgResourceModelSerializer)
class Meta(BaseSerializer.Meta): class Meta(BaseSerializer.Meta):
model = CommandFilterACL model = CommandFilterACL
fields = BaseSerializer.Meta.fields + ['command_groups'] fields = BaseSerializer.Meta.fields + ['command_groups']
action_choices_exclude = [ActionChoices.notice, action_choices_exclude = [ActionChoices.notice]
ActionChoices.face_verify,
ActionChoices.face_online]
class CommandReviewSerializer(serializers.Serializer): class CommandReviewSerializer(serializers.Serializer):

View File

@@ -4,7 +4,6 @@ from common.serializers import MethodSerializer
from orgs.mixins.serializers import BulkOrgResourceModelSerializer from orgs.mixins.serializers import BulkOrgResourceModelSerializer
from .base import BaseUserACLSerializer from .base import BaseUserACLSerializer
from .rules import RuleSerializer from .rules import RuleSerializer
from ..const import ActionChoices
from ..models import LoginACL from ..models import LoginACL
__all__ = ["LoginACLSerializer"] __all__ = ["LoginACLSerializer"]
@@ -18,7 +17,6 @@ class LoginACLSerializer(BaseUserACLSerializer, BulkOrgResourceModelSerializer):
class Meta(BaseUserACLSerializer.Meta): class Meta(BaseUserACLSerializer.Meta):
model = LoginACL model = LoginACL
fields = BaseUserACLSerializer.Meta.fields + ['rules', ] fields = BaseUserACLSerializer.Meta.fields + ['rules', ]
action_choices_exclude = [ActionChoices.face_online, ActionChoices.face_verify]
def get_rules_serializer(self): def get_rules_serializer(self):
return RuleSerializer() return RuleSerializer()

View File

@@ -123,10 +123,6 @@ class AssetViewSet(SuggestionMixin, OrgBulkModelViewSet):
NodeFilterBackend, AttrRulesFilterBackend NodeFilterBackend, AttrRulesFilterBackend
] ]
def perform_destroy(self, instance):
instance.accounts.update(su_from_id=None)
instance.delete()
def get_queryset(self): def get_queryset(self):
queryset = super().get_queryset() queryset = super().get_queryset()
if queryset.model is not Asset: if queryset.model is not Asset:

View File

@@ -1,10 +1,10 @@
from django.db.models import Subquery, OuterRef, Count, Value from django.db.models import Count
from django.db.models.functions import Coalesce
from django_filters import rest_framework as filters from django_filters import rest_framework as filters
from rest_framework import generics from rest_framework import generics
from rest_framework import serializers from rest_framework import serializers
from rest_framework.decorators import action from rest_framework.decorators import action
from rest_framework.response import Response from rest_framework.response import Response
from assets.const import AllTypes from assets.const import AllTypes
from assets.models import Platform, Node, Asset, PlatformProtocol from assets.models import Platform, Node, Asset, PlatformProtocol
from assets.serializers import PlatformSerializer, PlatformProtocolSerializer, PlatformListSerializer from assets.serializers import PlatformSerializer, PlatformProtocolSerializer, PlatformListSerializer
@@ -42,10 +42,7 @@ class AssetPlatformViewSet(JMSModelViewSet):
def get_queryset(self): def get_queryset(self):
# 因为没有走分页逻辑,所以需要这里 prefetch # 因为没有走分页逻辑,所以需要这里 prefetch
asset_count_subquery = Asset.objects.filter(platform=OuterRef('pk')).values('platform').annotate( queryset = super().get_queryset().annotate(assets_amount=Count('assets')).prefetch_related(
count=Count('id')).values('count')
queryset = super().get_queryset().annotate(
assets_amount=Coalesce(Subquery(asset_count_subquery), Value(0))).prefetch_related(
'protocols', 'automation', 'labels', 'labels__label' 'protocols', 'automation', 'labels', 'labels__label'
) )
queryset = queryset.filter(type__in=AllTypes.get_types_values()) queryset = queryset.filter(type__in=AllTypes.get_types_values())

View File

@@ -3,7 +3,6 @@ from collections import defaultdict
from copy import deepcopy from copy import deepcopy
from django.conf import settings from django.conf import settings
from django.utils.functional import lazy
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from common.db.models import ChoicesMixin from common.db.models import ChoicesMixin
@@ -30,15 +29,15 @@ class AllTypes(ChoicesMixin):
@classmethod @classmethod
def choices(cls): def choices(cls):
return lazy(cls.get_choices, list)()
@classmethod
def get_choices(cls):
choices = [] choices = []
for tp in cls.includes: for tp in cls.includes:
choices.extend(tp.get_choices()) choices.extend(tp.get_choices())
return choices return choices
@classmethod
def get_choices(cls):
return cls.choices()
@classmethod @classmethod
def filter_choices(cls, category): def filter_choices(cls, category):
choices = dict(cls.category_types()).get(category, cls).get_choices() choices = dict(cls.category_types()).get(category, cls).get_choices()

View File

@@ -10,11 +10,7 @@ from assets.tasks import execute_asset_automation_task
from common.const.choices import Trigger from common.const.choices import Trigger
from common.db.fields import EncryptJsonDictTextField from common.db.fields import EncryptJsonDictTextField
from ops.mixin import PeriodTaskModelMixin from ops.mixin import PeriodTaskModelMixin
from orgs.mixins.models import OrgModelMixin, JMSOrgBaseModel, OrgManager from orgs.mixins.models import OrgModelMixin, JMSOrgBaseModel
class BaseAutomationManager(OrgManager):
pass
class BaseAutomation(PeriodTaskModelMixin, JMSOrgBaseModel): class BaseAutomation(PeriodTaskModelMixin, JMSOrgBaseModel):
@@ -25,8 +21,6 @@ class BaseAutomation(PeriodTaskModelMixin, JMSOrgBaseModel):
is_active = models.BooleanField(default=True, verbose_name=_("Is active")) is_active = models.BooleanField(default=True, verbose_name=_("Is active"))
params = models.JSONField(default=dict, verbose_name=_("Parameters")) params = models.JSONField(default=dict, verbose_name=_("Parameters"))
objects = BaseAutomationManager.from_queryset(models.QuerySet)()
def __str__(self): def __str__(self):
return self.name + '@' + str(self.created_by) return self.name + '@' + str(self.created_by)

View File

@@ -27,7 +27,7 @@ class DomainSerializer(ResourceLabelsMixin, BulkOrgResourceModelSerializer):
model = Domain model = Domain
fields_mini = ['id', 'name'] fields_mini = ['id', 'name']
fields_small = fields_mini + ['comment'] fields_small = fields_mini + ['comment']
fields_m2m = ['assets', 'gateways', 'labels', 'assets_amount'] fields_m2m = ['assets', 'gateways', 'assets_amount']
read_only_fields = ['date_created'] read_only_fields = ['date_created']
fields = fields_small + fields_m2m + read_only_fields fields = fields_small + fields_m2m + read_only_fields
extra_kwargs = { extra_kwargs = {

View File

@@ -7,7 +7,6 @@ from django.db.models import F, Value, CharField, Q
from django.db.models.functions import Cast from django.db.models.functions import Cast
from django.http import HttpResponse, FileResponse from django.http import HttpResponse, FileResponse
from django.utils.encoding import escape_uri_path from django.utils.encoding import escape_uri_path
from django_celery_beat.models import PeriodicTask
from rest_framework import generics from rest_framework import generics
from rest_framework import status from rest_framework import status
from rest_framework import viewsets from rest_framework import viewsets
@@ -23,9 +22,6 @@ from common.plugins.es import QuerySet as ESQuerySet
from common.sessions.cache import user_session_manager from common.sessions.cache import user_session_manager
from common.storage.ftp_file import FTPFileStorageHandler from common.storage.ftp_file import FTPFileStorageHandler
from common.utils import is_uuid, get_logger, lazyproperty from common.utils import is_uuid, get_logger, lazyproperty
from ops.const import Types
from ops.models import Job
from ops.serializers.job import JobSerializer
from orgs.mixins.api import OrgReadonlyModelViewSet, OrgModelViewSet from orgs.mixins.api import OrgReadonlyModelViewSet, OrgModelViewSet
from orgs.models import Organization from orgs.models import Organization
from orgs.utils import current_org, tmp_to_root_org from orgs.utils import current_org, tmp_to_root_org
@@ -43,14 +39,14 @@ from .serializers import (
FTPLogSerializer, UserLoginLogSerializer, JobLogSerializer, FTPLogSerializer, UserLoginLogSerializer, JobLogSerializer,
OperateLogSerializer, OperateLogActionDetailSerializer, OperateLogSerializer, OperateLogActionDetailSerializer,
PasswordChangeLogSerializer, ActivityUnionLogSerializer, PasswordChangeLogSerializer, ActivityUnionLogSerializer,
FileSerializer, UserSessionSerializer, JobsAuditSerializer FileSerializer, UserSessionSerializer
) )
from .utils import construct_userlogin_usernames from .utils import construct_userlogin_usernames
logger = get_logger(__name__) logger = get_logger(__name__)
class JobLogAuditViewSet(OrgReadonlyModelViewSet): class JobAuditViewSet(OrgReadonlyModelViewSet):
model = JobLog model = JobLog
extra_filter_backends = [DatetimeRangeFilterBackend] extra_filter_backends = [DatetimeRangeFilterBackend]
date_range_filter_fields = [ date_range_filter_fields = [
@@ -62,35 +58,6 @@ class JobLogAuditViewSet(OrgReadonlyModelViewSet):
ordering = ['-date_start'] ordering = ['-date_start']
class JobsAuditViewSet(OrgModelViewSet):
model = Job
search_fields = ['creator__name']
filterset_fields = ['creator__name']
serializer_class = JobsAuditSerializer
ordering = ['-is_periodic', '-date_updated']
http_method_names = ['get', 'options', 'patch']
def get_queryset(self):
queryset = super().get_queryset()
queryset = queryset.exclude(type=Types.upload_file).filter(instant=False)
return queryset
def perform_update(self, serializer):
job = self.get_object()
is_periodic = serializer.validated_data.get('is_periodic')
if job.is_periodic != is_periodic:
job.is_periodic = is_periodic
job.save()
name, task, args, kwargs = job.get_register_task()
task_obj = PeriodicTask.objects.filter(name=name).first()
if task_obj:
is_periodic = job.is_periodic
if task_obj.enabled != is_periodic:
task_obj.enabled = is_periodic
task_obj.save()
return super().perform_update(serializer)
class FTPLogViewSet(OrgModelViewSet): class FTPLogViewSet(OrgModelViewSet):
model = FTPLog model = FTPLog
serializer_class = FTPLogSerializer serializer_class = FTPLogSerializer
@@ -222,13 +189,9 @@ class ResourceActivityAPIView(generics.ListAPIView):
'id', 'datetime', 'r_detail', 'r_detail_id', 'id', 'datetime', 'r_detail', 'r_detail_id',
'r_user', 'r_action', 'r_type' 'r_user', 'r_action', 'r_type'
) )
org_q = Q(org_id=Organization.SYSTEM_ID) | Q(org_id=current_org.id)
org_q = Q() if resource_id:
if not current_org.is_root(): org_q |= Q(org_id='') | Q(org_id=Organization.ROOT_ID)
org_q = Q(org_id=Organization.SYSTEM_ID) | Q(org_id=current_org.id)
if resource_id:
org_q |= Q(org_id='') | Q(org_id=Organization.ROOT_ID)
with tmp_to_root_org(): with tmp_to_root_org():
qs1 = self.get_operate_log_qs(fields, limit, org_q, resource_id=resource_id) qs1 = self.get_operate_log_qs(fields, limit, org_q, resource_id=resource_id)
qs2 = self.get_activity_log_qs(fields, limit, org_q, resource_id=resource_id) qs2 = self.get_activity_log_qs(fields, limit, org_q, resource_id=resource_id)

View File

@@ -7,7 +7,7 @@ from audits.backends.db import OperateLogStore
from common.serializers.fields import LabeledChoiceField, ObjectRelatedField from common.serializers.fields import LabeledChoiceField, ObjectRelatedField
from common.utils import reverse, i18n_trans from common.utils import reverse, i18n_trans
from common.utils.timezone import as_current_tz from common.utils.timezone import as_current_tz
from ops.serializers.job import JobExecutionSerializer, JobSerializer from ops.serializers.job import JobExecutionSerializer
from orgs.mixins.serializers import BulkOrgResourceModelSerializer from orgs.mixins.serializers import BulkOrgResourceModelSerializer
from terminal.models import Session from terminal.models import Session
from users.models import User from users.models import User
@@ -34,30 +34,6 @@ class JobLogSerializer(JobExecutionSerializer):
} }
class JobsAuditSerializer(JobSerializer):
material = serializers.ReadOnlyField(label=_("Command"))
summary = serializers.ReadOnlyField(label=_("Summary"))
crontab = serializers.ReadOnlyField(label=_("Execution cycle"))
is_periodic_display = serializers.BooleanField(read_only=True, source='is_periodic')
class Meta(JobSerializer.Meta):
read_only_fields = [
"id", 'name', 'args', 'material', 'type', 'crontab', 'interval', 'date_last_run', 'summary', 'created_by',
'is_periodic_display'
]
fields = read_only_fields + ['is_periodic']
def validate(self, attrs):
allowed_fields = {'is_periodic'}
submitted_fields = set(attrs.keys())
invalid_fields = submitted_fields - allowed_fields
if invalid_fields:
raise serializers.ValidationError(
f"Updating {', '.join(invalid_fields)} fields is not allowed"
)
return attrs
class FTPLogSerializer(serializers.ModelSerializer): class FTPLogSerializer(serializers.ModelSerializer):
operate = LabeledChoiceField(choices=OperateChoices.choices, label=_("Operate")) operate = LabeledChoiceField(choices=OperateChoices.choices, label=_("Operate"))

View File

@@ -14,7 +14,7 @@ from audits.handler import (
create_or_update_operate_log, get_instance_dict_from_cache create_or_update_operate_log, get_instance_dict_from_cache
) )
from audits.utils import model_to_dict_for_operate_log as model_to_dict from audits.utils import model_to_dict_for_operate_log as model_to_dict
from common.const.signals import POST_ADD, POST_REMOVE, POST_CLEAR, OP_LOG_SKIP_SIGNAL from common.const.signals import POST_ADD, POST_REMOVE, POST_CLEAR, SKIP_SIGNAL
from common.signals import django_ready from common.signals import django_ready
from jumpserver.utils import current_request from jumpserver.utils import current_request
from ..const import MODELS_NEED_RECORD, ActionChoices from ..const import MODELS_NEED_RECORD, ActionChoices
@@ -77,7 +77,7 @@ def signal_of_operate_log_whether_continue(
condition = True condition = True
if not instance: if not instance:
condition = False condition = False
if instance and getattr(instance, OP_LOG_SKIP_SIGNAL, False): if instance and getattr(instance, SKIP_SIGNAL, False):
condition = False condition = False
# 不记录组件的操作日志 # 不记录组件的操作日志
user = current_request.user if current_request else None user = current_request.user if current_request else None
@@ -187,7 +187,7 @@ def on_django_start_set_operate_log_monitor_models(sender, **kwargs):
'PermedAsset', 'PermedAccount', 'MenuPermission', 'PermedAsset', 'PermedAccount', 'MenuPermission',
'Permission', 'TicketSession', 'ApplyLoginTicket', 'Permission', 'TicketSession', 'ApplyLoginTicket',
'ApplyCommandTicket', 'ApplyLoginAssetTicket', 'ApplyCommandTicket', 'ApplyLoginAssetTicket',
'FavoriteAsset', 'ChangeSecretRecord', 'AppProvider', 'Variable' 'FavoriteAsset', 'ChangeSecretRecord', 'AppProvider',
} }
include_models = {'UserSession'} include_models = {'UserSession'}
for i, app in enumerate(apps.get_models(), 1): for i, app in enumerate(apps.get_models(), 1):

View File

@@ -13,9 +13,7 @@ router.register(r'ftp-logs', api.FTPLogViewSet, 'ftp-log')
router.register(r'login-logs', api.UserLoginLogViewSet, 'login-log') router.register(r'login-logs', api.UserLoginLogViewSet, 'login-log')
router.register(r'operate-logs', api.OperateLogViewSet, 'operate-log') router.register(r'operate-logs', api.OperateLogViewSet, 'operate-log')
router.register(r'password-change-logs', api.PasswordChangeLogViewSet, 'password-change-log') router.register(r'password-change-logs', api.PasswordChangeLogViewSet, 'password-change-log')
router.register(r'job-logs', api.JobLogAuditViewSet, 'job-log') router.register(r'job-logs', api.JobAuditViewSet, 'job-log')
router.register(r'jobs', api.JobsAuditViewSet, 'job')
router.register(r'my-login-logs', api.MyLoginLogViewSet, 'my-login-log') router.register(r'my-login-logs', api.MyLoginLogViewSet, 'my-login-log')
router.register(r'user-sessions', api.UserSessionViewSet, 'user-session') router.register(r'user-sessions', api.UserSessionViewSet, 'user-session')

View File

@@ -15,4 +15,3 @@ from .ssh_key import *
from .sso import * from .sso import *
from .temp_token import * from .temp_token import *
from .token import * from .token import *
from .face import *

View File

@@ -24,13 +24,11 @@ from common.utils.http import is_true, is_false
from orgs.mixins.api import RootOrgViewMixin from orgs.mixins.api import RootOrgViewMixin
from orgs.utils import tmp_to_org from orgs.utils import tmp_to_org
from perms.models import ActionChoices from perms.models import ActionChoices
from terminal.connect_methods import NativeClient, ConnectMethodUtil, WebMethod from terminal.connect_methods import NativeClient, ConnectMethodUtil
from terminal.models import EndpointRule, Endpoint from terminal.models import EndpointRule, Endpoint
from users.const import FileNameConflictResolution from users.const import FileNameConflictResolution
from users.const import RDPSmartSize, RDPColorQuality from users.const import RDPSmartSize, RDPColorQuality
from users.models import Preference from users.models import Preference
from .face import FaceMonitorContext
from ..mixins import AuthFaceMixin
from ..models import ConnectionToken, date_expired_default from ..models import ConnectionToken, date_expired_default
from ..serializers import ( from ..serializers import (
ConnectionTokenSerializer, ConnectionTokenSecretSerializer, ConnectionTokenSerializer, ConnectionTokenSecretSerializer,
@@ -69,36 +67,6 @@ class RDPFileClientProtocolURLMixin:
'bookmarktype:i': '3', 'bookmarktype:i': '3',
'use redirection server name:i': '0', 'use redirection server name:i': '0',
} }
# copy from
# https://learn.microsoft.com/zh-cn/windows-server/administration/performance-tuning/role/remote-desktop/session-hosts
rdp_low_speed_broadband_option = {
"connection type:i": 2,
"disable wallpaper:i": 1,
"bitmapcachepersistenable:i": 1,
"disable full window drag:i": 1,
"disable menu anims:i": 1,
"allow font smoothing:i": 0,
"allow desktop composition:i": 0,
"disable themes:i": 0
}
rdp_high_speed_broadband_option = {
"connection type:i": 4,
"disable wallpaper:i": 0,
"bitmapcachepersistenable:i": 1,
"disable full window drag:i": 1,
"disable menu anims:i": 0,
"allow font smoothing:i": 0,
"allow desktop composition:i": 1,
"disable themes:i": 0
}
RDP_CONNECTION_SPEED_OPTION_MAP = {
"auto": {},
"low_speed_broadband": rdp_low_speed_broadband_option,
"high_speed_broadband": rdp_high_speed_broadband_option,
}
# 设置多屏显示 # 设置多屏显示
multi_mon = is_true(self.request.query_params.get('multi_mon')) multi_mon = is_true(self.request.query_params.get('multi_mon'))
if multi_mon: if multi_mon:
@@ -123,15 +91,13 @@ class RDPFileClientProtocolURLMixin:
# rdp_options['domain:s'] = token.account_ad_domain # rdp_options['domain:s'] = token.account_ad_domain
# 设置宽高 # 设置宽高
height = self.request.query_params.get('height')
resolution_value = token.connect_options.get('resolution', 'auto') width = self.request.query_params.get('width')
if resolution_value != 'auto': if width and height:
width, height = resolution_value.split('x') rdp_options['desktopwidth:i'] = width
if width and height: rdp_options['desktopheight:i'] = height
rdp_options['desktopwidth:i'] = width rdp_options['winposstr:s'] = f'0,1,0,0,{width},{height}'
rdp_options['desktopheight:i'] = height rdp_options['dynamic resolution:i'] = '0'
rdp_options['winposstr:s'] = f'0,1,0,0,{width},{height}'
rdp_options['dynamic resolution:i'] = '0'
color_quality = self.request.query_params.get('rdp_color_quality') color_quality = self.request.query_params.get('rdp_color_quality')
color_quality = color_quality if color_quality else os.getenv('JUMPSERVER_COLOR_DEPTH', RDPColorQuality.HIGH) color_quality = color_quality if color_quality else os.getenv('JUMPSERVER_COLOR_DEPTH', RDPColorQuality.HIGH)
@@ -149,8 +115,6 @@ class RDPFileClientProtocolURLMixin:
rdp = token.asset.platform.protocols.filter(name='rdp').first() rdp = token.asset.platform.protocols.filter(name='rdp').first()
if rdp and rdp.setting.get('console'): if rdp and rdp.setting.get('console'):
rdp_options['administrative session:i'] = '1' rdp_options['administrative session:i'] = '1'
rdp_connection_speed = token.connect_options.get('rdp_connection_speed', 'auto')
rdp_options.update(RDP_CONNECTION_SPEED_OPTION_MAP.get(rdp_connection_speed, {}))
# 文件名 # 文件名
name = token.asset.name name = token.asset.name
@@ -257,8 +221,6 @@ class ExtraActionApiMixin(RDPFileClientProtocolURLMixin):
get_serializer: callable get_serializer: callable
perform_create: callable perform_create: callable
validate_exchange_token: callable validate_exchange_token: callable
need_face_verify: bool
create_face_verify: callable
@action(methods=['POST', 'GET'], detail=True, url_path='rdp-file') @action(methods=['POST', 'GET'], detail=True, url_path='rdp-file')
def get_rdp_file(self, request, *args, **kwargs): def get_rdp_file(self, request, *args, **kwargs):
@@ -318,13 +280,10 @@ class ExtraActionApiMixin(RDPFileClientProtocolURLMixin):
instance.date_expired = date_expired_default() instance.date_expired = date_expired_default()
instance.save() instance.save()
serializer = self.get_serializer(instance) serializer = self.get_serializer(instance)
response = Response(serializer.data, status=status.HTTP_201_CREATED) return Response(serializer.data, status=status.HTTP_201_CREATED)
if self.need_face_verify:
self.create_face_verify(response)
return response
class ConnectionTokenViewSet(AuthFaceMixin, ExtraActionApiMixin, RootOrgViewMixin, JMSModelViewSet): class ConnectionTokenViewSet(ExtraActionApiMixin, RootOrgViewMixin, JMSModelViewSet):
filterset_fields = ( filterset_fields = (
'user_display', 'asset_display' 'user_display', 'asset_display'
) )
@@ -345,8 +304,6 @@ class ConnectionTokenViewSet(AuthFaceMixin, ExtraActionApiMixin, RootOrgViewMixi
'get_client_protocol_url': 'authentication.add_connectiontoken', 'get_client_protocol_url': 'authentication.add_connectiontoken',
} }
input_username = '' input_username = ''
need_face_verify = False
face_monitor_token = ''
def get_queryset(self): def get_queryset(self):
queryset = ConnectionToken.objects \ queryset = ConnectionToken.objects \
@@ -398,9 +355,8 @@ class ConnectionTokenViewSet(AuthFaceMixin, ExtraActionApiMixin, RootOrgViewMixi
asset = data.get('asset') asset = data.get('asset')
account_name = data.get('account') account_name = data.get('account')
protocol = data.get('protocol') protocol = data.get('protocol')
connect_method = data.get('connect_method')
self.input_username = self.get_input_username(data) self.input_username = self.get_input_username(data)
_data = self._validate(user, asset, account_name, protocol, connect_method) _data = self._validate(user, asset, account_name, protocol)
data.update(_data) data.update(_data)
return serializer return serializer
@@ -408,12 +364,12 @@ class ConnectionTokenViewSet(AuthFaceMixin, ExtraActionApiMixin, RootOrgViewMixi
user = token.user user = token.user
asset = token.asset asset = token.asset
account_name = token.account account_name = token.account
_data = self._validate(user, asset, account_name, token.protocol, token.connect_method) _data = self._validate(user, asset, account_name, token.protocol)
for k, v in _data.items(): for k, v in _data.items():
setattr(token, k, v) setattr(token, k, v)
return token return token
def _validate(self, user, asset, account_name, protocol, connect_method): def _validate(self, user, asset, account_name, protocol):
data = dict() data = dict()
data['org_id'] = asset.org_id data['org_id'] = asset.org_id
data['user'] = user data['user'] = user
@@ -429,16 +385,10 @@ class ConnectionTokenViewSet(AuthFaceMixin, ExtraActionApiMixin, RootOrgViewMixi
if account.username != AliasAccount.INPUT: if account.username != AliasAccount.INPUT:
data['input_username'] = '' data['input_username'] = ''
ticket = self._validate_acl(user, asset, account, connect_method) ticket = self._validate_acl(user, asset, account)
if ticket: if ticket:
data['from_ticket'] = ticket data['from_ticket'] = ticket
if ticket or self.need_face_verify:
data['is_active'] = False data['is_active'] = False
if self.face_monitor_token:
FaceMonitorContext.get_or_create_context(self.face_monitor_token,
self.request.user.id)
data['face_monitor_token'] = self.face_monitor_token
return data return data
@staticmethod @staticmethod
@@ -467,7 +417,7 @@ class ConnectionTokenViewSet(AuthFaceMixin, ExtraActionApiMixin, RootOrgViewMixi
after=after, object_name=object_name after=after, object_name=object_name
) )
def _validate_acl(self, user, asset, account, connect_method): def _validate_acl(self, user, asset, account):
from acls.models import LoginAssetACL from acls.models import LoginAssetACL
kwargs = {'user': user, 'asset': asset, 'account': account} kwargs = {'user': user, 'asset': asset, 'account': account}
if account.username == AliasAccount.INPUT: if account.username == AliasAccount.INPUT:
@@ -494,26 +444,6 @@ class ConnectionTokenViewSet(AuthFaceMixin, ExtraActionApiMixin, RootOrgViewMixi
assignees=acl.reviewers.all(), org_id=asset.org_id assignees=acl.reviewers.all(), org_id=asset.org_id
) )
return ticket return ticket
if acl.is_action(acl.ActionChoices.face_verify):
if not self.request.query_params.get('face_verify'):
msg = _('ACL action is face verify')
raise JMSException(code='acl_face_verify', detail=msg)
self.need_face_verify = True
if acl.is_action(acl.ActionChoices.face_online):
if connect_method not in [WebMethod.web_cli, WebMethod.web_gui]:
msg = _('ACL action not supported for this asset')
raise JMSException(detail=msg, code='acl_face_online_not_supported')
face_verify = self.request.query_params.get('face_verify')
face_monitor_token = self.request.query_params.get('face_monitor_token')
if not face_verify or not face_monitor_token:
msg = _('ACL action is face online')
raise JMSException(code='acl_face_online', detail=msg)
self.need_face_verify = True
self.face_monitor_token = face_monitor_token
if acl.is_action(acl.ActionChoices.notice): if acl.is_action(acl.ActionChoices.notice):
reviewers = acl.reviewers.all() reviewers = acl.reviewers.all()
if not reviewers: if not reviewers:
@@ -525,22 +455,9 @@ class ConnectionTokenViewSet(AuthFaceMixin, ExtraActionApiMixin, RootOrgViewMixi
reviewer, asset, user, account, self.input_username reviewer, asset, user, account, self.input_username
).publish_async() ).publish_async()
def create_face_verify(self, response):
if not self.request.user.face_vector:
raise JMSException(code='no_face_feature', detail=_('No available face feature'))
connection_token_id = response.data.get('id')
context_data = {
"action": "login_asset",
"connection_token_id": connection_token_id,
}
face_verify_token = self.create_face_verify_context(context_data)
response.data['face_token'] = face_verify_token
def create(self, request, *args, **kwargs): def create(self, request, *args, **kwargs):
try: try:
response = super().create(request, *args, **kwargs) response = super().create(request, *args, **kwargs)
if self.need_face_verify:
self.create_face_verify(response)
except JMSException as e: except JMSException as e:
data = {'code': e.detail.code, 'detail': e.detail} data = {'code': e.detail.code, 'detail': e.detail}
return Response(data, status=e.status_code) return Response(data, status=e.status_code)

View File

@@ -1,256 +0,0 @@
from django.core.cache import cache
from django.utils.translation import gettext as _
from rest_framework.generics import CreateAPIView, RetrieveAPIView
from rest_framework.response import Response
from rest_framework.serializers import ValidationError
from rest_framework.permissions import AllowAny
from rest_framework.exceptions import NotFound
from common.permissions import IsServiceAccount
from common.utils import get_logger, get_object_or_none
from orgs.utils import tmp_to_root_org
from terminal.api.session.task import create_sessions_tasks
from users.models import User
from .. import serializers
from ..mixins import AuthMixin
from ..const import FACE_CONTEXT_CACHE_KEY_PREFIX, FACE_SESSION_KEY, FACE_CONTEXT_CACHE_TTL, FaceMonitorActionChoices
from ..models import ConnectionToken
from ..serializers.face import FaceMonitorCallbackSerializer, FaceMonitorContextSerializer
logger = get_logger(__name__)
__all__ = [
'FaceCallbackApi',
'FaceContextApi',
'FaceMonitorContext',
'FaceMonitorContextApi',
'FaceMonitorCallbackApi'
]
class FaceCallbackApi(AuthMixin, CreateAPIView):
permission_classes = (IsServiceAccount,)
serializer_class = serializers.FaceCallbackSerializer
def perform_create(self, serializer):
token = serializer.validated_data.get('token')
context = self._get_context_from_cache(token)
if not serializer.validated_data.get('success', False):
self._update_context_with_error(
context,
serializer.validated_data.get('error_message', 'Unknown error')
)
return Response(status=200)
face_code = serializer.validated_data.get('face_code')
if not face_code:
self._update_context_with_error(context, "missing field 'face_code'")
raise ValidationError({'error': "missing field 'face_code'"})
try:
self._handle_success(context, face_code)
except Exception as e:
self._update_context_with_error(context, str(e))
return Response(status=200)
@staticmethod
def get_face_cache_key(token):
return f"{FACE_CONTEXT_CACHE_KEY_PREFIX}_{token}"
def _get_context_from_cache(self, token):
cache_key = self.get_face_cache_key(token)
context = cache.get(cache_key)
if not context:
raise ValidationError({'error': "token not exists or expired"})
return context
def _update_context_with_error(self, context, error_message):
context.update({
'is_finished': True,
'success': False,
'error_message': error_message,
})
self._update_cache(context)
def _update_cache(self, context):
cache_key = self.get_face_cache_key(context['token'])
cache.set(cache_key, context, FACE_CONTEXT_CACHE_TTL)
def _handle_success(self, context, face_code):
context.update({
'is_finished': True,
'success': True,
'face_code': face_code
})
action = context.get('action', None)
if action == 'login_asset':
user_id = context.get('user_id')
user = User.objects.get(id=user_id)
if user.check_face(face_code):
with tmp_to_root_org():
connection_token_id = context.get('connection_token_id')
token = ConnectionToken.objects.filter(id=connection_token_id).first()
token.is_active = True
token.save()
else:
context.update({
'success': False,
'error_message': _('Facial comparison failed')
})
self._update_cache(context)
class FaceContextApi(AuthMixin, RetrieveAPIView, CreateAPIView):
permission_classes = (AllowAny,)
face_token_session_key = FACE_SESSION_KEY
@staticmethod
def get_face_cache_key(token):
return f"{FACE_CONTEXT_CACHE_KEY_PREFIX}_{token}"
def new_face_context(self):
return self.create_face_verify_context()
def post(self, request, *args, **kwargs):
token = self.new_face_context()
return Response({'token': token})
def get(self, request, *args, **kwargs):
token = self.request.session.get(self.face_token_session_key)
cache_key = self.get_face_cache_key(token)
context = cache.get(cache_key)
if not context:
raise NotFound({'error': "Token does not exist or has expired."})
return Response({
"is_finished": context.get('is_finished', False),
"success": context.get('success', False),
"error_message": _(context.get("error_message", ''))
})
class FaceMonitorContext:
def __init__(self, token, user_id, session_ids=None):
self.token = token
self.user_id = user_id
if session_ids is None:
self.session_ids = []
else:
self.session_ids = session_ids
@classmethod
def get_cache_key(cls, token):
return 'FACE_MONITOR_CONTEXT_{}'.format(token)
@classmethod
def get_or_create_context(cls, token, user_id):
context = cls.get(token)
if not context:
context = FaceMonitorContext(token=token,
user_id=user_id)
context.save()
return context
def add_session(self, session_id):
self.session_ids.append(session_id)
self.save()
@classmethod
def get(cls, token):
cache_key = cls.get_cache_key(token)
return cache.get(cache_key, None)
def save(self):
cache_key = self.get_cache_key(self.token)
cache.set(cache_key, self)
def close(self):
self.terminal_sessions()
self._destroy()
def _destroy(self):
cache_key = self.get_cache_key(self.token)
cache.delete(cache_key)
def pause_sessions(self):
self._send_task('lock_session')
def resume_sessions(self):
self._send_task('unlock_session')
def terminal_sessions(self):
self._send_task("kill_session")
def _send_task(self, task_name):
create_sessions_tasks(self.session_ids, 'facelive', task_name=task_name)
class FaceMonitorContextApi(CreateAPIView):
permission_classes = (IsServiceAccount,)
serializer_class = FaceMonitorContextSerializer
def perform_create(self, serializer):
face_monitor_token = serializer.validated_data.get('face_monitor_token')
session_id = serializer.validated_data.get('session_id')
context = FaceMonitorContext.get(face_monitor_token)
context.add_session(session_id)
def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
self.perform_create(serializer)
return Response(status=201)
class FaceMonitorCallbackApi(CreateAPIView):
permission_classes = (IsServiceAccount,)
serializer_class = FaceMonitorCallbackSerializer
def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
token = serializer.validated_data.get('token')
context = FaceMonitorContext.get(token=token)
is_finished = serializer.validated_data.get('is_finished')
if is_finished:
context.close()
return Response(status=200)
action = serializer.validated_data.get('action')
if action == FaceMonitorActionChoices.Verify:
user = get_object_or_none(User, pk=context.user_id)
face_codes = serializer.validated_data.get('face_codes')
if not user:
context.save()
return Response(data={'msg': 'user {} not found'
.format(context.user_id)}, status=400)
if not face_codes or not self._check_face_codes(face_codes, user):
context.save()
return Response(data={'msg': 'face codes not matched'}, status=400)
if action == FaceMonitorActionChoices.Pause:
context.pause_sessions()
if action == FaceMonitorActionChoices.Resume:
context.resume_sessions()
context.save()
return Response(status=200)
@staticmethod
def _check_face_codes(face_codes, user):
matched = False
for face_code in face_codes:
matched = user.check_face(face_code,
distance_threshold=0.45,
similarity_threshold=0.92)
if matched:
break
return matched

View File

@@ -1,9 +1,10 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from rest_framework import exceptions from rest_framework import exceptions
from rest_framework.generics import CreateAPIView, RetrieveAPIView from rest_framework.generics import CreateAPIView
from rest_framework.permissions import AllowAny from rest_framework.permissions import AllowAny
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.serializers import ValidationError from rest_framework.serializers import ValidationError
@@ -19,12 +20,10 @@ from ..mixins import AuthMixin
logger = get_logger(__name__) logger = get_logger(__name__)
__all__ = [ __all__ = [
'MFAChallengeVerifyApi', 'MFASendCodeApi', 'MFAChallengeVerifyApi', 'MFASendCodeApi'
] ]
# MFASelectAPi 原来的名字 # MFASelectAPi 原来的名字
class MFASendCodeApi(AuthMixin, CreateAPIView): class MFASendCodeApi(AuthMixin, CreateAPIView):
""" """

View File

@@ -1,6 +1,5 @@
import time import time
from django.conf import settings
from django.core.cache import cache from django.core.cache import cache
from django.http import HttpResponseRedirect from django.http import HttpResponseRedirect
from django.shortcuts import reverse from django.shortcuts import reverse
@@ -41,15 +40,12 @@ class UserResetPasswordSendCodeApi(CreateAPIView):
return user, None return user, None
@staticmethod @staticmethod
def safe_send_code(token, code, target, form_type, content, user_info): def safe_send_code(token, code, target, form_type, content):
token_sent_key = '{}_send_at'.format(token) token_sent_key = '{}_send_at'.format(token)
token_send_at = cache.get(token_sent_key, 0) token_send_at = cache.get(token_sent_key, 0)
if token_send_at: if token_send_at:
raise IntervalTooShort(60) raise IntervalTooShort(60)
tooler = SendAndVerifyCodeUtil( SendAndVerifyCodeUtil(target, code, backend=form_type, **content).gen_and_send_async()
target, code, backend=form_type, user_info=user_info, **content
)
tooler.gen_and_send_async()
cache.set(token_sent_key, int(time.time()), 60) cache.set(token_sent_key, int(time.time()), 60)
def prepare_code_data(self, user_info, serializer): def prepare_code_data(self, user_info, serializer):
@@ -65,7 +61,7 @@ class UserResetPasswordSendCodeApi(CreateAPIView):
if not user: if not user:
raise ValueError(err) raise ValueError(err)
code = random_string(settings.SMS_CODE_LENGTH, lower=False, upper=False) code = random_string(6, lower=False, upper=False)
subject = '%s: %s' % (get_login_title(), _('Forgot password')) subject = '%s: %s' % (get_login_title(), _('Forgot password'))
context = { context = {
'user': user, 'title': subject, 'code': code, 'user': user, 'title': subject, 'code': code,
@@ -86,7 +82,7 @@ class UserResetPasswordSendCodeApi(CreateAPIView):
code, target, form_type, content = self.prepare_code_data(user_info, serializer) code, target, form_type, content = self.prepare_code_data(user_info, serializer)
except ValueError as e: except ValueError as e:
return Response({'error': str(e)}, status=400) return Response({'error': str(e)}, status=400)
self.safe_send_code(token, code, target, form_type, content, user_info) self.safe_send_code(token, code, target, form_type, content)
return Response({'data': 'ok'}, status=200) return Response({'data': 'ok'}, status=200)

View File

@@ -23,9 +23,10 @@ class JMSBaseAuthBackend:
Reject users with is_valid=False. Custom user models that don't have Reject users with is_valid=False. Custom user models that don't have
that attribute are allowed. that attribute are allowed.
""" """
# 三方用户认证完成后,在后续的 get_user 获取逻辑中,也应该需要检查用户是否有效 # 在 check_user_auth 中进行了校验,可以返回对应的错误信息
is_valid = getattr(user, 'is_valid', None) # is_valid = getattr(user, 'is_valid', None)
return is_valid or is_valid is None # return is_valid or is_valid is None
return True
# allow user to authenticate # allow user to authenticate
def username_allow_authenticate(self, username): def username_allow_authenticate(self, username):
@@ -51,14 +52,6 @@ class JMSBaseAuthBackend:
logger.info(info) logger.info(info)
return allow return allow
def get_user(self, user_id):
""" 三方用户认证成功后 request.user 赋值时会调用 backend 的当前方法获取用户 """
try:
user = UserModel._default_manager.get(pk=user_id)
except UserModel.DoesNotExist:
return None
return user if self.user_can_authenticate(user) else None
class JMSModelBackend(JMSBaseAuthBackend, ModelBackend): class JMSModelBackend(JMSBaseAuthBackend, ModelBackend):
pass pass

View File

@@ -1,13 +1,12 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
import django_cas_ng.views
from django.urls import path from django.urls import path
import django_cas_ng.views
from .views import CASLoginView, CASCallbackClientView from .views import CASLoginView
urlpatterns = [ urlpatterns = [
path('login/', CASLoginView.as_view(), name='cas-login'), path('login/', CASLoginView.as_view(), name='cas-login'),
path('logout/', django_cas_ng.views.LogoutView.as_view(), name='cas-logout'), path('logout/', django_cas_ng.views.LogoutView.as_view(), name='cas-logout'),
path('callback/', django_cas_ng.views.CallbackView.as_view(), name='cas-proxy-callback'), path('callback/', django_cas_ng.views.CallbackView.as_view(), name='cas-proxy-callback'),
path('login/client', CASCallbackClientView.as_view(), name='cas-proxy-callback-client'),
] ]

View File

@@ -1,12 +1,9 @@
from django_cas_ng.views import LoginView
from django.core.exceptions import PermissionDenied from django.core.exceptions import PermissionDenied
from django.http import HttpResponseRedirect from django.http import HttpResponseRedirect
from django.views.generic import View
from django_cas_ng.views import LoginView
__all__ = ['LoginView'] __all__ = ['LoginView']
from authentication.views.utils import redirect_to_guard_view
class CASLoginView(LoginView): class CASLoginView(LoginView):
def get(self, request): def get(self, request):
@@ -16,8 +13,3 @@ class CASLoginView(LoginView):
return HttpResponseRedirect('/') return HttpResponseRedirect('/')
class CASCallbackClientView(View):
http_method_names = ['get', ]
def get(self, request):
return redirect_to_guard_view(query_string='next=client')

View File

@@ -5,7 +5,7 @@ from django.utils.translation import gettext_lazy as _
from authentication.signals import user_auth_failed, user_auth_success from authentication.signals import user_auth_failed, user_auth_success
from common.utils import get_logger from common.utils import get_logger
from .base import JMSBaseAuthBackend from .base import JMSModelBackend
logger = get_logger(__file__) logger = get_logger(__file__)
@@ -20,10 +20,9 @@ if settings.AUTH_CUSTOM:
logger.warning('Import custom auth method failed: {}, Maybe not enabled'.format(e)) logger.warning('Import custom auth method failed: {}, Maybe not enabled'.format(e))
class CustomAuthBackend(JMSBaseAuthBackend): class CustomAuthBackend(JMSModelBackend):
@staticmethod def is_enabled(self):
def is_enabled():
return settings.AUTH_CUSTOM and callable(custom_authenticate_method) return settings.AUTH_CUSTOM and callable(custom_authenticate_method)
@staticmethod @staticmethod
@@ -36,10 +35,10 @@ class CustomAuthBackend(JMSBaseAuthBackend):
) )
return user, created return user, created
def authenticate(self, request, username=None, password=None): def authenticate(self, request, username=None, password=None, **kwargs):
try: try:
userinfo: dict = custom_authenticate_method( userinfo: dict = custom_authenticate_method(
username=username, password=password username=username, password=password, **kwargs
) )
user, created = self.get_or_create_user_from_userinfo(userinfo) user, created = self.get_or_create_user_from_userinfo(userinfo)
except Exception as e: except Exception as e:

View File

@@ -3,9 +3,8 @@
import abc import abc
import ldap import ldap
from django.conf import settings from django.conf import settings
from django.core.cache import cache
from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist
from django_auth_ldap.backend import _LDAPUser, LDAPBackend, valid_cache_key from django_auth_ldap.backend import _LDAPUser, LDAPBackend
from django_auth_ldap.config import _LDAPConfig, LDAPSearch, LDAPSearchUnion from django_auth_ldap.config import _LDAPConfig, LDAPSearch, LDAPSearchUnion
from users.utils import construct_user_email from users.utils import construct_user_email
@@ -147,53 +146,30 @@ class LDAPHAAuthorizationBackend(JMSBaseAuthBackend, LDAPBaseBackend):
class LDAPUser(_LDAPUser): class LDAPUser(_LDAPUser):
def __init__(self, backend, username=None, user=None, request=None):
super().__init__(backend=backend, username=username, user=user, request=request)
config_prefix = "" if isinstance(self.backend, LDAPAuthorizationBackend) else "_ha"
self.user_dn_cache_key = valid_cache_key(
f"django_auth_ldap{config_prefix}.user_dn.{self._username}"
)
self.category = f"ldap{config_prefix}"
self.search_filter = getattr(settings, f"AUTH_LDAP{config_prefix.upper()}_SEARCH_FILTER", None)
self.search_ou = getattr(settings, f"AUTH_LDAP{config_prefix.upper()}_SEARCH_OU", None)
def _search_for_user_dn_from_ldap_util(self): def _search_for_user_dn_from_ldap_util(self):
from settings.utils import LDAPServerUtil from settings.utils import LDAPServerUtil
util = LDAPServerUtil(category=self.category) util = LDAPServerUtil()
user_dn = util.search_for_user_dn(self._username) user_dn = util.search_for_user_dn(self._username)
return user_dn return user_dn
def _load_user_dn(self):
"""
Populates self._user_dn with the distinguished name of our user.
This will either construct the DN from a template in
AUTH_LDAP_USER_DN_TEMPLATE or connect to the server and search for it.
If we have to search, we'll cache the DN.
"""
if self._using_simple_bind_mode():
self._user_dn = self._construct_simple_user_dn()
else:
if self.settings.CACHE_TIMEOUT > 0:
self._user_dn = cache.get_or_set(
self.user_dn_cache_key, self._search_for_user_dn, self.settings.CACHE_TIMEOUT
)
else:
self._user_dn = self._search_for_user_dn()
def _search_for_user_dn(self): def _search_for_user_dn(self):
""" """
This method was overridden because the AUTH_LDAP_USER_SEARCH This method was overridden because the AUTH_LDAP_USER_SEARCH
configuration in the settings.py file configuration in the settings.py file
is configured with a `lambda` problem value is configured with a `lambda` problem value
""" """
if isinstance(self.backend, LDAPAuthorizationBackend):
search_filter = settings.AUTH_LDAP_SEARCH_FILTER
search_ou = settings.AUTH_LDAP_SEARCH_OU
else:
search_filter = settings.AUTH_LDAP_HA_SEARCH_FILTER
search_ou = settings.AUTH_LDAP_HA_SEARCH_OU
user_search_union = [ user_search_union = [
LDAPSearch( LDAPSearch(
USER_SEARCH, ldap.SCOPE_SUBTREE, USER_SEARCH, ldap.SCOPE_SUBTREE,
self.search_filter search_filter
) )
for USER_SEARCH in str(self.search_ou).split("|") for USER_SEARCH in str(search_ou).split("|")
] ]
search = LDAPSearchUnion(*user_search_union) search = LDAPSearchUnion(*user_search_union)

View File

@@ -1,6 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
import base64
import requests import requests
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
@@ -18,7 +17,7 @@ from common.exceptions import JMSException
from .signals import ( from .signals import (
oauth2_create_or_update_user oauth2_create_or_update_user
) )
from ..base import JMSBaseAuthBackend from ..base import JMSModelBackend
__all__ = ['OAuth2Backend'] __all__ = ['OAuth2Backend']
@@ -26,7 +25,7 @@ __all__ = ['OAuth2Backend']
logger = get_logger(__name__) logger = get_logger(__name__)
class OAuth2Backend(JMSBaseAuthBackend): class OAuth2Backend(JMSModelBackend):
@staticmethod @staticmethod
def is_enabled(): def is_enabled():
return settings.AUTH_OAUTH2 return settings.AUTH_OAUTH2
@@ -68,7 +67,15 @@ class OAuth2Backend(JMSBaseAuthBackend):
response_data = response_data['data'] response_data = response_data['data']
return response_data return response_data
def authenticate(self, request, code=None): @staticmethod
def get_query_dict(response_data, query_dict):
query_dict.update({
'uid': response_data.get('uid', ''),
'access_token': response_data.get('access_token', '')
})
return query_dict
def authenticate(self, request, code=None, **kwargs):
log_prompt = "Process authenticate [OAuth2Backend]: {}" log_prompt = "Process authenticate [OAuth2Backend]: {}"
logger.debug(log_prompt.format('Start')) logger.debug(log_prompt.format('Start'))
if code is None: if code is None:
@@ -76,31 +83,29 @@ class OAuth2Backend(JMSBaseAuthBackend):
return None return None
query_dict = { query_dict = {
'grant_type': 'authorization_code', 'code': code, 'client_id': settings.AUTH_OAUTH2_CLIENT_ID,
'client_secret': settings.AUTH_OAUTH2_CLIENT_SECRET,
'grant_type': 'authorization_code',
'code': code,
'redirect_uri': build_absolute_uri( 'redirect_uri': build_absolute_uri(
request, path=reverse(settings.AUTH_OAUTH2_AUTH_LOGIN_CALLBACK_URL_NAME) request, path=reverse(settings.AUTH_OAUTH2_AUTH_LOGIN_CALLBACK_URL_NAME)
) )
} }
separator = '&' if '?' in settings.AUTH_OAUTH2_ACCESS_TOKEN_ENDPOINT else '?' if '?' in settings.AUTH_OAUTH2_ACCESS_TOKEN_ENDPOINT:
separator = '&'
else:
separator = '?'
access_token_url = '{url}{separator}{query}'.format( access_token_url = '{url}{separator}{query}'.format(
url=settings.AUTH_OAUTH2_ACCESS_TOKEN_ENDPOINT, url=settings.AUTH_OAUTH2_ACCESS_TOKEN_ENDPOINT, separator=separator, query=urlencode(query_dict)
separator=separator, query=urlencode(query_dict)
) )
# token_method -> get, post(post_data), post_json # token_method -> get, post(post_data), post_json
token_method = settings.AUTH_OAUTH2_ACCESS_TOKEN_METHOD.lower() token_method = settings.AUTH_OAUTH2_ACCESS_TOKEN_METHOD.lower()
logger.debug(log_prompt.format('Call the access token endpoint[method: %s]' % token_method)) logger.debug(log_prompt.format('Call the access token endpoint[method: %s]' % token_method))
encoded_credentials = base64.b64encode(
f"{settings.AUTH_OAUTH2_CLIENT_ID}:{settings.AUTH_OAUTH2_CLIENT_SECRET}".encode()
).decode()
headers = { headers = {
'Accept': 'application/json', 'Authorization': f'Basic {encoded_credentials}' 'Accept': 'application/json'
} }
if token_method.startswith('post'): if token_method.startswith('post'):
body_key = 'json' if token_method.endswith('json') else 'data' body_key = 'json' if token_method.endswith('json') else 'data'
query_dict.update({
'client_id': settings.AUTH_OAUTH2_CLIENT_ID,
'client_secret': settings.AUTH_OAUTH2_CLIENT_SECRET,
})
access_token_response = requests.post( access_token_response = requests.post(
access_token_url, headers=headers, **{body_key: query_dict} access_token_url, headers=headers, **{body_key: query_dict}
) )
@@ -116,12 +121,22 @@ class OAuth2Backend(JMSBaseAuthBackend):
logger.error(log_prompt.format(error)) logger.error(log_prompt.format(error))
return None return None
query_dict = self.get_query_dict(response_data, query_dict)
headers = { headers = {
'Accept': 'application/json', 'Accept': 'application/json',
'Authorization': 'Bearer {}'.format(response_data.get('access_token', '')) 'Authorization': 'Bearer {}'.format(response_data.get('access_token', ''))
} }
logger.debug(log_prompt.format('Get userinfo endpoint')) logger.debug(log_prompt.format('Get userinfo endpoint'))
userinfo_url = settings.AUTH_OAUTH2_PROVIDER_USERINFO_ENDPOINT if '?' in settings.AUTH_OAUTH2_PROVIDER_USERINFO_ENDPOINT:
separator = '&'
else:
separator = '?'
userinfo_url = '{url}{separator}{query}'.format(
url=settings.AUTH_OAUTH2_PROVIDER_USERINFO_ENDPOINT, separator=separator,
query=urlencode(query_dict)
)
userinfo_response = requests.get(userinfo_url, headers=headers) userinfo_response = requests.get(userinfo_url, headers=headers)
try: try:
userinfo_response.raise_for_status() userinfo_response.raise_for_status()

View File

@@ -4,9 +4,9 @@ from django.urls import path
from . import views from . import views
urlpatterns = [ urlpatterns = [
path('login/', views.OAuth2AuthRequestView.as_view(), name='login'), path('login/', views.OAuth2AuthRequestView.as_view(), name='login'),
path('callback/', views.OAuth2AuthCallbackView.as_view(), name='login-callback'), path('callback/', views.OAuth2AuthCallbackView.as_view(), name='login-callback'),
path('callback/client/', views.OAuth2AuthCallbackClientView.as_view(), name='login-callback-client'),
path('logout/', views.OAuth2EndSessionView.as_view(), name='logout') path('logout/', views.OAuth2EndSessionView.as_view(), name='logout')
] ]

View File

@@ -1,16 +1,16 @@
from django.views import View
from django.conf import settings from django.conf import settings
from django.contrib import auth from django.contrib import auth
from django.http import HttpResponseRedirect from django.http import HttpResponseRedirect
from django.urls import reverse from django.urls import reverse
from django.utils.http import urlencode from django.utils.http import urlencode
from django.views import View
from authentication.mixins import authenticate
from authentication.utils import build_absolute_uri from authentication.utils import build_absolute_uri
from authentication.views.mixins import FlashMessageMixin from authentication.views.mixins import FlashMessageMixin
from authentication.views.utils import redirect_to_guard_view from authentication.mixins import authenticate
from common.utils import get_logger from common.utils import get_logger
logger = get_logger(__file__) logger = get_logger(__file__)
@@ -67,13 +67,6 @@ class OAuth2AuthCallbackView(View, FlashMessageMixin):
return HttpResponseRedirect(redirect_url) return HttpResponseRedirect(redirect_url)
class OAuth2AuthCallbackClientView(View):
http_method_names = ['get', ]
def get(self, request):
return redirect_to_guard_view(query_string='next=client')
class OAuth2EndSessionView(View): class OAuth2EndSessionView(View):
http_method_names = ['get', 'post', ] http_method_names = ['get', 'post', ]

View File

@@ -13,8 +13,10 @@ import requests
from django.conf import settings from django.conf import settings
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from django.contrib.auth.backends import ModelBackend from django.contrib.auth.backends import ModelBackend
from django.core.exceptions import SuspiciousOperation
from django.db import transaction from django.db import transaction
from django.urls import reverse from django.urls import reverse
from rest_framework.exceptions import ParseError
from authentication.signals import user_auth_success, user_auth_failed from authentication.signals import user_auth_success, user_auth_failed
from authentication.utils import build_absolute_uri_for_oidc from authentication.utils import build_absolute_uri_for_oidc
@@ -86,7 +88,7 @@ class OIDCAuthCodeBackend(OIDCBaseBackend):
""" """
@ssl_verification @ssl_verification
def authenticate(self, request, nonce=None, code_verifier=None): def authenticate(self, request, nonce=None, code_verifier=None, **kwargs):
""" Authenticates users in case of the OpenID Connect Authorization code flow. """ """ Authenticates users in case of the OpenID Connect Authorization code flow. """
log_prompt = "Process authenticate [OIDCAuthCodeBackend]: {}" log_prompt = "Process authenticate [OIDCAuthCodeBackend]: {}"
logger.debug(log_prompt.format('start')) logger.debug(log_prompt.format('start'))
@@ -105,7 +107,7 @@ class OIDCAuthCodeBackend(OIDCBaseBackend):
# parameters because we won't be able to get a valid token for the user in that case. # parameters because we won't be able to get a valid token for the user in that case.
if (state is None and settings.AUTH_OPENID_USE_STATE) or code is None: if (state is None and settings.AUTH_OPENID_USE_STATE) or code is None:
logger.debug(log_prompt.format('Authorization code or state value is missing')) logger.debug(log_prompt.format('Authorization code or state value is missing'))
return raise SuspiciousOperation('Authorization code or state value is missing')
# Prepares the token payload that will be used to request an authentication token to the # Prepares the token payload that will be used to request an authentication token to the
# token endpoint of the OIDC provider. # token endpoint of the OIDC provider.
@@ -163,7 +165,7 @@ class OIDCAuthCodeBackend(OIDCBaseBackend):
error = "Json token response error, token response " \ error = "Json token response error, token response " \
"content is: {}, error is: {}".format(token_response.content, str(e)) "content is: {}, error is: {}".format(token_response.content, str(e))
logger.debug(log_prompt.format(error)) logger.debug(log_prompt.format(error))
return raise ParseError(error)
# Validates the token. # Validates the token.
logger.debug(log_prompt.format('Validate ID Token')) logger.debug(log_prompt.format('Validate ID Token'))
@@ -204,7 +206,7 @@ class OIDCAuthCodeBackend(OIDCBaseBackend):
error = "Json claims response error, claims response " \ error = "Json claims response error, claims response " \
"content is: {}, error is: {}".format(claims_response.content, str(e)) "content is: {}, error is: {}".format(claims_response.content, str(e))
logger.debug(log_prompt.format(error)) logger.debug(log_prompt.format(error))
return raise ParseError(error)
logger.debug(log_prompt.format('Get or create user from claims')) logger.debug(log_prompt.format('Get or create user from claims'))
user, created = self.get_or_create_user_from_claims(request, claims) user, created = self.get_or_create_user_from_claims(request, claims)
@@ -233,15 +235,15 @@ class OIDCAuthCodeBackend(OIDCBaseBackend):
class OIDCAuthPasswordBackend(OIDCBaseBackend): class OIDCAuthPasswordBackend(OIDCBaseBackend):
@ssl_verification @ssl_verification
def authenticate(self, request, username=None, password=None): def authenticate(self, request, username=None, password=None, **kwargs):
try: try:
return self._authenticate(request, username, password) return self._authenticate(request, username, password, **kwargs)
except Exception as e: except Exception as e:
error = f'Authenticate exception: {e}' error = f'Authenticate exception: {e}'
logger.error(error, exc_info=True) logger.error(error, exc_info=True)
return return
def _authenticate(self, request, username=None, password=None): def _authenticate(self, request, username=None, password=None, **kwargs):
""" """
https://oauth.net/2/ https://oauth.net/2/
https://aaronparecki.com/oauth-2-simplified/#password https://aaronparecki.com/oauth-2-simplified/#password

View File

@@ -4,9 +4,7 @@
import warnings import warnings
import contextlib import contextlib
import requests import requests
import inspect
from functools import wraps
from django.conf import settings from django.conf import settings
from urllib3.exceptions import InsecureRequestWarning from urllib3.exceptions import InsecureRequestWarning
@@ -54,7 +52,6 @@ def no_ssl_verification():
def ssl_verification(func): def ssl_verification(func):
@wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
if not settings.AUTH_OPENID_IGNORE_SSL_VERIFICATION: if not settings.AUTH_OPENID_IGNORE_SSL_VERIFICATION:
return func(*args, **kwargs) return func(*args, **kwargs)

View File

@@ -12,9 +12,9 @@ from django.urls import path
from . import views from . import views
urlpatterns = [ urlpatterns = [
path('login/', views.OIDCAuthRequestView.as_view(), name='login'), path('login/', views.OIDCAuthRequestView.as_view(), name='login'),
path('callback/', views.OIDCAuthCallbackView.as_view(), name='login-callback'), path('callback/', views.OIDCAuthCallbackView.as_view(), name='login-callback'),
path('callback/client/', views.OIDCAuthCallbackClientView.as_view(), name='login-callback-client'),
path('logout/', views.OIDCEndSessionView.as_view(), name='logout'), path('logout/', views.OIDCEndSessionView.as_view(), name='logout'),
] ]

View File

@@ -22,14 +22,13 @@ from django.http import HttpResponseRedirect, QueryDict
from django.urls import reverse from django.urls import reverse
from django.utils.crypto import get_random_string from django.utils.crypto import get_random_string
from django.utils.http import urlencode from django.utils.http import urlencode
from django.utils.translation import gettext_lazy as _
from django.views.generic import View from django.views.generic import View
from django.utils.translation import gettext_lazy as _
from authentication.utils import build_absolute_uri_for_oidc from authentication.utils import build_absolute_uri_for_oidc
from authentication.views.mixins import FlashMessageMixin from authentication.views.mixins import FlashMessageMixin
from common.utils import safe_next_url from common.utils import safe_next_url
from .utils import get_logger from .utils import get_logger
from ...views.utils import redirect_to_guard_view
logger = get_logger(__file__) logger = get_logger(__file__)
@@ -209,13 +208,6 @@ class OIDCAuthCallbackView(View, FlashMessageMixin):
return HttpResponseRedirect(settings.AUTH_OPENID_AUTHENTICATION_FAILURE_REDIRECT_URI) return HttpResponseRedirect(settings.AUTH_OPENID_AUTHENTICATION_FAILURE_REDIRECT_URI)
class OIDCAuthCallbackClientView(View):
http_method_names = ['get', ]
def get(self, request):
return redirect_to_guard_view(query_string='next=client')
class OIDCEndSessionView(View): class OIDCEndSessionView(View):
""" Allows to end the session of any user authenticated using OpenID Connect. """ Allows to end the session of any user authenticated using OpenID Connect.

View File

@@ -13,7 +13,7 @@ class Passkey(JMSBaseModel):
added_on = models.DateTimeField(auto_now_add=True, verbose_name=_("Added on")) added_on = models.DateTimeField(auto_now_add=True, verbose_name=_("Added on"))
date_last_used = models.DateTimeField(null=True, default=None, verbose_name=_("Date last used")) date_last_used = models.DateTimeField(null=True, default=None, verbose_name=_("Date last used"))
credential_id = models.CharField(max_length=255, unique=True, null=False, verbose_name=_("Credential ID")) credential_id = models.CharField(max_length=255, unique=True, null=False, verbose_name=_("Credential ID"))
token = models.CharField(max_length=1024, null=False, verbose_name=_("Token")) token = models.CharField(max_length=255, null=False, verbose_name=_("Token"))
def __str__(self): def __str__(self):
return self.name return self.name

View File

@@ -51,10 +51,10 @@ class RadiusBaseBackend(CreateUserMixin, JMSBaseAuthBackend):
class RadiusBackend(RadiusBaseBackend, RADIUSBackend): class RadiusBackend(RadiusBaseBackend, RADIUSBackend):
def authenticate(self, request, username='', password=''): def authenticate(self, request, username='', password='', **kwargs):
return super().authenticate(request, username=username, password=password) return super().authenticate(request, username=username, password=password)
class RadiusRealmBackend(RadiusBaseBackend, RADIUSRealmBackend): class RadiusRealmBackend(RadiusBaseBackend, RADIUSRealmBackend):
def authenticate(self, request, username='', password='', realm=None): def authenticate(self, request, username='', password='', realm=None, **kwargs):
return super().authenticate(request, username=username, password=password, realm=realm) return super().authenticate(request, username=username, password=password, realm=realm)

View File

@@ -10,14 +10,14 @@ from .signals import (
saml2_create_or_update_user saml2_create_or_update_user
) )
from authentication.signals import user_auth_failed, user_auth_success from authentication.signals import user_auth_failed, user_auth_success
from ..base import JMSBaseAuthBackend from ..base import JMSModelBackend
__all__ = ['SAML2Backend'] __all__ = ['SAML2Backend']
logger = get_logger(__name__) logger = get_logger(__name__)
class SAML2Backend(JMSBaseAuthBackend): class SAML2Backend(JMSModelBackend):
@staticmethod @staticmethod
def is_enabled(): def is_enabled():
return settings.AUTH_SAML2 return settings.AUTH_SAML2
@@ -42,7 +42,7 @@ class SAML2Backend(JMSBaseAuthBackend):
) )
return user, created return user, created
def authenticate(self, request, saml_user_data=None): def authenticate(self, request, saml_user_data=None, **kwargs):
log_prompt = "Process authenticate [SAML2Backend]: {}" log_prompt = "Process authenticate [SAML2Backend]: {}"
logger.debug(log_prompt.format('Start')) logger.debug(log_prompt.format('Start'))
if saml_user_data is None: if saml_user_data is None:

View File

@@ -4,10 +4,10 @@ from django.urls import path
from . import views from . import views
urlpatterns = [ urlpatterns = [
path('login/', views.Saml2AuthRequestView.as_view(), name='saml2-login'), path('login/', views.Saml2AuthRequestView.as_view(), name='saml2-login'),
path('logout/', views.Saml2EndSessionView.as_view(), name='saml2-logout'), path('logout/', views.Saml2EndSessionView.as_view(), name='saml2-logout'),
path('callback/', views.Saml2AuthCallbackView.as_view(), name='saml2-callback'), path('callback/', views.Saml2AuthCallbackView.as_view(), name='saml2-callback'),
path('callback/client/', views.Saml2AuthCallbackClientView.as_view(), name='saml2-callback-client'),
path('metadata/', views.Saml2AuthMetadataView.as_view(), name='saml2-metadata'), path('metadata/', views.Saml2AuthMetadataView.as_view(), name='saml2-metadata'),
] ]

View File

@@ -19,7 +19,6 @@ from onelogin.saml2.idp_metadata_parser import (
from authentication.views.mixins import FlashMessageMixin from authentication.views.mixins import FlashMessageMixin
from common.utils import get_logger from common.utils import get_logger
from .settings import JmsSaml2Settings from .settings import JmsSaml2Settings
from ...views.utils import redirect_to_guard_view
logger = get_logger(__file__) logger = get_logger(__file__)
@@ -299,13 +298,6 @@ class Saml2AuthCallbackView(View, PrepareRequestMixin, FlashMessageMixin):
return super().dispatch(*args, **kwargs) return super().dispatch(*args, **kwargs)
class Saml2AuthCallbackClientView(View):
http_method_names = ['get', ]
def get(self, request):
return redirect_to_guard_view(query_string='next=client')
class Saml2AuthMetadataView(View, PrepareRequestMixin): class Saml2AuthMetadataView(View, PrepareRequestMixin):
def get(self, request): def get(self, request):

View File

@@ -1,41 +1,57 @@
from django.conf import settings from django.conf import settings
from .base import JMSBaseAuthBackend from .base import JMSModelBackend
class SSOAuthentication(JMSBaseAuthBackend): class SSOAuthentication(JMSModelBackend):
"""
什么也不做呀😺
"""
@staticmethod @staticmethod
def is_enabled(): def is_enabled():
return settings.AUTH_SSO return settings.AUTH_SSO
def authenticate(self): def authenticate(self, request, sso_token=None, **kwargs):
pass pass
class WeComAuthentication(JMSBaseAuthBackend): class WeComAuthentication(JMSModelBackend):
"""
什么也不做呀😺
"""
@staticmethod @staticmethod
def is_enabled(): def is_enabled():
return settings.AUTH_WECOM return settings.AUTH_WECOM
def authenticate(self): def authenticate(self, request, **kwargs):
pass pass
class DingTalkAuthentication(JMSBaseAuthBackend): class DingTalkAuthentication(JMSModelBackend):
"""
什么也不做呀😺
"""
@staticmethod @staticmethod
def is_enabled(): def is_enabled():
return settings.AUTH_DINGTALK return settings.AUTH_DINGTALK
def authenticate(self): def authenticate(self, request, **kwargs):
pass pass
class FeiShuAuthentication(JMSBaseAuthBackend): class FeiShuAuthentication(JMSModelBackend):
"""
什么也不做呀😺
"""
@staticmethod @staticmethod
def is_enabled(): def is_enabled():
return settings.AUTH_FEISHU return settings.AUTH_FEISHU
def authenticate(self): def authenticate(self, request, **kwargs):
pass pass
@@ -45,15 +61,23 @@ class LarkAuthentication(FeiShuAuthentication):
return settings.AUTH_LARK return settings.AUTH_LARK
class SlackAuthentication(JMSBaseAuthBackend): class SlackAuthentication(JMSModelBackend):
"""
什么也不做呀😺
"""
@staticmethod @staticmethod
def is_enabled(): def is_enabled():
return settings.AUTH_SLACK return settings.AUTH_SLACK
def authenticate(self): def authenticate(self, request, **kwargs):
pass pass
class AuthorizationTokenAuthentication(JMSBaseAuthBackend): class AuthorizationTokenAuthentication(JMSModelBackend):
def authenticate(self): """
什么也不做呀😺
"""
def authenticate(self, request, **kwargs):
pass pass

View File

@@ -3,17 +3,13 @@ from django.conf import settings
from django.core.exceptions import PermissionDenied from django.core.exceptions import PermissionDenied
from authentication.models import TempToken from authentication.models import TempToken
from .base import JMSBaseAuthBackend from .base import JMSModelBackend
class TempTokenAuthBackend(JMSBaseAuthBackend): class TempTokenAuthBackend(JMSModelBackend):
model = TempToken model = TempToken
@staticmethod def authenticate(self, request, username='', password='', *args, **kwargs):
def is_enabled():
return settings.AUTH_TEMP_TOKEN
def authenticate(self, request, username='', password=''):
token = self.model.objects.filter(username=username, secret=password).first() token = self.model.objects.filter(username=username, secret=password).first()
if not token: if not token:
return None return None
@@ -25,3 +21,6 @@ class TempTokenAuthBackend(JMSBaseAuthBackend):
token.save() token.save()
return token.user return token.user
@staticmethod
def is_enabled():
return settings.AUTH_TEMP_TOKEN

View File

@@ -22,6 +22,5 @@ class ConfirmMFA(BaseConfirm):
def authenticate(self, secret_key, mfa_type): def authenticate(self, secret_key, mfa_type):
mfa_backend = self.user.get_mfa_backend_by_type(mfa_type) mfa_backend = self.user.get_mfa_backend_by_type(mfa_type)
mfa_backend.set_request(self.request)
ok, msg = mfa_backend.check_code(secret_key) ok, msg = mfa_backend.check_code(secret_key)
return ok, msg return ok, msg

View File

@@ -2,7 +2,7 @@ from django.db.models import TextChoices
from authentication.confirm import CONFIRM_BACKENDS from authentication.confirm import CONFIRM_BACKENDS
from .confirm import ConfirmMFA, ConfirmPassword, ConfirmReLogin from .confirm import ConfirmMFA, ConfirmPassword, ConfirmReLogin
from .mfa import MFAOtp, MFASms, MFARadius, MFAFace, MFACustom from .mfa import MFAOtp, MFASms, MFARadius, MFACustom
RSA_PRIVATE_KEY = 'rsa_private_key' RSA_PRIVATE_KEY = 'rsa_private_key'
RSA_PUBLIC_KEY = 'rsa_public_key' RSA_PUBLIC_KEY = 'rsa_public_key'
@@ -35,17 +35,5 @@ class ConfirmType(TextChoices):
class MFAType(TextChoices): class MFAType(TextChoices):
OTP = MFAOtp.name, MFAOtp.display_name OTP = MFAOtp.name, MFAOtp.display_name
SMS = MFASms.name, MFASms.display_name SMS = MFASms.name, MFASms.display_name
Face = MFAFace.name, MFAFace.display_name
Radius = MFARadius.name, MFARadius.display_name Radius = MFARadius.name, MFARadius.display_name
Custom = MFACustom.name, MFACustom.display_name Custom = MFACustom.name, MFACustom.display_name
FACE_CONTEXT_CACHE_KEY_PREFIX = "FACE_CONTEXT"
FACE_CONTEXT_CACHE_TTL = 60
FACE_SESSION_KEY = "face_token"
class FaceMonitorActionChoices(TextChoices):
Verify = 'verify', 'verify'
Pause = 'pause', 'pause'
Resume = 'resume', 'resume'

View File

@@ -2,4 +2,3 @@ from .otp import MFAOtp, otp_failed_msg
from .sms import MFASms from .sms import MFASms
from .radius import MFARadius from .radius import MFARadius
from .custom import MFACustom from .custom import MFACustom
from .face import MFAFace

View File

@@ -12,14 +12,10 @@ class BaseMFA(abc.ABC):
因为首页登录时,可能没法获取到一些状态 因为首页登录时,可能没法获取到一些状态
""" """
self.user = user self.user = user
self.request = None
def is_authenticated(self): def is_authenticated(self):
return self.user and self.user.is_authenticated return self.user and self.user.is_authenticated
def set_request(self, request):
self.request = request
@property @property
@abc.abstractmethod @abc.abstractmethod
def name(self): def name(self):

View File

@@ -1,59 +0,0 @@
from authentication.mfa.base import BaseMFA
from django.utils.translation import gettext_lazy as _
from authentication.mixins import AuthFaceMixin
from common.const import LicenseEditionChoices
from settings.api import settings
class MFAFace(BaseMFA, AuthFaceMixin):
name = "face"
display_name = _('Face Recognition')
placeholder = 'Face Recognition'
def check_code(self, code):
assert self.is_authenticated()
try:
code = self.get_face_code()
if not self.user.check_face(code):
return False, _('Facial comparison failed')
except Exception as e:
return False, "{}:{}".format(_('Facial comparison failed'), str(e))
return True, ''
def is_active(self):
if not self.is_authenticated():
return True
return bool(self.user.face_vector)
@staticmethod
def global_enabled():
return (
settings.XPACK_LICENSE_IS_VALID and
settings.XPACK_LICENSE_EDITION_ULTIMATE and
settings.FACE_RECOGNITION_ENABLED
)
def get_enable_url(self) -> str:
return '/ui/#/profile/index'
def get_disable_url(self) -> str:
return '/ui/#/profile/index'
def disable(self):
assert self.is_authenticated()
self.user.face_vector = ''
self.user.save(update_fields=['face_vector'])
def can_disable(self) -> bool:
return True
@staticmethod
def help_text_of_enable():
return _("Bind face to enable")
@staticmethod
def help_text_of_disable():
return _("Unbind face to disable")

View File

@@ -12,7 +12,7 @@ class MFARadius(BaseMFA):
display_name = 'Radius' display_name = 'Radius'
placeholder = _("Radius verification code") placeholder = _("Radius verification code")
def check_code(self, code=None): def check_code(self, code):
assert self.is_authenticated() assert self.is_authenticated()
backend = RadiusBackend() backend = RadiusBackend()
username = self.user.username username = self.user.username

View File

@@ -2,7 +2,6 @@ from django.conf import settings
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from common.utils.verify_code import SendAndVerifyCodeUtil from common.utils.verify_code import SendAndVerifyCodeUtil
from users.serializers import SmsUserSerializer
from .base import BaseMFA from .base import BaseMFA
sms_failed_msg = _("SMS verify code invalid") sms_failed_msg = _("SMS verify code invalid")
@@ -15,13 +14,8 @@ class MFASms(BaseMFA):
def __init__(self, user): def __init__(self, user):
super().__init__(user) super().__init__(user)
phone, user_info = '', None phone = user.phone if self.is_authenticated() else ''
if self.is_authenticated(): self.sms = SendAndVerifyCodeUtil(phone, backend=self.name)
phone = user.phone
user_info = SmsUserSerializer(user).data
self.sms = SendAndVerifyCodeUtil(
phone, backend=self.name, user_info=user_info
)
def check_code(self, code): def check_code(self, code):
assert self.is_authenticated() assert self.is_authenticated()

View File

@@ -35,7 +35,7 @@ class MFAMiddleware:
# 这个是 mfa 登录页需要的请求, 也得放出来, 用户其实已经在 CAS/OIDC 中完成登录了 # 这个是 mfa 登录页需要的请求, 也得放出来, 用户其实已经在 CAS/OIDC 中完成登录了
white_urls = [ white_urls = [
'login/mfa', 'mfa/select', 'face/context','jsi18n/', '/static/', 'login/mfa', 'mfa/select', 'jsi18n/', '/static/',
'/profile/otp', '/logout/', '/profile/otp', '/logout/',
] ]
for url in white_urls: for url in white_urls:

View File

@@ -1,18 +0,0 @@
# Generated by Django 4.1.13 on 2024-12-12 06:25
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('authentication', '0003_sshkey'),
]
operations = [
migrations.AlterField(
model_name='passkey',
name='token',
field=models.CharField(max_length=1024, verbose_name='Token'),
),
]

View File

@@ -1,18 +0,0 @@
# Generated by Django 4.1.13 on 2024-12-11 02:33
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('authentication', '0004_alter_passkey_token'),
]
operations = [
migrations.AddField(
model_name='connectiontoken',
name='face_monitor_token',
field=models.CharField(blank=True, max_length=128, null=True, verbose_name='Face monitor token'),
),
]

View File

@@ -2,7 +2,6 @@
# #
import inspect import inspect
import time import time
import uuid
from functools import partial from functools import partial
from typing import Callable from typing import Callable
@@ -51,7 +50,7 @@ auth._get_backends = _get_backends
def authenticate(request=None, **credentials): def authenticate(request=None, **credentials):
""" """
If the given credentials are valid, return a User object. If the given credentials are valid, return a User object.
之所以 hack 这个 authenticate 之所以 hack 这个 auticate
""" """
username = credentials.get('username') username = credentials.get('username')
@@ -264,6 +263,7 @@ class MFAMixin:
user = user if user else self.get_user_from_session() user = user if user else self.get_user_from_session()
if not user.mfa_enabled: if not user.mfa_enabled:
return return
# 监测 MFA 是不是屏蔽了 # 监测 MFA 是不是屏蔽了
ip = self.get_request_ip() ip = self.get_request_ip()
self.check_mfa_is_block(user.username, ip) self.check_mfa_is_block(user.username, ip)
@@ -276,7 +276,6 @@ class MFAMixin:
elif not mfa_backend.is_active(): elif not mfa_backend.is_active():
msg = backend_error.format(mfa_backend.display_name) msg = backend_error.format(mfa_backend.display_name)
else: else:
mfa_backend.set_request(self.request)
ok, msg = mfa_backend.check_code(code) ok, msg = mfa_backend.check_code(code)
if ok: if ok:
@@ -429,83 +428,17 @@ class AuthACLMixin:
return ticket return ticket
class AuthFaceMixin: class AuthMixin(CommonMixin, AuthPreCheckMixin, AuthACLMixin, MFAMixin, AuthPostCheckMixin):
request: Request
@staticmethod
def _get_face_cache_key(token):
from authentication.const import FACE_CONTEXT_CACHE_KEY_PREFIX
return f"{FACE_CONTEXT_CACHE_KEY_PREFIX}_{token}"
@staticmethod
def _is_context_finished(context):
return context.get('is_finished', False)
@staticmethod
def _is_context_success(context):
return context.get('success', False)
def create_face_verify_context(self, data=None):
token = uuid.uuid4().hex
context_data = {
"action": "mfa",
"token": token,
"user_id": self.request.user.id,
"is_finished": False
}
if data:
context_data.update(data)
cache_key = self._get_face_cache_key(token)
from .const import FACE_CONTEXT_CACHE_TTL, FACE_SESSION_KEY
cache.set(cache_key, context_data, FACE_CONTEXT_CACHE_TTL)
self.request.session[FACE_SESSION_KEY] = token
return token
def get_face_token_from_session(self):
from authentication.const import FACE_SESSION_KEY
token = self.request.session.get(FACE_SESSION_KEY)
if not token:
raise ValueError("Face recognition token is missing from the session.")
return token
def get_face_verify_context(self):
token = self.get_face_token_from_session()
cache_key = self._get_face_cache_key(token)
context = cache.get(cache_key)
if not context:
raise ValueError(f"Face recognition context does not exist for token: {token}")
return context
def get_face_code(self):
context = self.get_face_verify_context()
if not self._is_context_finished(context):
raise RuntimeError("Face recognition is not yet completed.")
if not self._is_context_success(context):
msg = context.get('error_message', '')
raise RuntimeError(msg)
face_code = context.get('face_code')
if not face_code:
raise ValueError("Face code is missing from the context.")
return face_code
class AuthMixin(CommonMixin, AuthPreCheckMixin, AuthACLMixin, AuthFaceMixin, MFAMixin, AuthPostCheckMixin, ):
request = None request = None
partial_credential_error = None partial_credential_error = None
key_prefix_captcha = "_LOGIN_INVALID_{}" key_prefix_captcha = "_LOGIN_INVALID_{}"
def _check_auth_user_is_valid(self, username, password, public_key): def _check_auth_user_is_valid(self, username, password, public_key):
credentials = {'username': username} user = authenticate(
if password: self.request, username=username,
credentials['password'] = password password=password, public_key=public_key
if public_key: )
credentials['public_key'] = public_key
user = authenticate(self.request, **credentials)
if not user: if not user:
self.raise_credential_error(errors.reason_password_failed) self.raise_credential_error(errors.reason_password_failed)

View File

@@ -50,7 +50,6 @@ class ConnectionToken(JMSOrgBaseModel):
on_delete=models.SET_NULL, null=True, blank=True, on_delete=models.SET_NULL, null=True, blank=True,
verbose_name=_('From ticket') verbose_name=_('From ticket')
) )
face_monitor_token = models.CharField(max_length=128, null=True, blank=True, verbose_name=_("Face monitor token"))
is_active = models.BooleanField(default=True, verbose_name=_("Active")) is_active = models.BooleanField(default=True, verbose_name=_("Active"))
class Meta: class Meta:

View File

@@ -4,4 +4,3 @@ from .connection_token import *
from .password_mfa import * from .password_mfa import *
from .ssh_key import * from .ssh_key import *
from .token import * from .token import *
from .face import *

View File

@@ -148,10 +148,9 @@ class ConnectionTokenSecretSerializer(OrgResourceModelSerializerMixin):
'platform', 'command_filter_acls', 'protocol', 'platform', 'command_filter_acls', 'protocol',
'domain', 'gateway', 'actions', 'expire_at', 'domain', 'gateway', 'actions', 'expire_at',
'from_ticket', 'expire_now', 'connect_method', 'from_ticket', 'expire_now', 'connect_method',
'connect_options', 'face_monitor_token' 'connect_options',
] ]
extra_kwargs = { extra_kwargs = {
'face_monitor_token': {'read_only': True},
'value': {'read_only': True}, 'value': {'read_only': True},
} }

View File

@@ -28,7 +28,7 @@ class ConnectionTokenSerializer(CommonModelSerializer):
'connect_method', 'connect_options', 'protocol', 'actions', 'connect_method', 'connect_options', 'protocol', 'actions',
'is_active', 'is_reusable', 'from_ticket', 'from_ticket_info', 'is_active', 'is_reusable', 'from_ticket', 'from_ticket_info',
'date_expired', 'date_created', 'date_updated', 'created_by', 'date_expired', 'date_created', 'date_updated', 'created_by',
'updated_by', 'org_id', 'org_name','face_monitor_token', 'updated_by', 'org_id', 'org_name',
] ]
read_only_fields = [ read_only_fields = [
# 普通 Token 不支持指定 user # 普通 Token 不支持指定 user
@@ -37,7 +37,6 @@ class ConnectionTokenSerializer(CommonModelSerializer):
] ]
fields = fields_small + read_only_fields fields = fields_small + read_only_fields
extra_kwargs = { extra_kwargs = {
'face_monitor_token': {'read_only': True},
'from_ticket': {'read_only': True}, 'from_ticket': {'read_only': True},
'value': {'read_only': True}, 'value': {'read_only': True},
'is_expired': {'read_only': True, 'label': _('Is expired')}, 'is_expired': {'read_only': True, 'label': _('Is expired')},

View File

@@ -1,50 +0,0 @@
from django.core.validators import RegexValidator
from rest_framework import serializers
__all__ = [
'FaceCallbackSerializer', 'FaceMonitorCallbackSerializer'
]
from authentication.const import FaceMonitorActionChoices
class FaceCallbackSerializer(serializers.Serializer):
token = serializers.CharField(required=True, allow_blank=False)
success = serializers.BooleanField(required=True, allow_null=False)
error_message = serializers.CharField(required=False, allow_null=True, allow_blank=True)
face_code = serializers.CharField(required=False, allow_null=True, allow_blank=True)
def update(self, instance, validated_data):
pass
def create(self, validated_data):
pass
class FaceMonitorContextSerializer(serializers.Serializer):
session_id = serializers.CharField(required=True, allow_null=False, allow_blank=False)
face_monitor_token = serializers.CharField(required=True, allow_blank=False, allow_null=False)
def update(self, instance, validated_data):
pass
def create(self, validated_data):
pass
class FaceMonitorCallbackSerializer(serializers.Serializer):
token = serializers.CharField(required=True, allow_blank=False)
is_finished = serializers.BooleanField(required=True)
success = serializers.BooleanField(required=True)
error_message = serializers.CharField(required=True, allow_blank=True)
action = serializers.ChoiceField(required=True, choices=FaceMonitorActionChoices.choices)
face_codes = serializers.ListField(
required=False, allow_null=True, allow_empty=True,
child=serializers.CharField(),
)
def update(self, instance, validated_data):
pass
def create(self, validated_data):
pass

Some files were not shown because too many files have changed in this diff Show More